13#ifndef IMPACTX_TRACKED_VECTOR_H
14#define IMPACTX_TRACKED_VECTOR_H
50 static_assert(std::is_trivially_copyable<T>(),
"TrackedVector can only hold trivially copyable types");
57 if (*m_finalize_registered) {
return; }
58 *m_finalize_registered =
true;
60 std::weak_ptr<std::vector<T>> weak_host =
m_host;
61 std::weak_ptr<amrex::Gpu::DeviceVector<T>> weak_device = m_device;
62 std::weak_ptr<Status> weak_status =
m_status;
63 std::weak_ptr<bool> weak_registered = m_finalize_registered;
67 auto host = weak_host.lock();
68 auto device = weak_device.lock();
69 auto status = weak_status.lock();
86 if (
auto reg = weak_registered.lock()) {
97 :
m_host(std::make_shared<std::vector<T>>(a_size))
105 :
m_host(std::make_shared<std::vector<T>>(a_size, a_value))
113 :
m_host(std::make_shared<std::vector<T>>(a_initializer_list))
121 :
m_host(std::make_shared<std::vector<T>>(std::move(a_vector)))
129 :
m_host(std::make_shared<std::vector<T>>(*a_vector.
m_host))
133 *m_device = *a_vector.m_device;
140 std::swap(
m_status, a_vector.m_status);
141 std::swap(
m_host, a_vector.m_host);
143 std::swap(m_device, a_vector.m_device);
144 std::swap(m_finalize_registered, a_vector.m_finalize_registered);
149 if (
this != &a_vector) {
153 *m_device = *a_vector.m_device;
161 if (
this != &a_vector) {
162 std::swap(
m_host, a_vector.m_host);
163 std::swap(
m_status, a_vector.m_status);
165 std::swap(m_device, a_vector.m_device);
166 std::swap(m_finalize_registered, a_vector.m_finalize_registered);
187 [[nodiscard]] std::vector<T> &
202 [[nodiscard]] std::vector<T>
const &
221 throw std::runtime_error(
"TrackedVector::device() called before AMReX initialize/after AMReX finalize");
238 throw std::runtime_error(
"TrackedVector::device_const() called before AMReX initialize/after AMReX finalize");
248 [[nodiscard]] std::vector<T> &
252 [[nodiscard]] std::vector<T>
const &
270 m_device->shrink_to_fit();
281 throw std::runtime_error(
"TrackedVector::to_device() called outside of AMReX initialize/finalize");
283 auto const size =
m_host->size();
285 m_device->resize(size);
299 throw std::runtime_error(
"TrackedVector::to_host() called outside of AMReX initialize/finalize");
301 m_host->resize(m_device->size());
303 m_device->begin(), m_device->end(),
m_host->begin());
308 mutable std::shared_ptr<Status>
m_status = std::make_shared<Status>();
309 mutable std::shared_ptr<std::vector<T>>
m_host = std::make_shared<std::vector<T>>();
311 mutable std::shared_ptr<amrex::Gpu::DeviceVector<T>> m_device = std::make_shared<amrex::Gpu::DeviceVector<T>>();
312 mutable std::shared_ptr<bool> m_finalize_registered = std::make_shared<bool>(
false);
PODVector< T, ArenaAllocator< T > > DeviceVector
void copy(HostToDevice, InIter begin, InIter end, OutIter result) noexcept
static constexpr DeviceToHost deviceToHost
static constexpr HostToDevice hostToDevice
void ExecOnFinalize(std::function< void()>)
Definition alignment.H:23
std::size_t size_type
Definition TrackedVector.H:52
TrackedVector & operator=(TrackedVector const &a_vector)
Definition TrackedVector.H:148
std::vector< T > const & device_const() const
Definition TrackedVector.H:253
Status
Definition TrackedVector.H:174
@ host_dirty
device data needs an update
Definition TrackedVector.H:177
@ up_to_date
host and device data are in sync
Definition TrackedVector.H:175
@ device_dirty
host data needs an update
Definition TrackedVector.H:176
T value_type
Definition TrackedVector.H:51
void release_gpu()
Definition TrackedVector.H:263
TrackedVector(std::vector< T > a_vector)
Definition TrackedVector.H:120
std::shared_ptr< std::vector< T > > m_host
Definition TrackedVector.H:309
void to_host() const
Definition TrackedVector.H:296
std::vector< T > const & host_const() const
Definition TrackedVector.H:203
void to_device() const
Definition TrackedVector.H:278
TrackedVector(size_type a_size)
Definition TrackedVector.H:96
TrackedVector(TrackedVector &&a_vector) noexcept
Definition TrackedVector.H:138
Status status() const
Definition TrackedVector.H:180
std::shared_ptr< Status > m_status
Definition TrackedVector.H:308
void register_finalize()
Definition TrackedVector.H:55
std::vector< T > & device()
Definition TrackedVector.H:249
TrackedVector(TrackedVector const &a_vector)
Definition TrackedVector.H:128
std::vector< T > & host()
Definition TrackedVector.H:188
TrackedVector(size_type a_size, value_type const &a_value)
Definition TrackedVector.H:104
TrackedVector(std::initializer_list< T > a_initializer_list)
Definition TrackedVector.H:112