From 096bafddd84b1e5cc0a5c3ec692448cda10ec777 Mon Sep 17 00:00:00 2001 From: yangtao555 Date: Fri, 26 Jun 2026 17:15:21 +0800 Subject: [PATCH 1/2] [fix](nereids) Gate aggregate parent shuffle reuse by NDV stats --- .../properties/RequestPropertyDeriver.java | 65 +++++------ .../RequestPropertyDeriverTest.java | 108 +++++++++++++++++- 2 files changed, 130 insertions(+), 43 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequestPropertyDeriver.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequestPropertyDeriver.java index 70f2b51665b740..2df8e1809bf1e4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequestPropertyDeriver.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequestPropertyDeriver.java @@ -31,7 +31,6 @@ import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.OrderExpression; -import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.plans.DistributeType; import org.apache.doris.nereids.trees.plans.Plan; @@ -74,16 +73,12 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.Lists; import com.google.common.collect.Maps; -import com.google.common.collect.Sets; import java.util.ArrayList; import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.Set; import java.util.stream.Collectors; /** @@ -509,29 +504,34 @@ public Void visitPhysicalHashAggregate(PhysicalHashAggregate agg addRequestPropertyToChildren(PhysicalProperties.GATHER); return null; } - List groupByExprIds = agg.getGroupByExpressions().stream() - .filter(SlotReference.class::isInstance) - .map(SlotReference.class::cast) - .map(SlotReference::getExprId) - .collect(Collectors.toList()); + List groupByExprIds = new ArrayList<>(); + Map groupByExprIdToExpr = Maps.newHashMap(); + for (Expression groupByExpr : agg.getGroupByExpressions()) { + if (groupByExpr instanceof SlotReference) { + ExprId groupByExprId = ((SlotReference) groupByExpr).getExprId(); + groupByExprIds.add(groupByExprId); + groupByExprIdToExpr.put(groupByExprId, groupByExpr); + } + } DistributionSpec parentDist = requestPropertyFromParent.getDistributionSpec(); if (parentDist instanceof DistributionSpecHash) { DistributionSpecHash distributionRequestFromParent = (DistributionSpecHash) parentDist; List parentHashExprIds = distributionRequestFromParent.getOrderedShuffledColumns(); - Set intersectIdSet = Sets.intersection(new HashSet<>(parentHashExprIds), - new HashSet<>(groupByExprIds)); - if (!intersectIdSet.isEmpty() && intersectIdSet.size() < groupByExprIds.size()) { - List intersectIdList = new ArrayList<>(); - for (ExprId exprId : parentHashExprIds) { - if (!intersectIdSet.contains(exprId)) { - continue; - } - intersectIdList.add(exprId); - } - if (shouldUseParent(intersectIdList, agg, context)) { - addRequestPropertyToChildren( - PhysicalProperties.createHash(intersectIdList, ShuffleType.REQUIRE)); + List parentHashExprIdsInGroupBy = new ArrayList<>(); + List parentHashExprsInGroupBy = new ArrayList<>(); + for (ExprId parentHashExprId : parentHashExprIds) { + Expression parentHashExpr = groupByExprIdToExpr.get(parentHashExprId); + if (parentHashExpr == null) { + continue; } + parentHashExprIdsInGroupBy.add(parentHashExprId); + parentHashExprsInGroupBy.add(parentHashExpr); + } + if (!parentHashExprIdsInGroupBy.isEmpty() + && parentHashExprIdsInGroupBy.size() < groupByExprIds.size() + && shouldUseParent(parentHashExprsInGroupBy, agg, context)) { + addRequestPropertyToChildren( + PhysicalProperties.createHash(parentHashExprIdsInGroupBy, ShuffleType.REQUIRE)); } } addRequestPropertyToChildren(PhysicalProperties.createHash(groupByExprIds, ShuffleType.REQUIRE)); @@ -547,35 +547,24 @@ public Void visitPhysicalBucketedHashAggregate( return null; } - private boolean shouldUseParent(List parentHashExprIds, PhysicalHashAggregate agg, + private boolean shouldUseParent(List parentHashExprs, PhysicalHashAggregate agg, PlanContext context) { if (!context.getConnectContext().getSessionVariable().aggShuffleUseParentKey) { return false; } Optional groupExpression = agg.getGroupExpression(); if (!groupExpression.isPresent()) { - return true; + return false; } if (agg.hasSourceRepeat()) { return false; } Statistics aggChildStats = groupExpression.get().childStatistics(0); if (aggChildStats == null) { - return true; - } - List aggChildOutput = agg.child().getOutput(); - Map exprIdSlotMap = new HashMap<>(); - for (Slot slot : aggChildOutput) { - exprIdSlotMap.put(slot.getExprId(), slot); - } - List parentHashExprs = new ArrayList<>(parentHashExprIds.size()); - for (ExprId exprId : parentHashExprIds) { - if (exprIdSlotMap.containsKey(exprId)) { - parentHashExprs.add(exprIdSlotMap.get(exprId)); - } + return false; } if (AggregateUtils.hasUnknownStatistics(parentHashExprs, aggChildStats)) { - return true; + return false; } double combinedNdv = StatsCalculator.estimateGroupByRowCount(parentHashExprs, aggChildStats); return combinedNdv > AggregateUtils.LOW_NDV_THRESHOLD; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/RequestPropertyDeriverTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/RequestPropertyDeriverTest.java index 27d2c1f4ca19af..837b7339de6741 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/RequestPropertyDeriverTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/RequestPropertyDeriverTest.java @@ -52,10 +52,14 @@ import org.apache.doris.nereids.trees.plans.physical.PhysicalNestedLoopJoin; import org.apache.doris.nereids.trees.plans.physical.PhysicalWindow; import org.apache.doris.nereids.types.IntegerType; +import org.apache.doris.nereids.util.AggregateUtils; import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.MemoTestUtils; import org.apache.doris.qe.ConnectContext; import org.apache.doris.qe.SessionVariable; +import org.apache.doris.statistics.ColumnStatisticBuilder; +import org.apache.doris.statistics.Statistics; +import org.apache.doris.statistics.StatisticsBuilder; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; @@ -411,7 +415,7 @@ void testAggregateWithAggShuffleUseParentKeyDisabled() { } @Test - void testAggregateWithAggShuffleUseParentKeyEnabled() { + void testAggregateWithAggShuffleUseParentKeyEnabledAndUnknownStats() { // Create ConnectContext with aggShuffleUseParentKey = true (default value) ConnectContext testConnectContext = MemoTestUtils.createConnectContext(); testConnectContext.getSessionVariable().aggShuffleUseParentKey = true; @@ -446,14 +450,108 @@ public org.apache.doris.statistics.Statistics childStatistics(int idx) { List> actual = requestPropertyDeriver.getRequestChildrenPropertyList(groupExpression); - // When aggShuffleUseParentKey is true, shouldUseParent may return true - // If shouldUseParent returns true, it will add parent key (key1) first, then all groupByExpressions (key1, key2) - Assertions.assertEquals(2, actual.size(), "Should have at least one property request"); + List> expected = Lists.newArrayList(); + expected.add(Lists.newArrayList(PhysicalProperties.createHash( + Lists.newArrayList(key1.getExprId(), key2.getExprId()), ShuffleType.REQUIRE))); + Assertions.assertEquals(expected, actual); + } + + @Test + void testAggregateWithAggShuffleUseParentKeyEnabledAndLowNdvStats() { + ConnectContext testConnectContext = MemoTestUtils.createConnectContext(); + testConnectContext.getSessionVariable().aggShuffleUseParentKey = true; + testConnectContext.getSessionVariable().setBeNumberForTest(3); + + SlotReference key1 = new SlotReference(new ExprId(0), "col1", IntegerType.INSTANCE, true, ImmutableList.of()); + SlotReference key2 = new SlotReference(new ExprId(1), "col2", IntegerType.INSTANCE, true, ImmutableList.of()); + GroupPlan childPlan = new GroupPlan(new Group(GroupId.createGenerator().getNextId(), + new GroupExpression(new LogicalOneRowRelation(new RelationId(6), ImmutableList.of(key1, key2))) + .getPlan().getLogicalProperties())); + PhysicalHashAggregate aggregate = new PhysicalHashAggregate<>( + Lists.newArrayList(key1, key2), + Lists.newArrayList(key1, key2), + new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_RESULT), + true, + logicalProperties, + false, + childPlan + ); + Statistics childStats = new StatisticsBuilder() + .setRowCount(10000) + .putColumnStatistics(key1, + new ColumnStatisticBuilder(10000).setNdv(AggregateUtils.LOW_NDV_THRESHOLD).build()) + .build(); + GroupExpression groupExpression = new GroupExpression(aggregate) { + @Override + public Statistics childStatistics(int idx) { + return childStats; + } + }; + new Group(null, groupExpression, null); + + PhysicalProperties parentProperties = PhysicalProperties.createHash( + Lists.newArrayList(key1.getExprId()), ShuffleType.REQUIRE); + + Mockito.when(jobContext.getRequiredProperties()).thenReturn(parentProperties); + + RequestPropertyDeriver requestPropertyDeriver = new RequestPropertyDeriver(testConnectContext, jobContext); + List> actual + = requestPropertyDeriver.getRequestChildrenPropertyList(groupExpression); + + List> expected = Lists.newArrayList(); + expected.add(Lists.newArrayList(PhysicalProperties.createHash( + Lists.newArrayList(key1.getExprId(), key2.getExprId()), ShuffleType.REQUIRE))); + Assertions.assertEquals(expected, actual); + } + + @Test + void testAggregateWithAggShuffleUseParentKeyEnabledAndHighNdvStats() { + ConnectContext testConnectContext = MemoTestUtils.createConnectContext(); + testConnectContext.getSessionVariable().aggShuffleUseParentKey = true; + testConnectContext.getSessionVariable().setBeNumberForTest(3); + + SlotReference key1 = new SlotReference(new ExprId(0), "col1", IntegerType.INSTANCE, true, ImmutableList.of()); + SlotReference key2 = new SlotReference(new ExprId(1), "col2", IntegerType.INSTANCE, true, ImmutableList.of()); + GroupPlan childPlan = new GroupPlan(new Group(GroupId.createGenerator().getNextId(), + new GroupExpression(new LogicalOneRowRelation(new RelationId(6), ImmutableList.of(key1, key2))) + .getPlan().getLogicalProperties())); + PhysicalHashAggregate aggregate = new PhysicalHashAggregate<>( + Lists.newArrayList(key1, key2), + Lists.newArrayList(key1, key2), + new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_RESULT), + true, + logicalProperties, + false, + childPlan + ); + Statistics childStats = new StatisticsBuilder() + .setRowCount(10000) + .putColumnStatistics(key1, new ColumnStatisticBuilder(10000).setNdv(2000).build()) + .build(); + GroupExpression groupExpression = new GroupExpression(aggregate) { + @Override + public Statistics childStatistics(int idx) { + return childStats; + } + }; + new Group(null, groupExpression, null); + + PhysicalProperties parentProperties = PhysicalProperties.createHash( + Lists.newArrayList(key1.getExprId()), ShuffleType.REQUIRE); + + Mockito.when(jobContext.getRequiredProperties()).thenReturn(parentProperties); + + RequestPropertyDeriver requestPropertyDeriver = new RequestPropertyDeriver(testConnectContext, jobContext); + List> actual + = requestPropertyDeriver.getRequestChildrenPropertyList(groupExpression); + PhysicalProperties parentProp = PhysicalProperties.createHash( Lists.newArrayList(key1.getExprId()), ShuffleType.REQUIRE); PhysicalProperties aggProp = PhysicalProperties.createHash( Lists.newArrayList(key1.getExprId(), key2.getExprId()), ShuffleType.REQUIRE); - Assertions.assertTrue(actual.contains(ImmutableList.of(aggProp)) && actual.contains(ImmutableList.of(parentProp))); + Assertions.assertEquals(2, actual.size()); + Assertions.assertTrue(actual.contains(ImmutableList.of(parentProp))); + Assertions.assertTrue(actual.contains(ImmutableList.of(aggProp))); } @Test From 2de423e6fa25661ce9af7e98830e7c9f94268d0f Mon Sep 17 00:00:00 2001 From: yangtao555 Date: Fri, 26 Jun 2026 17:55:27 +0800 Subject: [PATCH 2/2] Fix test --- .../rules/rewrite/SplitMultiDistinctTest.java | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SplitMultiDistinctTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SplitMultiDistinctTest.java index 94c60586feec96..dc71bfe6ad24e1 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SplitMultiDistinctTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SplitMultiDistinctTest.java @@ -184,12 +184,16 @@ void countMultiColumnsWithGby() { physicalHashJoin( physicalProject( physicalHashAggregate( - physicalHashAggregate( - physicalDistribute(any())))), + physicalDistribute( + physicalHashAggregate( + physicalHashAggregate( + physicalDistribute(any())))))), physicalProject( physicalHashAggregate( - physicalHashAggregate( - physicalDistribute(any())))) + physicalDistribute( + physicalHashAggregate( + physicalHashAggregate( + physicalDistribute(any())))))) ).when(join -> join.getJoinType() == JoinType.INNER_JOIN && join.getHashJoinConjuncts().get(0) instanceof NullSafeEqual )