1  
//
1  
//
2  
// Copyright (c) 2025 Vinnie Falco (vinnie.falco@gmail.com)
2  
// Copyright (c) 2025 Vinnie Falco (vinnie.falco@gmail.com)
3  
//
3  
//
4  
// Distributed under the Boost Software License, Version 1.0. (See accompanying
4  
// Distributed under the Boost Software License, Version 1.0. (See accompanying
5  
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
5  
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
6  
//
6  
//
7  
// Official repository: https://github.com/cppalliance/capy
7  
// Official repository: https://github.com/cppalliance/capy
8  
//
8  
//
9  

9  

10  
#include "src/ex/detail/strand_queue.hpp"
10  
#include "src/ex/detail/strand_queue.hpp"
11  
#include <boost/capy/ex/detail/strand_service.hpp>
11  
#include <boost/capy/ex/detail/strand_service.hpp>
12  
#include <atomic>
12  
#include <atomic>
13  
#include <coroutine>
13  
#include <coroutine>
14  
#include <mutex>
14  
#include <mutex>
15  
#include <thread>
15  
#include <thread>
16  
#include <utility>
16  
#include <utility>
17  

17  

18  
namespace boost {
18  
namespace boost {
19  
namespace capy {
19  
namespace capy {
20  
namespace detail {
20  
namespace detail {
21  

21  

22  
//----------------------------------------------------------
22  
//----------------------------------------------------------
23  

23  

24  
/** Implementation state for a strand.
24  
/** Implementation state for a strand.
25  

25  

26  
    Each strand_impl provides serialization for coroutines
26  
    Each strand_impl provides serialization for coroutines
27  
    dispatched through strands that share it.
27  
    dispatched through strands that share it.
28  
*/
28  
*/
29  
// Sentinel stored in cached_frame_ after shutdown to prevent
29  
// Sentinel stored in cached_frame_ after shutdown to prevent
30  
// in-flight invokers from repopulating a freed cache slot.
30  
// in-flight invokers from repopulating a freed cache slot.
31  
inline void* const kCacheClosed = reinterpret_cast<void*>(1);
31  
inline void* const kCacheClosed = reinterpret_cast<void*>(1);
32  

32  

33  
struct strand_impl
33  
struct strand_impl
34  
{
34  
{
35  
    std::mutex mutex_;
35  
    std::mutex mutex_;
36  
    strand_queue pending_;
36  
    strand_queue pending_;
37  
    bool locked_ = false;
37  
    bool locked_ = false;
38  
    std::atomic<std::thread::id> dispatch_thread_{};
38  
    std::atomic<std::thread::id> dispatch_thread_{};
39  
    std::atomic<void*> cached_frame_{nullptr};
39  
    std::atomic<void*> cached_frame_{nullptr};
40  
};
40  
};
41  

41  

42  
//----------------------------------------------------------
42  
//----------------------------------------------------------
43  

43  

44  
/** Invoker coroutine for strand dispatch.
44  
/** Invoker coroutine for strand dispatch.
45  

45  

46  
    Uses custom allocator to recycle frame - one allocation
46  
    Uses custom allocator to recycle frame - one allocation
47  
    per strand_impl lifetime, stored in trailer for recovery.
47  
    per strand_impl lifetime, stored in trailer for recovery.
48  
*/
48  
*/
49  
struct strand_invoker
49  
struct strand_invoker
50  
{
50  
{
51  
    struct promise_type
51  
    struct promise_type
52  
    {
52  
    {
53  
        void* operator new(std::size_t n, strand_impl& impl)
53  
        void* operator new(std::size_t n, strand_impl& impl)
54  
        {
54  
        {
55  
            constexpr auto A = alignof(strand_impl*);
55  
            constexpr auto A = alignof(strand_impl*);
56  
            std::size_t padded = (n + A - 1) & ~(A - 1);
56  
            std::size_t padded = (n + A - 1) & ~(A - 1);
57  
            std::size_t total = padded + sizeof(strand_impl*);
57  
            std::size_t total = padded + sizeof(strand_impl*);
58  

58  

59  
            void* p = impl.cached_frame_.exchange(
59  
            void* p = impl.cached_frame_.exchange(
60  
                nullptr, std::memory_order_acquire);
60  
                nullptr, std::memory_order_acquire);
61  
            if(!p || p == kCacheClosed)
61  
            if(!p || p == kCacheClosed)
62  
                p = ::operator new(total);
62  
                p = ::operator new(total);
63  

63  

64  
            // Trailer lets delete recover impl
64  
            // Trailer lets delete recover impl
65  
            *reinterpret_cast<strand_impl**>(
65  
            *reinterpret_cast<strand_impl**>(
66  
                static_cast<char*>(p) + padded) = &impl;
66  
                static_cast<char*>(p) + padded) = &impl;
67  
            return p;
67  
            return p;
68  
        }
68  
        }
69  

69  

70  
        void operator delete(void* p, std::size_t n) noexcept
70  
        void operator delete(void* p, std::size_t n) noexcept
71  
        {
71  
        {
72  
            constexpr auto A = alignof(strand_impl*);
72  
            constexpr auto A = alignof(strand_impl*);
73  
            std::size_t padded = (n + A - 1) & ~(A - 1);
73  
            std::size_t padded = (n + A - 1) & ~(A - 1);
74  

74  

75  
            auto* impl = *reinterpret_cast<strand_impl**>(
75  
            auto* impl = *reinterpret_cast<strand_impl**>(
76  
                static_cast<char*>(p) + padded);
76  
                static_cast<char*>(p) + padded);
77  

77  

78  
            void* expected = nullptr;
78  
            void* expected = nullptr;
79  
            if(!impl->cached_frame_.compare_exchange_strong(
79  
            if(!impl->cached_frame_.compare_exchange_strong(
80  
                expected, p, std::memory_order_release))
80  
                expected, p, std::memory_order_release))
81  
                ::operator delete(p);
81  
                ::operator delete(p);
82  
        }
82  
        }
83  

83  

84  
        strand_invoker get_return_object() noexcept
84  
        strand_invoker get_return_object() noexcept
85  
        { return {std::coroutine_handle<promise_type>::from_promise(*this)}; }
85  
        { return {std::coroutine_handle<promise_type>::from_promise(*this)}; }
86  

86  

87  
        std::suspend_always initial_suspend() noexcept { return {}; }
87  
        std::suspend_always initial_suspend() noexcept { return {}; }
88  
        std::suspend_never final_suspend() noexcept { return {}; }
88  
        std::suspend_never final_suspend() noexcept { return {}; }
89  
        void return_void() noexcept {}
89  
        void return_void() noexcept {}
90  
        void unhandled_exception() { std::terminate(); }
90  
        void unhandled_exception() { std::terminate(); }
91  
    };
91  
    };
92  

92  

93  
    std::coroutine_handle<promise_type> h_;
93  
    std::coroutine_handle<promise_type> h_;
94  
};
94  
};
95  

95  

96  
//----------------------------------------------------------
96  
//----------------------------------------------------------
97  

97  

98  
/** Concrete implementation of strand_service.
98  
/** Concrete implementation of strand_service.
99  

99  

100  
    Holds the fixed pool of strand_impl objects.
100  
    Holds the fixed pool of strand_impl objects.
101  
*/
101  
*/
102  
class strand_service_impl : public strand_service
102  
class strand_service_impl : public strand_service
103  
{
103  
{
104  
    static constexpr std::size_t num_impls = 211;
104  
    static constexpr std::size_t num_impls = 211;
105  

105  

106  
    strand_impl impls_[num_impls];
106  
    strand_impl impls_[num_impls];
107  
    std::size_t salt_ = 0;
107  
    std::size_t salt_ = 0;
108  
    std::mutex mutex_;
108  
    std::mutex mutex_;
109  

109  

110  
public:
110  
public:
111  
    explicit
111  
    explicit
112  
    strand_service_impl(execution_context&)
112  
    strand_service_impl(execution_context&)
113  
    {
113  
    {
114  
    }
114  
    }
115  

115  

116  
    strand_impl*
116  
    strand_impl*
117  
    get_implementation() override
117  
    get_implementation() override
118  
    {
118  
    {
119  
        std::lock_guard<std::mutex> lock(mutex_);
119  
        std::lock_guard<std::mutex> lock(mutex_);
120  
        std::size_t index = salt_++;
120  
        std::size_t index = salt_++;
121  
        index = index % num_impls;
121  
        index = index % num_impls;
122  
        return &impls_[index];
122  
        return &impls_[index];
123  
    }
123  
    }
124  

124  

125  
protected:
125  
protected:
126  
    void
126  
    void
127  
    shutdown() override
127  
    shutdown() override
128  
    {
128  
    {
129  
        for(std::size_t i = 0; i < num_impls; ++i)
129  
        for(std::size_t i = 0; i < num_impls; ++i)
130  
        {
130  
        {
131  
            std::lock_guard<std::mutex> lock(impls_[i].mutex_);
131  
            std::lock_guard<std::mutex> lock(impls_[i].mutex_);
132  
            impls_[i].locked_ = true;
132  
            impls_[i].locked_ = true;
133  

133  

134  
            void* p = impls_[i].cached_frame_.exchange(
134  
            void* p = impls_[i].cached_frame_.exchange(
135  
                kCacheClosed, std::memory_order_acquire);
135  
                kCacheClosed, std::memory_order_acquire);
136  
            if(p)
136  
            if(p)
137  
                ::operator delete(p);
137  
                ::operator delete(p);
138  
        }
138  
        }
139  
    }
139  
    }
140  

140  

141  
private:
141  
private:
142  
    static bool
142  
    static bool
143  
    enqueue(strand_impl& impl, std::coroutine_handle<> h)
143  
    enqueue(strand_impl& impl, std::coroutine_handle<> h)
144  
    {
144  
    {
145  
        std::lock_guard<std::mutex> lock(impl.mutex_);
145  
        std::lock_guard<std::mutex> lock(impl.mutex_);
146  
        impl.pending_.push(h);
146  
        impl.pending_.push(h);
147  
        if(!impl.locked_)
147  
        if(!impl.locked_)
148  
        {
148  
        {
149  
            impl.locked_ = true;
149  
            impl.locked_ = true;
150  
            return true;
150  
            return true;
151  
        }
151  
        }
152  
        return false;
152  
        return false;
153  
    }
153  
    }
154  

154  

155  
    static void
155  
    static void
156  
    dispatch_pending(strand_impl& impl)
156  
    dispatch_pending(strand_impl& impl)
157  
    {
157  
    {
158  
        strand_queue::taken_batch batch;
158  
        strand_queue::taken_batch batch;
159  
        {
159  
        {
160  
            std::lock_guard<std::mutex> lock(impl.mutex_);
160  
            std::lock_guard<std::mutex> lock(impl.mutex_);
161  
            batch = impl.pending_.take_all();
161  
            batch = impl.pending_.take_all();
162  
        }
162  
        }
163  
        impl.pending_.dispatch_batch(batch);
163  
        impl.pending_.dispatch_batch(batch);
164  
    }
164  
    }
165  

165  

166  
    static bool
166  
    static bool
167  
    try_unlock(strand_impl& impl)
167  
    try_unlock(strand_impl& impl)
168  
    {
168  
    {
169  
        std::lock_guard<std::mutex> lock(impl.mutex_);
169  
        std::lock_guard<std::mutex> lock(impl.mutex_);
170  
        if(impl.pending_.empty())
170  
        if(impl.pending_.empty())
171  
        {
171  
        {
172  
            impl.locked_ = false;
172  
            impl.locked_ = false;
173  
            return true;
173  
            return true;
174  
        }
174  
        }
175  
        return false;
175  
        return false;
176  
    }
176  
    }
177  

177  

178  
    static void
178  
    static void
179  
    set_dispatch_thread(strand_impl& impl) noexcept
179  
    set_dispatch_thread(strand_impl& impl) noexcept
180  
    {
180  
    {
181  
        impl.dispatch_thread_.store(std::this_thread::get_id());
181  
        impl.dispatch_thread_.store(std::this_thread::get_id());
182  
    }
182  
    }
183  

183  

184  
    static void
184  
    static void
185  
    clear_dispatch_thread(strand_impl& impl) noexcept
185  
    clear_dispatch_thread(strand_impl& impl) noexcept
186  
    {
186  
    {
187  
        impl.dispatch_thread_.store(std::thread::id{});
187  
        impl.dispatch_thread_.store(std::thread::id{});
188  
    }
188  
    }
189  

189  

190  
    // Loops until queue empty (aggressive). Alternative: per-batch fairness
190  
    // Loops until queue empty (aggressive). Alternative: per-batch fairness
191  
    // (repost after each batch to let other work run) - explore if starvation observed.
191  
    // (repost after each batch to let other work run) - explore if starvation observed.
192  
    static strand_invoker
192  
    static strand_invoker
193  
    make_invoker(strand_impl& impl)
193  
    make_invoker(strand_impl& impl)
194  
    {
194  
    {
195  
        strand_impl* p = &impl;
195  
        strand_impl* p = &impl;
196  
        for(;;)
196  
        for(;;)
197  
        {
197  
        {
198  
            set_dispatch_thread(*p);
198  
            set_dispatch_thread(*p);
199  
            dispatch_pending(*p);
199  
            dispatch_pending(*p);
200  
            if(try_unlock(*p))
200  
            if(try_unlock(*p))
201  
            {
201  
            {
202  
                clear_dispatch_thread(*p);
202  
                clear_dispatch_thread(*p);
203  
                co_return;
203  
                co_return;
204  
            }
204  
            }
205  
        }
205  
        }
206  
    }
206  
    }
207  

207  

208  
    friend class strand_service;
208  
    friend class strand_service;
209  
};
209  
};
210  

210  

211  
//----------------------------------------------------------
211  
//----------------------------------------------------------
212  

212  

213  
strand_service::
213  
strand_service::
214  
strand_service()
214  
strand_service()
215  
    : service()
215  
    : service()
216  
{
216  
{
217  
}
217  
}
218  

218  

219  
strand_service::
219  
strand_service::
220  
~strand_service() = default;
220  
~strand_service() = default;
221  

221  

222  
bool
222  
bool
223  
strand_service::
223  
strand_service::
224  
running_in_this_thread(strand_impl& impl) noexcept
224  
running_in_this_thread(strand_impl& impl) noexcept
225  
{
225  
{
226  
    return impl.dispatch_thread_.load() == std::this_thread::get_id();
226  
    return impl.dispatch_thread_.load() == std::this_thread::get_id();
227  
}
227  
}
228  

228  

229  
std::coroutine_handle<>
229  
std::coroutine_handle<>
230  
strand_service::
230  
strand_service::
231  
dispatch(strand_impl& impl, executor_ref ex, std::coroutine_handle<> h)
231  
dispatch(strand_impl& impl, executor_ref ex, std::coroutine_handle<> h)
232  
{
232  
{
233  
    if(running_in_this_thread(impl))
233  
    if(running_in_this_thread(impl))
234  
        return h;
234  
        return h;
235  

235  

236  
    if(strand_service_impl::enqueue(impl, h))
236  
    if(strand_service_impl::enqueue(impl, h))
237  
        ex.post(strand_service_impl::make_invoker(impl).h_);
237  
        ex.post(strand_service_impl::make_invoker(impl).h_);
238  
    return std::noop_coroutine();
238  
    return std::noop_coroutine();
239  
}
239  
}
240  

240  

241  
void
241  
void
242  
strand_service::
242  
strand_service::
243  
post(strand_impl& impl, executor_ref ex, std::coroutine_handle<> h)
243  
post(strand_impl& impl, executor_ref ex, std::coroutine_handle<> h)
244  
{
244  
{
245  
    if(strand_service_impl::enqueue(impl, h))
245  
    if(strand_service_impl::enqueue(impl, h))
246  
        ex.post(strand_service_impl::make_invoker(impl).h_);
246  
        ex.post(strand_service_impl::make_invoker(impl).h_);
247  
}
247  
}
248  

248  

249  
strand_service&
249  
strand_service&
250  
get_strand_service(execution_context& ctx)
250  
get_strand_service(execution_context& ctx)
251  
{
251  
{
252  
    return ctx.use_service<strand_service_impl>();
252  
    return ctx.use_service<strand_service_impl>();
253  
}
253  
}
254  

254  

255  
} // namespace detail
255  
} // namespace detail
256  
} // namespace capy
256  
} // namespace capy
257  
} // namespace boost
257  
} // namespace boost