Skip to content

Commit 1284d59

Browse files
committed
improve the ThreadLocal approach
1 parent 0d61467 commit 1284d59

File tree

2 files changed

+162
-81
lines changed

2 files changed

+162
-81
lines changed

driver-core/src/main/com/mongodb/internal/async/function/AsyncCallbackLoop.java

Lines changed: 73 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121

2222
import java.util.function.Supplier;
2323

24-
import static com.mongodb.assertions.Assertions.assertTrue;
25-
2624
/**
2725
* A decorator that implements automatic repeating of an {@link AsyncCallbackRunnable}.
2826
* {@link AsyncCallbackLoop} may execute the original asynchronous function multiple times sequentially,
@@ -53,66 +51,103 @@ public AsyncCallbackLoop(final LoopState state, final AsyncCallbackRunnable body
5351

5452
@Override
5553
public void run(final SingleResultCallback<Void> callback) {
56-
body.initiateIteration(false, new ReusableLoopCallback(callback));
54+
body.run(false, new ReusableLoopCallback(callback));
5755
}
5856

5957
private static final class Body {
60-
private final AsyncCallbackRunnable wrapped;
58+
private final AsyncCallbackRunnable body;
6159
private final LoopState state;
62-
private final ThreadLocal<Boolean> iterationIsExecutingSynchronously;
63-
private final ThreadLocal<Status> status;
60+
private final ThreadLocal<SameThreadDetectionStatus> sameThreadDetector;
6461

65-
private enum Status {
66-
ITERATION_INITIATED,
67-
LAST_ITERATION_COMPLETED,
68-
ANOTHER_ITERATION_NEEDED
62+
private enum SameThreadDetectionStatus {
63+
NEGATIVE,
64+
PROBING,
65+
POSITIVE
6966
}
7067

7168
private Body(final LoopState state, final AsyncCallbackRunnable body) {
72-
this.wrapped = body;
69+
this.body = body;
7370
this.state = state;
74-
iterationIsExecutingSynchronously = ThreadLocal.withInitial(() -> false);
75-
status = ThreadLocal.withInitial(() -> Status.ITERATION_INITIATED);
71+
sameThreadDetector = ThreadLocal.withInitial(() -> SameThreadDetectionStatus.NEGATIVE);
7672
}
7773

7874
/**
79-
* Invoking this method initiates a new iteration of the loop. An iteration may be executed either
80-
* synchronously or asynchronously with the execution of this method:
81-
*
75+
* Initiates a new iteration of the loop by invoking
76+
* {@link #body}{@code .}{@link AsyncCallbackRunnable#run(SingleResultCallback) run}.
77+
* The initiated iteration may be executed either synchronously or asynchronously with the method that initiated it:
8278
* <ul>
83-
* <li>synchronous execution: iteration completes before (in the happens-before order) the method completes;</li>
84-
* <li>asynchronous execution: the aforementioned relation does not exist.</li>
79+
* <li>synchronous execution—completion of the initiated iteration happens-before the method completion;</li>
80+
* <li>asynchronous executionthe aforementioned relation does not exist.</li>
8581
* </ul>
8682
*
83+
* <p>If another iteration is needed, it is initiated from the callback passed to
84+
* {@link #body}{@code .}{@link AsyncCallbackRunnable#run(SingleResultCallback) run}
85+
* by invoking {@link #run(boolean, ReusableLoopCallback)}.
86+
* Completing the initiated iteration is {@linkplain SingleResultCallback#onResult(Object, Throwable) invoking} the callback.
87+
* Thus, it is guaranteed that all iterations are executed sequentially with each other
88+
* (that is, completion of one iteration happens-before initiation of the next one)
89+
* regardless of them being executed synchronously or asynchronously with the method that initiated them.
90+
*
91+
* <p>Initiating any but the {@linkplain LoopState#isFirstIteration() first} iteration is done using trampolining,
92+
* which allows us to do it iteratively rather than recursively, if iterations are executed synchronously,
93+
* and ensures stack usage does not increase with the number of iterations.
94+
*
8795
* @return {@code true} iff it is known that another iteration must be initiated.
88-
* Such information is available to this method only if the iteration it initiated has completed synchronously.
96+
* This information is used only for trampolining, and is available only if the iteration executed synchronously.
97+
*
98+
* <p>It is impossible to detect whether an iteration is executed synchronously.
99+
* It is, however, possible to detect whether an iteration is executed in the same thread as the method that initiated it,
100+
* and we use it as proxy indicator of synchronous execution. Unfortunately, this means we do not support and behave incorrectly
101+
* if an iteration is executed synchronously but in a thread different from the one in which the method that
102+
* initiated the iteration was invoked.
103+
*
104+
* <p>The above limitation should not be a problem in practice:
105+
* <ul>
106+
* <li>the only way to execute an iteration synchronously but in a different thread is to block the thread that
107+
* initiated the iteration by waiting for completion of the iteration by that other thread;</li>
108+
* <li>blocking a thread is forbidden in asynchronous code, and we do not do it;</li>
109+
* <li>therefore, we would not have an iteration that is executed synchronously but in a different thread.</li>
110+
* </ul>
89111
*/
90-
Status initiateIteration(final boolean trampolining, final ReusableLoopCallback callback) {
91-
iterationIsExecutingSynchronously.set(true);
92-
wrapped.run((r, t) -> {
93-
boolean localIterationIsExecutingSynchronously = iterationIsExecutingSynchronously.get();
112+
boolean run(final boolean trampolining, final ReusableLoopCallback callback) {
113+
// The `trampoliningResult` variable must be used only if the initiated iteration is executed synchronously with
114+
// the current method, which must be detected separately.
115+
//
116+
// It may be tempting to detect whether the iteration was executed synchronously by reading from the variable
117+
// and observing a write that is part of the callback execution. However, if the iteration is executed asynchronously with
118+
// the current method, then the aforementioned conflicting write and read actions are not ordered by
119+
// the happens-before relation, the execution contains a data race and the read is allowed to observe the write.
120+
// If such observation happens when the iteration is executed asynchronously, then we have a false positive.
121+
// Furthermore, depending on the nature of the value read, it may not be trustworthy.
122+
boolean[] trampoliningResult = {false};
123+
sameThreadDetector.set(SameThreadDetectionStatus.PROBING);
124+
body.run((r, t) -> {
94125
if (callback.onResult(state, r, t)) {
95-
status.set(Status.LAST_ITERATION_COMPLETED);
126+
// If we are trampolining, then here we bounce up, trampolining completes and so is the whole loop;
127+
// otherwise, the whole loop simply completes.
96128
return;
97129
}
98-
if (trampolining && localIterationIsExecutingSynchronously) {
99-
// bounce
100-
status.set(Status.ANOTHER_ITERATION_NEEDED);
101-
return;
130+
if (trampolining) {
131+
boolean sameThread = sameThreadDetector.get().equals(SameThreadDetectionStatus.PROBING);
132+
if (sameThread) {
133+
// Bounce up if we are trampolining and the iteration was executed synchronously;
134+
// otherwise proceed to begin trampolining.
135+
sameThreadDetector.set(SameThreadDetectionStatus.POSITIVE);
136+
trampoliningResult[0] = true;
137+
return;
138+
} else {
139+
sameThreadDetector.remove();
140+
}
102141
}
103-
Status localStatus;
104-
do {
105-
localStatus = initiateIteration(true, callback);
106-
} while (localStatus.equals(Status.ANOTHER_ITERATION_NEEDED));
107-
status.set(localStatus);
108-
109-
// VAKOTODO remove thread-locals if executed asynchronously
142+
boolean anotherIterationNeeded;
143+
do { // trampolining
144+
anotherIterationNeeded = run(true, callback);
145+
} while (anotherIterationNeeded);
110146
});
111147
try {
112-
return status.get();
148+
return sameThreadDetector.get().equals(SameThreadDetectionStatus.POSITIVE) && trampoliningResult[0];
113149
} finally {
114-
status.remove();
115-
iterationIsExecutingSynchronously.remove();
150+
sameThreadDetector.remove();
116151
}
117152
}
118153
}

driver-core/src/test/unit/com/mongodb/internal/async/VakoTest.java

Lines changed: 89 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import com.mongodb.internal.async.function.LoopState;
2020
import com.mongodb.internal.time.StartTime;
2121
import com.mongodb.lang.Nullable;
22+
import org.junit.jupiter.api.AfterAll;
23+
import org.junit.jupiter.api.BeforeAll;
2224
import org.junit.jupiter.params.ParameterizedTest;
2325
import org.junit.jupiter.params.provider.CsvSource;
2426

@@ -28,18 +30,32 @@
2830
import java.time.Duration;
2931
import java.util.Objects;
3032
import java.util.concurrent.CompletableFuture;
31-
import java.util.concurrent.ForkJoinPool;
33+
import java.util.concurrent.Executors;
34+
import java.util.concurrent.ScheduledExecutorService;
3235
import java.util.concurrent.TimeUnit;
3336

3437
import static com.mongodb.internal.async.AsyncRunnable.beginAsync;
3538

3639
class VakoTest {
40+
private static ScheduledExecutorService executor;
41+
42+
@BeforeAll
43+
static void beforeAll() {
44+
executor = Executors.newScheduledThreadPool(2);
45+
}
46+
47+
@AfterAll
48+
static void afterAll() throws InterruptedException {
49+
executor.shutdownNow();
50+
com.mongodb.assertions.Assertions.assertTrue(executor.awaitTermination(1, TimeUnit.MINUTES));
51+
}
52+
3753
@ParameterizedTest
3854
@CsvSource({
3955
"10"
4056
})
4157
void asyncCallbackLoop(final int iterations) throws Exception {
42-
System.err.printf("baselineStackDepth=%d%n", Thread.currentThread().getStackTrace().length);
58+
System.err.printf("baselineStackDepth=%d%n%n", Thread.currentThread().getStackTrace().length);
4359
CompletableFuture<Void> join = new CompletableFuture<>();
4460
LoopState loopState = new LoopState();
4561
new AsyncCallbackLoop(loopState, c -> {
@@ -51,78 +67,92 @@ void asyncCallbackLoop(final int iterations) throws Exception {
5167
}).run((r, t) -> {
5268
System.err.printf("test callback completed callStackDepth=%d, r=%s, t=%s%n",
5369
Thread.currentThread().getStackTrace().length, r, exceptionToString(t));
54-
if (t != null) {
55-
join.completeExceptionally(t);
56-
} else {
57-
join.complete(r);
58-
}
70+
complete(join, r, t);
5971
});
6072
join.get();
6173
System.err.printf("%n%nDONE%n%n");
6274
}
6375

76+
private enum IterationExecutionType {
77+
SYNC_SAME_THREAD,
78+
SYNC_DIFFERENT_THREAD,
79+
ASYNC,
80+
}
81+
6482
@ParameterizedTest()
6583
@CsvSource({
66-
"0, false, 0, 10",
67-
"0, true, 4, 10",
68-
"4, true, 0, 10"
84+
"10, 0, SYNC_SAME_THREAD, 0",
85+
// "10, 0, SYNC_DIFFERENT_THREAD, 0",
86+
"10, 0, ASYNC, 4",
87+
"10, 4, ASYNC, 0"
6988
})
7089
void testThenRunDoWhileLoop(
71-
final int blockInAsyncMethodTotalSeconds,
72-
final boolean asyncExecution,
73-
final int delayAsyncExecutionTotalSeconds,
74-
final int counterInitialValue) throws Exception {
75-
Duration blockInAsyncMethodTotalDuration = Duration.ofSeconds(blockInAsyncMethodTotalSeconds);
76-
com.mongodb.assertions.Assertions.assertTrue(asyncExecution || delayAsyncExecutionTotalSeconds == 0);
90+
final int counterInitialValue,
91+
final int blockSyncPartOfIterationTotalSeconds,
92+
final IterationExecutionType executionType,
93+
final int delayAsyncExecutionTotalSeconds) throws Exception {
94+
System.err.printf("baselineStackDepth=%d%n%n", Thread.currentThread().getStackTrace().length);
95+
Duration blockSyncPartOfIterationTotalDuration = Duration.ofSeconds(blockSyncPartOfIterationTotalSeconds);
96+
com.mongodb.assertions.Assertions.assertTrue(
97+
executionType.equals(IterationExecutionType.ASYNC) || delayAsyncExecutionTotalSeconds == 0);
7798
Duration delayAsyncExecutionTotalDuration = Duration.ofSeconds(delayAsyncExecutionTotalSeconds);
7899
StartTime start = StartTime.now();
79-
System.err.printf("baselineStackDepth=%d%n", Thread.currentThread().getStackTrace().length);
80100
CompletableFuture<Void> join = new CompletableFuture<>();
81-
asyncMethod1(blockInAsyncMethodTotalDuration, asyncExecution, delayAsyncExecutionTotalDuration, new Counter(counterInitialValue),
101+
asyncLoop(new Counter(counterInitialValue), blockSyncPartOfIterationTotalDuration, executionType, delayAsyncExecutionTotalDuration,
82102
(r, t) -> {
83-
System.err.printf("TEST callback completed callStackDepth=%s, r=%s, t=%s%n",
103+
System.err.printf("test callback completed callStackDepth=%s, r=%s, t=%s%n",
84104
Thread.currentThread().getStackTrace().length, r, exceptionToString(t));
85-
if (t != null) {
86-
join.completeExceptionally(t);
87-
} else {
88-
join.complete(r);
89-
}
105+
complete(join, r, t);
90106
});
91-
System.err.printf("asyncMethod1 returned in %s%n", start.elapsed());
107+
System.err.printf("\tasyncLoop returned in %s%n", start.elapsed());
92108
join.get();
93109
System.err.printf("%n%nDONE%n%n");
94110
}
95111

96-
private static void asyncMethod1(
97-
final Duration blockInAsyncMethodTotalDuration,
98-
final boolean asyncExecution,
99-
final Duration delayAsyncExecutionTotalDuration,
112+
private static void asyncLoop(
100113
final Counter counter,
114+
final Duration blockSyncPartOfIterationTotalDuration,
115+
final IterationExecutionType executionType,
116+
final Duration delayAsyncExecutionTotalDuration,
101117
final SingleResultCallback<Void> callback) {
102118
beginAsync().thenRunDoWhileLoop(c -> {
103-
sleep(blockInAsyncMethodTotalDuration.dividedBy(counter.initial()));
119+
sleep(blockSyncPartOfIterationTotalDuration.dividedBy(counter.initial()));
104120
StartTime start = StartTime.now();
105-
asyncMethod2(asyncExecution, delayAsyncExecutionTotalDuration, counter, c);
106-
System.err.printf("asyncMethod2 returned in %s%n", start.elapsed());
121+
asyncPartOfIteration(counter, executionType, delayAsyncExecutionTotalDuration, c);
122+
System.err.printf("\tasyncPartOfIteration returned in %s%n", start.elapsed());
107123
}, () -> !counter.done()).finish(callback);
108124
}
109125

110-
private static void asyncMethod2(
111-
final boolean asyncExecution,
112-
final Duration delayAsyncExecutionTotalDuration,
126+
private static void asyncPartOfIteration(
113127
final Counter counter,
128+
final IterationExecutionType executionType,
129+
final Duration delayAsyncExecutionTotalDuration,
114130
final SingleResultCallback<Void> callback) {
115-
Runnable action = () -> {
116-
sleep(delayAsyncExecutionTotalDuration.dividedBy(counter.initial()));
131+
Runnable asyncPartOfIteration = () -> {
117132
counter.countDown();
118133
StartTime start = StartTime.now();
119134
callback.complete(callback);
120-
System.err.printf("asyncMethod2 callback.complete returned in %s%n", start.elapsed());
135+
System.err.printf("\tasyncPartOfIteration callback.complete returned in %s%n", start.elapsed());
121136
};
122-
if (asyncExecution) {
123-
ForkJoinPool.commonPool().execute(action);
124-
} else {
125-
action.run();
137+
switch (executionType) {
138+
case SYNC_SAME_THREAD: {
139+
asyncPartOfIteration.run();
140+
break;
141+
}
142+
case SYNC_DIFFERENT_THREAD: {
143+
Thread guaranteedDifferentThread = new Thread(asyncPartOfIteration);
144+
guaranteedDifferentThread.start();
145+
join(guaranteedDifferentThread);
146+
break;
147+
}
148+
case ASYNC: {
149+
executor.schedule(asyncPartOfIteration,
150+
delayAsyncExecutionTotalDuration.dividedBy(counter.initial()).toNanos(), TimeUnit.NANOSECONDS);
151+
break;
152+
}
153+
default: {
154+
com.mongodb.assertions.Assertions.fail(executionType.toString());
155+
}
126156
}
127157
}
128158

@@ -170,6 +200,23 @@ private static String exceptionToString(@Nullable final Throwable t) {
170200
}
171201
}
172202

203+
private static <T> void complete(final CompletableFuture<T> future, @Nullable final T result, @Nullable final Throwable t) {
204+
if (t != null) {
205+
future.completeExceptionally(t);
206+
} else {
207+
future.complete(result);
208+
}
209+
}
210+
211+
private static void join(final Thread thread) {
212+
try {
213+
thread.join();
214+
} catch (InterruptedException e) {
215+
Thread.currentThread().interrupt();
216+
throw new RuntimeException(e);
217+
}
218+
}
219+
173220
private static void sleep(final Duration duration) {
174221
if (duration.isZero()) {
175222
return;
@@ -179,7 +226,6 @@ private static void sleep(final Duration duration) {
179226
long durationMsPartFromNsPart = TimeUnit.MILLISECONDS.convert(duration.getNano(), TimeUnit.NANOSECONDS);
180227
long sleepMs = TimeUnit.MILLISECONDS.convert(duration.getSeconds(), TimeUnit.SECONDS) + durationMsPartFromNsPart;
181228
int sleepNs = Math.toIntExact(durationNsPart - TimeUnit.NANOSECONDS.convert(durationMsPartFromNsPart, TimeUnit.MILLISECONDS));
182-
System.err.printf("sleeping for %d ms %d ns%n", sleepMs, sleepNs);
183229
Thread.sleep(sleepMs, sleepNs);
184230
} catch (InterruptedException e) {
185231
Thread.currentThread().interrupt();

0 commit comments

Comments
 (0)