Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 118 additions & 25 deletions vortex-array/src/stats/rewrite/builtins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use vortex_error::VortexResult;
use crate::aggregate_fn::AggregateFnRef;
use crate::aggregate_fn::AggregateFnVTableExt;
use crate::aggregate_fn::EmptyOptions as AggregateEmptyOptions;
use crate::aggregate_fn::fns::all_non_nan::AllNonNan;
use crate::aggregate_fn::fns::all_non_null::AllNonNull;
use crate::aggregate_fn::fns::all_null::AllNull;
use crate::dtype::DType;
Expand Down Expand Up @@ -79,16 +80,25 @@ impl StatsRewriteRule for BinaryStatsRewrite {
let left = min(lhs).zip(max(rhs)).map(|(a, b)| gt(a, b));
let right = min(rhs).zip(max(lhs)).map(|(a, b)| gt(a, b));
or_collect(left.into_iter().chain(right))
.map(|value_predicate| with_nan_predicate(lhs, rhs, value_predicate))
}
Operator::NotEq => min(lhs).zip(max(rhs)).zip(max(lhs).zip(min(rhs))).map(
|((min_lhs, max_rhs), (max_lhs, min_rhs))| {
and(eq(min_lhs, max_rhs), eq(max_lhs, min_rhs))
with_nan_predicate(lhs, rhs, and(eq(min_lhs, max_rhs), eq(max_lhs, min_rhs)))
},
),
Operator::Gt => max(lhs).zip(min(rhs)).map(|(a, b)| lt_eq(a, b)),
Operator::Gte => max(lhs).zip(min(rhs)).map(|(a, b)| lt(a, b)),
Operator::Lt => min(lhs).zip(max(rhs)).map(|(a, b)| gt_eq(a, b)),
Operator::Lte => min(lhs).zip(max(rhs)).map(|(a, b)| gt(a, b)),
Operator::Gt => max(lhs)
.zip(min(rhs))
.map(|(a, b)| with_nan_predicate(lhs, rhs, lt_eq(a, b))),
Operator::Gte => max(lhs)
.zip(min(rhs))
.map(|(a, b)| with_nan_predicate(lhs, rhs, lt(a, b))),
Operator::Lt => min(lhs)
.zip(max(rhs))
.map(|(a, b)| with_nan_predicate(lhs, rhs, gt_eq(a, b))),
Operator::Lte => min(lhs)
.zip(max(rhs))
.map(|(a, b)| with_nan_predicate(lhs, rhs, gt(a, b))),
Operator::And => {
let lhs_falsifier = ctx.falsify(lhs)?;
let rhs_falsifier = ctx.falsify(rhs)?;
Expand Down Expand Up @@ -335,7 +345,8 @@ impl StatsRewriteRule for ListContainsStatsRewrite {
lt(value_max.clone(), lit(value.clone())),
gt(value_min.clone(), lit(value.clone())),
)
})))
}))
.map(|value_predicate| with_all_non_nan_predicate([needle], value_predicate)))
}
}

Expand Down Expand Up @@ -379,11 +390,42 @@ fn all_non_null(expr: &Expression) -> Expression {
stat_fn(expr.clone(), AllNonNull.bind(AggregateEmptyOptions))
}

// Min/max do not order NaN values, so comparison rewrites are only sound when every
// candidate value is known to be non-NaN. Literal non-NaN values and casts to
// non-float types are already proven, so they do not need a stat expression.
fn all_non_nan_stat(expr: &Expression) -> Option<Expression> {
if let Some(scalar) = expr.as_opt::<Literal>() {
let value = scalar.as_primitive_opt()?;
return value.is_nan().then(|| lit(false));
}

if let Some(dtype) = expr.as_opt::<Cast>() {
if !has_nans(dtype) {
return None;
}

return all_non_nan_stat(expr.child(0));
}

Some(stat_fn(expr.clone(), AllNonNan.bind(AggregateEmptyOptions)))
}

fn has_nans(dtype: &DType) -> bool {
matches!(dtype, DType::Primitive(ptype, _) if ptype.is_float())
}

