TLA Line data Source code
1 : //
2 : // Copyright (c) 2026 Steve Gerbino
3 : //
4 : // Distributed under the Boost Software License, Version 1.0. (See accompanying
5 : // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
6 : //
7 : // Official repository: https://github.com/cppalliance/capy
8 : //
9 :
10 : #ifndef BOOST_CAPY_WHEN_ALL_HPP
11 : #define BOOST_CAPY_WHEN_ALL_HPP
12 :
13 : #include <boost/capy/detail/config.hpp>
14 : #include <boost/capy/detail/void_to_monostate.hpp>
15 : #include <boost/capy/concept/executor.hpp>
16 : #include <boost/capy/concept/io_awaitable.hpp>
17 : #include <coroutine>
18 : #include <boost/capy/ex/io_env.hpp>
19 : #include <boost/capy/ex/frame_allocator.hpp>
20 : #include <boost/capy/task.hpp>
21 :
22 : #include <array>
23 : #include <atomic>
24 : #include <exception>
25 : #include <optional>
26 : #include <stop_token>
27 : #include <tuple>
28 : #include <type_traits>
29 : #include <utility>
30 :
31 : namespace boost {
32 : namespace capy {
33 :
34 : namespace detail {
35 :
36 : /** Holds the result of a single task within when_all.
37 : */
38 : template<typename T>
39 : struct result_holder
40 : {
41 : std::optional<T> value_;
42 :
43 HIT 62 : void set(T v)
44 : {
45 62 : value_ = std::move(v);
46 62 : }
47 :
48 55 : T get() &&
49 : {
50 55 : return std::move(*value_);
51 : }
52 : };
53 :
54 : /** Specialization for void tasks - returns monostate to preserve index mapping.
55 : */
56 : template<>
57 : struct result_holder<void>
58 : {
59 43 : std::monostate get() && { return {}; }
60 : };
61 :
62 : /** Shared state for when_all operation.
63 :
64 : @tparam Ts The result types of the tasks.
65 : */
66 : template<typename... Ts>
67 : struct when_all_state
68 : {
69 : static constexpr std::size_t task_count = sizeof...(Ts);
70 :
71 : // Completion tracking - when_all waits for all children
72 : std::atomic<std::size_t> remaining_count_;
73 :
74 : // Result storage in input order
75 : std::tuple<result_holder<Ts>...> results_;
76 :
77 : // Runner handles - destroyed in await_resume while allocator is valid
78 : std::array<std::coroutine_handle<>, task_count> runner_handles_{};
79 :
80 : // Exception storage - first error wins, others discarded
81 : std::atomic<bool> has_exception_{false};
82 : std::exception_ptr first_exception_;
83 :
84 : // Stop propagation - on error, request stop for siblings
85 : std::stop_source stop_source_;
86 :
87 : // Connects parent's stop_token to our stop_source
88 : struct stop_callback_fn
89 : {
90 : std::stop_source* source_;
91 4 : void operator()() const { source_->request_stop(); }
92 : };
93 : using stop_callback_t = std::stop_callback<stop_callback_fn>;
94 : std::optional<stop_callback_t> parent_stop_callback_;
95 :
96 : // Parent resumption
97 : std::coroutine_handle<> continuation_;
98 : io_env const* caller_env_ = nullptr;
99 :
100 61 : when_all_state()
101 61 : : remaining_count_(task_count)
102 : {
103 61 : }
104 :
105 : // Runners self-destruct in final_suspend. No destruction needed here.
106 :
107 : /** Capture an exception (first one wins).
108 : */
109 20 : void capture_exception(std::exception_ptr ep)
110 : {
111 20 : bool expected = false;
112 20 : if(has_exception_.compare_exchange_strong(
113 : expected, true, std::memory_order_relaxed))
114 17 : first_exception_ = ep;
115 20 : }
116 :
117 : };
118 :
119 : /** Wrapper coroutine that intercepts task completion.
120 :
121 : This runner awaits its assigned task and stores the result in
122 : the shared state, or captures the exception and requests stop.
123 : */
124 : template<typename T, typename... Ts>
125 : struct when_all_runner
126 : {
127 : struct promise_type // : frame_allocating_base // DISABLED FOR TESTING
128 : {
129 : when_all_state<Ts...>* state_ = nullptr;
130 : io_env env_;
131 :
132 134 : when_all_runner get_return_object()
133 : {
134 134 : return when_all_runner(std::coroutine_handle<promise_type>::from_promise(*this));
135 : }
136 :
137 134 : std::suspend_always initial_suspend() noexcept
138 : {
139 134 : return {};
140 : }
141 :
142 134 : auto final_suspend() noexcept
143 : {
144 : struct awaiter
145 : {
146 : promise_type* p_;
147 :
148 134 : bool await_ready() const noexcept
149 : {
150 134 : return false;
151 : }
152 :
153 134 : auto await_suspend(std::coroutine_handle<> h) noexcept
154 : {
155 : // Extract everything needed before self-destruction.
156 134 : auto* state = p_->state_;
157 134 : auto* counter = &state->remaining_count_;
158 134 : auto* caller_env = state->caller_env_;
159 134 : auto cont = state->continuation_;
160 :
161 134 : h.destroy();
162 :
163 : // If last runner, dispatch parent for symmetric transfer.
164 134 : auto remaining = counter->fetch_sub(1, std::memory_order_acq_rel);
165 134 : if(remaining == 1)
166 61 : return detail::symmetric_transfer(caller_env->executor.dispatch(cont));
167 73 : return detail::symmetric_transfer(std::noop_coroutine());
168 : }
169 :
170 MIS 0 : void await_resume() const noexcept
171 : {
172 0 : }
173 : };
174 HIT 134 : return awaiter{this};
175 : }
176 :
177 114 : void return_void()
178 : {
179 114 : }
180 :
181 20 : void unhandled_exception()
182 : {
183 20 : state_->capture_exception(std::current_exception());
184 : // Request stop for sibling tasks
185 20 : state_->stop_source_.request_stop();
186 20 : }
187 :
188 : template<class Awaitable>
189 : struct transform_awaiter
190 : {
191 : std::decay_t<Awaitable> a_;
192 : promise_type* p_;
193 :
194 134 : bool await_ready()
195 : {
196 134 : return a_.await_ready();
197 : }
198 :
199 134 : decltype(auto) await_resume()
200 : {
201 134 : return a_.await_resume();
202 : }
203 :
204 : template<class Promise>
205 133 : auto await_suspend(std::coroutine_handle<Promise> h)
206 : {
207 : using R = decltype(a_.await_suspend(h, &p_->env_));
208 : if constexpr (std::is_same_v<R, std::coroutine_handle<>>)
209 133 : return detail::symmetric_transfer(a_.await_suspend(h, &p_->env_));
210 : else
211 : return a_.await_suspend(h, &p_->env_);
212 : }
213 : };
214 :
215 : template<class Awaitable>
216 134 : auto await_transform(Awaitable&& a)
217 : {
218 : using A = std::decay_t<Awaitable>;
219 : if constexpr (IoAwaitable<A>)
220 : {
221 : return transform_awaiter<Awaitable>{
222 268 : std::forward<Awaitable>(a), this};
223 : }
224 : else
225 : {
226 : static_assert(sizeof(A) == 0, "requires IoAwaitable");
227 : }
228 134 : }
229 : };
230 :
231 : std::coroutine_handle<promise_type> h_;
232 :
233 134 : explicit when_all_runner(std::coroutine_handle<promise_type> h)
234 134 : : h_(h)
235 : {
236 134 : }
237 :
238 : // Enable move for all clang versions - some versions need it
239 : when_all_runner(when_all_runner&& other) noexcept : h_(std::exchange(other.h_, nullptr)) {}
240 :
241 : // Non-copyable
242 : when_all_runner(when_all_runner const&) = delete;
243 : when_all_runner& operator=(when_all_runner const&) = delete;
244 : when_all_runner& operator=(when_all_runner&&) = delete;
245 :
246 134 : auto release() noexcept
247 : {
248 134 : return std::exchange(h_, nullptr);
249 : }
250 : };
251 :
252 : /** Create a runner coroutine for a single awaitable.
253 :
254 : Awaitable is passed directly to ensure proper coroutine frame storage.
255 : */
256 : template<std::size_t Index, IoAwaitable Awaitable, typename... Ts>
257 : when_all_runner<awaitable_result_t<Awaitable>, Ts...>
258 134 : make_when_all_runner(Awaitable inner, when_all_state<Ts...>* state)
259 : {
260 : using T = awaitable_result_t<Awaitable>;
261 : if constexpr (std::is_void_v<T>)
262 : {
263 : co_await std::move(inner);
264 : }
265 : else
266 : {
267 : std::get<Index>(state->results_).set(co_await std::move(inner));
268 : }
269 268 : }
270 :
271 : /** Internal awaitable that launches all runner coroutines and waits.
272 :
273 : This awaitable is used inside the when_all coroutine to handle
274 : the concurrent execution of child awaitables.
275 : */
276 : template<IoAwaitable... Awaitables>
277 : class when_all_launcher
278 : {
279 : using state_type = when_all_state<awaitable_result_t<Awaitables>...>;
280 :
281 : std::tuple<Awaitables...>* awaitables_;
282 : state_type* state_;
283 :
284 : public:
285 61 : when_all_launcher(
286 : std::tuple<Awaitables...>* awaitables,
287 : state_type* state)
288 61 : : awaitables_(awaitables)
289 61 : , state_(state)
290 : {
291 61 : }
292 :
293 61 : bool await_ready() const noexcept
294 : {
295 61 : return sizeof...(Awaitables) == 0;
296 : }
297 :
298 61 : std::coroutine_handle<> await_suspend(std::coroutine_handle<> continuation, io_env const* caller_env)
299 : {
300 61 : state_->continuation_ = continuation;
301 61 : state_->caller_env_ = caller_env;
302 :
303 : // Forward parent's stop requests to children
304 61 : if(caller_env->stop_token.stop_possible())
305 : {
306 16 : state_->parent_stop_callback_.emplace(
307 8 : caller_env->stop_token,
308 8 : typename state_type::stop_callback_fn{&state_->stop_source_});
309 :
310 8 : if(caller_env->stop_token.stop_requested())
311 4 : state_->stop_source_.request_stop();
312 : }
313 :
314 : // CRITICAL: If the last task finishes synchronously then the parent
315 : // coroutine resumes, destroying its frame, and destroying this object
316 : // prior to the completion of await_suspend. Therefore, await_suspend
317 : // must ensure `this` cannot be referenced after calling `launch_one`
318 : // for the last time.
319 61 : auto token = state_->stop_source_.get_token();
320 62 : [&]<std::size_t... Is>(std::index_sequence<Is...>) {
321 61 : (..., launch_one<Is>(caller_env->executor, token));
322 61 : }(std::index_sequence_for<Awaitables...>{});
323 :
324 : // Let signal_completion() handle resumption
325 122 : return std::noop_coroutine();
326 61 : }
327 :
328 61 : void await_resume() const noexcept
329 : {
330 : // Results are extracted by the when_all coroutine from state
331 61 : }
332 :
333 : private:
334 : template<std::size_t I>
335 134 : void launch_one(executor_ref caller_ex, std::stop_token token)
336 : {
337 134 : auto runner = make_when_all_runner<I>(
338 134 : std::move(std::get<I>(*awaitables_)), state_);
339 :
340 134 : auto h = runner.release();
341 134 : h.promise().state_ = state_;
342 134 : h.promise().env_ = io_env{caller_ex, token, state_->caller_env_->frame_allocator};
343 :
344 134 : std::coroutine_handle<> ch{h};
345 134 : state_->runner_handles_[I] = ch;
346 134 : state_->caller_env_->executor.post(ch);
347 268 : }
348 : };
349 :
350 : /** Helper to extract a single result from state.
351 : This is a separate function to work around a GCC-11 ICE that occurs
352 : when using nested immediately-invoked lambdas with pack expansion.
353 : */
354 : template<std::size_t I, typename... Ts>
355 98 : auto extract_single_result(when_all_state<Ts...>& state)
356 : {
357 98 : return std::move(std::get<I>(state.results_)).get();
358 : }
359 :
360 : /** Extract all results from state as a tuple.
361 : */
362 : template<typename... Ts>
363 44 : auto extract_results(when_all_state<Ts...>& state)
364 : {
365 67 : return [&]<std::size_t... Is>(std::index_sequence<Is...>) {
366 44 : return std::tuple(extract_single_result<Is>(state)...);
367 88 : }(std::index_sequence_for<Ts...>{});
368 : }
369 :
370 : } // namespace detail
371 :
372 : /** Compute the when_all result tuple type.
373 :
374 : Void-returning tasks contribute std::monostate to preserve the
375 : task-index-to-result-index mapping, matching when_any's approach.
376 :
377 : Example: when_all_result_t<int, void, string> = std::tuple<int, std::monostate, string>
378 : Example: when_all_result_t<void, void> = std::tuple<std::monostate, std::monostate>
379 : */
380 : template<typename... Ts>
381 : using when_all_result_t = std::tuple<void_to_monostate_t<Ts>...>;
382 :
383 : /** Execute multiple awaitables concurrently and collect their results.
384 :
385 : Launches all awaitables simultaneously and waits for all to complete
386 : before returning. Results are collected in input order. If any
387 : awaitable throws, cancellation is requested for siblings and the first
388 : exception is rethrown after all awaitables complete.
389 :
390 : @li All child awaitables run concurrently on the caller's executor
391 : @li Results are returned as a tuple in input order
392 : @li Void-returning awaitables contribute std::monostate to the
393 : result tuple, preserving the task-index-to-result-index mapping
394 : @li First exception wins; subsequent exceptions are discarded
395 : @li Stop is requested for siblings on first error
396 : @li Completes only after all children have finished
397 :
398 : @par Thread Safety
399 : The returned task must be awaited from a single execution context.
400 : Child awaitables execute concurrently but complete through the caller's
401 : executor.
402 :
403 : @param awaitables The awaitables to execute concurrently. Each must
404 : satisfy @ref IoAwaitable and is consumed (moved-from) when
405 : `when_all` is awaited.
406 :
407 : @return A task yielding a tuple of results in input order. Void tasks
408 : contribute std::monostate to preserve index correspondence.
409 :
410 : @par Example
411 :
412 : @code
413 : task<> example()
414 : {
415 : // Concurrent fetch, results collected in order
416 : auto [user, posts] = co_await when_all(
417 : fetch_user( id ), // task<User>
418 : fetch_posts( id ) // task<std::vector<Post>>
419 : );
420 :
421 : // Void awaitables contribute monostate
422 : auto [a, _, b] = co_await when_all(
423 : fetch_int(), // task<int>
424 : log_event( "start" ), // task<void> → monostate
425 : fetch_str() // task<string>
426 : );
427 : // a is int, _ is monostate, b is string
428 : }
429 : @endcode
430 :
431 : @see IoAwaitable, task
432 : */
433 : template<IoAwaitable... As>
434 61 : [[nodiscard]] auto when_all(As... awaitables)
435 : -> task<when_all_result_t<awaitable_result_t<As>...>>
436 : {
437 : // State is stored in the coroutine frame, using the frame allocator
438 : detail::when_all_state<awaitable_result_t<As>...> state;
439 :
440 : // Store awaitables in the frame
441 : std::tuple<As...> awaitable_tuple(std::move(awaitables)...);
442 :
443 : // Launch all awaitables and wait for completion
444 : co_await detail::when_all_launcher<As...>(&awaitable_tuple, &state);
445 :
446 : // Propagate first exception if any.
447 : // Safe without explicit acquire: capture_exception() is sequenced-before
448 : // signal_completion()'s acq_rel fetch_sub, which synchronizes-with the
449 : // last task's decrement that resumes this coroutine.
450 : if(state.first_exception_)
451 : std::rethrow_exception(state.first_exception_);
452 :
453 : co_return detail::extract_results(state);
454 122 : }
455 :
456 : } // namespace capy
457 : } // namespace boost
458 :
459 : #endif
|