Skip to content

Commit de5feb5

Browse files
authored
[fix]: fix the thread leak when suspend (#280)
* fix thread pool leak * fix the race condition in concurrency operation * prevent Self-suppression not permitted
1 parent d1cd716 commit de5feb5

File tree

12 files changed

+135
-57
lines changed

12 files changed

+135
-57
lines changed

examples/src/main/java/software/amazon/lambda/durable/examples/callback/WaitForCallbackFailedExample.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import software.amazon.lambda.durable.config.WaitForCallbackConfig;
1010
import software.amazon.lambda.durable.examples.types.ApprovalRequest;
1111
import software.amazon.lambda.durable.exception.SerDesException;
12+
import software.amazon.lambda.durable.execution.SuspendExecutionException;
1213
import software.amazon.lambda.durable.serde.JacksonSerDes;
1314

1415
public class WaitForCallbackFailedExample extends DurableHandler<ApprovalRequest, String> {
@@ -31,6 +32,9 @@ public String handleRequest(ApprovalRequest input, DurableContext context) {
3132
.serDes(new FailedSerDes())
3233
.build())
3334
.build());
35+
} catch (SuspendExecutionException e) {
36+
// not to swallow the SuspendExecutionException
37+
throw e;
3438
} catch (Exception ex) {
3539
return ex.getClass().getSimpleName() + ":" + ex.getMessage();
3640
}

examples/src/main/java/software/amazon/lambda/durable/examples/parallel/DeserializationFailedParallelExample.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import software.amazon.lambda.durable.config.ParallelBranchConfig;
1111
import software.amazon.lambda.durable.config.ParallelConfig;
1212
import software.amazon.lambda.durable.exception.SerDesException;
13+
import software.amazon.lambda.durable.execution.SuspendExecutionException;
1314
import software.amazon.lambda.durable.serde.JacksonSerDes;
1415