fn stat_expr(expr: &Expression, stat: Stat) -> Option<Expression> {
if let Some(literal) = literal_stat(expr, stat) {
return Some(literal);
}

// `literal_stat` handled every stat that is defined for literals. If it returned
// `None`, the requested stat is not meaningful for this literal, such as
// `NaNCount` over a non-float value, so do not manufacture `stat(literal, ...)`.
if expr.is::<Literal>() {
Comment thread
gatesn marked this conversation as resolved.
return None;
}

if let Some(dtype) = expr.as_opt::<Cast>() {
return cast_stat(expr.child(0), dtype, stat);
}
Expand All @@ -392,17 +434,46 @@ fn stat_expr(expr: &Expression, stat: Stat) -> Option<Expression> {
.map(|aggregate_fn| stat_fn(expr.clone(), aggregate_fn))
}

fn with_nan_predicate(
lhs: &Expression,
rhs: &Expression,
value_predicate: Expression,
) -> Expression {
with_all_non_nan_predicate([lhs, rhs], value_predicate)
}

fn with_all_non_nan_predicate<'a>(
Comment thread
gatesn marked this conversation as resolved.
exprs: impl IntoIterator<Item = &'a Expression>,
value_predicate: Expression,
) -> Expression {
let nan_predicate = and_collect(exprs.into_iter().filter_map(all_non_nan_stat));

match nan_predicate {
Some(nan_check) => and(nan_check, value_predicate),
// No possible NaN-bearing expression remains, so the value predicate is
// already guarded.
None => value_predicate,
}
}

fn literal_stat(expr: &Expression, stat: Stat) -> Option<Expression> {
let scalar = expr.as_opt::<Literal>()?;
match stat {
Stat::Min | Stat::Max => Some(lit(scalar.clone())),
Stat::NullCount => Some(lit(if scalar.is_null() { 1u64 } else { 0u64 })),
Stat::NaNCount => {
let value = scalar.as_primitive_opt()?;
if !value.ptype().is_float() {
return None;
}

Some(lit(if value.is_nan() { 1u64 } else { 0u64 }))
}
Stat::IsConstant
| Stat::IsSorted
| Stat::IsStrictSorted
| Stat::Sum
| Stat::UncompressedSizeInBytes
| Stat::NaNCount => None,
| Stat::UncompressedSizeInBytes => None,
}
}

