diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index 9be0941b5d575..3d494f41b3a6a 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -54,6 +54,8 @@ mod null_equality; pub mod parquet_config; pub mod parsers; pub mod pruning; +#[doc(hidden)] +pub mod recursive_schema; pub mod rounding; pub mod scalar; pub mod spans; diff --git a/datafusion/common/src/recursive_schema.rs b/datafusion/common/src/recursive_schema.rs new file mode 100644 index 0000000000000..e236de484aff8 --- /dev/null +++ b/datafusion/common/src/recursive_schema.rs @@ -0,0 +1,252 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Internal helpers for recursive CTE schema reconciliation. +//! +//! Recursive CTE work-table references and children must expose schemas that +//! are conservative for nullability, while preserving every other schema +//! dimension exactly. + +use std::sync::Arc; + +use arrow::datatypes::{FieldRef, Schema, SchemaRef}; + +use crate::{DFSchema, DFSchemaRef, DataFusionError, Result}; + +/// Return an Arrow schema with all fields marked nullable, preserving field and +/// schema metadata. +#[doc(hidden)] +pub fn make_schema_nullable(schema: &Schema) -> SchemaRef { + Arc::new(Schema::new_with_metadata( + schema + .fields() + .iter() + .map(|field| field.as_ref().clone().with_nullable(true)) + .collect::>(), + schema.metadata().clone(), + )) +} + +/// Return a recursive query output schema that preserves `static_schema` except +/// for nullability widened by `recursive_schema`. +/// +/// This helper assumes recursive term expressions have already been coerced to +/// the static term's schema, and only reads field nullability from +/// `recursive_schema`. All other output schema dimensions come from +/// `static_schema`. +#[doc(hidden)] +pub fn recursive_query_output_schema( + static_schema: &DFSchema, + recursive_schema: &DFSchema, +) -> Result { + if static_schema.fields().len() != recursive_schema.fields().len() { + return Err(DataFusionError::Plan(format!( + "Non-recursive term and recursive term must have the same number of columns ({} != {})", + static_schema.fields().len(), + recursive_schema.fields().len() + ))); + } + + let fields = static_schema + .iter() + .zip(recursive_schema.fields()) + .map(|((qualifier, static_field), recursive_field)| { + ( + qualifier.cloned(), + static_field + .as_ref() + .clone() + .with_nullable( + static_field.is_nullable() || recursive_field.is_nullable(), + ) + .into(), + ) + }) + .collect::>(); + + DFSchema::new_with_metadata(fields, static_schema.metadata().clone())? + .with_functional_dependencies(static_schema.functional_dependencies().clone()) + .map(DFSchemaRef::new) +} + +/// Reconcile `logical_schema` with an Arrow schema, but only when the Arrow +/// schema differs by being more nullable. Returns `Ok(None)` if any other +/// schema dimension differs, so callers can report their normal schema error. +#[doc(hidden)] +pub fn reconcile_dfschema_with_schema_nullability( + logical_schema: &DFSchema, + physical_schema: &Schema, +) -> Result> { + if logical_schema.metadata() != physical_schema.metadata() + || logical_schema.fields().len() != physical_schema.fields().len() + { + return Ok(None); + } + + widen_dfschema_nullability_with_fields( + logical_schema, + physical_schema.fields().iter(), + ) +} + +fn widen_dfschema_nullability_with_fields<'a>( + base_schema: &DFSchema, + widening_fields: impl Iterator, +) -> Result> { + let mut widened_nullability = false; + let mut fields = Vec::with_capacity(base_schema.fields().len()); + + for ((qualifier, base_field), widening_field) in + base_schema.iter().zip(widening_fields) + { + if base_field.name() != widening_field.name() + || base_field.data_type() != widening_field.data_type() + || base_field.metadata() != widening_field.metadata() + { + return Ok(None); + } + + widened_nullability |= !base_field.is_nullable() && widening_field.is_nullable(); + fields.push(( + qualifier.cloned(), + base_field + .as_ref() + .clone() + .with_nullable(base_field.is_nullable() || widening_field.is_nullable()) + .into(), + )); + } + + if !widened_nullability { + return Ok(None); + } + + DFSchema::new_with_metadata(fields, base_schema.metadata().clone())? + .with_functional_dependencies(base_schema.functional_dependencies().clone()) + .map(Some) +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use arrow::datatypes::{DataType, Field, Schema}; + + use crate::ToDFSchema as _; + + use super::*; + + #[test] + fn make_schema_nullable_preserves_metadata() { + let schema = Schema::new_with_metadata( + vec![ + Field::new("c1", DataType::Int32, false) + .with_metadata(HashMap::from([("field".into(), "value".into())])), + ], + HashMap::from([("schema".into(), "value".into())]), + ); + + let nullable = make_schema_nullable(&schema); + + assert!(nullable.field(0).is_nullable()); + assert_eq!(nullable.field(0).metadata(), schema.field(0).metadata()); + assert_eq!(nullable.metadata(), schema.metadata()); + } + + #[test] + fn recursive_output_schema_preserves_static_dimensions_and_widens_nullability() { + let static_schema = Schema::new_with_metadata( + vec![ + Field::new("anchor_name", DataType::Int32, false) + .with_metadata(HashMap::from([("field".into(), "value".into())])), + ], + HashMap::from([("schema".into(), "value".into())]), + ) + .to_dfschema_ref() + .unwrap(); + let recursive_schema = Schema::new(vec![Field::new( + "recursive_expr_name", + DataType::Int32, + true, + )]) + .to_dfschema_ref() + .unwrap(); + + let output = + recursive_query_output_schema(&static_schema, &recursive_schema).unwrap(); + + assert_eq!(output.field(0).name(), "anchor_name"); + assert_eq!(output.field(0).data_type(), &DataType::Int32); + assert_eq!( + output.field(0).metadata(), + static_schema.field(0).metadata() + ); + assert_eq!(output.metadata(), static_schema.metadata()); + assert!(output.field(0).is_nullable()); + } + + #[test] + fn reconciliation_only_widens_nullability() { + let logical_schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]) + .to_dfschema_ref() + .unwrap(); + let physical_schema = Schema::new(vec![Field::new("c1", DataType::Int32, true)]); + + let reconciled = + reconcile_dfschema_with_schema_nullability(&logical_schema, &physical_schema) + .unwrap() + .expect("nullability widening should reconcile"); + + assert!(reconciled.field(0).is_nullable()); + assert_eq!(reconciled.field(0).name(), "c1"); + assert_eq!(reconciled.field(0).data_type(), &DataType::Int32); + } + + #[test] + fn reconciliation_rejects_other_mismatches() { + let logical_schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]) + .to_dfschema_ref() + .unwrap(); + + let cases = [ + Schema::new(vec![ + Field::new("c1", DataType::Int32, true), + Field::new("c2", DataType::Int32, true), + ]), + Schema::new(vec![Field::new("different", DataType::Int32, true)]), + Schema::new(vec![Field::new("c1", DataType::Int64, true)]), + Schema::new(vec![ + Field::new("c1", DataType::Int32, true) + .with_metadata(HashMap::from([("key".into(), "value".into())])), + ]), + Schema::new(vec![Field::new("c1", DataType::Int32, true)]) + .with_metadata(HashMap::from([("key".into(), "value".into())])), + ]; + + for physical_schema in cases { + assert!( + reconcile_dfschema_with_schema_nullability( + &logical_schema, + &physical_schema, + ) + .unwrap() + .is_none(), + "should not reconcile unsupported mismatch: {physical_schema:?}" + ); + } + } +} diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 3b2c7a78e898e..07d91a407a0a5 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -66,6 +66,7 @@ use datafusion_common::Column; use datafusion_common::HashMap as DFHashMap; use datafusion_common::display::ToStringifiedPlan; use datafusion_common::format::ExplainAnalyzeCategories; +use datafusion_common::recursive_schema::reconcile_dfschema_with_schema_nullability; use datafusion_common::tree_node::{ Transformed, TreeNode, TreeNodeRecursion, TreeNodeVisitor, }; @@ -115,6 +116,16 @@ use itertools::{Itertools, multiunzip}; use log::debug; use tokio::sync::Mutex; +/// Aggregate planning normally verifies that the physical input schema satisfies +/// the logical input schema exactly. Recursive CTEs are an exception only for +/// nullability widening: logical planning may conservatively expose nullable +/// recursive output after the aggregate's logical input schema was derived. +fn contains_recursive_query_input(plan: &LogicalPlan) -> bool { + plan.exists(|node| Ok(matches!(node, LogicalPlan::RecursiveQuery(_)))) + // Closure always returns Ok + .unwrap() +} + /// Physical query planner that converts a `LogicalPlan` to an /// `ExecutionPlan` suitable for execution. #[async_trait] @@ -987,6 +998,22 @@ impl DefaultPhysicalPlanner { let input_exec = children.one()?; let physical_input_schema = input_exec.schema(); let logical_input_schema = input.as_ref().schema(); + let reconciled_logical_schema; + let logical_input_schema = if schema_satisfied_by( + logical_input_schema.inner(), + &physical_input_schema, + ) || !contains_recursive_query_input(input) + { + logical_input_schema + } else if let Some(schema) = reconcile_dfschema_with_schema_nullability( + logical_input_schema, + &physical_input_schema, + )? { + reconciled_logical_schema = schema; + &reconciled_logical_schema + } else { + logical_input_schema + }; let physical_input_schema_from_logical = logical_input_schema.inner(); if !options.execution.skip_physical_aggregate_schema_check diff --git a/datafusion/core/tests/sql/explain_analyze.rs b/datafusion/core/tests/sql/explain_analyze.rs index b093563d9adda..fbbbee5a31e27 100644 --- a/datafusion/core/tests/sql/explain_analyze.rs +++ b/datafusion/core/tests/sql/explain_analyze.rs @@ -1014,7 +1014,7 @@ async fn parquet_recursive_projection_pushdown() -> Result<()> { SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] RecursiveQueryExec: name=number_series, is_distinct=false CoalescePartitionsExec - ProjectionExec: expr=[id@0 as id, 1 as level] + ProjectionExec: expr=[CAST(id@0 AS Int64) as id, CAST(1 AS Int64) as level] FilterExec: id@0 = 1 RepartitionExec: partitioning=RoundRobinBatch(NUM_CORES), input_partitions=1 DataSourceExec: file_groups={1 group: [[TMP_DIR/hierarchy.parquet]]}, projection=[id], file_type=parquet, predicate=id@0 = 1, pruning_predicate=id_null_count@2 != row_count@3 AND id_min@0 <= 1 AND 1 <= id_max@1, required_guarantees=[id in (1)] diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 017a123eb035b..d99d4ea564cd6 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -53,6 +53,7 @@ use arrow::datatypes::{DataType, Field, FieldRef, Fields, Schema, SchemaRef}; use datafusion_common::display::ToStringifiedPlan; use datafusion_common::file_options::file_type::FileType; use datafusion_common::metadata::FieldMetadata; +use datafusion_common::recursive_schema::recursive_query_output_schema; use datafusion_common::{ Column, Constraints, DFSchema, DFSchemaRef, NullEquality, Result, ScalarValue, TableReference, ToDFSchema, UnnestOptions, exec_err, @@ -66,6 +67,20 @@ use indexmap::IndexSet; /// Default table name for unnamed table pub const UNNAMED_TABLE: &str = "?table?"; +fn plan_with_schema(plan: LogicalPlan, schema: DFSchemaRef) -> Result { + match plan { + LogicalPlan::Projection(Projection { expr, input, .. }) => { + Projection::try_new_with_schema(expr, input, schema) + .map(LogicalPlan::Projection) + } + _ => { + let exprs = plan.schema().iter().map(Expr::from).collect(); + Projection::try_new_with_schema(exprs, Arc::new(plan), schema) + .map(LogicalPlan::Projection) + } + } +} + /// Options for [`LogicalPlanBuilder`] #[derive(Default, Debug, Clone)] pub struct LogicalPlanBuilderOptions { @@ -192,10 +207,19 @@ impl LogicalPlanBuilder { // Ensure that the recursive term has the same field types as the static term let coerced_recursive_term = coerce_plan_expr_for_schema(recursive_term, self.plan.schema())?; + let output_schema = recursive_query_output_schema( + self.plan.schema(), + coerced_recursive_term.schema(), + )?; + let static_term = plan_with_schema( + Arc::unwrap_or_clone(self.plan), + Arc::clone(&output_schema), + )?; + let recursive_term = plan_with_schema(coerced_recursive_term, output_schema)?; Ok(Self::from(LogicalPlan::RecursiveQuery(RecursiveQuery { name, - static_term: self.plan, - recursive_term: Arc::new(coerced_recursive_term), + static_term: Arc::new(static_term), + recursive_term: Arc::new(recursive_term), is_distinct, }))) } diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index db8b82fe87a14..56a0bd5570941 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -354,7 +354,8 @@ impl LogicalPlan { LogicalPlan::Ddl(ddl) => ddl.schema(), LogicalPlan::Unnest(Unnest { schema, .. }) => schema, LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, .. }) => { - // we take the schema of the static term as the schema of the entire recursive query + // The static term is coerced to the recursive output schema when + // building a RecursiveQuery. static_term.schema() } } @@ -1080,12 +1081,9 @@ impl LogicalPlan { }) => { self.assert_no_expressions(expr)?; let (static_term, recursive_term) = self.only_two_inputs(inputs)?; - Ok(LogicalPlan::RecursiveQuery(RecursiveQuery { - name: name.clone(), - static_term: Arc::new(static_term), - recursive_term: Arc::new(recursive_term), - is_distinct: *is_distinct, - })) + LogicalPlanBuilder::from(static_term) + .to_recursive_query(name.clone(), recursive_term, *is_distinct)? + .build() } LogicalPlan::Analyze(a) => { self.assert_no_expressions(expr)?; diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 8228e8e6f2ff0..4b282129357ab 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -56,8 +56,7 @@ use datafusion_datasource_json::file_format::{ #[cfg(feature = "parquet")] use datafusion_datasource_parquet::file_format::{ParquetFormat, ParquetFormatFactory}; use datafusion_expr::{ - AggregateUDF, DmlStatement, FetchType, HigherOrderUDF, RecursiveQuery, SkipType, - TableSource, Unnest, + AggregateUDF, DmlStatement, FetchType, HigherOrderUDF, SkipType, TableSource, Unnest, }; use datafusion_expr::{ DistinctOn, DropView, Expr, LogicalPlan, LogicalPlanBuilder, ScalarUDF, SortExpr, @@ -1085,12 +1084,13 @@ impl AsLogicalPlan for LogicalPlanNode { ))? .try_into_logical_plan(ctx, extension_codec)?; - Ok(LogicalPlan::RecursiveQuery(RecursiveQuery { - name: recursive_query_node.name.clone(), - static_term: Arc::new(static_term), - recursive_term: Arc::new(recursive_term), - is_distinct: recursive_query_node.is_distinct, - })) + LogicalPlanBuilder::from(static_term) + .to_recursive_query( + recursive_query_node.name.clone(), + recursive_term, + recursive_query_node.is_distinct, + )? + .build() } LogicalPlanType::CteWorkTableScan(cte_work_table_scan_node) => { let CteWorkTableScanNode { name, schema } = cte_work_table_scan_node; diff --git a/datafusion/sql/src/cte.rs b/datafusion/sql/src/cte.rs index 18766d7056355..24e31d6ba5d03 100644 --- a/datafusion/sql/src/cte.rs +++ b/datafusion/sql/src/cte.rs @@ -20,8 +20,8 @@ use std::sync::Arc; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_common::{ - Result, not_impl_err, plan_err, - tree_node::{TreeNode, TreeNodeRecursion}, + Result, not_impl_err, plan_err, recursive_schema::make_schema_nullable, + tree_node::TreeNode, }; use datafusion_expr::{LogicalPlan, LogicalPlanBuilder, TableSource}; use sqlparser::ast::{Query, SetExpr, SetOperator, With}; @@ -133,9 +133,10 @@ impl SqlToRel<'_, S> { // ---------- Step 2: Create a temporary relation ------------------ // Step 2.1: Create a table source for the temporary relation - let work_table_source = self - .context_provider - .create_cte_work_table(cte_name, Arc::clone(static_plan.schema().inner()))?; + let work_table_source = self.context_provider.create_cte_work_table( + cte_name, + make_schema_nullable(static_plan.schema().inner()), + )?; // Step 2.2: Create a temporary relation logical plan that will be used // as the input to the recursive term @@ -188,17 +189,9 @@ fn has_work_table_reference( plan: &LogicalPlan, work_table_source: &Arc, ) -> bool { - let mut has_reference = false; - plan.apply(|node| { - if let LogicalPlan::TableScan(scan) = node - && Arc::ptr_eq(&scan.source, work_table_source) - { - has_reference = true; - return Ok(TreeNodeRecursion::Stop); - } - Ok(TreeNodeRecursion::Continue) + plan.exists(|node| { + Ok(matches!(node, LogicalPlan::TableScan(scan) if Arc::ptr_eq(&scan.source, work_table_source))) }) - // Closure always return Ok - .unwrap(); - has_reference + // Closure always returns Ok + .unwrap() } diff --git a/datafusion/sqllogictest/test_files/cte.slt b/datafusion/sqllogictest/test_files/cte.slt index d13e0d4f085e9..a906063ace47a 100644 --- a/datafusion/sqllogictest/test_files/cte.slt +++ b/datafusion/sqllogictest/test_files/cte.slt @@ -171,7 +171,7 @@ logical_plan 07)--------TableScan: nodes projection=[id] physical_plan 01)RecursiveQueryExec: name=nodes, is_distinct=false -02)--ProjectionExec: expr=[1 as id] +02)--ProjectionExec: expr=[CAST(1 AS Int64) as id] 03)----PlaceholderRowExec 04)--CoalescePartitionsExec 05)----ProjectionExec: expr=[id@0 + 1 as id] @@ -195,6 +195,21 @@ SELECT * FROM nodes 3 4 +# recursive self-reference must use conservative nullability even when the +# anchor term uses non-null literals. Otherwise optimizer nullability-based +# simplification can remove this semantically required IS NOT NULL guard. +query II rowsort +WITH RECURSIVE t(a, b) AS ( + SELECT 0 AS a, 0 AS b + UNION ALL + SELECT b AS a, CAST(NULL AS INT) AS b FROM t WHERE a IS NOT NULL +) +SELECT * FROM t +---- +0 0 +0 NULL +NULL NULL + # deduplicating recursive CTE with two variables works query II WITH RECURSIVE ranges AS ( @@ -699,7 +714,7 @@ WITH RECURSIVE region_sales AS ( SELECT s.salesperson_id AS salesperson_id, SUM(s.sale_amount) AS amount, - SUM(0) as level + 0 as level FROM sales s GROUP BY @@ -1079,7 +1094,7 @@ logical_plan 07)--------TableScan: numbers projection=[n] physical_plan 01)RecursiveQueryExec: name=numbers, is_distinct=false -02)--ProjectionExec: expr=[1 as n] +02)--ProjectionExec: expr=[CAST(1 AS Int64) as n] 03)----PlaceholderRowExec 04)--CoalescePartitionsExec 05)----ProjectionExec: expr=[n@0 + 1 as n] @@ -1104,7 +1119,7 @@ logical_plan 07)--------TableScan: numbers projection=[n] physical_plan 01)RecursiveQueryExec: name=numbers, is_distinct=false -02)--ProjectionExec: expr=[1 as n] +02)--ProjectionExec: expr=[CAST(1 AS Int64) as n] 03)----PlaceholderRowExec 04)--CoalescePartitionsExec 05)----ProjectionExec: expr=[n@0 + 1 as n] @@ -1161,7 +1176,7 @@ logical_plan physical_plan 01)GlobalLimitExec: skip=0, fetch=5 02)--RecursiveQueryExec: name=r, is_distinct=false -03)----ProjectionExec: expr=[0 as k, 0 as v] +03)----ProjectionExec: expr=[CAST(0 AS Int64) as k, CAST(0 AS Int64) as v] 04)------PlaceholderRowExec 05)----SortExec: TopK(fetch=1), expr=[v@1 ASC NULLS LAST], preserve_partitioning=[false] 06)------WorkTableExec: name=r diff --git a/datafusion/sqllogictest/test_files/explain_tree.slt b/datafusion/sqllogictest/test_files/explain_tree.slt index 46d01f39a920b..1a2c4ebab437f 100644 --- a/datafusion/sqllogictest/test_files/explain_tree.slt +++ b/datafusion/sqllogictest/test_files/explain_tree.slt @@ -1582,7 +1582,7 @@ physical_plan 04)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ 05)│ ProjectionExec ││ CoalescePartitionsExec │ 06)│ -------------------- ││ │ -07)│ id: 1 ││ │ +07)│ id: CAST(1 AS Int64) ││ │ 08)└─────────────┬─────────────┘└─────────────┬─────────────┘ 09)┌─────────────┴─────────────┐┌─────────────┴─────────────┐ 10)│ PlaceholderRowExec ││ ProjectionExec │