Skip to content

Commit e97580a

Browse files
andrej-dbcloud-fan
andcommitted
[SPARK-50087][SQL][3.5] Robust handling of boolean expressions in CASE WHEN for MsSqlServer and future connectors
### What changes were proposed in this pull request? This PR proposes to propagate the `isPredicate` info in `V2ExpressionBuilder` and wrap the children of CASE WHEN expression (only `Predicate`s) with `IIF(<>, 1, 0)` for MsSqlServer. This is done to force returning an int instead of a boolean, as SqlServer cannot handle boolean expressions as a return type in CASE WHEN. E.g. ```CASE WHEN ... ELSE a = b END``` Old behavior: ```CASE WHEN ... ELSE a = b END = 1``` New behavior: Since in SqlServer a `= 1` is appended to the CASE WHEN, THEN and ELSE blocks must return an int. Therefore the final expression becomes: ```CASE WHEN ... ELSE IIF(a = b, 1, 0) END = 1``` ### Why are the changes needed? A user cannot work with an MsSqlServer data with CASE WHEN clauses or IF clauses if they wish to return a boolean value. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added tests to MsSqlServerIntegrationSuite ### Was this patch authored or co-authored using generative AI tooling? No Closes #49115 from andrej-db/CASEWHENBackport. Lead-authored-by: andrej-gobeljic_data <andrej.gobeljic@databricks.com> Co-authored-by: Wenchen Fan <cloud0fan@gmail.com> Co-authored-by: Andrej Gobeljić <andrej.gobeljic@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent bb953f9 commit e97580a

File tree

6 files changed

+147
-8
lines changed

6 files changed

+147
-8
lines changed

connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,11 @@ import java.sql.Connection
2222
import org.scalatest.time.SpanSugar._
2323

2424
import org.apache.spark.{SparkConf, SparkSQLFeatureNotSupportedException}
25+
import org.apache.spark.rdd.RDD
2526
import org.apache.spark.sql.AnalysisException
27+
import org.apache.spark.sql.catalyst.InternalRow
28+
import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan}
29+
import org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD
2630
import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog
2731
import org.apache.spark.sql.jdbc.MsSQLServerDatabaseOnDocker
2832
import org.apache.spark.sql.types._
@@ -39,6 +43,17 @@ import org.apache.spark.tags.DockerTest
3943
@DockerTest
4044
class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest {
4145

46+
def getExternalEngineQuery(executedPlan: SparkPlan): String = {
47+
getExternalEngineRdd(executedPlan).asInstanceOf[JDBCRDD].getExternalEngineQuery
48+
}
49+
50+
def getExternalEngineRdd(executedPlan: SparkPlan): RDD[InternalRow] = {
51+
val queryNode = executedPlan.collect { case r: RowDataSourceScanExec =>
52+
r
53+
}.head
54+
queryNode.rdd
55+
}
56+
4257
override def excluded: Seq[String] = Seq(
4358
"simple scan with OFFSET",
4459
"simple scan with LIMIT and OFFSET",
@@ -137,4 +152,68 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JD
137152
"WHERE (dept > 1 AND ((name LIKE 'am%') = (name LIKE '%y')))")
138153
assert(df3.collect().length == 3)
139154
}
155+
156+
test("SPARK-50087: SqlServer handle booleans in CASE WHEN test") {
157+
val df = sql(
158+
s"""|SELECT * FROM $catalogName.employee
159+
|WHERE CASE WHEN name = 'Legolas' THEN name = 'Elf' ELSE NOT (name = 'Wizard') END
160+
|""".stripMargin
161+
)
162+
163+
// scalastyle:off
164+
assert(getExternalEngineQuery(df.queryExecution.executedPlan) ==
165+
"""SELECT "dept","name","salary","bonus" FROM "employee" WHERE (CASE WHEN ("name" = 'Legolas') THEN IIF(("name" = 'Elf'), 1, 0) ELSE IIF(("name" <> 'Wizard'), 1, 0) END = 1) """
166+
)
167+
// scalastyle:on
168+
df.collect()
169+
}
170+
171+
test("SPARK-50087: SqlServer handle booleans in CASE WHEN with always true test") {
172+
val df = sql(
173+
s"""|SELECT * FROM $catalogName.employee
174+
|WHERE CASE WHEN (name = 'Legolas') THEN (name = 'Elf') ELSE (1=1) END
175+
|""".stripMargin
176+
)
177+
178+
// scalastyle:off
179+
assert(getExternalEngineQuery(df.queryExecution.executedPlan) ==
180+
"""SELECT "dept","name","salary","bonus" FROM "employee" WHERE (CASE WHEN ("name" = 'Legolas') THEN IIF(("name" = 'Elf'), 1, 0) ELSE 1 END = 1) """
181+
)
182+
// scalastyle:on
183+
df.collect()
184+
}
185+
186+
test("SPARK-50087: SqlServer handle booleans in nested CASE WHEN test") {
187+
val df = sql(
188+
s"""|SELECT * FROM $catalogName.employee
189+
|WHERE CASE WHEN (name = 'Legolas') THEN
190+
| CASE WHEN (name = 'Elf') THEN (name = 'Elrond') ELSE (name = 'Gandalf') END
191+
| ELSE (name = 'Sauron') END
192+
|""".stripMargin
193+
)
194+
195+
// scalastyle:off
196+
assert(getExternalEngineQuery(df.queryExecution.executedPlan) ==
197+
"""SELECT "dept","name","salary","bonus" FROM "employee" WHERE (CASE WHEN ("name" = 'Legolas') THEN IIF((CASE WHEN ("name" = 'Elf') THEN IIF(("name" = 'Elrond'), 1, 0) ELSE IIF(("name" = 'Gandalf'), 1, 0) END = 1), 1, 0) ELSE IIF(("name" = 'Sauron'), 1, 0) END = 1) """
198+
)
199+
// scalastyle:on
200+
df.collect()
201+
}
202+
203+
test("SPARK-50087: SqlServer handle non-booleans in nested CASE WHEN test") {
204+
val df = sql(
205+
s"""|SELECT * FROM $catalogName.employee
206+
|WHERE CASE WHEN (name = 'Legolas') THEN
207+
| CASE WHEN (name = 'Elf') THEN 'Elf' ELSE 'Wizard' END
208+
| ELSE 'Sauron' END = name
209+
|""".stripMargin
210+
)
211+
212+
// scalastyle:off
213+
assert(getExternalEngineQuery(df.queryExecution.executedPlan) ==
214+
"""SELECT "dept","name","salary","bonus" FROM "employee" WHERE ("name" IS NOT NULL) AND ((CASE WHEN "name" = 'Legolas' THEN CASE WHEN "name" = 'Elf' THEN 'Elf' ELSE 'Wizard' END ELSE 'Sauron' END) = "name") """
215+
)
216+
// scalastyle:on
217+
df.collect()
218+
}
140219
}

sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ protected String visitContains(String l, String r) {
290290
return l + " LIKE '%" + escapeSpecialCharsForLikePattern(value) + "%' ESCAPE '\\'";
291291
}
292292

293-
private String inputToSQL(Expression input) {
293+
protected String inputToSQL(Expression input) {
294294
if (input.children().length > 1) {
295295
return "(" + build(input) + ")";
296296
} else {

sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,8 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) {
189189
case _: BitwiseNot => generateExpressionWithName("~", expr, isPredicate)
190190
case caseWhen @ CaseWhen(branches, elseValue) =>
191191
val conditions = branches.map(_._1).flatMap(generateExpression(_, true))
192-
val values = branches.map(_._2).flatMap(generateExpression(_))
193-
val elseExprOpt = elseValue.flatMap(generateExpression(_))
192+
val values = branches.map(_._2).flatMap(generateExpression(_, isPredicate))
193+
val elseExprOpt = elseValue.flatMap(generateExpression(_, isPredicate))
194194
if (conditions.length == branches.length && values.length == branches.length &&
195195
elseExprOpt.size == elseValue.size) {
196196
val branchExpressions = conditions.zip(values).flatMap { case (c, v) =>
@@ -356,7 +356,7 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) {
356356
children: Seq[Expression],
357357
dataType: DataType,
358358
isPredicate: Boolean): Option[V2Expression] = {
359-
val childrenExpressions = children.flatMap(generateExpression(_))
359+
val childrenExpressions = children.flatMap(generateExpression(_, isPredicate))
360360
if (childrenExpressions.length == children.length) {
361361
if (isPredicate && dataType.isInstanceOf[BooleanType]) {
362362
Some(new V2Predicate(v2ExpressionName, childrenExpressions.toArray[V2Expression]))

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ object JDBCRDD extends Logging {
157157
* Both the driver code and the workers must be able to access the database; the driver
158158
* needs to fetch the schema while the workers need to fetch the data.
159159
*/
160-
private[jdbc] class JDBCRDD(
160+
class JDBCRDD(
161161
sc: SparkContext,
162162
getConnection: Int => Connection,
163163
schema: StructType,
@@ -173,11 +173,40 @@ private[jdbc] class JDBCRDD(
173173
offset: Int)
174174
extends RDD[InternalRow](sc, Nil) {
175175

176+
private lazy val dialect = JdbcDialects.get(url)
177+
178+
def generateJdbcQuery(partition: Option[JDBCPartition]): String = {
179+
// H2's JDBC driver does not support the setSchema() method. We pass a
180+
// fully-qualified table name in the SELECT statement. I don't know how to
181+
// talk about a table in a completely portable way.
182+
var builder = dialect
183+
.getJdbcSQLQueryBuilder(options)
184+
.withPredicates(predicates, partition.getOrElse(JDBCPartition(whereClause = null, idx = 1)))
185+
.withColumns(columns)
186+
.withSortOrders(sortOrders)
187+
.withLimit(limit)
188+
.withOffset(offset)
189+
190+
groupByColumns.foreach { groupByKeys =>
191+
builder = builder.withGroupByColumns(groupByKeys)
192+
}
193+
194+
sample.foreach { tableSampleInfo =>
195+
builder = builder.withTableSample(tableSampleInfo)
196+
}
197+
198+
builder.build()
199+
}
200+
176201
/**
177202
* Retrieve the list of partitions corresponding to this RDD.
178203
*/
179204
override def getPartitions: Array[Partition] = partitions
180205

206+
def getExternalEngineQuery: String = {
207+
generateJdbcQuery(partition = None)
208+
}
209+
181210
/**
182211
* Runs the SQL query against the JDBC driver.
183212
*

sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ import org.apache.spark.sql.connector.catalog.functions.UnboundFunction
3838
import org.apache.spark.sql.connector.catalog.index.TableIndex
3939
import org.apache.spark.sql.connector.expressions.{Expression, Literal, NamedReference}
4040
import org.apache.spark.sql.connector.expressions.aggregate.AggregateFunc
41+
import org.apache.spark.sql.connector.expressions.filter.Predicate
4142
import org.apache.spark.sql.connector.util.V2ExpressionSQLBuilder
4243
import org.apache.spark.sql.errors.QueryCompilationErrors
4344
import org.apache.spark.sql.execution.datasources.jdbc.{DriverRegistry, JDBCOptions, JdbcOptionsInWrite, JdbcUtils}
@@ -279,6 +280,18 @@ abstract class JdbcDialect extends Serializable with Logging {
279280
}
280281

281282
private[jdbc] class JDBCSQLBuilder extends V2ExpressionSQLBuilder {
283+
// Some dialects do not support boolean type and this convenient util function is
284+
// provided to generate SQL string without boolean values.
285+
protected def inputToSQLNoBool(input: Expression): String = input match {
286+
case p: Predicate if p.name() == "ALWAYS_TRUE" => "1"
287+
case p: Predicate if p.name() == "ALWAYS_FALSE" => "0"
288+
case p: Predicate => predicateToIntSQL(inputToSQL(p))
289+
case _ => super.inputToSQL(input)
290+
}
291+
292+
protected def predicateToIntSQL(input: String): String =
293+
"CASE WHEN " + input + " THEN 1 ELSE 0 END"
294+
282295
override def visitLiteral(literal: Literal[_]): String = {
283296
Option(literal.value()).map(v =>
284297
compileValue(CatalystTypeConverters.convertToScala(v, literal.dataType())).toString)

sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ private object MsSqlServerDialect extends JdbcDialect {
6666
supportedFunctions.contains(funcName)
6767

6868
class MsSqlServerSQLBuilder extends JDBCSQLBuilder {
69+
override protected def predicateToIntSQL(input: String): String =
70+
"IIF(" + input + ", 1, 0)"
6971
override def visitSortOrder(
7072
sortKey: String, sortDirection: SortDirection, nullOrdering: NullOrdering): String = {
7173
(sortDirection, nullOrdering) match {
@@ -93,9 +95,25 @@ private object MsSqlServerDialect extends JdbcDialect {
9395
// We shouldn't propagate these queries to MsSqlServer
9496
expr match {
9597
case e: Predicate => e.name() match {
96-
case "=" | "<>" | "<=>" | "<" | "<=" | ">" | ">="
97-
if e.children().exists(_.isInstanceOf[Predicate]) =>
98-
super.visitUnexpectedExpr(expr)
98+
case "=" | "<>" | "<=>" | "<" | "<=" | ">" | ">=" =>
99+
val Array(l, r) = e.children().map(inputToSQLNoBool)
100+
visitBinaryComparison(e.name(), l, r)
101+
case "CASE_WHEN" =>
102+
// Since MsSqlServer cannot handle boolean expressions inside
103+
// a CASE WHEN, it is necessary to convert those to another
104+
// CASE WHEN expression that will return 1 or 0 depending on
105+
// the result.
106+
// Example:
107+
// In: ... CASE WHEN a = b THEN c = d ... END
108+
// Out: ... CASE WHEN a = b THEN CASE WHEN c = d THEN 1 ELSE 0 END ... END = 1
109+
val stringArray = e.children().grouped(2).flatMap {
110+
case Array(whenExpression, thenExpression) =>
111+
Array(inputToSQL(whenExpression), inputToSQLNoBool(thenExpression))
112+
case Array(elseExpression) =>
113+
Array(inputToSQLNoBool(elseExpression))
114+
}.toArray
115+
116+
visitCaseWhen(stringArray) + " = 1"
99117
case _ => super.build(expr)
100118
}
101119
case _ => super.build(expr)

0 commit comments

Comments
 (0)