mirror of
https://github.com/FairRootGroup/FairMQ.git
synced 2025-10-13 08:41:16 +00:00
Implement shmem msg zero-copy
This commit is contained in:
parent
5d980723d8
commit
1b3f38b6f1
|
@ -47,6 +47,10 @@ struct Message
|
||||||
TransportFactory* GetTransport() { return fTransport; }
|
TransportFactory* GetTransport() { return fTransport; }
|
||||||
void SetTransport(TransportFactory* transport) { fTransport = transport; }
|
void SetTransport(TransportFactory* transport) { fTransport = transport; }
|
||||||
|
|
||||||
|
/// Copy the message buffer from another message
|
||||||
|
/// Transport may choose not to physically copy the buffer, but to share across the messages.
|
||||||
|
/// Modifying the buffer after a call to Copy() is undefined behaviour.
|
||||||
|
/// @param msg message to copy the buffer from.
|
||||||
virtual void Copy(const Message& msg) = 0;
|
virtual void Copy(const Message& msg) = 0;
|
||||||
|
|
||||||
virtual ~Message() = default;
|
virtual ~Message() = default;
|
||||||
|
|
|
@ -146,9 +146,10 @@ struct MetaHeader
|
||||||
{
|
{
|
||||||
size_t fSize;
|
size_t fSize;
|
||||||
size_t fHint;
|
size_t fHint;
|
||||||
uint16_t fRegionId;
|
|
||||||
uint16_t fSegmentId;
|
|
||||||
boost::interprocess::managed_shared_memory::handle_t fHandle;
|
boost::interprocess::managed_shared_memory::handle_t fHandle;
|
||||||
|
mutable boost::interprocess::managed_shared_memory::handle_t fShared;
|
||||||
|
uint16_t fRegionId;
|
||||||
|
mutable uint16_t fSegmentId;
|
||||||
};
|
};
|
||||||
|
|
||||||
#ifdef FAIRMQ_DEBUG_MODE
|
#ifdef FAIRMQ_DEBUG_MODE
|
||||||
|
@ -271,22 +272,22 @@ struct SegmentHandleFromAddress : public boost::static_visitor<boost::interproce
|
||||||
const void* ptr;
|
const void* ptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct SegmentAddressFromHandle : public boost::static_visitor<void*>
|
struct SegmentAddressFromHandle : public boost::static_visitor<char*>
|
||||||
{
|
{
|
||||||
SegmentAddressFromHandle(const boost::interprocess::managed_shared_memory::handle_t _handle) : handle(_handle) {}
|
SegmentAddressFromHandle(const boost::interprocess::managed_shared_memory::handle_t _handle) : handle(_handle) {}
|
||||||
|
|
||||||
template<typename S>
|
template<typename S>
|
||||||
void* operator()(S& s) const { return s.get_address_from_handle(handle); }
|
char* operator()(S& s) const { return reinterpret_cast<char*>(s.get_address_from_handle(handle)); }
|
||||||
|
|
||||||
const boost::interprocess::managed_shared_memory::handle_t handle;
|
const boost::interprocess::managed_shared_memory::handle_t handle;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct SegmentAllocate : public boost::static_visitor<void*>
|
struct SegmentAllocate : public boost::static_visitor<char*>
|
||||||
{
|
{
|
||||||
SegmentAllocate(const size_t _size) : size(_size) {}
|
SegmentAllocate(const size_t _size) : size(_size) {}
|
||||||
|
|
||||||
template<typename S>
|
template<typename S>
|
||||||
void* operator()(S& s) const { return s.allocate(size); }
|
char* operator()(S& s) const { return reinterpret_cast<char*>(s.allocate(size)); }
|
||||||
|
|
||||||
const size_t size;
|
const size_t size;
|
||||||
};
|
};
|
||||||
|
@ -322,12 +323,12 @@ struct SegmentBufferShrink : public boost::static_visitor<char*>
|
||||||
|
|
||||||
struct SegmentDeallocate : public boost::static_visitor<>
|
struct SegmentDeallocate : public boost::static_visitor<>
|
||||||
{
|
{
|
||||||
SegmentDeallocate(void* _ptr) : ptr(_ptr) {}
|
SegmentDeallocate(char* _ptr) : ptr(_ptr) {}
|
||||||
|
|
||||||
template<typename S>
|
template<typename S>
|
||||||
void operator()(S& s) const { return s.deallocate(ptr); }
|
void operator()(S& s) const { return s.deallocate(ptr); }
|
||||||
|
|
||||||
void* ptr;
|
char* ptr;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace fair::mq::shmem
|
} // namespace fair::mq::shmem
|
||||||
|
|
|
@ -52,29 +52,77 @@
|
||||||
|
|
||||||
#include <unistd.h> // getuid
|
#include <unistd.h> // getuid
|
||||||
#include <sys/types.h> // getuid
|
#include <sys/types.h> // getuid
|
||||||
|
|
||||||
#include <sys/mman.h> // mlock
|
#include <sys/mman.h> // mlock
|
||||||
|
|
||||||
namespace fair::mq::shmem
|
namespace fair::mq::shmem
|
||||||
{
|
{
|
||||||
|
|
||||||
struct ShmPtr
|
// ShmHeader stores user buffer alignment and the reference count in the following structure:
|
||||||
|
// [HdrOffset(uint16_t)][Hdr alignment][Hdr][user buffer alignment][user buffer]
|
||||||
|
// The alignment of Hdr depends on the alignment of std::atomic and is stored in the first entry
|
||||||
|
struct ShmHeader
|
||||||
{
|
{
|
||||||
explicit ShmPtr(char* rPtr)
|
struct Hdr
|
||||||
: realPtr(rPtr)
|
|
||||||
{}
|
|
||||||
|
|
||||||
char* RealPtr()
|
|
||||||
{
|
{
|
||||||
return realPtr;
|
uint16_t userOffset;
|
||||||
|
std::atomic<uint16_t> refCount;
|
||||||
|
};
|
||||||
|
|
||||||
|
static Hdr* HdrPtr(char* ptr)
|
||||||
|
{
|
||||||
|
// [HdrOffset(uint16_t)][Hdr alignment][Hdr][user buffer alignment][user buffer]
|
||||||
|
// ^
|
||||||
|
return reinterpret_cast<Hdr*>(ptr + sizeof(uint16_t) + *(reinterpret_cast<uint16_t*>(ptr)));
|
||||||
}
|
}
|
||||||
|
|
||||||
char* UserPtr()
|
static uint16_t HdrPartSize() // [HdrOffset(uint16_t)][Hdr alignment][Hdr]
|
||||||
{
|
{
|
||||||
return realPtr + sizeof(uint16_t) + *(reinterpret_cast<uint16_t*>(realPtr));
|
// [HdrOffset(uint16_t)][Hdr alignment][Hdr][user buffer alignment][user buffer]
|
||||||
|
// <--------------------------------------->
|
||||||
|
return sizeof(uint16_t) + alignof(Hdr) + sizeof(Hdr);
|
||||||
}
|
}
|
||||||
|
|
||||||
char* realPtr;
|
static std::atomic<uint16_t>& RefCountPtr(char* ptr)
|
||||||
|
{
|
||||||
|
// get the ref count ptr from the Hdr
|
||||||
|
return HdrPtr(ptr)->refCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
static char* UserPtr(char* ptr)
|
||||||
|
{
|
||||||
|
// [HdrOffset(uint16_t)][Hdr alignment][Hdr][user buffer alignment][user buffer]
|
||||||
|
// ^
|
||||||
|
return ptr + HdrPartSize() + HdrPtr(ptr)->userOffset;
|
||||||
|
}
|
||||||
|
|
||||||
|
static uint16_t RefCount(char* ptr) { return RefCountPtr(ptr).load(); }
|
||||||
|
static uint16_t IncrementRefCount(char* ptr) { return RefCountPtr(ptr).fetch_add(1); }
|
||||||
|
static uint16_t DecrementRefCount(char* ptr) { return RefCountPtr(ptr).fetch_sub(1); }
|
||||||
|
|
||||||
|
static size_t FullSize(size_t size, size_t alignment)
|
||||||
|
{
|
||||||
|
// [HdrOffset(uint16_t)][Hdr alignment][Hdr][user buffer alignment][user buffer]
|
||||||
|
// <--------------------------------------------------------------------------->
|
||||||
|
return HdrPartSize() + alignment + size;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void Construct(char* ptr, size_t alignment)
|
||||||
|
{
|
||||||
|
// place the Hdr in the aligned location, fill it and store its offset to HdrOffset
|
||||||
|
|
||||||
|
// the address alignment should be at least 2
|
||||||
|
assert(reinterpret_cast<uintptr_t>(ptr) % 2 == 0);
|
||||||
|
|
||||||
|
// offset to the beginning of the Hdr. store it in the beginning
|
||||||
|
uint16_t hdrOffset = alignof(Hdr) - ((reinterpret_cast<uintptr_t>(ptr) + sizeof(uint16_t)) % alignof(Hdr));
|
||||||
|
memcpy(ptr, &hdrOffset, sizeof(hdrOffset));
|
||||||
|
|
||||||
|
// offset to the beginning of the user buffer, store in Hdr together with the ref count
|
||||||
|
uint16_t userOffset = alignment - ((reinterpret_cast<uintptr_t>(ptr) + HdrPartSize()) % alignment);
|
||||||
|
new(ptr + sizeof(uint16_t) + hdrOffset) Hdr{ userOffset, std::atomic<uint16_t>(1) };
|
||||||
|
}
|
||||||
|
|
||||||
|
static void Destruct(char* ptr) { RefCountPtr(ptr).~atomic(); }
|
||||||
};
|
};
|
||||||
|
|
||||||
class Manager
|
class Manager
|
||||||
|
@ -635,44 +683,35 @@ class Manager
|
||||||
{
|
{
|
||||||
return boost::apply_visitor(SegmentHandleFromAddress(ptr), fSegments.at(segmentId));
|
return boost::apply_visitor(SegmentHandleFromAddress(ptr), fSegments.at(segmentId));
|
||||||
}
|
}
|
||||||
void* GetAddressFromHandle(const boost::interprocess::managed_shared_memory::handle_t handle, uint16_t segmentId) const
|
char* GetAddressFromHandle(const boost::interprocess::managed_shared_memory::handle_t handle, uint16_t segmentId) const
|
||||||
{
|
{
|
||||||
return boost::apply_visitor(SegmentAddressFromHandle(handle), fSegments.at(segmentId));
|
return boost::apply_visitor(SegmentAddressFromHandle(handle), fSegments.at(segmentId));
|
||||||
}
|
}
|
||||||
|
|
||||||
ShmPtr Allocate(size_t size, size_t alignment = 0)
|
char* Allocate(size_t size, size_t alignment = 0)
|
||||||
{
|
{
|
||||||
alignment = std::max(alignment, alignof(std::max_align_t));
|
alignment = std::max(alignment, alignof(std::max_align_t));
|
||||||
|
|
||||||
char* ptr = nullptr;
|
char* ptr = nullptr;
|
||||||
// [offset(uint16_t)][alignment][buffer]
|
size_t fullSize = ShmHeader::FullSize(size, alignment);
|
||||||
size_t fullSize = sizeof(uint16_t) + alignment + size;
|
|
||||||
// tools::RateLimiter rateLimiter(20);
|
|
||||||
|
|
||||||
while (ptr == nullptr) {
|
while (ptr == nullptr) {
|
||||||
try {
|
try {
|
||||||
// boost::interprocess::managed_shared_memory::size_type actualSize = size;
|
|
||||||
// char* hint = 0; // unused for boost::interprocess::allocate_new
|
|
||||||
// ptr = fSegments.at(fSegmentId).allocation_command<char>(boost::interprocess::allocate_new, size, actualSize, hint);
|
|
||||||
size_t segmentSize = boost::apply_visitor(SegmentSize(), fSegments.at(fSegmentId));
|
size_t segmentSize = boost::apply_visitor(SegmentSize(), fSegments.at(fSegmentId));
|
||||||
if (fullSize > segmentSize) {
|
if (fullSize > segmentSize) {
|
||||||
throw MessageBadAlloc(tools::ToString("Requested message size (", fullSize, ") exceeds segment size (", segmentSize, ")"));
|
throw MessageBadAlloc(tools::ToString("Requested message size (", fullSize, ") exceeds segment size (", segmentSize, ")"));
|
||||||
}
|
}
|
||||||
|
|
||||||
ptr = reinterpret_cast<char*>(boost::apply_visitor(SegmentAllocate{fullSize}, fSegments.at(fSegmentId)));
|
ptr = boost::apply_visitor(SegmentAllocate{fullSize}, fSegments.at(fSegmentId));
|
||||||
assert(reinterpret_cast<uintptr_t>(ptr) % 2 == 0);
|
ShmHeader::Construct(ptr, alignment);
|
||||||
uint16_t offset = 0;
|
|
||||||
offset = alignment - ((reinterpret_cast<uintptr_t>(ptr) + sizeof(uint16_t)) % alignment);
|
|
||||||
std::memcpy(ptr, &offset, sizeof(offset));
|
|
||||||
} catch (boost::interprocess::bad_alloc& ba) {
|
} catch (boost::interprocess::bad_alloc& ba) {
|
||||||
// LOG(warn) << "Shared memory full...";
|
// LOG(warn) << "Shared memory full...";
|
||||||
if (ThrowingOnBadAlloc()) {
|
if (ThrowingOnBadAlloc()) {
|
||||||
throw MessageBadAlloc(tools::ToString("shmem: could not create a message of size ", size, ", alignment: ", (alignment != 0) ? std::to_string(alignment) : "default", ", free memory: ", boost::apply_visitor(SegmentFreeMemory(), fSegments.at(fSegmentId))));
|
throw MessageBadAlloc(tools::ToString("shmem: could not create a message of size ", size, ", alignment: ", (alignment != 0) ? std::to_string(alignment) : "default", ", free memory: ", boost::apply_visitor(SegmentFreeMemory(), fSegments.at(fSegmentId))));
|
||||||
}
|
}
|
||||||
// rateLimiter.maybe_sleep();
|
|
||||||
std::this_thread::sleep_for(std::chrono::milliseconds(50));
|
std::this_thread::sleep_for(std::chrono::milliseconds(50));
|
||||||
if (Interrupted()) {
|
if (Interrupted()) {
|
||||||
return ShmPtr(ptr);
|
throw MessageBadAlloc(tools::ToString("shmem: could not create a message of size ", size, ", alignment: ", (alignment != 0) ? std::to_string(alignment) : "default", ", free memory: ", boost::apply_visitor(SegmentFreeMemory(), fSegments.at(fSegmentId))));
|
||||||
} else {
|
} else {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -684,18 +723,20 @@ class Manager
|
||||||
(*fMsgDebug).emplace(fSegmentId, fShmVoidAlloc);
|
(*fMsgDebug).emplace(fSegmentId, fShmVoidAlloc);
|
||||||
}
|
}
|
||||||
(*fMsgDebug).at(fSegmentId).emplace(
|
(*fMsgDebug).at(fSegmentId).emplace(
|
||||||
static_cast<size_t>(GetHandleFromAddress(ShmPtr(ptr).UserPtr(), fSegmentId)),
|
static_cast<size_t>(GetHandleFromAddress(ShmHeader::UserPtr(ptr), fSegmentId)),
|
||||||
MsgDebug(getpid(), size, std::chrono::system_clock::now().time_since_epoch().count())
|
MsgDebug(getpid(), size, std::chrono::system_clock::now().time_since_epoch().count())
|
||||||
);
|
);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
return ShmPtr(ptr);
|
return ptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Deallocate(boost::interprocess::managed_shared_memory::handle_t handle, uint16_t segmentId)
|
void Deallocate(boost::interprocess::managed_shared_memory::handle_t handle, uint16_t segmentId)
|
||||||
{
|
{
|
||||||
boost::apply_visitor(SegmentDeallocate(GetAddressFromHandle(handle, segmentId)), fSegments.at(segmentId));
|
char* ptr = GetAddressFromHandle(handle, segmentId);
|
||||||
|
ShmHeader::Destruct(ptr);
|
||||||
|
boost::apply_visitor(SegmentDeallocate(ptr), fSegments.at(segmentId));
|
||||||
#ifdef FAIRMQ_DEBUG_MODE
|
#ifdef FAIRMQ_DEBUG_MODE
|
||||||
boost::interprocess::scoped_lock<boost::interprocess::named_mutex> lock(fShmMtx);
|
boost::interprocess::scoped_lock<boost::interprocess::named_mutex> lock(fShmMtx);
|
||||||
DecrementShmMsgCounter(segmentId);
|
DecrementShmMsgCounter(segmentId);
|
||||||
|
|
|
@ -38,7 +38,7 @@ class Message final : public fair::mq::Message
|
||||||
: fair::mq::Message(factory)
|
: fair::mq::Message(factory)
|
||||||
, fManager(manager)
|
, fManager(manager)
|
||||||
, fQueued(false)
|
, fQueued(false)
|
||||||
, fMeta{0, 0, 0, fManager.GetSegmentId(), -1}
|
, fMeta{0, 0, -1, -1, 0, fManager.GetSegmentId()}
|
||||||
, fRegionPtr(nullptr)
|
, fRegionPtr(nullptr)
|
||||||
, fLocalPtr(nullptr)
|
, fLocalPtr(nullptr)
|
||||||
{
|
{
|
||||||
|
@ -49,7 +49,7 @@ class Message final : public fair::mq::Message
|
||||||
: fair::mq::Message(factory)
|
: fair::mq::Message(factory)
|
||||||
, fManager(manager)
|
, fManager(manager)
|
||||||
, fQueued(false)
|
, fQueued(false)
|
||||||
, fMeta{0, 0, 0, fManager.GetSegmentId(), -1}
|
, fMeta{0, 0, -1, -1, 0, fManager.GetSegmentId()}
|
||||||
, fAlignment(alignment.alignment)
|
, fAlignment(alignment.alignment)
|
||||||
, fRegionPtr(nullptr)
|
, fRegionPtr(nullptr)
|
||||||
, fLocalPtr(nullptr)
|
, fLocalPtr(nullptr)
|
||||||
|
@ -61,7 +61,7 @@ class Message final : public fair::mq::Message
|
||||||
: fair::mq::Message(factory)
|
: fair::mq::Message(factory)
|
||||||
, fManager(manager)
|
, fManager(manager)
|
||||||
, fQueued(false)
|
, fQueued(false)
|
||||||
, fMeta{0, 0, 0, fManager.GetSegmentId(), -1}
|
, fMeta{0, 0, -1, -1, 0, fManager.GetSegmentId()}
|
||||||
, fRegionPtr(nullptr)
|
, fRegionPtr(nullptr)
|
||||||
, fLocalPtr(nullptr)
|
, fLocalPtr(nullptr)
|
||||||
{
|
{
|
||||||
|
@ -73,7 +73,7 @@ class Message final : public fair::mq::Message
|
||||||
: fair::mq::Message(factory)
|
: fair::mq::Message(factory)
|
||||||
, fManager(manager)
|
, fManager(manager)
|
||||||
, fQueued(false)
|
, fQueued(false)
|
||||||
, fMeta{0, 0, 0, fManager.GetSegmentId(), -1}
|
, fMeta{0, 0, -1, -1, 0, fManager.GetSegmentId()}
|
||||||
, fAlignment(alignment.alignment)
|
, fAlignment(alignment.alignment)
|
||||||
, fRegionPtr(nullptr)
|
, fRegionPtr(nullptr)
|
||||||
, fLocalPtr(nullptr)
|
, fLocalPtr(nullptr)
|
||||||
|
@ -86,7 +86,7 @@ class Message final : public fair::mq::Message
|
||||||
: fair::mq::Message(factory)
|
: fair::mq::Message(factory)
|
||||||
, fManager(manager)
|
, fManager(manager)
|
||||||
, fQueued(false)
|
, fQueued(false)
|
||||||
, fMeta{0, 0, 0, fManager.GetSegmentId(), -1}
|
, fMeta{0, 0, -1, -1, 0, fManager.GetSegmentId()}
|
||||||
, fRegionPtr(nullptr)
|
, fRegionPtr(nullptr)
|
||||||
, fLocalPtr(nullptr)
|
, fLocalPtr(nullptr)
|
||||||
{
|
{
|
||||||
|
@ -105,7 +105,7 @@ class Message final : public fair::mq::Message
|
||||||
: fair::mq::Message(factory)
|
: fair::mq::Message(factory)
|
||||||
, fManager(manager)
|
, fManager(manager)
|
||||||
, fQueued(false)
|
, fQueued(false)
|
||||||
, fMeta{size, reinterpret_cast<size_t>(hint), static_cast<UnmanagedRegion*>(region.get())->fRegionId, fManager.GetSegmentId(), -1}
|
, fMeta{size, reinterpret_cast<size_t>(hint), -1, -1, static_cast<UnmanagedRegion*>(region.get())->fRegionId, fManager.GetSegmentId()}
|
||||||
, fRegionPtr(nullptr)
|
, fRegionPtr(nullptr)
|
||||||
, fLocalPtr(static_cast<char*>(data))
|
, fLocalPtr(static_cast<char*>(data))
|
||||||
{
|
{
|
||||||
|
@ -187,8 +187,7 @@ class Message final : public fair::mq::Message
|
||||||
if (fMeta.fRegionId == 0) {
|
if (fMeta.fRegionId == 0) {
|
||||||
if (fMeta.fSize > 0) {
|
if (fMeta.fSize > 0) {
|
||||||
fManager.GetSegment(fMeta.fSegmentId);
|
fManager.GetSegment(fMeta.fSegmentId);
|
||||||
ShmPtr shmPtr(reinterpret_cast<char*>(fManager.GetAddressFromHandle(fMeta.fHandle, fMeta.fSegmentId)));
|
fLocalPtr = ShmHeader::UserPtr(fManager.GetAddressFromHandle(fMeta.fHandle, fMeta.fSegmentId));
|
||||||
fLocalPtr = shmPtr.UserPtr();
|
|
||||||
} else {
|
} else {
|
||||||
fLocalPtr = nullptr;
|
fLocalPtr = nullptr;
|
||||||
}
|
}
|
||||||
|
@ -218,8 +217,8 @@ class Message final : public fair::mq::Message
|
||||||
} else if (newSize <= fMeta.fSize) {
|
} else if (newSize <= fMeta.fSize) {
|
||||||
try {
|
try {
|
||||||
try {
|
try {
|
||||||
ShmPtr shmPtr(fManager.ShrinkInPlace(newSize, static_cast<char*>(fManager.GetAddressFromHandle(fMeta.fHandle, fMeta.fSegmentId)), fMeta.fSegmentId));
|
char* ptr = fManager.ShrinkInPlace(newSize, fManager.GetAddressFromHandle(fMeta.fHandle, fMeta.fSegmentId), fMeta.fSegmentId);
|
||||||
fLocalPtr = shmPtr.UserPtr();
|
fLocalPtr = ShmHeader::UserPtr(ptr);
|
||||||
fMeta.fSize = newSize;
|
fMeta.fSize = newSize;
|
||||||
return true;
|
return true;
|
||||||
} catch (boost::interprocess::bad_alloc& e) {
|
} catch (boost::interprocess::bad_alloc& e) {
|
||||||
|
@ -227,17 +226,12 @@ class Message final : public fair::mq::Message
|
||||||
// unused size >= 1000000 bytes: reallocate fully
|
// unused size >= 1000000 bytes: reallocate fully
|
||||||
// unused size < 1000000 bytes: simply reset the size and keep the rest of the buffer until message destruction
|
// unused size < 1000000 bytes: simply reset the size and keep the rest of the buffer until message destruction
|
||||||
if (fMeta.fSize - newSize >= 1000000) {
|
if (fMeta.fSize - newSize >= 1000000) {
|
||||||
ShmPtr shmPtr = fManager.Allocate(newSize, fAlignment);
|
char* ptr = fManager.Allocate(newSize, fAlignment);
|
||||||
if (shmPtr.RealPtr()) {
|
char* userPtr = ShmHeader::UserPtr(ptr);
|
||||||
char* userPtr = shmPtr.UserPtr();
|
std::memcpy(userPtr, fLocalPtr, newSize);
|
||||||
std::memcpy(userPtr, fLocalPtr, newSize);
|
fManager.Deallocate(fMeta.fHandle, fMeta.fSegmentId);
|
||||||
fManager.Deallocate(fMeta.fHandle, fMeta.fSegmentId);
|
fLocalPtr = userPtr;
|
||||||
fLocalPtr = userPtr;
|
fMeta.fHandle = fManager.GetHandleFromAddress(ptr, fMeta.fSegmentId);
|
||||||
fMeta.fHandle = fManager.GetHandleFromAddress(shmPtr.RealPtr(), fMeta.fSegmentId);
|
|
||||||
} else {
|
|
||||||
LOG(debug) << "could not set used size: " << e.what();
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
fMeta.fSize = newSize;
|
fMeta.fSize = newSize;
|
||||||
return true;
|
return true;
|
||||||
|
@ -254,33 +248,65 @@ class Message final : public fair::mq::Message
|
||||||
|
|
||||||
Transport GetType() const override { return fair::mq::Transport::SHM; }
|
Transport GetType() const override { return fair::mq::Transport::SHM; }
|
||||||
|
|
||||||
void Copy(const fair::mq::Message& msg) override
|
uint16_t GetRefCount() const
|
||||||
{
|
{
|
||||||
if (fMeta.fHandle < 0) {
|
if (fMeta.fHandle < 0) {
|
||||||
boost::interprocess::managed_shared_memory::handle_t otherHandle = static_cast<const Message&>(msg).fMeta.fHandle;
|
return 1;
|
||||||
if (otherHandle) {
|
}
|
||||||
if (InitializeChunk(msg.GetSize())) {
|
|
||||||
std::memcpy(GetData(), msg.GetData(), msg.GetSize());
|
if (fMeta.fRegionId == 0) { // managed segment
|
||||||
}
|
fManager.GetSegment(fMeta.fSegmentId);
|
||||||
|
return ShmHeader::RefCount(fManager.GetAddressFromHandle(fMeta.fHandle, fMeta.fSegmentId));
|
||||||
|
} else { // unmanaged region
|
||||||
|
if (fMeta.fShared < 0) { // UR msg is not yet shared
|
||||||
|
return 1;
|
||||||
} else {
|
} else {
|
||||||
LOG(error) << "copy fail: source message not initialized!";
|
fManager.GetSegment(fMeta.fSegmentId);
|
||||||
|
return ShmHeader::RefCount(fManager.GetAddressFromHandle(fMeta.fShared, fMeta.fSegmentId));
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
LOG(error) << "copy fail: target message already initialized!";
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
~Message() override
|
void Copy(const fair::mq::Message& other) override
|
||||||
{
|
{
|
||||||
try {
|
const Message& otherMsg = static_cast<const Message&>(other);
|
||||||
|
if (otherMsg.fMeta.fHandle < 0) {
|
||||||
|
// if the other message is not initialized, close this one too and return
|
||||||
CloseMessage();
|
CloseMessage();
|
||||||
} catch(SharedMemoryError& sme) {
|
return;
|
||||||
LOG(error) << "error closing message: " << sme.what();
|
}
|
||||||
} catch(boost::interprocess::lock_exception& le) {
|
|
||||||
LOG(error) << "error closing message: " << le.what();
|
if (fMeta.fHandle >= 0) {
|
||||||
|
// if this msg is already initialized, close it first
|
||||||
|
CloseMessage();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (otherMsg.fMeta.fRegionId == 0) { // managed segment
|
||||||
|
fMeta = otherMsg.fMeta;
|
||||||
|
fManager.GetSegment(fMeta.fSegmentId);
|
||||||
|
ShmHeader::IncrementRefCount(fManager.GetAddressFromHandle(fMeta.fHandle, fMeta.fSegmentId));
|
||||||
|
} else { // unmanaged region
|
||||||
|
if (otherMsg.fMeta.fShared < 0) { // if UR msg is not yet shared
|
||||||
|
// TODO: minimize the size to 0 and don't create extra space for user buffer alignment
|
||||||
|
char* ptr = fManager.Allocate(2, 0);
|
||||||
|
// point the fShared in the unmanaged region message to the refCount holder
|
||||||
|
otherMsg.fMeta.fShared = fManager.GetHandleFromAddress(ptr, fMeta.fSegmentId);
|
||||||
|
// the message needs to be able to locate in which segment the refCount is stored
|
||||||
|
otherMsg.fMeta.fSegmentId = fMeta.fSegmentId;
|
||||||
|
// point this message to the same content as the unmanaged region message
|
||||||
|
fMeta = otherMsg.fMeta;
|
||||||
|
// increment the refCount
|
||||||
|
ShmHeader::IncrementRefCount(ptr);
|
||||||
|
} else { // if the UR msg is already shared
|
||||||
|
fMeta = otherMsg.fMeta;
|
||||||
|
fManager.GetSegment(fMeta.fSegmentId);
|
||||||
|
ShmHeader::IncrementRefCount(fManager.GetAddressFromHandle(fMeta.fShared, fMeta.fSegmentId));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
~Message() override { CloseMessage(); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Manager& fManager;
|
Manager& fManager;
|
||||||
bool fQueued;
|
bool fQueued;
|
||||||
|
@ -291,44 +317,70 @@ class Message final : public fair::mq::Message
|
||||||
|
|
||||||
char* InitializeChunk(const size_t size, size_t alignment = 0)
|
char* InitializeChunk(const size_t size, size_t alignment = 0)
|
||||||
{
|
{
|
||||||
ShmPtr shmPtr = fManager.Allocate(size, alignment);
|
if (size == 0) {
|
||||||
if (shmPtr.RealPtr()) {
|
fMeta.fSize = 0;
|
||||||
fMeta.fHandle = fManager.GetHandleFromAddress(shmPtr.RealPtr(), fMeta.fSegmentId);
|
return fLocalPtr;
|
||||||
fMeta.fSize = size;
|
|
||||||
fLocalPtr = shmPtr.UserPtr();
|
|
||||||
}
|
}
|
||||||
|
char* ptr = fManager.Allocate(size, alignment);
|
||||||
|
fMeta.fHandle = fManager.GetHandleFromAddress(ptr, fMeta.fSegmentId);
|
||||||
|
fMeta.fSize = size;
|
||||||
|
fLocalPtr = ShmHeader::UserPtr(ptr);
|
||||||
return fLocalPtr;
|
return fLocalPtr;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Deallocate()
|
void Deallocate()
|
||||||
{
|
{
|
||||||
if (fMeta.fHandle >= 0 && !fQueued) {
|
if (fMeta.fHandle >= 0 && !fQueued) {
|
||||||
if (fMeta.fRegionId == 0) {
|
if (fMeta.fRegionId == 0) { // managed segment
|
||||||
fManager.GetSegment(fMeta.fSegmentId);
|
fManager.GetSegment(fMeta.fSegmentId);
|
||||||
fManager.Deallocate(fMeta.fHandle, fMeta.fSegmentId);
|
uint16_t refCount = ShmHeader::DecrementRefCount(fManager.GetAddressFromHandle(fMeta.fHandle, fMeta.fSegmentId));
|
||||||
fMeta.fHandle = -1;
|
if (refCount == 1) {
|
||||||
} else {
|
fManager.Deallocate(fMeta.fHandle, fMeta.fSegmentId);
|
||||||
if (!fRegionPtr) {
|
|
||||||
fRegionPtr = fManager.GetRegion(fMeta.fRegionId);
|
|
||||||
}
|
}
|
||||||
|
} else { // unmanaged region
|
||||||
if (fRegionPtr) {
|
if (fMeta.fShared >= 0) {
|
||||||
fRegionPtr->ReleaseBlock({fMeta.fHandle, fMeta.fSize, fMeta.fHint});
|
// make sure segment is initialized in this transport
|
||||||
|
fManager.GetSegment(fMeta.fSegmentId);
|
||||||
|
// release unmanaged region block if ref count is one
|
||||||
|
uint16_t refCount = ShmHeader::DecrementRefCount(fManager.GetAddressFromHandle(fMeta.fShared, fMeta.fSegmentId));
|
||||||
|
if (refCount == 1) {
|
||||||
|
fManager.Deallocate(fMeta.fShared, fMeta.fSegmentId);
|
||||||
|
ReleaseUnmanagedRegionBlock();
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
LOG(warn) << "region ack queue for id " << fMeta.fRegionId << " no longer exist. Not sending ack";
|
ReleaseUnmanagedRegionBlock();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
fMeta.fHandle = -1;
|
||||||
fLocalPtr = nullptr;
|
fLocalPtr = nullptr;
|
||||||
fMeta.fSize = 0;
|
fMeta.fSize = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ReleaseUnmanagedRegionBlock()
|
||||||
|
{
|
||||||
|
if (!fRegionPtr) {
|
||||||
|
fRegionPtr = fManager.GetRegion(fMeta.fRegionId);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (fRegionPtr) {
|
||||||
|
fRegionPtr->ReleaseBlock({fMeta.fHandle, fMeta.fSize, fMeta.fHint});
|
||||||
|
} else {
|
||||||
|
LOG(warn) << "region ack queue for id " << fMeta.fRegionId << " no longer exist. Not sending ack";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void CloseMessage()
|
void CloseMessage()
|
||||||
{
|
{
|
||||||
Deallocate();
|
try {
|
||||||
fAlignment = 0;
|
Deallocate();
|
||||||
|
fAlignment = 0;
|
||||||
fManager.DecrementMsgCounter();
|
fManager.DecrementMsgCounter();
|
||||||
|
} catch(SharedMemoryError& sme) {
|
||||||
|
LOG(error) << "error closing message: " << sme.what();
|
||||||
|
} catch(boost::interprocess::lock_exception& le) {
|
||||||
|
LOG(error) << "error closing message: " << le.what();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -89,7 +89,7 @@ add_testsuite(Message
|
||||||
${CMAKE_CURRENT_BINARY_DIR}/runner.cxx
|
${CMAKE_CURRENT_BINARY_DIR}/runner.cxx
|
||||||
message/_message.cxx
|
message/_message.cxx
|
||||||
|
|
||||||
LINKS FairMQ
|
LINKS FairMQ PicoSHA2
|
||||||
INCLUDES ${CMAKE_CURRENT_SOURCE_DIR}
|
INCLUDES ${CMAKE_CURRENT_SOURCE_DIR}
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/message
|
${CMAKE_CURRENT_SOURCE_DIR}/message
|
||||||
${CMAKE_CURRENT_BINARY_DIR}
|
${CMAKE_CURRENT_BINARY_DIR}
|
||||||
|
|
|
@ -6,19 +6,23 @@
|
||||||
* copied verbatim in the file "LICENSE" *
|
* copied verbatim in the file "LICENSE" *
|
||||||
********************************************************************************/
|
********************************************************************************/
|
||||||
|
|
||||||
#include <array>
|
|
||||||
#include <cassert>
|
|
||||||
#include <cstdint>
|
|
||||||
#include <fairlogger/Logger.h>
|
#include <fairlogger/Logger.h>
|
||||||
#include <fairmq/Channel.h>
|
#include <fairmq/Channel.h>
|
||||||
#include <fairmq/ProgOptions.h>
|
#include <fairmq/ProgOptions.h>
|
||||||
#include <fairmq/TransportFactory.h>
|
#include <fairmq/tools/Semaphore.h>
|
||||||
#include <fairmq/tools/Strings.h>
|
#include <fairmq/tools/Strings.h>
|
||||||
#include <fairmq/tools/Unique.h>
|
#include <fairmq/tools/Unique.h>
|
||||||
|
#include <fairmq/TransportFactory.h>
|
||||||
|
#include <fairmq/shmem/Message.h>
|
||||||
|
|
||||||
#include <gtest/gtest.h>
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <array>
|
||||||
|
#include <cassert>
|
||||||
|
#include <cstdint>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
|
||||||
#include <string_view>
|
#include <string_view>
|
||||||
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
namespace
|
namespace
|
||||||
|
@ -190,7 +194,6 @@ auto EmptyMessage(string const& transport, string const& _address) -> void
|
||||||
push.Bind(address);
|
push.Bind(address);
|
||||||
pull.Connect(address);
|
pull.Connect(address);
|
||||||
|
|
||||||
|
|
||||||
{
|
{
|
||||||
auto outMsg(push.NewMessage());
|
auto outMsg(push.NewMessage());
|
||||||
ASSERT_EQ(outMsg->GetData(), nullptr);
|
ASSERT_EQ(outMsg->GetData(), nullptr);
|
||||||
|
@ -227,6 +230,129 @@ auto EmptyMessage(string const& transport, string const& _address) -> void
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The "zero copy" property of the Copy() method is an implementation detail and is not guaranteed.
|
||||||
|
// Currently it holds true for the shmem (across devices) and for zeromq (within same device) transports.
|
||||||
|
auto ZeroCopy() -> void
|
||||||
|
{
|
||||||
|
ProgOptions config;
|
||||||
|
config.SetProperty<string>("session", tools::Uuid());
|
||||||
|
auto factory(TransportFactory::CreateTransportFactory("shmem", tools::Uuid(), &config));
|
||||||
|
|
||||||
|
unique_ptr<string> str(make_unique<string>("asdf"));
|
||||||
|
const size_t size = 2;
|
||||||
|
MessagePtr original(factory->CreateMessage(size));
|
||||||
|
memcpy(original->GetData(), "AB", size);
|
||||||
|
{
|
||||||
|
MessagePtr copy(factory->CreateMessage());
|
||||||
|
copy->Copy(*original);
|
||||||
|
EXPECT_EQ(original->GetSize(), copy->GetSize());
|
||||||
|
EXPECT_EQ(original->GetData(), copy->GetData());
|
||||||
|
EXPECT_EQ(static_cast<const shmem::Message&>(*original).GetRefCount(), 2);
|
||||||
|
EXPECT_EQ(static_cast<const shmem::Message&>(*copy).GetRefCount(), 2);
|
||||||
|
|
||||||
|
// buffer must be still intact
|
||||||
|
ASSERT_EQ(AsStringView(*original)[0], 'A');
|
||||||
|
ASSERT_EQ(AsStringView(*original)[1], 'B');
|
||||||
|
ASSERT_EQ(AsStringView(*copy)[0], 'A');
|
||||||
|
ASSERT_EQ(AsStringView(*copy)[1], 'B');
|
||||||
|
}
|
||||||
|
EXPECT_EQ(static_cast<const shmem::Message&>(*original).GetRefCount(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// The "zero copy" property of the Copy() method is an implementation detail and is not guaranteed.
|
||||||
|
// Currently it holds true for the shmem (across devices) and for zeromq (within same device) transports.
|
||||||
|
auto ZeroCopyFromUnmanaged(string const& address) -> void
|
||||||
|
{
|
||||||
|
ProgOptions config1;
|
||||||
|
ProgOptions config2;
|
||||||
|
string session(tools::Uuid());
|
||||||
|
config1.SetProperty<string>("session", session);
|
||||||
|
config2.SetProperty<string>("session", session);
|
||||||
|
// ref counts should be accessible accross different segments
|
||||||
|
config2.SetProperty<uint16_t>("shm-segment-id", 2);
|
||||||
|
auto factory1(TransportFactory::CreateTransportFactory("shmem", tools::Uuid(), &config1));
|
||||||
|
auto factory2(TransportFactory::CreateTransportFactory("shmem", tools::Uuid(), &config2));
|
||||||
|
|
||||||
|
const size_t msgSize{100};
|
||||||
|
const size_t regionSize{1000000};
|
||||||
|
tools::Semaphore blocker;
|
||||||
|
|
||||||
|
auto region = factory1->CreateUnmanagedRegion(regionSize, [&blocker](void*, size_t, void*) {
|
||||||
|
blocker.Signal();
|
||||||
|
});
|
||||||
|
|
||||||
|
{
|
||||||
|
FairMQChannel push("Push", "push", factory1);
|
||||||
|
FairMQChannel pull("Pull", "pull", factory2);
|
||||||
|
|
||||||
|
push.Bind(address);
|
||||||
|
pull.Connect(address);
|
||||||
|
|
||||||
|
const size_t offset = 100;
|
||||||
|
auto msg1(push.NewMessage(region, static_cast<char*>(region->GetData()), msgSize, nullptr));
|
||||||
|
auto msg2(push.NewMessage(region, static_cast<char*>(region->GetData()) + offset, msgSize, nullptr));
|
||||||
|
const size_t contentSize = 2;
|
||||||
|
memcpy(msg1->GetData(), "AB", contentSize);
|
||||||
|
memcpy(msg2->GetData(), "CD", contentSize);
|
||||||
|
EXPECT_EQ(static_cast<const shmem::Message&>(*msg1).GetRefCount(), 1);
|
||||||
|
|
||||||
|
{
|
||||||
|
auto copyFromOriginal(push.NewMessage());
|
||||||
|
copyFromOriginal->Copy(*msg1);
|
||||||
|
EXPECT_EQ(static_cast<const shmem::Message&>(*msg1).GetRefCount(), 2);
|
||||||
|
EXPECT_EQ(static_cast<const shmem::Message&>(*msg1).GetRefCount(), static_cast<const shmem::Message&>(*copyFromOriginal).GetRefCount());
|
||||||
|
{
|
||||||
|
auto copyFromCopy(push.NewMessage());
|
||||||
|
copyFromCopy->Copy(*copyFromOriginal);
|
||||||
|
EXPECT_EQ(static_cast<const shmem::Message&>(*msg1).GetRefCount(), 3);
|
||||||
|
EXPECT_EQ(static_cast<const shmem::Message&>(*msg1).GetRefCount(), static_cast<const shmem::Message&>(*copyFromCopy).GetRefCount());
|
||||||
|
|
||||||
|
EXPECT_EQ(msg1->GetSize(), copyFromOriginal->GetSize());
|
||||||
|
EXPECT_EQ(msg1->GetData(), copyFromOriginal->GetData());
|
||||||
|
EXPECT_EQ(msg1->GetSize(), copyFromCopy->GetSize());
|
||||||
|
EXPECT_EQ(msg1->GetData(), copyFromCopy->GetData());
|
||||||
|
EXPECT_EQ(copyFromOriginal->GetSize(), copyFromCopy->GetSize());
|
||||||
|
EXPECT_EQ(copyFromOriginal->GetData(), copyFromCopy->GetData());
|
||||||
|
|
||||||
|
// messing with the ref count should not have affected the user buffer
|
||||||
|
ASSERT_EQ(AsStringView(*msg1)[0], 'A');
|
||||||
|
ASSERT_EQ(AsStringView(*msg1)[1], 'B');
|
||||||
|
|
||||||
|
push.Send(copyFromCopy);
|
||||||
|
push.Send(msg2);
|
||||||
|
|
||||||
|
auto incomingCopiedMsg(pull.NewMessage());
|
||||||
|
auto incomingOriginalMsg(pull.NewMessage());
|
||||||
|
pull.Receive(incomingCopiedMsg);
|
||||||
|
pull.Receive(incomingOriginalMsg);
|
||||||
|
|
||||||
|
EXPECT_EQ(static_cast<const shmem::Message&>(*incomingCopiedMsg).GetRefCount(), 3);
|
||||||
|
EXPECT_EQ(static_cast<const shmem::Message&>(*incomingOriginalMsg).GetRefCount(), 1);
|
||||||
|
|
||||||
|
ASSERT_EQ(AsStringView(*incomingCopiedMsg)[0], 'A');
|
||||||
|
ASSERT_EQ(AsStringView(*incomingCopiedMsg)[1], 'B');
|
||||||
|
|
||||||
|
{
|
||||||
|
// copying on a different segment should work
|
||||||
|
auto copyFromIncoming(pull.NewMessage());
|
||||||
|
copyFromIncoming->Copy(*incomingOriginalMsg);
|
||||||
|
EXPECT_EQ(static_cast<const shmem::Message&>(*copyFromIncoming).GetRefCount(), 2);
|
||||||
|
|
||||||
|
ASSERT_EQ(AsStringView(*incomingOriginalMsg)[0], 'C');
|
||||||
|
ASSERT_EQ(AsStringView(*incomingOriginalMsg)[1], 'D');
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPECT_EQ(static_cast<const shmem::Message&>(*incomingOriginalMsg).GetRefCount(), 1);
|
||||||
|
}
|
||||||
|
EXPECT_EQ(static_cast<const shmem::Message&>(*msg1).GetRefCount(), 2);
|
||||||
|
}
|
||||||
|
EXPECT_EQ(static_cast<const shmem::Message&>(*msg1).GetRefCount(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
blocker.Wait();
|
||||||
|
blocker.Wait();
|
||||||
|
}
|
||||||
|
|
||||||
TEST(Resize, zeromq) // NOLINT
|
TEST(Resize, zeromq) // NOLINT
|
||||||
{
|
{
|
||||||
RunPushPullWithMsgResize("zeromq", "ipc://test_message_resize");
|
RunPushPullWithMsgResize("zeromq", "ipc://test_message_resize");
|
||||||
|
@ -267,4 +393,14 @@ TEST(EmptyMessage, shmem) // NOLINT
|
||||||
EmptyMessage("shmem", "ipc://test_empty_message");
|
EmptyMessage("shmem", "ipc://test_empty_message");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(ZeroCopy, shmem) // NOLINT
|
||||||
|
{
|
||||||
|
ZeroCopy();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ZeroCopyFromUnmanaged, shmem) // NOLINT
|
||||||
|
{
|
||||||
|
ZeroCopyFromUnmanaged("ipc://test_zerocopy_unmanaged");
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
|
@ -199,7 +199,6 @@ void RegionCallbacks(const string& transport, const string& _address)
|
||||||
});
|
});
|
||||||
ptr2 = region2->GetData();
|
ptr2 = region2->GetData();
|
||||||
|
|
||||||
|
|
||||||
{
|
{
|
||||||
FairMQMessagePtr msg1out(push.NewMessage(region1, ptr1, size1, intPtr1.get()));
|
FairMQMessagePtr msg1out(push.NewMessage(region1, ptr1, size1, intPtr1.get()));
|
||||||
FairMQMessagePtr msg2out(push.NewMessage(region2, ptr2, size2, intPtr2.get()));
|
FairMQMessagePtr msg2out(push.NewMessage(region2, ptr2, size2, intPtr2.get()));
|
||||||
|
|
Loading…
Reference in New Issue
Block a user