blob: dc6df2ed50adabbb66deba30f552b653977cd2b5 [file] [log] [blame]
/*
* Copyright (c) Facebook, Inc. and its affiliates.
* Copyright (c) 2021-2022 NVIDIA Corporation
*
* Licensed under the Apache License Version 2.0 with LLVM Exceptions
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* https://llvm.org/LICENSE.txt
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
// The original idea is taken from libunifex and adapted to stdexec.
#include "../stdexec/execution.hpp"
#include "any_sender_of.hpp"
#include "inline_scheduler.hpp"
#include <exception>
#include <type_traits>
namespace exec
{
namespace __at_coro_exit
{
using namespace stdexec;
using __any_scheduler = //
any_receiver_ref< //
completion_signatures<set_error_t(std::exception_ptr),
set_stopped_t()>> //
::any_sender<>::any_scheduler<>;
struct __die_on_stop_t
{
template <class _Receiver>
struct __receiver_id
{
struct __t
{
using is_receiver = void;
using __id = __receiver_id;
_Receiver __receiver_;
template <__one_of<set_value_t, set_error_t> _Tag,
__decays_to<__t> _Self, class... _Args>
requires __callable<_Tag, _Receiver, _Args...>
friend void tag_invoke(_Tag, _Self&& __self,
_Args&&... __args) noexcept
{
_Tag{}((_Receiver&&)__self.__receiver_, (_Args&&)__args...);
}
template <same_as<set_stopped_t> _Tag>
[[noreturn]] friend void tag_invoke(_Tag, __t&&) noexcept
{
std::terminate();
}
friend env_of_t<_Receiver> tag_invoke(get_env_t,
const __t& __self) noexcept
{
return get_env(__self.__receiver_);
}
};
};
template <class _Rec>
using __receiver = __t<__receiver_id<_Rec>>;
template <class _Sender>
struct __sender_id
{
template <class _Env>
using __completion_signatures = //
__mapply<__remove<set_stopped_t(), __q<completion_signatures>>,
completion_signatures_of_t<_Sender, _Env>>;
struct __t
{
using __id = __sender_id;
using is_sender = void;
_Sender __sender_;
template <receiver _Receiver>
requires sender_to<_Sender, __receiver<_Receiver>>
friend connect_result_t<_Sender, __receiver<_Receiver>>
tag_invoke(connect_t, __t&& __self, _Receiver&& __rcvr) noexcept
{
return stdexec::connect(
(_Sender&&)__self.__sender_,
__receiver<_Receiver>{(_Receiver&&)__rcvr});
}
template <__decays_to<__t> _Self, class _Env>
friend auto tag_invoke(get_completion_signatures_t, _Self&&, _Env&&)
-> dependent_completion_signatures<_Env>;
template <__decays_to<__t> _Self, class _Env>
friend auto tag_invoke(get_completion_signatures_t, _Self&&, _Env&&)
-> __completion_signatures<_Env>
requires true;
friend env_of_t<_Sender> tag_invoke(get_env_t,
const __t& __self) noexcept
{
return get_env(__self.__sender_);
}
};
};
template <class _Sender>
using __sender = __t<__sender_id<__decay_t<_Sender>>>;
template <sender _Sender>
__sender<_Sender> operator()(_Sender&& __sndr) const
noexcept(__nothrow_decay_copyable<_Sender>)
{
return __sender<_Sender>{(_Sender&&)__sndr};
}
template <class _Value>
_Value&& operator()(_Value&& __value) const noexcept
{
return (_Value&&)__value;
}
};
inline constexpr __die_on_stop_t __die_on_stop;
template <class _Promise>
concept __has_continuation = //
requires(_Promise& __promise, __continuation_handle<> __c) {
{
__promise.continuation()
} -> convertible_to<__continuation_handle<>>;
{
__promise.set_continuation(__c)
};
};
template <class... _Ts>
class [[nodiscard]] __task
{
struct __promise;
public:
using promise_type = __promise;
explicit __task(__coro::coroutine_handle<__promise> __coro) noexcept :
__coro_(__coro)
{}
__task(__task&& __that) noexcept :
__coro_(std::exchange(__that.__coro_, {}))
{}
bool await_ready() const noexcept
{
return false;
}
template <__has_continuation _Promise>
bool await_suspend(__coro::coroutine_handle<_Promise> __parent) noexcept
{
__coro_.promise().__scheduler_ =
get_scheduler(get_env(__parent.promise()));
__coro_.promise().set_continuation(__parent.promise().continuation());
__parent.promise().set_continuation(__coro_);
return false;
}
std::tuple<_Ts&...> await_resume() noexcept
{
return std::exchange(__coro_, {}).promise().__args_;
}
private:
struct __final_awaitable
{
static constexpr bool await_ready() noexcept
{
return false;
}
static __coro::coroutine_handle<>
await_suspend(__coro::coroutine_handle<__promise> __h) noexcept
{
__promise& __p = __h.promise();
auto __coro = __p.__is_unhandled_stopped_
? __p.continuation().unhandled_stopped()
: __p.continuation().handle();
__h.destroy();
return __coro;
}
void await_resume() const noexcept {}
};
struct __env
{
const __promise& __promise_;
friend __any_scheduler tag_invoke(get_scheduler_t,
__env __self) noexcept
{
return __self.__promise_.__scheduler_;
}
};
struct __promise : with_awaitable_senders<__promise>
{
template <class _Action>
explicit __promise(_Action&&, _Ts&... __ts) noexcept : __args_{__ts...}
{}
__coro::suspend_always initial_suspend() noexcept
{
return {};
}
__final_awaitable final_suspend() noexcept
{
return {};
}
void return_void() noexcept {}
[[noreturn]] void unhandled_exception() noexcept
{
std::terminate();
}
__coro::coroutine_handle<__promise> unhandled_stopped() noexcept
{
__is_unhandled_stopped_ = true;
return __coro::coroutine_handle<__promise>::from_promise(*this);
}
__task get_return_object() noexcept
{
return __task(
__coro::coroutine_handle<__promise>::from_promise(*this));
}
template <class _Awaitable>
decltype(auto) await_transform(_Awaitable&& __awaitable) noexcept
{
return as_awaitable(__die_on_stop((_Awaitable&&)__awaitable),
*this);
}
friend __env tag_invoke(get_env_t, const __promise& __self) noexcept
{
return {__self};
}
bool __is_unhandled_stopped_{false};
std::tuple<_Ts&...> __args_{};
__any_scheduler __scheduler_{inline_scheduler{}};
};
__coro::coroutine_handle<__promise> __coro_;
};
struct __at_coro_exit_t
{
private:
template <class _Action, class... _Ts>
static __task<_Ts...> __impl(_Action __action, _Ts... __ts)
{
co_await ((_Action&&)__action)((_Ts&&)__ts...);
}
public:
template <class _Action, class... _Ts>
requires __callable<__decay_t<_Action>, __decay_t<_Ts>...>
__task<_Ts...> operator()(_Action&& __action, _Ts&&... __ts) const
{
return __impl((_Action&&)__action, (_Ts&&)__ts...);
}
};
} // namespace __at_coro_exit
inline constexpr __at_coro_exit::__at_coro_exit_t at_coroutine_exit{};
} // namespace exec