1516
/**
@@ -55,6 +56,8 @@ public String handleRequest(Input input, DurableContext context) {
5556
parallel.get();
5657
try {
5758
return future.get();
59+
} catch (SuspendExecutionException e) {
60+
throw e;
5861
} catch (Exception e) {
5962
return e.getMessage();
6063
}

examples/src/main/java/software/amazon/lambda/durable/examples/step/DeserializationFailureExample.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import software.amazon.lambda.durable.TypeToken;
99
import software.amazon.lambda.durable.config.StepConfig;
1010
import software.amazon.lambda.durable.exception.SerDesException;
11+
import software.amazon.lambda.durable.execution.SuspendExecutionException;
1112
import software.amazon.lambda.durable.serde.JacksonSerDes;
1213

1314
public class DeserializationFailureExample extends DurableHandler<String, String> {
@@ -22,6 +23,8 @@ public String handleRequest(String input, DurableContext context) {
2223
throw new RuntimeException("this is a test");
2324
},
2425
StepConfig.builder().serDes(new FailedSerDes()).build());
26+
} catch (SuspendExecutionException e) {
27+
throw e;
2528
} catch (Exception e) {
2629
context.wait("suspend and replay", Duration.ofSeconds(1));
2730
return e.getClass().getSimpleName() + ":" + e.getMessage();

sdk-integration-tests/src/test/java/software/amazon/lambda/durable/ParallelIntegrationTest.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import software.amazon.lambda.durable.model.ConcurrencyCompletionStatus;
1515
import software.amazon.lambda.durable.model.ExecutionStatus;
1616
import software.amazon.lambda.durable.testing.LocalDurableTestRunner;
17+
import software.amazon.lambda.durable.testing.TestOperation;
1718

1819
class ParallelIntegrationTest {
1920

@@ -598,7 +599,14 @@ void testParallelWithMinSuccessful_earlyTermination() {
598599
});
599600

600601
var result = runner.runUntilComplete("test");
601-
assertEquals(ExecutionStatus.SUCCEEDED, result.getStatus());
602+
assertEquals(
603+
ExecutionStatus.SUCCEEDED,
604+
result.getStatus(),
605+
String.join(
606+
" ",
607+
result.getOperations().stream()
608+
.map(TestOperation::toString)
609+
.toList()));
602610
}
603611

604612
@Test

sdk/src/main/java/software/amazon/lambda/durable/DurableConfig.java

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,20 @@ public final class DurableConfig {
7474
private static final String PROJECT_VERSION = getProjectVersion(VERSION_FILE);
7575
private static final String USER_AGENT_SUFFIX = "@aws/durable-execution-sdk-java/" + PROJECT_VERSION;
7676

77+
/**
78+
* A default ExecutorService for running user-defined operations. Uses a cached thread pool with daemon threads by
79+
* default.
80+
*
81+
* <p>This executor is used exclusively for user operations. Internal SDK coordination uses the
82+
* InternalExecutor::INSTANCE
83+
*/
84+
private static final ExecutorService DEFAULT_USER_THREAD_POOL = Executors.newCachedThreadPool(r -> {
85+
Thread t = new Thread(r);
86+
t.setName("durable-exec-" + t.getId());
87+
t.setDaemon(true);
88+
return t;
89+
});
90+
7791
private final DurableExecutionClient durableExecutionClient;
7892
private final SerDes serDes;
7993
private final ExecutorService executorService;
@@ -250,12 +264,7 @@ private static String getProjectVersion(String versionFile) {
250264
*/
251265
private static ExecutorService createDefaultExecutor() {
252266
logger.debug("Creating default ExecutorService");
253-
return Executors.newCachedThreadPool(r -> {
254-
Thread t = new Thread(r);
255-
t.setName("durable-exec-" + t.getId());
256-
t.setDaemon(true);
257-
return t;
258-
});
267+
return DEFAULT_USER_THREAD_POOL;
259268
}
260269

261270
/** Builder for DurableConfig. Provides fluent API for configuring SDK components. */

sdk/src/main/java/software/amazon/lambda/durable/execution/ExecutionManager.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,17 +280,23 @@ public static boolean isTerminalStatus(OperationStatus status) {
280280
* @param exception the unrecoverable exception that caused termination
281281
*/
282282
public void terminateExecution(UnrecoverableDurableExecutionException exception) {
283+
stopAllOperations(exception);
283284
executionExceptionFuture.completeExceptionally(exception);
284285
throw exception;
285286
}
286287

287288
/** Suspends the execution by completing the execution exception future with a {@link SuspendExecutionException}. */
288289
public void suspendExecution() {
289290
var ex = new SuspendExecutionException();
291+
stopAllOperations(ex);
290292
executionExceptionFuture.completeExceptionally(ex);
291293
throw ex;
292294
}
293295

296+
private void stopAllOperations(Exception cause) {
297+
registeredOperations.values().forEach(op -> op.getCompletionFuture().completeExceptionally(cause));
298+
}
299+
294300
/**
295301
* return a future that completes when userFuture completes successfully or the execution is terminated or
296302
* suspended.

sdk/src/main/java/software/amazon/lambda/durable/operation/BaseDurableOperation.java

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import software.amazon.lambda.durable.execution.ThreadType;
2323
import software.amazon.lambda.durable.model.OperationIdentifier;
2424
import software.amazon.lambda.durable.model.OperationSubType;
25+
import software.amazon.lambda.durable.util.ExceptionHelper;
2526

2627
/**
2728
* Base class for all durable operations (STEP, WAIT, etc.).
@@ -187,7 +188,7 @@ protected Operation waitForOperationCompletion() {
187188
// is between `isOperationCompleted` and `thenRun`.
188189
// If this operation is a branch/iteration of a ConcurrencyOperation (map or parallel), the branches/iterations
189190
// must be completed sequentially to avoid race conditions.
190-
synchronized (parentOperation == null ? completionFuture : parentOperation) {
191+
synchronized (parentOperation == null ? completionFuture : parentOperation.completionFuture) {
191192
if (!isOperationCompleted()) {
192193
// Operation not done yet
193194
logger.trace(
@@ -208,7 +209,11 @@ protected Operation waitForOperationCompletion() {
208209
}
209210

210211
// Block until operation completes. No-op if the future is already completed.
211-
completionFuture.join();
212+
try {
213+
completionFuture.join();
214+
} catch (Throwable throwable) {
215+
ExceptionHelper.sneakyThrow(ExceptionHelper.unwrapCompletableFuture(throwable));
216+
}
212217

213218
// Get result based on status
214219
var op = getOperation();
@@ -290,7 +295,7 @@ protected void markAlreadyCompleted() {
290295
private void markCompletionFutureCompleted() {
291296
// It's important that we synchronize access to the future, otherwise the processing could happen
292297
// on someone else's thread and cause a race condition.
293-
synchronized (parentOperation == null ? completionFuture : parentOperation) {
298+
synchronized (parentOperation == null ? completionFuture : parentOperation.completionFuture) {
294299
// Completing the future here will also run any other completion stages that have been attached
295300
// to the future. In our case, other contexts may have attached a function to reactivate themselves,
296301
// so they will definitely have a chance to reactivate before we finish completing and deactivating

sdk/src/main/java/software/amazon/lambda/durable/operation/ConcurrencyOperation.java

Lines changed: 67 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,15 @@
2121
import software.amazon.lambda.durable.TypeToken;
2222
import software.amazon.lambda.durable.config.RunInChildContextConfig;
2323
import software.amazon.lambda.durable.context.DurableContextImpl;
24+
import software.amazon.lambda.durable.exception.UnrecoverableDurableExecutionException;
2425
import software.amazon.lambda.durable.execution.OperationIdGenerator;
26+
import software.amazon.lambda.durable.execution.SuspendExecutionException;
2527
import software.amazon.lambda.durable.execution.ThreadType;
2628
import software.amazon.lambda.durable.model.ConcurrencyCompletionStatus;
2729
import software.amazon.lambda.durable.model.OperationIdentifier;
2830
import software.amazon.lambda.durable.model.OperationSubType;
2931
import software.amazon.lambda.durable.serde.SerDes;
32+
import software.amazon.lambda.durable.util.ExceptionHelper;
3033

3134
/**
3235
* Abstract base class for concurrent execution of multiple child context operations.
@@ -143,7 +146,7 @@ protected <R> ChildContextOperation<R> enqueueItem(
143146
}
144147

145148
private void notifyConsumerThread() {
146-
synchronized (this) {
149+
synchronized (completionFuture) {
147150
consumerThreadListener.get().complete(null);
148151
}
149152
}
@@ -156,61 +159,80 @@ protected void executeItems() {
156159
AtomicInteger failedCount = new AtomicInteger(0);
157160

158161
Runnable consumer = () -> {
159-
while (true) {
160-
// Set a new future if it's completed so that it will be able to receive a notification of
161-
// new items when the thread is checking completion condition and processing
162-
// the queued items below.
163-
synchronized (this) {
164-
if (consumerThreadListener.get() != null
165-
&& consumerThreadListener.get().isDone()) {
166-
consumerThreadListener.set(new CompletableFuture<>());
162+
try {
163+
while (true) {
164+
// Set a new future if it's completed so that it will be able to receive a notification of
165+
// new items when the thread is checking completion condition and processing
166+
// the queued items below.
167+
synchronized (completionFuture) {
168+
if (consumerThreadListener.get() != null
169+
&& consumerThreadListener.get().isDone()) {
170+
consumerThreadListener.set(new CompletableFuture<>());
171+
}
167172
}
168-
}
169173

170-
// Process completion condition. Quit the loop if the condition is met.
171-
if (isOperationCompleted()) {
172-
return;
173-
}
174-
var completionStatus = canComplete(succeededCount, failedCount, runningChildren);
175-
if (completionStatus != null) {
176-
handleCompletion(completionStatus);
177-
return;
178-
}
174+
// Process completion condition. Quit the loop if the condition is met.
175+
if (isOperationCompleted()) {
176+
return;
177+
}
178+
var completionStatus = canComplete(succeededCount, failedCount, runningChildren);
179+
if (completionStatus != null) {
180+
handleCompletion(completionStatus);
181+
return;
182+
}
179183

180-
// process new items in the queue
181-
while (runningChildren.size() < maxConcurrency && !pendingQueue.isEmpty()) {
182-
var next = pendingQueue.poll();
183-
runningChildren.add(next);
184-
logger.debug("Executing operation {}", next.getName());
185-
next.execute();
186-
}
184+
// process new items in the queue
185+
while (runningChildren.size() < maxConcurrency && !pendingQueue.isEmpty()) {
186+
var next = pendingQueue.poll();
187+
runningChildren.add(next);
188+
logger.debug("Executing operation {}", next.getName());
189+
next.execute();
190+
}
187191

188-
// If consumerThreadListener has been completed when processing above, waitForChildCompletion will
189-
// immediately return null and repeat the above again
190-
var child = waitForChildCompletion(succeededCount, failedCount, runningChildren);
191-
192-
// child may be null if the consumer thread is woken up due to new items added or completion condition
193-
// changed
194-
if (child != null) {
195-
if (runningChildren.contains(child)) {
196-
runningChildren.remove(child);
197-
onItemComplete(succeededCount, failedCount, (ChildContextOperation<?>) child);
198-
} else {
199-
throw new IllegalStateException("Unexpected completion: " + child);
192+
// If consumerThreadListener has been completed when processing above, waitForChildCompletion will
193+
// immediately return null and repeat the above again
194+
var child = waitForChildCompletion(succeededCount, failedCount, runningChildren);
195+
196+
// child may be null if the consumer thread is woken up due to new items added or completion
197+
// condition
198+
// changed
199+
if (child != null) {
200+
if (runningChildren.contains(child)) {
201+
runningChildren.remove(child);
202+
onItemComplete(succeededCount, failedCount, (ChildContextOperation<?>) child);
203+
} else {
204+
throw new IllegalStateException("Unexpected completion: " + child);
205+
}
200206
}
201207
}
208+
} catch (Throwable ex) {
209+
handleException(ex);
202210
}
203211
};
204212
// run consumer in the user thread pool, although it's not a real user thread
205213
runUserHandler(consumer, getOperationId(), ThreadType.CONTEXT);
206214
}
207215

216+
private void handleException(Throwable ex) {
217+
Throwable throwable = ExceptionHelper.unwrapCompletableFuture(ex);
218+
if (throwable instanceof SuspendExecutionException suspendExecutionException) {
219+
// Rethrow Error immediately — do not checkpoint
220+
throw suspendExecutionException;
221+
}
222+
if (throwable instanceof UnrecoverableDurableExecutionException unrecoverableDurableExecutionException) {
223+
throw terminateExecution(unrecoverableDurableExecutionException);
224+
}
225+
226+
throw terminateExecutionWithIllegalDurableOperationException(
227+
String.format("Unexpected exception in concurrency operation: %s", throwable));
228+
}
229+
208230
private BaseDurableOperation waitForChildCompletion(
209231
AtomicInteger succeededCount, AtomicInteger failedCount, Set<BaseDurableOperation> runningChildren) {
210232
var threadContext = getCurrentThreadContext();
211233
CompletableFuture<Object> future;
212234

213-
synchronized (this) {
235+
synchronized (completionFuture) {
214236
// check again in synchronized block to prevent race conditions
215237
if (isOperationCompleted()) {
216238
return null;
@@ -238,7 +260,12 @@ private BaseDurableOperation waitForChildCompletion(
238260
executionManager.deregisterActiveThread(threadContext.threadId());
239261
}
240262
}
241-
return future.thenApply(o -> (BaseDurableOperation) o).join();
263+
try {
264+
return future.thenApply(o -> (BaseDurableOperation) o).join();
265+
} catch (Throwable throwable) {
266+
ExceptionHelper.sneakyThrow(ExceptionHelper.unwrapCompletableFuture(throwable));
267+
throw throwable;
268+
}
242269
}
243270

244271
/**

sdk/src/main/java/software/amazon/lambda/durable/operation/MapOperation.java

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,14 @@
1414
import software.amazon.lambda.durable.config.CompletionConfig;
1515
import software.amazon.lambda.durable.config.MapConfig;
1616
import software.amazon.lambda.durable.context.DurableContextImpl;
17+
import software.amazon.lambda.durable.exception.UnrecoverableDurableExecutionException;
18+
import software.amazon.lambda.durable.execution.SuspendExecutionException;
1719
import software.amazon.lambda.durable.model.ConcurrencyCompletionStatus;
1820
import software.amazon.lambda.durable.model.MapResult;
1921
import software.amazon.lambda.durable.model.OperationIdentifier;
2022
import software.amazon.lambda.durable.model.OperationSubType;
2123
import software.amazon.lambda.durable.serde.SerDes;
24+
import software.amazon.lambda.durable.util.ExceptionHelper;
2225

2326
/**
2427
* Executes a map operation: applies a function to each item in a collection concurrently, with each item running in its
@@ -153,8 +156,18 @@ protected void handleCompletion(ConcurrencyCompletionStatus concurrencyCompletio
153156
} else {
154157
try {
155158
resultItems.set(i, MapResult.MapResultItem.succeeded(branch.get()));
156-
} catch (Exception e) {
157-
resultItems.set(i, MapResult.MapResultItem.failed(MapResult.MapError.of(e)));
159+
} catch (Throwable exception) {
160+
Throwable throwable = ExceptionHelper.unwrapCompletableFuture(exception);
161+
if (throwable instanceof SuspendExecutionException suspendExecutionException) {
162+
// Rethrow Error immediately — do not checkpoint
163+
throw suspendExecutionException;
164+
}
165+
if (throwable
166+
instanceof UnrecoverableDurableExecutionException unrecoverableDurableExecutionException) {
167+
// terminate the execution and throw the exception if it's not recoverable
168+
throw terminateExecution(unrecoverableDurableExecutionException);
169+
}
170+
resultItems.set(i, MapResult.MapResultItem.failed(MapResult.MapError.of(throwable)));
158171
}
159172
}
160173
}

sdk/src/main/java/software/amazon/lambda/durable/operation/ParallelOperation.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,9 @@ public ParallelResult get() {
113113
/** Calls {@link #get()} if not already called. Guarantees that the context is closed. */
114114
@Override
115115
public void close() {
116+
if (isJoined.get()) {
117+
return;
118+
}
116119
join();
117120
}
118121

0 commit comments

Comments
 (0)