diff --git a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java index 9b26c6a2d..3ca5093dd 100644 --- a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java +++ b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java @@ -1,6 +1,7 @@ package io.substrait.dsl; import io.substrait.expression.AggregateFunctionInvocation; +import io.substrait.expression.EnumArg; import io.substrait.expression.Expression; import io.substrait.expression.Expression.Cast; import io.substrait.expression.Expression.FailureBehavior; @@ -13,6 +14,7 @@ import io.substrait.expression.FieldReference; import io.substrait.expression.FunctionArg; import io.substrait.expression.FunctionOption; +import io.substrait.expression.StatisticalDistribution; import io.substrait.expression.WindowBound; import io.substrait.extension.DefaultExtensionCatalog; import io.substrait.extension.SimpleExtension; @@ -1361,6 +1363,193 @@ public Aggregate.Measure sum0(Expression expr) { R.I64); } + /** + * Creates a population standard deviation aggregate measure for a specific field. + * + *

Computes the standard deviation using the population formula (n denominator), which + * considers all values in the dataset as the entire population. This is equivalent to SQL's + * STDDEV_POP function. + * + * @param input the input relation containing the field + * @param field the zero-based index of the field to aggregate + * @return an aggregate measure computing population standard deviation with + * distribution=POPULATION enum argument + */ + public Aggregate.Measure stddevPopulation(Rel input, int field) { + return stddevPopulation(fieldReference(input, field)); + } + + /** + * Creates a population standard deviation aggregate measure for an expression. + * + *

Computes the standard deviation using the population formula (n denominator), which + * considers all values in the dataset as the entire population. This is equivalent to SQL's + * STDDEV_POP function. + * + *

The measure is created with: + * + *

+ * + * @param expr the expression to aggregate (typically a numeric field reference) + * @return an aggregate measure computing population standard deviation + */ + public Aggregate.Measure stddevPopulation(Expression expr) { + return statisticalAggregate(expr, "std_dev", StatisticalDistribution.POPULATION); + } + + /** + * Creates a sample standard deviation aggregate measure for a specific field. + * + *

Computes the standard deviation using the sample formula (n-1 denominator), which applies + * Bessel's correction for sample data. This is equivalent to SQL's STDDEV_SAMP or STDDEV + * function. + * + * @param input the input relation containing the field + * @param field the zero-based index of the field to aggregate + * @return an aggregate measure computing sample standard deviation with distribution=SAMPLE enum + * argument + */ + public Aggregate.Measure stddevSample(Rel input, int field) { + return stddevSample(fieldReference(input, field)); + } + + /** + * Creates a sample standard deviation aggregate measure for an expression. + * + *

Computes the standard deviation using the sample formula (n-1 denominator), which applies + * Bessel's correction for sample data. This is equivalent to SQL's STDDEV_SAMP or STDDEV + * function. + * + *

The measure is created with: + * + *

+ * + * @param expr the expression to aggregate (typically a numeric field reference) + * @return an aggregate measure computing sample standard deviation + */ + public Aggregate.Measure stddevSample(Expression expr) { + return statisticalAggregate(expr, "std_dev", StatisticalDistribution.SAMPLE); + } + + /** + * Creates a population variance aggregate measure for a specific field. + * + *

Computes the variance using the population formula (n denominator), which considers all + * values in the dataset as the entire population. This is equivalent to SQL's VAR_POP function. + * + * @param input the input relation containing the field + * @param field the zero-based index of the field to aggregate + * @return an aggregate measure computing population variance with distribution=POPULATION enum + * argument + */ + public Aggregate.Measure variancePopulation(Rel input, int field) { + return variancePopulation(fieldReference(input, field)); + } + + /** + * Creates a population variance aggregate measure for an expression. + * + *

Computes the variance using the population formula (n denominator), which considers all + * values in the dataset as the entire population. This is equivalent to SQL's VAR_POP function. + * + *

The measure is created with: + * + *

+ * + * @param expr the expression to aggregate (typically a numeric field reference) + * @return an aggregate measure computing population variance + */ + public Aggregate.Measure variancePopulation(Expression expr) { + return statisticalAggregate(expr, "variance", StatisticalDistribution.POPULATION); + } + + /** + * Creates a sample variance aggregate measure for a specific field. + * + *

Computes the variance using the sample formula (n-1 denominator), which applies Bessel's + * correction for sample data. This is equivalent to SQL's VAR_SAMP or VARIANCE function. + * + * @param input the input relation containing the field + * @param field the zero-based index of the field to aggregate + * @return an aggregate measure computing sample variance with distribution=SAMPLE enum argument + */ + public Aggregate.Measure varianceSample(Rel input, int field) { + return varianceSample(fieldReference(input, field)); + } + + /** + * Creates a sample variance aggregate measure for an expression. + * + *

Computes the variance using the sample formula (n-1 denominator), which applies Bessel's + * correction for sample data. This is equivalent to SQL's VAR_SAMP or VARIANCE function. + * + *

The measure is created with: + * + *

+ * + * @param expr the expression to aggregate (typically a numeric field reference) + * @return an aggregate measure computing sample variance + */ + public Aggregate.Measure varianceSample(Expression expr) { + return statisticalAggregate(expr, "variance", StatisticalDistribution.SAMPLE); + } + + /** + * Helper method to create statistical aggregate measures (std_dev, variance) with a {@code + * distribution} enum argument. + * + *

Uses the non-deprecated function signatures that carry the population/sample distinction as + * a leading {@code distribution} {@link EnumArg} (e.g. {@code std_dev:req_fp64}). + * + * @param expr the expression to aggregate + * @param functionName the Substrait function name ("std_dev" or "variance") + * @param distribution the distribution type (SAMPLE or POPULATION) + * @return an aggregate measure with the specified distribution argument + */ + private Aggregate.Measure statisticalAggregate( + Expression expr, String functionName, StatisticalDistribution distribution) { + String typeString = ToTypeString.apply(expr.getType()); + SimpleExtension.AggregateFunctionVariant declaration = + extensions.getAggregateFunction( + SimpleExtension.FunctionAnchor.of( + DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, + String.format("%s:req_%s", functionName, typeString))); + EnumArg distributionArg = + EnumArg.of((SimpleExtension.EnumArgument) declaration.args().get(0), distribution.name()); + return measure( + AggregateFunctionInvocation.builder() + .arguments(Arrays.asList(distributionArg, expr)) + .outputType(TypeCreator.asNullable(expr.getType())) + .declaration(declaration) + .aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT) + .invocation(Expression.AggregationInvocation.ALL) + .build()); + } + private Aggregate.Measure singleArgumentArithmeticAggregate( Expression expr, String functionName, Type outputType) { String typeString = ToTypeString.apply(expr.getType()); diff --git a/core/src/main/java/io/substrait/expression/StatisticalDistribution.java b/core/src/main/java/io/substrait/expression/StatisticalDistribution.java new file mode 100644 index 000000000..62d223024 --- /dev/null +++ b/core/src/main/java/io/substrait/expression/StatisticalDistribution.java @@ -0,0 +1,17 @@ +package io.substrait.expression; + +/** + * The {@code distribution} enum argument of the Substrait {@code std_dev} and {@code variance} + * aggregate functions. + * + *

Distinguishes between the sample (n-1 denominator, Bessel's correction) and population (n + * denominator) variants. The enum constant names match the Substrait extension's enum option names + * ({@code SAMPLE} / {@code POPULATION}), so {@link #name()} yields the value used to build an + * {@link EnumArg}. + */ +public enum StatisticalDistribution { + /** Sample distribution (uses the n-1 denominator, Bessel's correction). */ + SAMPLE, + /** Population distribution (uses the n denominator). */ + POPULATION +} diff --git a/isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java b/isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java index 0d5d5bf0e..4d7ea1bb9 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java +++ b/isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java @@ -29,6 +29,30 @@ public class AggregateFunctions { /** Substrait-specific AVG aggregate function (nullable return type). */ public static SqlAggFunction AVG = new SubstraitAvgAggFunction(SqlKind.AVG); + /** + * Standard deviation (population) aggregate function. Maps to Substrait's std_dev function with + * distribution=POPULATION enum argument. + */ + public static SqlAggFunction STDDEV_POP = new SubstraitAvgAggFunction(SqlKind.STDDEV_POP); + + /** + * Standard deviation (sample) aggregate function. Maps to Substrait's std_dev function with + * distribution=SAMPLE enum argument. + */ + public static SqlAggFunction STDDEV_SAMP = new SubstraitAvgAggFunction(SqlKind.STDDEV_SAMP); + + /** + * Variance (population) aggregate function. Maps to Substrait's variance function with + * distribution=POPULATION enum argument. + */ + public static SqlAggFunction VAR_POP = new SubstraitAvgAggFunction(SqlKind.VAR_POP); + + /** + * Variance (sample) aggregate function. Maps to Substrait's variance function with + * distribution=SAMPLE enum argument. + */ + public static SqlAggFunction VAR_SAMP = new SubstraitAvgAggFunction(SqlKind.VAR_SAMP); + /** Substrait-specific SUM aggregate function (nullable return type). */ public static SqlAggFunction SUM = new SubstraitSumAggFunction(); @@ -42,18 +66,34 @@ public class AggregateFunctions { * @return optional containing Substrait equivalent if conversion applies */ public static Optional toSubstraitAggVariant(SqlAggFunction aggFunction) { - if (aggFunction instanceof SqlMinMaxAggFunction) { - SqlMinMaxAggFunction fun = (SqlMinMaxAggFunction) aggFunction; - return Optional.of( - fun.getKind() == SqlKind.MIN ? AggregateFunctions.MIN : AggregateFunctions.MAX); - } else if (aggFunction instanceof SqlAvgAggFunction) { - return Optional.of(AggregateFunctions.AVG); - } else if (aggFunction instanceof SqlSumAggFunction) { - return Optional.of(AggregateFunctions.SUM); - } else if (aggFunction instanceof SqlSumEmptyIsZeroAggFunction) { - return Optional.of(AggregateFunctions.SUM0); - } else { - return Optional.empty(); + // First check by SqlKind to handle all statistical functions + SqlKind kind = aggFunction.getKind(); + switch (kind) { + case MIN: + return Optional.of(AggregateFunctions.MIN); + case MAX: + return Optional.of(AggregateFunctions.MAX); + case AVG: + return Optional.of(AggregateFunctions.AVG); + case STDDEV_POP: + return Optional.of(AggregateFunctions.STDDEV_POP); + case STDDEV_SAMP: + return Optional.of(AggregateFunctions.STDDEV_SAMP); + case VAR_POP: + return Optional.of(AggregateFunctions.VAR_POP); + case VAR_SAMP: + return Optional.of(AggregateFunctions.VAR_SAMP); + case SUM: + case SUM0: + // Check instance type for SUM variants + if (aggFunction instanceof SqlSumEmptyIsZeroAggFunction) { + return Optional.of(AggregateFunctions.SUM0); + } else if (aggFunction instanceof SqlSumAggFunction) { + return Optional.of(AggregateFunctions.SUM); + } + return Optional.empty(); + default: + return Optional.empty(); } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/PreCalciteAggregateValidator.java b/isthmus/src/main/java/io/substrait/isthmus/PreCalciteAggregateValidator.java index 705e02e3c..f8695db64 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/PreCalciteAggregateValidator.java +++ b/isthmus/src/main/java/io/substrait/isthmus/PreCalciteAggregateValidator.java @@ -46,8 +46,11 @@ public static boolean isValidCalciteAggregate(Aggregate aggregate) { */ private static boolean isValidCalciteMeasure(Aggregate.Measure measure) { return - // all function arguments to measures must be field references - measure.getFunction().arguments().stream().allMatch(farg -> isSimpleFieldReference(farg)) + // all value (Expression) function arguments to measures must be field references; non-value + // arguments such as the std_dev/variance "distribution" enum argument are exempt + measure.getFunction().arguments().stream() + .filter(farg -> farg instanceof Expression) + .allMatch(farg -> isSimpleFieldReference(farg)) && // all sort fields must be field references measure.getFunction().sort().stream().allMatch(sf -> isSimpleFieldReference(sf.expr())) @@ -157,9 +160,9 @@ public static Aggregate transformToValidCalciteAggregate(Aggregate aggregate) { private Aggregate.Measure updateMeasure(Aggregate.Measure measure) { AggregateFunctionInvocation oldAggregateFunctionInvocation = measure.getFunction(); - List newFunctionArgs = + List newFunctionArgs = oldAggregateFunctionInvocation.arguments().stream() - .map(this::projectOutNonFieldReference) + .map(this::projectOutNonFieldReferenceArg) .collect(Collectors.toList()); List newSortFields = @@ -194,11 +197,13 @@ private Aggregate.Grouping updateGrouping(Aggregate.Grouping grouping) { return Aggregate.Grouping.builder().expressions(newGroupingExpressions).build(); } - private Expression projectOutNonFieldReference(FunctionArg farg) { + private FunctionArg projectOutNonFieldReferenceArg(FunctionArg farg) { if ((farg instanceof Expression)) { return projectOutNonFieldReference((Expression) farg); } else { - throw new IllegalArgumentException("cannot handle non-expression argument for aggregate"); + // Non-value arguments (e.g. the std_dev/variance "distribution" enum argument) are not + // field references and are passed through unchanged. + return farg; } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java index 52849f96f..2d42b9f35 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java @@ -384,8 +384,11 @@ public RelNode visit(Aggregate aggregate, Context context) throws RuntimeExcepti private AggregateCall fromMeasure(Aggregate.Measure measure, Context context) { List eArgs = measure.getFunction().arguments(); + // Only value (Expression) arguments map to Calcite aggregate operands. Enum arguments such as + // the std_dev/variance "distribution" are used to disambiguate the operator, not as operands. List arguments = - IntStream.range(0, measure.getFunction().arguments().size()) + IntStream.range(0, eArgs.size()) + .filter(i -> eArgs.get(i) instanceof Expression) .mapToObj( i -> eArgs @@ -398,7 +401,9 @@ private AggregateCall fromMeasure(Aggregate.Measure measure, Context context) { .collect(java.util.stream.Collectors.toList()); Optional operator = aggregateFunctionConverter.getSqlOperatorFromSubstraitFunc( - measure.getFunction().declaration().key(), measure.getFunction().outputType()); + measure.getFunction().declaration().key(), + measure.getFunction().outputType(), + measure.getFunction().arguments()); if (!operator.isPresent()) { throw new IllegalArgumentException( String.format( diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java index bc4d0a9f2..94b41964f 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java @@ -36,6 +36,7 @@ import io.substrait.relation.VirtualTableScan; import io.substrait.type.NamedStruct; import io.substrait.type.Type; +import io.substrait.type.TypeCreator; import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -52,11 +53,14 @@ import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.JoinRelType; import org.apache.calcite.rel.core.TableModify; +import org.apache.calcite.rel.logical.LogicalProject; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexFieldAccess; import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.util.ImmutableBitSet; import org.immutables.value.Value; @@ -84,6 +88,9 @@ public class SubstraitRelVisitor extends RelNodeVisitor { private Map fieldAccessDepthMap; + /** Rex builder for creating Rex expressions during conversion. */ + protected RexBuilder rexBuilder; + /** * Creates a new SubstraitRelVisitor with the specified type factory and extensions. * @@ -106,6 +113,7 @@ public SubstraitRelVisitor(ConverterProvider converterProvider) { this.typeConverter = converterProvider.getTypeConverter(); this.aggregateFunctionConverter = converterProvider.getAggregateFunctionConverter(); this.rexExpressionConverter = converterProvider.getRexExpressionConverter(this); + this.rexBuilder = new RexBuilder(converterProvider.getTypeFactory()); } /** @@ -331,6 +339,16 @@ public Rel visit(org.apache.calcite.rel.core.Minus minus) { */ @Override public Rel visit(org.apache.calcite.rel.core.Aggregate aggregate) { + // Substrait's std_dev/variance functions only define fp32/fp64 signatures. If a statistical + // aggregate has a non-floating-point argument, rewrite the aggregate to cast that argument to + // fp64 and cast the result back to the type Calcite inferred, then convert the rewritten plan + // through the normal path. The rewrite is idempotent (fp32/fp64 arguments are left untouched), + // so it terminates when the converted plan is re-converted. + RelNode rewritten = castStatisticalAggregatesToFloatingPoint(aggregate); + if (rewritten != aggregate) { + return apply(rewritten); + } + Rel input = apply(aggregate.getInput()); Stream sets; if (aggregate.groupSets != null) { @@ -411,6 +429,28 @@ Aggregate.Grouping fromGroupSet(ImmutableBitSet bitSet, Rel input) { return Aggregate.Grouping.builder().addAllExpressions(references).build(); } + /** + * Converts a Calcite {@link AggregateCall} to a Substrait {@link Aggregate.Measure}. + * + *

This method handles the conversion of aggregate function calls from Calcite's representation + * to Substrait's format. For statistical aggregate functions (STDDEV_POP, STDDEV_SAMP, VAR_POP, + * VAR_SAMP), it automatically transforms the input relation by inserting a projection that casts + * the aggregate function's argument fields to DOUBLE (FP64) type, ensuring type compatibility + * with Substrait's statistical function requirements. Fields not referenced by the aggregate + * function are passed through unchanged. + * + *

The method also processes optional filter arguments (FILTER clauses) by converting them to + * Substrait's preMeasureFilter representation. + * + * @param input the input relational node providing data to the aggregate operation + * @param inputType the Substrait struct type representing the schema of the input relation + * @param call the Calcite aggregate call to convert, containing the aggregate function, + * arguments, and optional filter + * @return a Substrait {@link Aggregate.Measure} representing the aggregate function invocation + * with its configuration + * @throws UnsupportedOperationException if the aggregate function cannot be converted to a + * Substrait representation (no matching function binding found) + */ Aggregate.Measure fromAggCall(RelNode input, Type.Struct inputType, AggregateCall call) { Optional invocation = aggregateFunctionConverter.convert( @@ -427,6 +467,141 @@ Aggregate.Measure fromAggCall(RelNode input, Type.Struct inputType, AggregateCal return builder.build(); } + private static boolean isStatisticalDistributionAggregate(SqlKind kind) { + return kind == SqlKind.STDDEV_POP + || kind == SqlKind.STDDEV_SAMP + || kind == SqlKind.VAR_POP + || kind == SqlKind.VAR_SAMP; + } + + private boolean isFloatingPoint(RelDataType type) { + Type substraitType = typeConverter.toSubstrait(type); + return TypeCreator.NULLABLE.FP32.equalsIgnoringNullability(substraitType) + || TypeCreator.NULLABLE.FP64.equalsIgnoringNullability(substraitType); + } + + /** + * Rewrites a Calcite aggregate so that statistical aggregate functions (STDDEV_POP, STDDEV_SAMP, + * VAR_POP, VAR_SAMP) with non-floating-point arguments operate on fp64, since Substrait's {@code + * std_dev} / {@code variance} functions only define fp32 and fp64 signatures. + * + *

For each statistical aggregate whose single argument is neither fp32 nor fp64 (e.g. an + * integer or decimal column), the rewrite: + * + *

    + *
  1. appends a {@code cast(arg AS fp64)} column to the aggregate's input (leaving the original + * column in place, so other aggregates over the same column are unaffected), + *
  2. re-points the statistical aggregate at the appended column (its return type is re-derived + * over fp64), and + *
  3. casts the aggregate's results back to the types Calcite originally inferred, via a + * projection on top, so the aggregate's output row type is preserved. + *
+ * + *

The rewrite is idempotent: fp32/fp64 arguments are left untouched, so converting the + * rewritten plan (whose statistical arguments are already fp64) produces no further rewrite and + * the recursion in {@link #visit(org.apache.calcite.rel.core.Aggregate)} terminates. If no + * argument needs casting, the aggregate is returned unchanged. + * + * @param aggregate the Calcite aggregate to inspect + * @return {@code aggregate} unchanged, or a {@link LogicalProject} wrapping a rewritten aggregate + */ + protected RelNode castStatisticalAggregatesToFloatingPoint( + org.apache.calcite.rel.core.Aggregate aggregate) { + RelNode input = aggregate.getInput(); + List calls = aggregate.getAggCallList(); + int inputFieldCount = input.getRowType().getFieldCount(); + + // fp64 cast expressions to append to the input, and the source field each one casts (for reuse) + List appendedCasts = new ArrayList<>(); + List appendedSourceFields = new ArrayList<>(); + // per call: the appended column its argument should be re-pointed at, or -1 if unchanged + List rewrittenArgColumns = new ArrayList<>(calls.size()); + + for (AggregateCall call : calls) { + int rewrittenArgColumn = -1; + if (isStatisticalDistributionAggregate(call.getAggregation().getKind()) + && call.getArgList().size() == 1) { + int argIndex = call.getArgList().get(0); + RelDataType argType = input.getRowType().getFieldList().get(argIndex).getType(); + if (!isFloatingPoint(argType)) { + int existing = appendedSourceFields.indexOf(argIndex); + if (existing >= 0) { + rewrittenArgColumn = inputFieldCount + existing; + } else { + RelDataType fp64 = + typeConverter.toCalcite( + rexBuilder.getTypeFactory(), Type.withNullability(argType.isNullable()).FP64); + appendedCasts.add(rexBuilder.makeCast(fp64, rexBuilder.makeInputRef(input, argIndex))); + appendedSourceFields.add(argIndex); + rewrittenArgColumn = inputFieldCount + appendedCasts.size() - 1; + } + } + } + rewrittenArgColumns.add(rewrittenArgColumn); + } + + if (appendedCasts.isEmpty()) { + return aggregate; + } + + // Extended input: all original columns (passthrough) followed by the appended fp64 casts. + List inputProjects = new ArrayList<>(inputFieldCount + appendedCasts.size()); + for (int i = 0; i < inputFieldCount; i++) { + inputProjects.add(rexBuilder.makeInputRef(input, i)); + } + inputProjects.addAll(appendedCasts); + RelNode extendedInput = + LogicalProject.create(input, Collections.emptyList(), inputProjects, (List) null); + + // Re-point the statistical calls at the appended fp64 columns (return type re-derived); leave + // all other calls unchanged. + List rewrittenCalls = new ArrayList<>(calls.size()); + for (int i = 0; i < calls.size(); i++) { + AggregateCall call = calls.get(i); + int rewrittenArgColumn = rewrittenArgColumns.get(i); + if (rewrittenArgColumn < 0) { + rewrittenCalls.add(call); + } else { + rewrittenCalls.add( + AggregateCall.create( + call.getAggregation(), + call.isDistinct(), + call.isApproximate(), + call.ignoreNulls(), + Collections.singletonList(rewrittenArgColumn), + call.filterArg, + call.distinctKeys, + call.getCollation(), + aggregate.getGroupCount(), + extendedInput, + /* type, null to re-derive over fp64 */ null, + call.getName())); + } + } + + org.apache.calcite.rel.core.Aggregate rewrittenAggregate = + aggregate.copy( + aggregate.getTraitSet(), + extendedInput, + aggregate.getGroupSet(), + aggregate.getGroupSets(), + rewrittenCalls); + + // Cast the (now fp64) statistical results back to the types Calcite originally inferred, + // preserving the aggregate's original output row type. Group keys and unaffected measures pass + // through unchanged. + RelDataType originalRowType = aggregate.getRowType(); + List outputProjects = new ArrayList<>(originalRowType.getFieldCount()); + for (int i = 0; i < originalRowType.getFieldCount(); i++) { + RelDataType targetType = originalRowType.getFieldList().get(i).getType(); + RexNode ref = rexBuilder.makeInputRef(rewrittenAggregate, i); + outputProjects.add( + ref.getType().equals(targetType) ? ref : rexBuilder.makeCast(targetType, ref)); + } + return LogicalProject.create( + rewrittenAggregate, Collections.emptyList(), outputProjects, originalRowType); + } + /** * Converts a Calcite {@link org.apache.calcite.rel.core.Match}. * diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/AggregateFunctionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/AggregateFunctionConverter.java index 6f8918179..0b9b7e63d 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/AggregateFunctionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/AggregateFunctionConverter.java @@ -5,11 +5,13 @@ import io.substrait.expression.Expression; import io.substrait.expression.ExpressionCreator; import io.substrait.expression.FunctionArg; +import io.substrait.expression.StatisticalDistribution; import io.substrait.extension.SimpleExtension; import io.substrait.isthmus.AggregateFunctions; import io.substrait.isthmus.SubstraitRelVisitor; import io.substrait.isthmus.TypeConverter; import io.substrait.type.Type; +import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Optional; @@ -22,6 +24,7 @@ import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexNode; import org.apache.calcite.sql.SqlAggFunction; +import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.fun.SqlStdOperatorTable; /** @@ -76,9 +79,19 @@ public AggregateFunctionConverter( /** * Builds a Substrait aggregate invocation from the matched call and arguments. * - * @param call wrapped aggregate call - * @param function matched Substrait function variant - * @param arguments converted arguments + *

This method constructs an {@link AggregateFunctionInvocation} with appropriate configuration + * including sort fields and invocation type (DISTINCT or ALL). + * + *

Statistical Functions: For standard deviation and variance functions (STDDEV_POP, + * STDDEV_SAMP, VAR_POP, VAR_SAMP), the population/sample distinction is carried by a leading + * {@code distribution} {@link io.substrait.expression.EnumArg} argument. That argument is + * synthesized as an operand in {@link #convert} so that the generic matcher resolves the enum-arg + * function variant ({@code std_dev:req_*} / {@code variance:req_*}) and constructs the {@link + * io.substrait.expression.EnumArg}; no special handling is required here. + * + * @param call wrapped aggregate call containing the Calcite aggregate information + * @param function matched Substrait function variant from the extension catalog + * @param arguments converted function arguments * @param outputType result type of the invocation * @return aggregate function invocation */ @@ -100,6 +113,7 @@ protected AggregateFunctionInvocation generateBinding( agg.isDistinct() ? Expression.AggregationInvocation.DISTINCT : Expression.AggregationInvocation.ALL; + return ExpressionCreator.aggregateFunction( function, outputType, @@ -128,14 +142,50 @@ public Optional convert( if (m == null) { return Optional.empty(); } - if (!m.allowedArgCount(call.getArgList().size())) { + + // For statistical aggregates (std_dev/variance) the SAMPLE/POPULATION distinction is carried + // by a leading "distribution" enum argument. Synthesize it as an operand so the generic matcher + // resolves the enum-arg function variant and builds the EnumArg. + List leadingArgs = leadingEnumArgs(call); + if (!m.allowedArgCount(call.getArgList().size() + leadingArgs.size())) { return Optional.empty(); } - WrappedAggregateCall wrapped = new WrappedAggregateCall(call, input, rexBuilder, inputType); + WrappedAggregateCall wrapped = + new WrappedAggregateCall(call, leadingArgs, input, rexBuilder, inputType); return m.attemptMatch(wrapped, topLevelConverter); } + /** + * Computes the synthetic leading operands to prepend to a Calcite aggregate call before matching. + * + *

For standard deviation and variance functions, Substrait carries the population/sample + * distinction as a leading {@code distribution} enum argument, whereas Calcite encodes it in the + * operator's {@link SqlKind}. This returns the matching {@link StatisticalDistribution} flag so + * the generic matcher selects the {@code std_dev:req_*} / {@code variance:req_*} variant and + * constructs the corresponding {@link io.substrait.expression.EnumArg}. + * + * @param call the Calcite aggregate call + * @return the leading enum operands (a single distribution flag for statistical functions, empty + * otherwise) + */ + private List leadingEnumArgs(AggregateCall call) { + List leadingArgs = new ArrayList<>(); + switch (call.getAggregation().getKind()) { + case STDDEV_SAMP: + case VAR_SAMP: + leadingArgs.add(rexBuilder.makeFlag(StatisticalDistribution.SAMPLE)); + break; + case STDDEV_POP: + case VAR_POP: + leadingArgs.add(rexBuilder.makeFlag(StatisticalDistribution.POPULATION)); + break; + default: + break; + } + return leadingArgs; + } + /** * Resolves the appropriate function finder, applying Substrait-specific variants when needed. * @@ -160,6 +210,7 @@ protected FunctionFinder getFunctionFinder(AggregateCall call) { /** Lightweight wrapper around {@link AggregateCall} providing operands and type access. */ static class WrappedAggregateCall implements FunctionConverter.GenericCall { private final AggregateCall call; + private final List leadingArgs; private final RelNode input; private final RexBuilder rexBuilder; private final Type.Struct inputType; @@ -168,26 +219,36 @@ static class WrappedAggregateCall implements FunctionConverter.GenericCall { * Creates a new wrapped aggregate call. * * @param call underlying Calcite aggregate call + * @param leadingArgs synthetic operands (e.g. a {@code distribution} enum flag) prepended ahead + * of the field arguments during matching * @param input input relational node * @param rexBuilder Rex builder for operand construction * @param inputType Substrait input struct type */ private WrappedAggregateCall( - AggregateCall call, RelNode input, RexBuilder rexBuilder, Type.Struct inputType) { + AggregateCall call, + List leadingArgs, + RelNode input, + RexBuilder rexBuilder, + Type.Struct inputType) { this.call = call; + this.leadingArgs = leadingArgs; this.input = input; this.rexBuilder = rexBuilder; this.inputType = inputType; } /** - * Returns operands as input references over the argument list. + * Returns operands as the synthetic leading operands followed by input references over the + * argument list. * * @return stream of RexNode operands */ @Override public Stream getOperands() { - return call.getArgList().stream().map(r -> rexBuilder.makeInputRef(input, r)); + return Stream.concat( + leadingArgs.stream(), + call.getArgList().stream().map(r -> rexBuilder.makeInputRef(input, r))); } /** diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/EnumConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/EnumConverter.java index 2f95ca276..2fef9f683 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/EnumConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/EnumConverter.java @@ -1,6 +1,7 @@ package io.substrait.isthmus.expression; import io.substrait.expression.EnumArg; +import io.substrait.expression.StatisticalDistribution; import io.substrait.extension.DefaultExtensionCatalog; import io.substrait.extension.SimpleExtension; import io.substrait.extension.SimpleExtension.Argument; @@ -78,6 +79,20 @@ public class EnumConverter { calciteEnumMap.put( argAnchor(DefaultExtensionCatalog.FUNCTIONS_DATETIME, "extract:req_req_date", 1), ExtractIndexing.class); + + // std_dev and variance carry the SAMPLE/POPULATION distinction as a leading enum argument. + calciteEnumMap.put( + argAnchor(DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, "std_dev:req_fp32", 0), + StatisticalDistribution.class); + calciteEnumMap.put( + argAnchor(DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, "std_dev:req_fp64", 0), + StatisticalDistribution.class); + calciteEnumMap.put( + argAnchor(DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, "variance:req_fp32", 0), + StatisticalDistribution.class); + calciteEnumMap.put( + argAnchor(DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, "variance:req_fp64", 0), + StatisticalDistribution.class); } private static Optional> constructValue( @@ -90,6 +105,10 @@ private static Optional> constructValue( return option.get().map(SqlTrimFunction.Flag::valueOf); } + if (cls.isAssignableFrom(StatisticalDistribution.class)) { + return option.get().map(StatisticalDistribution::valueOf); + } + // ExtractIndexing does not need to be converted here. Calcite // doesn't have the concept of the indexing. It's date // functions are all indexed from 1 diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java index 34655ccb5..3896fa897 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionConverter.java @@ -7,9 +7,11 @@ import com.google.common.collect.Multimap; import com.google.common.collect.Multimaps; import com.google.common.collect.Streams; +import io.substrait.expression.EnumArg; import io.substrait.expression.Expression; import io.substrait.expression.ExpressionCreator; import io.substrait.expression.FunctionArg; +import io.substrait.expression.StatisticalDistribution; import io.substrait.extension.SimpleExtension; import io.substrait.extension.SimpleExtension.Argument; import io.substrait.function.ParameterizedType; @@ -41,6 +43,7 @@ import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.SqlOperator; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -181,6 +184,27 @@ public FunctionConverter( * @return matching {@link SqlOperator}, or empty if none */ public Optional getSqlOperatorFromSubstraitFunc(String key, Type outputType) { + return getSqlOperatorFromSubstraitFunc(key, outputType, java.util.Collections.emptyList()); + } + + /** + * Converts a Substrait function to a Calcite {@link SqlOperator} (Substrait → Calcite direction). + * + *

Given a Substrait function key (e.g., "std_dev:req_fp64"), output type, and function + * arguments (which may include a {@code distribution} {@link io.substrait.expression.EnumArg}), + * this method finds the corresponding Calcite {@link SqlOperator}. When multiple operators match, + * the output type and the {@code distribution} enum argument are used to disambiguate. + * + *

For example, both STDDEV_POP and STDDEV_SAMP map to "std_dev:req_fp64", but differ in the + * {@code distribution} enum argument (POPULATION vs SAMPLE). + * + * @param key the Substrait function key (function name with type signature) + * @param outputType the expected output type + * @param arguments the function arguments (used to read the {@code distribution} enum argument) + * @return the matching {@link SqlOperator}, or empty if no match found + */ + public Optional getSqlOperatorFromSubstraitFunc( + String key, Type outputType, List arguments) { Map resolver = getTypeBasedResolver(); Collection operators = substraitFuncKeyToSqlOperatorMap.get(key); if (operators.isEmpty()) { @@ -192,27 +216,101 @@ public Optional getSqlOperatorFromSubstraitFunc(String key, Type ou return Optional.of(operators.iterator().next()); } - // at least 2 operators. Use output type to resolve SqlOperator. + // First, filter by output type to ensure type compatibility String outputTypeStr = outputType.accept(ToTypeString.INSTANCE); - List resolvedOperators = + List typeFilteredOperators = operators.stream() .filter( operator -> resolver.containsKey(operator) && resolver.get(operator).types().contains(outputTypeStr)) .collect(Collectors.toList()); + + // If type filtering resolved to a single operator, return it + if (typeFilteredOperators.size() == 1) { + return Optional.of(typeFilteredOperators.get(0)); + } + + // If still ambiguous and a distribution enum argument is present, disambiguate by it. + // Both the population and sample operators share one key (e.g. variance:req_fp32), since the + // SAMPLE/POPULATION value lives in the argument, not the signature. + Optional distribution = distributionArgument(arguments); + List resolvedOperators = typeFilteredOperators; + if (distribution.isPresent()) { + List candidates = + typeFilteredOperators.isEmpty() ? List.copyOf(operators) : typeFilteredOperators; + resolvedOperators = filterByDistribution(candidates, distribution.get()); + } + // only one SqlOperator is possible if (resolvedOperators.size() == 1) { return Optional.of(resolvedOperators.get(0)); } else if (resolvedOperators.size() > 1) { throw new IllegalStateException( String.format( - "Found %d SqlOperators: %s for ScalarFunction %s: ", + "Found %d SqlOperators: %s for function %s", resolvedOperators.size(), resolvedOperators, key)); } return Optional.empty(); } + /** + * Extracts the value of the {@code distribution} enum argument, if present. + * + *

This returns the value of the first {@link EnumArg} in the argument list. It assumes the + * only enum argument that disambiguates between operators sharing a key is the {@code + * distribution} argument of {@code std_dev}/{@code variance} — the only enum-argument aggregate + * functions currently mapped. {@link #filterByDistribution} rejects values it does not recognize. + * + * @param arguments the Substrait function arguments + * @return the distribution value (e.g. {@code SAMPLE} / {@code POPULATION}) if an {@link EnumArg} + * is present + */ + private static Optional distributionArgument(List arguments) { + if (arguments == null) { + return Optional.empty(); + } + return arguments.stream() + .filter(arg -> arg instanceof EnumArg) + .map(arg -> (EnumArg) arg) + .flatMap(arg -> arg.value().stream()) + .findFirst(); + } + + /** + * Filters SqlOperators based on the {@code distribution} enum argument. + * + *

For statistical functions like STDDEV and VAR, the {@code distribution} argument determines + * whether to use the population or sample variant: + * + *

+ * + * @param operators the list of candidate SqlOperators + * @param distributionValue the distribution value from the Substrait enum argument + * @return filtered list of SqlOperators matching the distribution + */ + private List filterByDistribution( + List operators, String distributionValue) { + return operators.stream() + .filter( + operator -> { + SqlKind kind = operator.getKind(); + // Match distribution value to SqlKind + if (StatisticalDistribution.POPULATION.name().equals(distributionValue)) { + return kind == SqlKind.STDDEV_POP || kind == SqlKind.VAR_POP; + } else if (StatisticalDistribution.SAMPLE.name().equals(distributionValue)) { + return kind == SqlKind.STDDEV_SAMP || kind == SqlKind.VAR_SAMP; + } + throw new IllegalArgumentException( + String.format( + "Unknown distribution value '%s' for operator %s", distributionValue, kind)); + }) + .collect(Collectors.toList()); + } + /** * Returns the resolver used to disambiguate Calcite operators by output type. * diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java index 85f44a312..4c7ca3023 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/FunctionMappings.java @@ -155,7 +155,19 @@ public class FunctionMappings { s(AggregateFunctions.SUM0, "sum0"), s(SqlStdOperatorTable.COUNT, "count"), s(SqlStdOperatorTable.APPROX_COUNT_DISTINCT, "approx_count_distinct"), - s(AggregateFunctions.AVG, "avg")) + s(AggregateFunctions.AVG, "avg"), + /* + * Substrait std_dev and variance functions use a leading 'distribution' enum + * argument (SAMPLE or POPULATION) to distinguish between population and sample + * variants. AggregateFunctionConverter synthesizes that argument based on the SqlKind. + * + * Note: Standard Calcite operators (SqlStdOperatorTable.STDDEV_SAMP, etc.) are + * automatically converted to these Substrait variants via toSubstraitAggVariant(). + */ + s(AggregateFunctions.STDDEV_POP, "std_dev"), + s(AggregateFunctions.STDDEV_SAMP, "std_dev"), + s(AggregateFunctions.VAR_POP, "variance"), + s(AggregateFunctions.VAR_SAMP, "variance")) .build(); /** Window function signatures (including supported aggregates) mapped to Substrait names. */ diff --git a/isthmus/src/test/java/io/substrait/isthmus/AggregationFunctionsTest.java b/isthmus/src/test/java/io/substrait/isthmus/AggregationFunctionsTest.java index e3082f55b..04459ae7e 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/AggregationFunctionsTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/AggregationFunctionsTest.java @@ -45,6 +45,14 @@ private Aggregate.Measure functionPicker(Rel input, int field, String fname) { return sb.sum0(input, field); case "avg": return sb.avg(input, field); + case "stddev_pop": + return sb.stddevPopulation(input, field); + case "stddev_samp": + return sb.stddevSample(input, field); + case "var_pop": + return sb.variancePopulation(input, field); + case "var_samp": + return sb.varianceSample(input, field); default: throw new UnsupportedOperationException( String.format("no function is associated with %s", fname)); @@ -54,14 +62,41 @@ private Aggregate.Measure functionPicker(Rel input, int field, String fname) { // Create one function call per numeric type column private List functions(Rel input, String fname) { // first column is for grouping, skip it + // Statistical functions (stddev_*, var_*) only support floating-point types in Substrait. + // This filtering ensures we only test with fp32 and fp64 types for these functions, + // avoiding type mismatch errors during round-trip conversion. + boolean isStatisticalFunction = fname.startsWith("stddev_") || fname.startsWith("var_"); return IntStream.range(1, tableTypes.size()) .boxed() + .filter( + index -> { + if (!isStatisticalFunction) { + return true; // All numeric types for non-statistical functions + } + // Only floating-point types for statistical functions + Type type = tableTypes.get(index); + return type.equals(R.FP32) + || type.equals(R.FP64) + || type.equals(N.FP32) + || type.equals(N.FP64); + }) .map(index -> functionPicker(input, index, fname)) .collect(Collectors.toList()); } @ParameterizedTest - @ValueSource(strings = {"max", "min", "sum", "sum0", "avg"}) + @ValueSource( + strings = { + "max", + "min", + "sum", + "sum0", + "avg", + "stddev_pop", + "stddev_samp", + "var_pop", + "var_samp" + }) void emptyGrouping(String aggFunction) { Aggregate rel = sb.aggregate( @@ -70,7 +105,18 @@ void emptyGrouping(String aggFunction) { } @ParameterizedTest - @ValueSource(strings = {"max", "min", "sum", "sum0", "avg"}) + @ValueSource( + strings = { + "max", + "min", + "sum", + "sum0", + "avg", + "stddev_pop", + "stddev_samp", + "var_pop", + "var_samp" + }) void withGrouping(String aggFunction) { Aggregate rel = sb.aggregate( diff --git a/isthmus/src/test/java/io/substrait/isthmus/StatisticalFunctionTest.java b/isthmus/src/test/java/io/substrait/isthmus/StatisticalFunctionTest.java new file mode 100644 index 000000000..7dd03b5f5 --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/StatisticalFunctionTest.java @@ -0,0 +1,114 @@ +package io.substrait.isthmus; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import io.substrait.expression.AggregateFunctionInvocation; +import io.substrait.expression.EnumArg; +import io.substrait.isthmus.sql.SubstraitCreateStatementParser; +import io.substrait.isthmus.sql.SubstraitSqlToCalcite; +import io.substrait.plan.Plan; +import io.substrait.relation.Aggregate; +import io.substrait.relation.Rel; +import java.util.List; +import java.util.Optional; +import org.apache.calcite.prepare.Prepare; +import org.apache.calcite.rel.RelRoot; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; + +/** + * Verifies that the SQL statistical aggregates (STDDEV_POP/SAMP, VAR_POP/SAMP) map to the Substrait + * {@code std_dev} / {@code variance} functions using the non-deprecated enum-argument signatures + * ({@code std_dev:req_fp64} etc.), carrying the population/sample distinction as a {@code + * distribution} {@link EnumArg} rather than a function option. + */ +class StatisticalFunctionTest extends PlanTestBase { + + static final String CREATES = + "CREATE TABLE numbers (i8 TINYINT, i16 SMALLINT, i32 INT, i64 BIGINT, fp32 REAL, fp64 DOUBLE)"; + + @ParameterizedTest + @CsvSource({"STDDEV_POP", "STDDEV_SAMP", "VAR_POP", "VAR_SAMP"}) + void roundTrip(String fn) throws Exception { + assertFullRoundTrip(String.format("SELECT %s(fp32), %s(fp64) FROM numbers", fn, fn), CREATES); + } + + // Integer arguments are cast to fp64 (and the result cast back) since std_dev/variance only have + // fp32/fp64 signatures. This rewrite (castStatisticalAggregatesToFloatingPoint) inserts a cast + // projection that Calcite normalizes (project merge/column pruning) on the first round trip, so + // these use the identity-projection workaround, which asserts stability after normalization. + + @ParameterizedTest + @CsvSource({"STDDEV_POP", "STDDEV_SAMP", "VAR_POP", "VAR_SAMP"}) + void roundTripIntegerInput(String fn) throws Exception { + assertFullRoundTripWithIdentityProjectionWorkaround( + String.format("SELECT %s(i32) FROM numbers", fn), + SubstraitCreateStatementParser.processCreateStatementsToCatalog(CREATES)); + } + + @Test + void roundTripIntegerInputSharedWithOtherAggregate() throws Exception { + // The integer column is shared by SUM (which must keep operating on the integer) and STDDEV_POP + // (which is cast to fp64); the cast must be appended, not applied in place. + assertFullRoundTripWithIdentityProjectionWorkaround( + "SELECT SUM(i32), STDDEV_POP(i32) FROM numbers", + SubstraitCreateStatementParser.processCreateStatementsToCatalog(CREATES)); + } + + @Test + void roundTripIntegerInputWithGrouping() throws Exception { + assertFullRoundTripWithIdentityProjectionWorkaround( + "SELECT i8, VAR_POP(i32) FROM numbers GROUP BY i8", + SubstraitCreateStatementParser.processCreateStatementsToCatalog(CREATES)); + } + + @ParameterizedTest + @CsvSource({ + "STDDEV_POP, std_dev, POPULATION", + "STDDEV_SAMP, std_dev, SAMPLE", + "VAR_POP, variance, POPULATION", + "VAR_SAMP, variance, SAMPLE", + }) + void usesEnumArgSignature(String sqlFn, String substraitFn, String distribution) + throws Exception { + Prepare.CatalogReader catalog = + SubstraitCreateStatementParser.processCreateStatementsToCatalog(CREATES); + RelRoot calcite = + SubstraitSqlToCalcite.convertQuery( + String.format("SELECT %s(fp64) FROM numbers", sqlFn), + catalog, + converterProvider.getSqlOperatorTable()); + Plan.Root root = SubstraitRelVisitor.convert(calcite, converterProvider); + + AggregateFunctionInvocation function = firstMeasure(root.getInput()).getFunction(); + + // The non-deprecated enum-arg variant is used (note the "req" enum argument in the key). + assertEquals(substraitFn + ":req_fp64", function.declaration().key()); + + // The distribution is carried as a leading EnumArg, not as a function option. + List args = function.arguments(); + EnumArg distributionArg = assertInstanceOf(EnumArg.class, args.get(0)); + assertEquals(Optional.of(distribution), distributionArg.value()); + assertTrue(function.options().isEmpty(), "expected no function options"); + } + + /** Recursively finds the first {@link Aggregate} measure in the relation tree. */ + private static Aggregate.Measure firstMeasure(Rel rel) { + if (rel instanceof Aggregate) { + Aggregate aggregate = (Aggregate) rel; + if (!aggregate.getMeasures().isEmpty()) { + return aggregate.getMeasures().get(0); + } + } + for (Rel input : rel.getInputs()) { + Aggregate.Measure measure = firstMeasure(input); + if (measure != null) { + return measure; + } + } + return null; + } +} diff --git a/isthmus/src/test/java/io/substrait/isthmus/expression/AggregateFunctionConverterTest.java b/isthmus/src/test/java/io/substrait/isthmus/expression/AggregateFunctionConverterTest.java index d39cc2084..142eb78bf 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/expression/AggregateFunctionConverterTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/expression/AggregateFunctionConverterTest.java @@ -9,6 +9,8 @@ import io.substrait.isthmus.expression.FunctionConverter.FunctionFinder; import java.util.List; import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.fun.SqlAvgAggFunction; import org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction; import org.apache.calcite.sql.type.SqlTypeName; import org.junit.jupiter.api.Test; @@ -34,4 +36,84 @@ void testFunctionFinderMatch() { assertEquals("sum0", functionFinder.getSubstraitName()); assertEquals(AggregateFunctions.SUM0, functionFinder.getOperator()); } + + @Test + void testStddevPopFunctionFinderMatch() { + AggregateFunctionConverter converter = + new AggregateFunctionConverter( + extensions.aggregateFunctions(), List.of(), typeFactory, TypeConverter.DEFAULT); + + FunctionFinder functionFinder = + converter.getFunctionFinder( + AggregateCall.create( + new SqlAvgAggFunction(SqlKind.STDDEV_POP), + false, + List.of(0), + -1, + typeFactory.createSqlType(SqlTypeName.DOUBLE), + null)); + assertNotNull(functionFinder); + assertEquals("std_dev", functionFinder.getSubstraitName()); + assertEquals(AggregateFunctions.STDDEV_POP, functionFinder.getOperator()); + } + + @Test + void testStddevSampFunctionFinderMatch() { + AggregateFunctionConverter converter = + new AggregateFunctionConverter( + extensions.aggregateFunctions(), List.of(), typeFactory, TypeConverter.DEFAULT); + + FunctionFinder functionFinder = + converter.getFunctionFinder( + AggregateCall.create( + new SqlAvgAggFunction(SqlKind.STDDEV_SAMP), + false, + List.of(0), + -1, + typeFactory.createSqlType(SqlTypeName.DOUBLE), + null)); + assertNotNull(functionFinder); + assertEquals("std_dev", functionFinder.getSubstraitName()); + assertEquals(AggregateFunctions.STDDEV_SAMP, functionFinder.getOperator()); + } + + @Test + void testVarPopFunctionFinderMatch() { + AggregateFunctionConverter converter = + new AggregateFunctionConverter( + extensions.aggregateFunctions(), List.of(), typeFactory, TypeConverter.DEFAULT); + + FunctionFinder functionFinder = + converter.getFunctionFinder( + AggregateCall.create( + new SqlAvgAggFunction(SqlKind.VAR_POP), + false, + List.of(0), + -1, + typeFactory.createSqlType(SqlTypeName.DOUBLE), + null)); + assertNotNull(functionFinder); + assertEquals("variance", functionFinder.getSubstraitName()); + assertEquals(AggregateFunctions.VAR_POP, functionFinder.getOperator()); + } + + @Test + void testVarSampFunctionFinderMatch() { + AggregateFunctionConverter converter = + new AggregateFunctionConverter( + extensions.aggregateFunctions(), List.of(), typeFactory, TypeConverter.DEFAULT); + + FunctionFinder functionFinder = + converter.getFunctionFinder( + AggregateCall.create( + new SqlAvgAggFunction(SqlKind.VAR_SAMP), + false, + List.of(0), + -1, + typeFactory.createSqlType(SqlTypeName.DOUBLE), + null)); + assertNotNull(functionFinder); + assertEquals("variance", functionFinder.getSubstraitName()); + assertEquals(AggregateFunctions.VAR_SAMP, functionFinder.getOperator()); + } }