Expand Down Expand Up @@ -431,6 +502,9 @@ mod tests {
use super::all_non_null;
use super::all_null;
use crate::aggregate_fn::AggregateFnRef;
use crate::aggregate_fn::AggregateFnVTableExt;
use crate::aggregate_fn::EmptyOptions as AggregateEmptyOptions;
use crate::aggregate_fn::fns::all_non_nan::AllNonNan;
use crate::dtype::DType;
use crate::dtype::Nullability;
use crate::dtype::PType;
Expand Down Expand Up @@ -471,20 +545,30 @@ mod tests {
StatFn.new_expr(StatOptions::new(aggregate_fn), [expr])
}

fn nan_free(expr: Expression) -> Expression {
stat_fn(expr, AllNonNan.bind(AggregateEmptyOptions))
}

#[test]
fn rewrites_comparison_falsifier() -> VortexResult<()> {
let expr = gt(col("a"), lit(10));
assert_eq!(
expr.falsify(&SESSION)?,
Some(lt_eq(stat(col("a"), Stat::Max), lit(10)))
Some(and(
nan_free(col("a")),
lt_eq(stat(col("a"), Stat::Max), lit(10)),
))
);

let expr = eq(col("a"), col("b"));
assert_eq!(
expr.falsify(&SESSION)?,
Some(or(
gt(stat(col("a"), Stat::Min), stat(col("b"), Stat::Max)),
gt(stat(col("b"), Stat::Min), stat(col("a"), Stat::Max)),
Some(and(
and(nan_free(col("a")), nan_free(col("b"))),
or(
gt(stat(col("a"), Stat::Min), stat(col("b"), Stat::Max)),
gt(stat(col("b"), Stat::Min), stat(col("a"), Stat::Max)),
),
))
);
Ok(())
Expand All @@ -496,8 +580,14 @@ mod tests {
assert_eq!(
expr.falsify(&SESSION)?,
Some(or(
lt_eq(stat(col("a"), Stat::Max), lit(10)),
gt_eq(stat(col("a"), Stat::Min), lit(50)),
and(
nan_free(col("a")),
lt_eq(stat(col("a"), Stat::Max), lit(10)),
),
and(
nan_free(col("a")),
gt_eq(stat(col("a"), Stat::Min), lit(50)),
),
))
);
Ok(())
Expand All @@ -518,8 +608,8 @@ mod tests {
assert_eq!(
expr.falsify(&SESSION)?,
Some(or(
gt(lit(10), stat(col("a"), Stat::Max)),
gt(stat(col("a"), Stat::Min), lit(50)),
and(nan_free(col("a")), gt(lit(10), stat(col("a"), Stat::Max)),),
and(nan_free(col("a")), gt(stat(col("a"), Stat::Min), lit(50)),),
))
);
Ok(())
Expand Down Expand Up @@ -597,20 +687,23 @@ mod tests {
assert_eq!(
expr.falsify(&SESSION)?,
Some(and(
nan_free(col("a")),
and(
or(
lt(stat(col("a"), Stat::Max), lit(1i32)),
gt(stat(col("a"), Stat::Min), lit(1i32)),
and(
or(
lt(stat(col("a"), Stat::Max), lit(1i32)),
gt(stat(col("a"), Stat::Min), lit(1i32)),
),
or(
lt(stat(col("a"), Stat::Max), lit(2i32)),
gt(stat(col("a"), Stat::Min), lit(2i32)),
),
),
or(
lt(stat(col("a"), Stat::Max), lit(2i32)),
gt(stat(col("a"), Stat::Min), lit(2i32)),
lt(stat(col("a"), Stat::Max), lit(3i32)),
gt(stat(col("a"), Stat::Min), lit(3i32)),
),
),
or(
lt(stat(col("a"), Stat::Max), lit(3i32)),
gt(stat(col("a"), Stat::Min), lit(3i32)),
),
))
);
Ok(())
Expand Down
2 changes: 0 additions & 2 deletions vortex-layout/public-api.lock
Original file line number Diff line number Diff line change
Expand Up @@ -798,8 +798,6 @@ impl vortex_layout::layouts::zoned::zone_map::ZoneMap

pub fn vortex_layout::layouts::zoned::zone_map::ZoneMap::dtype_for_stats_table(&vortex_array::dtype::DType, &[vortex_array::expr::stats::Stat]) -> vortex_array::dtype::DType

pub unsafe fn vortex_layout::layouts::zoned::zone_map::ZoneMap::new_unchecked(vortex_array::arrays::struct_::vtable::StructArray, u64, u64) -> Self

pub fn vortex_layout::layouts::zoned::zone_map::ZoneMap::prune(&self, &vortex_array::expr::expression::Expression, &vortex_session::VortexSession) -> vortex_error::VortexResult<vortex_mask::Mask>

pub fn vortex_layout::layouts::zoned::zone_map::ZoneMap::try_new(vortex_array::dtype::DType, vortex_array::arrays::struct_::vtable::StructArray, alloc::sync::Arc<[vortex_array::expr::stats::Stat]>, u64, u64) -> vortex_error::VortexResult<Self>
Expand Down
25 changes: 11 additions & 14 deletions vortex-layout/src/layouts/zoned/pruning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,9 @@ use tracing::trace;
use vortex_array::MaskFuture;
use vortex_array::VortexSessionExecute;
use vortex_array::arrays::StructArray;
use vortex_array::dtype::FieldPath;
use vortex_array::dtype::FieldPathSet;
use vortex_array::dtype::DType;
use vortex_array::expr::Expression;
use vortex_array::expr::pruning::checked_pruning_expr;
use vortex_array::expr::root;
use vortex_array::expr::stats::Stat;
use vortex_array::scalar_fn::fns::dynamic::DynamicExprUpdates;
use vortex_error::SharedVortexResult;
use vortex_error::VortexExpect;
Expand All @@ -43,7 +40,7 @@ pub(super) struct PruningState {
zone_count: usize,
row_count: u64,
zone_len: u64,
present_stats: Arc<[Stat]>,
dtype: DType,
lazy_children: Arc<LazyReaderChildren>,
session: VortexSession,

Expand All @@ -62,7 +59,7 @@ impl PruningState {
zone_count: layout.nzones(),
row_count: layout.row_count(),
zone_len: layout.zone_len() as u64,
present_stats: Arc::clone(layout.present_stats()),
dtype: layout.dtype().clone(),
lazy_children,
session,
pruning_result: Default::default(),
Expand Down Expand Up @@ -119,13 +116,12 @@ impl PruningState {
self.pruning_predicates
.entry(expr.clone())
.or_default()
.get_or_init(move || {
let available_stats = FieldPathSet::from_iter(
self.present_stats
.iter()
.map(|stat| FieldPath::from_name(stat.name())),
);
checked_pruning_expr(&expr, &available_stats).map(|(expr, _)| expr)
.get_or_init(move || match expr.falsify(&self.session) {
Ok(predicate) => predicate,
Err(error) => {
trace!(%expr, %error, "failed to construct stats rewrite predicate");
None
}
})
.clone()
}
Expand All @@ -147,13 +143,14 @@ impl PruningState {
let session = self.session.clone();
let zone_len = self.zone_len;
let row_count = self.row_count;
let dtype = self.dtype.clone();

async move {
let mut ctx = session.create_execution_ctx();
let zones_array = zones_eval.await?.execute::<StructArray>(&mut ctx)?;
// SAFETY: zoned layout validation ensures the zones child matches the expected
// stats-table schema for `present_stats`.
Ok(unsafe { ZoneMap::new_unchecked(zones_array, zone_len, row_count) })
Ok(unsafe { ZoneMap::new_unchecked(dtype, zones_array, zone_len, row_count) })
}
.map_err(Arc::new)
.boxed()
Expand Down
Loading
Loading