diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index 59660f4f0404f..9c31db4bc883d 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -15,7 +15,9 @@ // specific language governing permissions and limitations // under the License. -use datafusion_common::{internal_err, not_impl_err, plan_err, DataFusionError, Result}; +use datafusion_common::{ + internal_err, not_impl_err, plan_err, Column, DataFusionError, Result, +}; use datafusion_expr::{ expr::Alias, Distinct, Expr, JoinConstraint, JoinType, LogicalPlan, Projection, }; @@ -368,37 +370,11 @@ impl Unparser<'_> { self.select_to_sql_recursively(input, query, select, relation) } LogicalPlan::Join(join) => { - match join.join_constraint { - JoinConstraint::On => {} - JoinConstraint::Using => { - return not_impl_err!( - "Unsupported join constraint: {:?}", - join.join_constraint - ) - } - } - - // parse filter if exists - let join_filter = match &join.filter { - Some(filter) => Some(self.expr_to_sql(filter)?), - None => None, - }; - - // map join.on to `l.a = r.a AND l.b = r.b AND ...` - let eq_op = ast::BinaryOperator::Eq; - let join_on = self.join_conditions_to_sql(&join.on, eq_op)?; - - // Merge `join_on` and `join_filter` - let join_expr = match (join_filter, join_on) { - (Some(filter), Some(on)) => Some(self.and_op_to_sql(filter, on)), - (Some(filter), None) => Some(filter), - (None, Some(on)) => Some(on), - (None, None) => None, - }; - let join_constraint = match join_expr { - Some(expr) => ast::JoinConstraint::On(expr), - None => ast::JoinConstraint::None, - }; + let join_constraint = self.join_constraint_to_sql( + join.join_constraint, + &join.on, + join.filter.as_ref(), + )?; let mut right_relation = RelationBuilder::default(); @@ -582,24 +558,108 @@ impl Unparser<'_> { } } - fn join_conditions_to_sql( + /// Convert the components of a USING clause to the USING AST. Returns + /// 'None' if the conditions are not compatible with a USING expression, + /// e.g. non-column expressions or non-matching names. + fn join_using_to_sql( &self, - join_conditions: &Vec<(Expr, Expr)>, - eq_op: ast::BinaryOperator, - ) -> Result> { - // Only support AND conjunction for each binary expression in join conditions - let mut exprs: Vec = vec![]; + join_conditions: &[(Expr, Expr)], + ) -> Option { + let mut idents = Vec::with_capacity(join_conditions.len()); for (left, right) in join_conditions { - // Parse left + match (left, right) { + ( + Expr::Column(Column { + relation: _, + name: left_name, + }), + Expr::Column(Column { + relation: _, + name: right_name, + }), + ) if left_name == right_name => { + idents.push(self.new_ident_quoted_if_needs(left_name.to_string())); + } + // USING is only valid with matching column names; arbitrary expressions + // are not allowed + _ => return None, + } + } + Some(ast::JoinConstraint::Using(idents)) + } + + /// Convert a join constraint and associated conditions and filter to a SQL AST node + fn join_constraint_to_sql( + &self, + constraint: JoinConstraint, + conditions: &[(Expr, Expr)], + filter: Option<&Expr>, + ) -> Result { + match (constraint, conditions, filter) { + // No constraints + (JoinConstraint::On | JoinConstraint::Using, [], None) => { + Ok(ast::JoinConstraint::None) + } + + (JoinConstraint::Using, conditions, None) => { + match self.join_using_to_sql(conditions) { + Some(using) => Ok(using), + // As above, this should not be reachable from parsed SQL, + // but a user could create this; we "downgrade" to ON. + None => self.join_conditions_to_sql_on(conditions, None), + } + } + + // Two cases here: + // 1. Straightforward ON case, with possible equi-join conditions + // and additional filters + // 2. USING with additional filters; we "downgrade" to ON, because + // you can't use USING with arbitrary filters. (This should not + // be accessible from parsed SQL, but may have been a + // custom-built JOIN by a user.) + (JoinConstraint::On | JoinConstraint::Using, conditions, filter) => { + self.join_conditions_to_sql_on(conditions, filter) + } + } + } + + // Convert a list of equi0join conditions and an optional filter to a SQL ON + // AST node, with the equi-join conditions and the filter merged into a + // single conditional expression + fn join_conditions_to_sql_on( + &self, + join_conditions: &[(Expr, Expr)], + filter: Option<&Expr>, + ) -> Result { + let mut condition = None; + // AND the join conditions together to create the overall condition + for (left, right) in join_conditions { + // Parse left and right let l = self.expr_to_sql(left)?; - // Parse right let r = self.expr_to_sql(right)?; - // AND with existing expression - exprs.push(self.binary_op_to_sql(l, r, eq_op.clone())); + let e = self.binary_op_to_sql(l, r, ast::BinaryOperator::Eq); + condition = match condition { + Some(expr) => Some(self.and_op_to_sql(expr, e)), + None => Some(e), + }; } - let join_expr: Option = - exprs.into_iter().reduce(|r, l| self.and_op_to_sql(r, l)); - Ok(join_expr) + + // Then AND the non-equijoin filter condition as well + condition = match (condition, filter) { + (Some(expr), Some(filter)) => { + Some(self.and_op_to_sql(expr, self.expr_to_sql(filter)?)) + } + (Some(expr), None) => Some(expr), + (None, Some(filter)) => Some(self.expr_to_sql(filter)?), + (None, None) => None, + }; + + let constraint = match condition { + Some(filter) => ast::JoinConstraint::On(filter), + None => ast::JoinConstraint::None, + }; + + Ok(constraint) } fn and_op_to_sql(&self, lhs: ast::Expr, rhs: ast::Expr) -> ast::Expr { diff --git a/datafusion/sql/tests/cases/plan_to_sql.rs b/datafusion/sql/tests/cases/plan_to_sql.rs index aada560fd884a..a52333e54fac6 100644 --- a/datafusion/sql/tests/cases/plan_to_sql.rs +++ b/datafusion/sql/tests/cases/plan_to_sql.rs @@ -84,6 +84,7 @@ fn roundtrip_statement() -> Result<()> { "select 1;", "select 1 limit 0;", "select ta.j1_id from j1 ta join (select 1 as j1_id) tb on ta.j1_id = tb.j1_id;", + "select ta.j1_id from j1 ta join (select 1 as j1_id) tb using (j1_id);", "select ta.j1_id from j1 ta join (select 1 as j1_id) tb on ta.j1_id = tb.j1_id where ta.j1_id > 1;", "select ta.j1_id from (select 1 as j1_id) ta;", "select ta.j1_id from j1 ta;", @@ -142,6 +143,7 @@ fn roundtrip_statement() -> Result<()> { r#"SELECT id, count(distinct id) over (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), sum(id) OVER (PARTITION BY first_name ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) from person"#, "SELECT id, sum(id) OVER (PARTITION BY first_name ROWS BETWEEN 5 PRECEDING AND 2 FOLLOWING) from person", + "WITH t1 AS (SELECT j1_id AS id, j1_string name FROM j1), t2 AS (SELECT j2_id AS id, j2_string name FROM j2) SELECT * FROM t1 JOIN t2 USING (id, name)", ]; // For each test sql string, we transform as follows: