From 096128aa5dd366b88892c2ac58aafc50e7376949 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Fri, 15 May 2026 00:27:14 +0100 Subject: [PATCH] Teach zoned pruning to lower StatFn Signed-off-by: Nicholas Gates --- vortex-array/src/stats/rewrite/builtins.rs | 143 +++++++-- vortex-layout/public-api.lock | 2 - vortex-layout/src/layouts/zoned/pruning.rs | 25 +- vortex-layout/src/layouts/zoned/zone_map.rs | 320 ++++++++++++++++++-- 4 files changed, 419 insertions(+), 71 deletions(-) diff --git a/vortex-array/src/stats/rewrite/builtins.rs b/vortex-array/src/stats/rewrite/builtins.rs index abe1a805e0d..c67906a0f2d 100644 --- a/vortex-array/src/stats/rewrite/builtins.rs +++ b/vortex-array/src/stats/rewrite/builtins.rs @@ -6,6 +6,7 @@ use vortex_error::VortexResult; use crate::aggregate_fn::AggregateFnRef; use crate::aggregate_fn::AggregateFnVTableExt; use crate::aggregate_fn::EmptyOptions as AggregateEmptyOptions; +use crate::aggregate_fn::fns::all_non_nan::AllNonNan; use crate::aggregate_fn::fns::all_non_null::AllNonNull; use crate::aggregate_fn::fns::all_null::AllNull; use crate::dtype::DType; @@ -79,16 +80,25 @@ impl StatsRewriteRule for BinaryStatsRewrite { let left = min(lhs).zip(max(rhs)).map(|(a, b)| gt(a, b)); let right = min(rhs).zip(max(lhs)).map(|(a, b)| gt(a, b)); or_collect(left.into_iter().chain(right)) + .map(|value_predicate| with_nan_predicate(lhs, rhs, value_predicate)) } Operator::NotEq => min(lhs).zip(max(rhs)).zip(max(lhs).zip(min(rhs))).map( |((min_lhs, max_rhs), (max_lhs, min_rhs))| { - and(eq(min_lhs, max_rhs), eq(max_lhs, min_rhs)) + with_nan_predicate(lhs, rhs, and(eq(min_lhs, max_rhs), eq(max_lhs, min_rhs))) }, ), - Operator::Gt => max(lhs).zip(min(rhs)).map(|(a, b)| lt_eq(a, b)), - Operator::Gte => max(lhs).zip(min(rhs)).map(|(a, b)| lt(a, b)), - Operator::Lt => min(lhs).zip(max(rhs)).map(|(a, b)| gt_eq(a, b)), - Operator::Lte => min(lhs).zip(max(rhs)).map(|(a, b)| gt(a, b)), + Operator::Gt => max(lhs) + .zip(min(rhs)) + .map(|(a, b)| with_nan_predicate(lhs, rhs, lt_eq(a, b))), + Operator::Gte => max(lhs) + .zip(min(rhs)) + .map(|(a, b)| with_nan_predicate(lhs, rhs, lt(a, b))), + Operator::Lt => min(lhs) + .zip(max(rhs)) + .map(|(a, b)| with_nan_predicate(lhs, rhs, gt_eq(a, b))), + Operator::Lte => min(lhs) + .zip(max(rhs)) + .map(|(a, b)| with_nan_predicate(lhs, rhs, gt(a, b))), Operator::And => { let lhs_falsifier = ctx.falsify(lhs)?; let rhs_falsifier = ctx.falsify(rhs)?; @@ -335,7 +345,8 @@ impl StatsRewriteRule for ListContainsStatsRewrite { lt(value_max.clone(), lit(value.clone())), gt(value_min.clone(), lit(value.clone())), ) - }))) + })) + .map(|value_predicate| with_all_non_nan_predicate([needle], value_predicate))) } } @@ -379,11 +390,42 @@ fn all_non_null(expr: &Expression) -> Expression { stat_fn(expr.clone(), AllNonNull.bind(AggregateEmptyOptions)) } +// Min/max do not order NaN values, so comparison rewrites are only sound when every +// candidate value is known to be non-NaN. Literal non-NaN values and casts to +// non-float types are already proven, so they do not need a stat expression. +fn all_non_nan_stat(expr: &Expression) -> Option { + if let Some(scalar) = expr.as_opt::() { + let value = scalar.as_primitive_opt()?; + return value.is_nan().then(|| lit(false)); + } + + if let Some(dtype) = expr.as_opt::() { + if !has_nans(dtype) { + return None; + } + + return all_non_nan_stat(expr.child(0)); + } + + Some(stat_fn(expr.clone(), AllNonNan.bind(AggregateEmptyOptions))) +} + +fn has_nans(dtype: &DType) -> bool { + matches!(dtype, DType::Primitive(ptype, _) if ptype.is_float()) +} + fn stat_expr(expr: &Expression, stat: Stat) -> Option { if let Some(literal) = literal_stat(expr, stat) { return Some(literal); } + // `literal_stat` handled every stat that is defined for literals. If it returned + // `None`, the requested stat is not meaningful for this literal, such as + // `NaNCount` over a non-float value, so do not manufacture `stat(literal, ...)`. + if expr.is::() { + return None; + } + if let Some(dtype) = expr.as_opt::() { return cast_stat(expr.child(0), dtype, stat); } @@ -392,17 +434,46 @@ fn stat_expr(expr: &Expression, stat: Stat) -> Option { .map(|aggregate_fn| stat_fn(expr.clone(), aggregate_fn)) } +fn with_nan_predicate( + lhs: &Expression, + rhs: &Expression, + value_predicate: Expression, +) -> Expression { + with_all_non_nan_predicate([lhs, rhs], value_predicate) +} + +fn with_all_non_nan_predicate<'a>( + exprs: impl IntoIterator, + value_predicate: Expression, +) -> Expression { + let nan_predicate = and_collect(exprs.into_iter().filter_map(all_non_nan_stat)); + + match nan_predicate { + Some(nan_check) => and(nan_check, value_predicate), + // No possible NaN-bearing expression remains, so the value predicate is + // already guarded. + None => value_predicate, + } +} + fn literal_stat(expr: &Expression, stat: Stat) -> Option { let scalar = expr.as_opt::()?; match stat { Stat::Min | Stat::Max => Some(lit(scalar.clone())), Stat::NullCount => Some(lit(if scalar.is_null() { 1u64 } else { 0u64 })), + Stat::NaNCount => { + let value = scalar.as_primitive_opt()?; + if !value.ptype().is_float() { + return None; + } + + Some(lit(if value.is_nan() { 1u64 } else { 0u64 })) + } Stat::IsConstant | Stat::IsSorted | Stat::IsStrictSorted | Stat::Sum - | Stat::UncompressedSizeInBytes - | Stat::NaNCount => None, + | Stat::UncompressedSizeInBytes => None, } } @@ -431,6 +502,9 @@ mod tests { use super::all_non_null; use super::all_null; use crate::aggregate_fn::AggregateFnRef; + use crate::aggregate_fn::AggregateFnVTableExt; + use crate::aggregate_fn::EmptyOptions as AggregateEmptyOptions; + use crate::aggregate_fn::fns::all_non_nan::AllNonNan; use crate::dtype::DType; use crate::dtype::Nullability; use crate::dtype::PType; @@ -471,20 +545,30 @@ mod tests { StatFn.new_expr(StatOptions::new(aggregate_fn), [expr]) } + fn nan_free(expr: Expression) -> Expression { + stat_fn(expr, AllNonNan.bind(AggregateEmptyOptions)) + } + #[test] fn rewrites_comparison_falsifier() -> VortexResult<()> { let expr = gt(col("a"), lit(10)); assert_eq!( expr.falsify(&SESSION)?, - Some(lt_eq(stat(col("a"), Stat::Max), lit(10))) + Some(and( + nan_free(col("a")), + lt_eq(stat(col("a"), Stat::Max), lit(10)), + )) ); let expr = eq(col("a"), col("b")); assert_eq!( expr.falsify(&SESSION)?, - Some(or( - gt(stat(col("a"), Stat::Min), stat(col("b"), Stat::Max)), - gt(stat(col("b"), Stat::Min), stat(col("a"), Stat::Max)), + Some(and( + and(nan_free(col("a")), nan_free(col("b"))), + or( + gt(stat(col("a"), Stat::Min), stat(col("b"), Stat::Max)), + gt(stat(col("b"), Stat::Min), stat(col("a"), Stat::Max)), + ), )) ); Ok(()) @@ -496,8 +580,14 @@ mod tests { assert_eq!( expr.falsify(&SESSION)?, Some(or( - lt_eq(stat(col("a"), Stat::Max), lit(10)), - gt_eq(stat(col("a"), Stat::Min), lit(50)), + and( + nan_free(col("a")), + lt_eq(stat(col("a"), Stat::Max), lit(10)), + ), + and( + nan_free(col("a")), + gt_eq(stat(col("a"), Stat::Min), lit(50)), + ), )) ); Ok(()) @@ -518,8 +608,8 @@ mod tests { assert_eq!( expr.falsify(&SESSION)?, Some(or( - gt(lit(10), stat(col("a"), Stat::Max)), - gt(stat(col("a"), Stat::Min), lit(50)), + and(nan_free(col("a")), gt(lit(10), stat(col("a"), Stat::Max)),), + and(nan_free(col("a")), gt(stat(col("a"), Stat::Min), lit(50)),), )) ); Ok(()) @@ -597,20 +687,23 @@ mod tests { assert_eq!( expr.falsify(&SESSION)?, Some(and( + nan_free(col("a")), and( - or( - lt(stat(col("a"), Stat::Max), lit(1i32)), - gt(stat(col("a"), Stat::Min), lit(1i32)), + and( + or( + lt(stat(col("a"), Stat::Max), lit(1i32)), + gt(stat(col("a"), Stat::Min), lit(1i32)), + ), + or( + lt(stat(col("a"), Stat::Max), lit(2i32)), + gt(stat(col("a"), Stat::Min), lit(2i32)), + ), ), or( - lt(stat(col("a"), Stat::Max), lit(2i32)), - gt(stat(col("a"), Stat::Min), lit(2i32)), + lt(stat(col("a"), Stat::Max), lit(3i32)), + gt(stat(col("a"), Stat::Min), lit(3i32)), ), ), - or( - lt(stat(col("a"), Stat::Max), lit(3i32)), - gt(stat(col("a"), Stat::Min), lit(3i32)), - ), )) ); Ok(()) diff --git a/vortex-layout/public-api.lock b/vortex-layout/public-api.lock index 1c6ae70dd6f..0e67a96e38e 100644 --- a/vortex-layout/public-api.lock +++ b/vortex-layout/public-api.lock @@ -798,8 +798,6 @@ impl vortex_layout::layouts::zoned::zone_map::ZoneMap pub fn vortex_layout::layouts::zoned::zone_map::ZoneMap::dtype_for_stats_table(&vortex_array::dtype::DType, &[vortex_array::expr::stats::Stat]) -> vortex_array::dtype::DType -pub unsafe fn vortex_layout::layouts::zoned::zone_map::ZoneMap::new_unchecked(vortex_array::arrays::struct_::vtable::StructArray, u64, u64) -> Self - pub fn vortex_layout::layouts::zoned::zone_map::ZoneMap::prune(&self, &vortex_array::expr::expression::Expression, &vortex_session::VortexSession) -> vortex_error::VortexResult pub fn vortex_layout::layouts::zoned::zone_map::ZoneMap::try_new(vortex_array::dtype::DType, vortex_array::arrays::struct_::vtable::StructArray, alloc::sync::Arc<[vortex_array::expr::stats::Stat]>, u64, u64) -> vortex_error::VortexResult diff --git a/vortex-layout/src/layouts/zoned/pruning.rs b/vortex-layout/src/layouts/zoned/pruning.rs index 630a7d3ad2e..c10193ab6ef 100644 --- a/vortex-layout/src/layouts/zoned/pruning.rs +++ b/vortex-layout/src/layouts/zoned/pruning.rs @@ -16,12 +16,9 @@ use tracing::trace; use vortex_array::MaskFuture; use vortex_array::VortexSessionExecute; use vortex_array::arrays::StructArray; -use vortex_array::dtype::FieldPath; -use vortex_array::dtype::FieldPathSet; +use vortex_array::dtype::DType; use vortex_array::expr::Expression; -use vortex_array::expr::pruning::checked_pruning_expr; use vortex_array::expr::root; -use vortex_array::expr::stats::Stat; use vortex_array::scalar_fn::fns::dynamic::DynamicExprUpdates; use vortex_error::SharedVortexResult; use vortex_error::VortexExpect; @@ -43,7 +40,7 @@ pub(super) struct PruningState { zone_count: usize, row_count: u64, zone_len: u64, - present_stats: Arc<[Stat]>, + dtype: DType, lazy_children: Arc, session: VortexSession, @@ -62,7 +59,7 @@ impl PruningState { zone_count: layout.nzones(), row_count: layout.row_count(), zone_len: layout.zone_len() as u64, - present_stats: Arc::clone(layout.present_stats()), + dtype: layout.dtype().clone(), lazy_children, session, pruning_result: Default::default(), @@ -119,13 +116,12 @@ impl PruningState { self.pruning_predicates .entry(expr.clone()) .or_default() - .get_or_init(move || { - let available_stats = FieldPathSet::from_iter( - self.present_stats - .iter() - .map(|stat| FieldPath::from_name(stat.name())), - ); - checked_pruning_expr(&expr, &available_stats).map(|(expr, _)| expr) + .get_or_init(move || match expr.falsify(&self.session) { + Ok(predicate) => predicate, + Err(error) => { + trace!(%expr, %error, "failed to construct stats rewrite predicate"); + None + } }) .clone() } @@ -147,13 +143,14 @@ impl PruningState { let session = self.session.clone(); let zone_len = self.zone_len; let row_count = self.row_count; + let dtype = self.dtype.clone(); async move { let mut ctx = session.create_execution_ctx(); let zones_array = zones_eval.await?.execute::(&mut ctx)?; // SAFETY: zoned layout validation ensures the zones child matches the expected // stats-table schema for `present_stats`. - Ok(unsafe { ZoneMap::new_unchecked(zones_array, zone_len, row_count) }) + Ok(unsafe { ZoneMap::new_unchecked(dtype, zones_array, zone_len, row_count) }) } .map_err(Arc::new) .boxed() diff --git a/vortex-layout/src/layouts/zoned/zone_map.rs b/vortex-layout/src/layouts/zoned/zone_map.rs index 16360ff287d..8a892a49d1d 100644 --- a/vortex-layout/src/layouts/zoned/zone_map.rs +++ b/vortex-layout/src/layouts/zoned/zone_map.rs @@ -8,18 +8,38 @@ use std::sync::Arc; use vortex_array::ArrayRef; use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; +use vortex_array::aggregate_fn::fns::all_nan::AllNan; +use vortex_array::aggregate_fn::fns::all_non_nan::AllNonNan; +use vortex_array::aggregate_fn::fns::all_non_null::AllNonNull; +use vortex_array::aggregate_fn::fns::all_null::AllNull; +use vortex_array::aggregate_fn::fns::nan_count::NanCount; use vortex_array::arrays::ConstantArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::StructArray; +use vortex_array::arrays::struct_::StructArrayExt; use vortex_array::dtype::DType; +use vortex_array::dtype::Nullability; use vortex_array::expr::Expression; +use vortex_array::expr::eq; +use vortex_array::expr::get_item; +use vortex_array::expr::is_root; +use vortex_array::expr::lit; +use vortex_array::expr::root; use vortex_array::expr::stats::Stat; +use vortex_array::expr::traversal::NodeExt; +use vortex_array::expr::traversal::Transformed; +use vortex_array::scalar::Scalar; +use vortex_array::scalar_fn::EmptyOptions; +use vortex_array::scalar_fn::ScalarFnVTableExt; +use vortex_array::scalar_fn::fns::stat::StatFn; +use vortex_array::scalar_fn::internal::row_count::RowCount; use vortex_array::scalar_fn::internal::row_count::contains_row_count; use vortex_array::scalar_fn::internal::row_count::substitute_row_count; use vortex_array::validity::Validity; use vortex_buffer::buffer; use vortex_error::VortexResult; use vortex_error::vortex_bail; +use vortex_error::vortex_err; use vortex_mask::Mask; use vortex_runend::RunEnd; use vortex_session::VortexSession; @@ -32,6 +52,8 @@ use crate::layouts::zoned::schema::stats_table_dtype; /// Note that it's possible for the zone map to have no statistics. #[derive(Clone)] pub struct ZoneMap { + // The dtype of the data column this zone map describes. + column_dtype: DType, // The struct array backing the zone map array: StructArray, // The length of each zone in the zone map. @@ -56,16 +78,17 @@ impl ZoneMap { } // SAFETY: We checked that the array matches the expected stats-table schema. - Ok(unsafe { Self::new_unchecked(array, zone_len, row_count) }) + Ok(unsafe { Self::new_unchecked(column_dtype, array, zone_len, row_count) }) } - /// Creates [`ZoneMap`] without validating return array against expected stats. - /// - /// # Safety - /// - /// Assumes that the input struct array has the correct statistics as fields. Or in other words, - pub unsafe fn new_unchecked(array: StructArray, zone_len: u64, row_count: u64) -> Self { + pub(super) unsafe fn new_unchecked( + column_dtype: DType, + array: StructArray, + zone_len: u64, + row_count: u64, + ) -> Self { Self { + column_dtype, array, zone_len, row_count, @@ -82,21 +105,23 @@ impl ZoneMap { /// Apply a pruning predicate to this zone map. /// - /// `predicate` should be the result of converting a filter with - /// [`checked_pruning_expr`]. The returned mask has one value per zone, where + /// `predicate` should be a stats rewrite expression such as the result of + /// [`Expression::falsify`]. The returned mask has one value per zone, where /// `true` means the zone cannot contain matching rows and can be skipped. /// /// If the predicate contains [`row_count`][vortex_array::scalar_fn::internal::row_count] /// placeholders, they are replaced after [`ArrayRef::apply`] with per-zone /// counts derived from `zone_len` and `row_count`. Uniform zones use a /// [`ConstantArray`]; a short final zone uses a run-end encoded array. - /// - /// [`checked_pruning_expr`]: vortex_array::expr::pruning::checked_pruning_expr + /// `row_count` is a layout property rather than a stored stats field, and the + /// final zone may be shorter than the nominal zone length, so it is materialized + /// only after the predicate has been lowered to the zone-map table. pub fn prune(&self, predicate: &Expression, session: &VortexSession) -> VortexResult { let mut ctx = session.create_execution_ctx(); let num_zones = self.array.len(); + let predicate = self.lower_stats(predicate.clone())?; - let applied = self.array.clone().into_array().apply(predicate)?; + let applied = self.array.clone().into_array().apply(&predicate)?; if num_zones == 0 || !contains_row_count(&applied) { return applied.execute::(&mut ctx); @@ -106,6 +131,105 @@ impl ZoneMap { let substituted = substitute_row_count(applied, &row_count_array)?; substituted.execute::(&mut ctx) } + + fn lower_stats(&self, predicate: Expression) -> VortexResult { + predicate + .transform_down(|expr| { + if expr.is::() { + return self.lower_stat_fn(expr).map(Transformed::yes); + } + + Ok(Transformed::no(expr)) + }) + .map(Transformed::into_inner) + } + + fn lower_stat_fn(&self, expr: Expression) -> VortexResult { + let options = expr.as_::(); + let input = expr.child(0); + let input_dtype = input.return_dtype(&self.column_dtype)?; + let input_is_root = is_root(input); + + if options.aggregate_fn().is::() { + if !has_nans(&input_dtype) { + return Ok(lit(false)); + } + if !input_is_root { + return Ok(null_expr(DType::Bool(Nullability::NonNullable))); + } + return Ok(eq(self.stat_field_expr(Stat::NaNCount)?, row_count_expr())); + } + + if options.aggregate_fn().is::() { + if !has_nans(&input_dtype) { + return Ok(lit(true)); + } + if !input_is_root { + return Ok(null_expr(DType::Bool(Nullability::NonNullable))); + } + return Ok(eq(self.stat_field_expr(Stat::NaNCount)?, lit(0u64))); + } + + if options.aggregate_fn().is::() && !has_nans(&input_dtype) { + return Ok(lit(0u64)); + } + + let return_dtype = options + .aggregate_fn() + .return_dtype(&input_dtype) + .ok_or_else(|| { + vortex_err!( + "Aggregate function {} does not support input dtype {}", + options.aggregate_fn(), + input_dtype + ) + })?; + + if !input_is_root { + return Ok(null_expr(return_dtype)); + } + + if options.aggregate_fn().is::() { + return Ok(eq(self.stat_field_expr(Stat::NullCount)?, row_count_expr())); + } + + if options.aggregate_fn().is::() { + return Ok(eq(self.stat_field_expr(Stat::NullCount)?, lit(0u64))); + } + + let Some(stat) = Stat::from_aggregate_fn(options.aggregate_fn()) else { + return Ok(null_expr(return_dtype)); + }; + + self.stat_field_expr(stat) + } + + fn stat_field_expr(&self, stat: Stat) -> VortexResult { + if self.array.unmasked_field_by_name_opt(stat.name()).is_some() { + return Ok(get_item(stat.name(), root())); + } + + let Some(dtype) = stat.dtype(&self.column_dtype) else { + vortex_bail!( + "Stat {} does not support column dtype {}", + stat, + self.column_dtype + ); + }; + Ok(null_expr(dtype)) + } +} + +fn row_count_expr() -> Expression { + RowCount.new_expr(EmptyOptions, []) +} + +fn null_expr(dtype: DType) -> Expression { + lit(Scalar::null(dtype.as_nullable())) +} + +fn has_nans(dtype: &DType) -> bool { + matches!(dtype, DType::Primitive(ptype, _) if ptype.is_float()) } /// Build per-zone row counts for a zone map. @@ -145,17 +269,19 @@ mod tests { use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::StructArray; use vortex_array::assert_arrays_eq; - use vortex_array::dtype::FieldPath; - use vortex_array::dtype::FieldPathSet; + use vortex_array::dtype::FieldNames; use vortex_array::dtype::PType; use vortex_array::expr::gt; use vortex_array::expr::gt_eq; use vortex_array::expr::is_not_null; use vortex_array::expr::lit; use vortex_array::expr::lt; - use vortex_array::expr::pruning::checked_pruning_expr; use vortex_array::expr::root; use vortex_array::expr::stats::Stat; + use vortex_array::stats::all_nan; + use vortex_array::stats::all_non_nan; + use vortex_array::stats::all_non_null; + use vortex_array::stats::all_null; use vortex_array::validity::Validity; use vortex_buffer::buffer; @@ -164,12 +290,6 @@ mod tests { #[test] fn test_zone_map_prunes() { - // All stats that are known at pruning time. - let stats = FieldPathSet::from_iter([ - FieldPath::from_iter([Stat::Min.name().into()]), - FieldPath::from_iter([Stat::Max.name().into()]), - ]); - // Construct a zone map with 3 zones: // // +----------+----------+ @@ -211,7 +331,7 @@ mod tests { // A >= 6 // => A.max < 6 let expr = gt_eq(root(), lit(6i32)); - let (pruning_expr, _) = checked_pruning_expr(&expr, &stats).unwrap(); + let pruning_expr = expr.falsify(&SESSION).unwrap().unwrap(); let mask = zone_map.prune(&pruning_expr, &SESSION).unwrap(); assert_arrays_eq!( mask.into_array(), @@ -221,7 +341,7 @@ mod tests { // A > 5 // => A.max <= 5 let expr = gt(root(), lit(5i32)); - let (pruning_expr, _) = checked_pruning_expr(&expr, &stats).unwrap(); + let pruning_expr = expr.falsify(&SESSION).unwrap().unwrap(); let mask = zone_map.prune(&pruning_expr, &SESSION).unwrap(); assert_arrays_eq!( mask.into_array(), @@ -231,7 +351,7 @@ mod tests { // A < 2 // => A.min >= 2 let expr = lt(root(), lit(2i32)); - let (pruning_expr, _) = checked_pruning_expr(&expr, &stats).unwrap(); + let pruning_expr = expr.falsify(&SESSION).unwrap().unwrap(); let mask = zone_map.prune(&pruning_expr, &SESSION).unwrap(); assert_arrays_eq!(mask.into_array(), BoolArray::from_iter([false, true, true])); } @@ -251,10 +371,8 @@ mod tests { ) .unwrap(); - let available_stats = - FieldPathSet::from_iter([FieldPath::from_iter([Stat::NullCount.name().into()])]); let expr = is_not_null(root()); - let (pruning_expr, _) = checked_pruning_expr(&expr, &available_stats).unwrap(); + let pruning_expr = expr.falsify(&SESSION).unwrap().unwrap(); let mask = zone_map.prune(&pruning_expr, &SESSION).unwrap(); assert_arrays_eq!( @@ -263,6 +381,150 @@ mod tests { ); } + #[test] + fn all_null_stat_fn_lowers_to_null_count_and_row_count() { + let zone_map = ZoneMap::try_new( + PType::U64.into(), + StructArray::from_fields(&[( + "null_count", + PrimitiveArray::new(buffer![0u64, 4, 2], Validity::AllValid).into_array(), + )]) + .unwrap(), + Arc::new([Stat::NullCount]), + 4, + 10, + ) + .unwrap(); + + let mask = zone_map.prune(&all_null(root()), &SESSION).unwrap(); + assert_arrays_eq!(mask.into_array(), BoolArray::from_iter([false, true, true])); + } + + #[test] + fn all_non_null_stat_fn_lowers_to_null_count() { + let zone_map = ZoneMap::try_new( + PType::U64.into(), + StructArray::from_fields(&[( + "null_count", + PrimitiveArray::new(buffer![0u64, 4, 2], Validity::AllValid).into_array(), + )]) + .unwrap(), + Arc::new([Stat::NullCount]), + 4, + 10, + ) + .unwrap(); + + let mask = zone_map.prune(&all_non_null(root()), &SESSION).unwrap(); + assert_arrays_eq!( + mask.into_array(), + BoolArray::from_iter([true, false, false]) + ); + } + + #[test] + fn non_float_nan_stat_fns_lower_to_constants() { + let zone_map = ZoneMap::try_new( + PType::I32.into(), + StructArray::try_new(FieldNames::empty(), vec![], 2, Validity::NonNullable).unwrap(), + Arc::new([]), + 4, + 8, + ) + .unwrap(); + + let mask = zone_map.prune(&all_nan(root()), &SESSION).unwrap(); + assert_arrays_eq!(mask.into_array(), BoolArray::from_iter([false, false])); + + let mask = zone_map.prune(&all_non_nan(root()), &SESSION).unwrap(); + assert_arrays_eq!(mask.into_array(), BoolArray::from_iter([true, true])); + } + + #[test] + fn unavailable_stat_fn_lowers_to_unknown_mask() { + let zone_map = ZoneMap::try_new( + PType::U64.into(), + StructArray::try_new(FieldNames::empty(), vec![], 3, Validity::NonNullable).unwrap(), + Arc::new([]), + 4, + 10, + ) + .unwrap(); + + let mask = zone_map.prune(&all_non_null(root()), &SESSION).unwrap(); + assert_arrays_eq!( + mask.into_array(), + BoolArray::from_iter([false, false, false]) + ); + + let expr = gt(root(), lit(5u64)); + let pruning_expr = expr.falsify(&SESSION).unwrap().unwrap(); + let mask = zone_map.prune(&pruning_expr, &SESSION).unwrap(); + assert_arrays_eq!( + mask.into_array(), + BoolArray::from_iter([false, false, false]) + ); + } + + #[test] + fn float_min_max_stat_fn_requires_nan_count() { + let zone_map = ZoneMap::try_new( + PType::F32.into(), + StructArray::from_fields(&[ + ( + "max", + PrimitiveArray::new(buffer![5.0f32, 6.0, 7.0], Validity::AllValid).into_array(), + ), + ( + "max_is_truncated", + BoolArray::from_iter([false, false, false]).into_array(), + ), + ]) + .unwrap(), + Arc::new([Stat::Max]), + 4, + 12, + ) + .unwrap(); + + let expr = gt(root(), lit(5.0f32)); + let pruning_expr = expr.falsify(&SESSION).unwrap().unwrap(); + let mask = zone_map.prune(&pruning_expr, &SESSION).unwrap(); + assert_arrays_eq!( + mask.into_array(), + BoolArray::from_iter([false, false, false]) + ); + + let zone_map = ZoneMap::try_new( + PType::F32.into(), + StructArray::from_fields(&[ + ( + "max", + PrimitiveArray::new(buffer![5.0f32, 6.0, 7.0], Validity::AllValid).into_array(), + ), + ( + "max_is_truncated", + BoolArray::from_iter([false, false, false]).into_array(), + ), + ( + "nan_count", + PrimitiveArray::new(buffer![0u64, 0, 0], Validity::AllValid).into_array(), + ), + ]) + .unwrap(), + Arc::new([Stat::Max, Stat::NaNCount]), + 4, + 12, + ) + .unwrap(); + + let mask = zone_map.prune(&pruning_expr, &SESSION).unwrap(); + assert_arrays_eq!( + mask.into_array(), + BoolArray::from_iter([true, false, false]) + ); + } + #[test] fn row_count_prunes_all_null_uniform_zones() { let zone_map = ZoneMap::try_new( @@ -278,10 +540,8 @@ mod tests { ) .unwrap(); - let available_stats = - FieldPathSet::from_iter([FieldPath::from_iter([Stat::NullCount.name().into()])]); let expr = is_not_null(root()); - let (pruning_expr, _) = checked_pruning_expr(&expr, &available_stats).unwrap(); + let pruning_expr = expr.falsify(&SESSION).unwrap().unwrap(); // All three zones have length 4 (total rows = 12). let mask = zone_map.prune(&pruning_expr, &SESSION).unwrap();