Implement shmem msg zero-copy

This commit is contained in:
Alexey Rybalchenko 2021-07-14 10:46:12 +02:00
parent 5d980723d8
commit 1b3f38b6f1
7 changed files with 333 additions and 100 deletions

View File

@ -47,6 +47,10 @@ struct Message
TransportFactory* GetTransport() { return fTransport; }
void SetTransport(TransportFactory* transport) { fTransport = transport; }
/// Copy the message buffer from another message
/// Transport may choose not to physically copy the buffer, but to share across the messages.
/// Modifying the buffer after a call to Copy() is undefined behaviour.
/// @param msg message to copy the buffer from.
virtual void Copy(const Message& msg) = 0;
virtual ~Message() = default;

View File

@ -146,9 +146,10 @@ struct MetaHeader
{
size_t fSize;
size_t fHint;
uint16_t fRegionId;
uint16_t fSegmentId;
boost::interprocess::managed_shared_memory::handle_t fHandle;
mutable boost::interprocess::managed_shared_memory::handle_t fShared;
uint16_t fRegionId;
mutable uint16_t fSegmentId;
};
#ifdef FAIRMQ_DEBUG_MODE
@ -271,22 +272,22 @@ struct SegmentHandleFromAddress : public boost::static_visitor<boost::interproce
const void* ptr;
};
struct SegmentAddressFromHandle : public boost::static_visitor<void*>
struct SegmentAddressFromHandle : public boost::static_visitor<char*>
{
SegmentAddressFromHandle(const boost::interprocess::managed_shared_memory::handle_t _handle) : handle(_handle) {}
template<typename S>
void* operator()(S& s) const { return s.get_address_from_handle(handle); }
char* operator()(S& s) const { return reinterpret_cast<char*>(s.get_address_from_handle(handle)); }
const boost::interprocess::managed_shared_memory::handle_t handle;
};
struct SegmentAllocate : public boost::static_visitor<void*>
struct SegmentAllocate : public boost::static_visitor<char*>
{
SegmentAllocate(const size_t _size) : size(_size) {}
template<typename S>
void* operator()(S& s) const { return s.allocate(size); }
char* operator()(S& s) const { return reinterpret_cast<char*>(s.allocate(size)); }
const size_t size;
};
@ -322,12 +323,12 @@ struct SegmentBufferShrink : public boost::static_visitor<char*>
struct SegmentDeallocate : public boost::static_visitor<>
{
SegmentDeallocate(void* _ptr) : ptr(_ptr) {}
SegmentDeallocate(char* _ptr) : ptr(_ptr) {}
template<typename S>
void operator()(S& s) const { return s.deallocate(ptr); }
void* ptr;
char* ptr;
};
} // namespace fair::mq::shmem

View File

@ -52,29 +52,77 @@
#include <unistd.h> // getuid
#include <sys/types.h> // getuid
#include <sys/mman.h> // mlock
namespace fair::mq::shmem
{
struct ShmPtr
// ShmHeader stores user buffer alignment and the reference count in the following structure:
// [HdrOffset(uint16_t)][Hdr alignment][Hdr][user buffer alignment][user buffer]
// The alignment of Hdr depends on the alignment of std::atomic and is stored in the first entry
struct ShmHeader
{
explicit ShmPtr(char* rPtr)
: realPtr(rPtr)
{}
char* RealPtr()
struct Hdr
{
return realPtr;
uint16_t userOffset;
std::atomic<uint16_t> refCount;
};
static Hdr* HdrPtr(char* ptr)
{
// [HdrOffset(uint16_t)][Hdr alignment][Hdr][user buffer alignment][user buffer]
// ^
return reinterpret_cast<Hdr*>(ptr + sizeof(uint16_t) + *(reinterpret_cast<uint16_t*>(ptr)));
}
char* UserPtr()
static uint16_t HdrPartSize() // [HdrOffset(uint16_t)][Hdr alignment][Hdr]
{
return realPtr + sizeof(uint16_t) + *(reinterpret_cast<uint16_t*>(realPtr));
// [HdrOffset(uint16_t)][Hdr alignment][Hdr][user buffer alignment][user buffer]
// <--------------------------------------->
return sizeof(uint16_t) + alignof(Hdr) + sizeof(Hdr);
}
char* realPtr;
static std::atomic<uint16_t>& RefCountPtr(char* ptr)
{
// get the ref count ptr from the Hdr
return HdrPtr(ptr)->refCount;
}
static char* UserPtr(char* ptr)
{
// [HdrOffset(uint16_t)][Hdr alignment][Hdr][user buffer alignment][user buffer]
// ^
return ptr + HdrPartSize() + HdrPtr(ptr)->userOffset;
}
static uint16_t RefCount(char* ptr) { return RefCountPtr(ptr).load(); }
static uint16_t IncrementRefCount(char* ptr) { return RefCountPtr(ptr).fetch_add(1); }
static uint16_t DecrementRefCount(char* ptr) { return RefCountPtr(ptr).fetch_sub(1); }
static size_t FullSize(size_t size, size_t alignment)
{
// [HdrOffset(uint16_t)][Hdr alignment][Hdr][user buffer alignment][user buffer]
// <--------------------------------------------------------------------------->
return HdrPartSize() + alignment + size;
}
static void Construct(char* ptr, size_t alignment)
{
// place the Hdr in the aligned location, fill it and store its offset to HdrOffset
// the address alignment should be at least 2
assert(reinterpret_cast<uintptr_t>(ptr) % 2 == 0);
// offset to the beginning of the Hdr. store it in the beginning
uint16_t hdrOffset = alignof(Hdr) - ((reinterpret_cast<uintptr_t>(ptr) + sizeof(uint16_t)) % alignof(Hdr));
memcpy(ptr, &hdrOffset, sizeof(hdrOffset));
// offset to the beginning of the user buffer, store in Hdr together with the ref count
uint16_t userOffset = alignment - ((reinterpret_cast<uintptr_t>(ptr) + HdrPartSize()) % alignment);
new(ptr + sizeof(uint16_t) + hdrOffset) Hdr{ userOffset, std::atomic<uint16_t>(1) };
}
static void Destruct(char* ptr) { RefCountPtr(ptr).~atomic(); }
};
class Manager
@ -635,44 +683,35 @@ class Manager
{
return boost::apply_visitor(SegmentHandleFromAddress(ptr), fSegments.at(segmentId));
}
void* GetAddressFromHandle(const boost::interprocess::managed_shared_memory::handle_t handle, uint16_t segmentId) const
char* GetAddressFromHandle(const boost::interprocess::managed_shared_memory::handle_t handle, uint16_t segmentId) const
{
return boost::apply_visitor(SegmentAddressFromHandle(handle), fSegments.at(segmentId));
}
ShmPtr Allocate(size_t size, size_t alignment = 0)
char* Allocate(size_t size, size_t alignment = 0)
{
alignment = std::max(alignment, alignof(std::max_align_t));
char* ptr = nullptr;
// [offset(uint16_t)][alignment][buffer]
size_t fullSize = sizeof(uint16_t) + alignment + size;
// tools::RateLimiter rateLimiter(20);
size_t fullSize = ShmHeader::FullSize(size, alignment);
while (ptr == nullptr) {
try {
// boost::interprocess::managed_shared_memory::size_type actualSize = size;
// char* hint = 0; // unused for boost::interprocess::allocate_new
// ptr = fSegments.at(fSegmentId).allocation_command<char>(boost::interprocess::allocate_new, size, actualSize, hint);
size_t segmentSize = boost::apply_visitor(SegmentSize(), fSegments.at(fSegmentId));
if (fullSize > segmentSize) {
throw MessageBadAlloc(tools::ToString("Requested message size (", fullSize, ") exceeds segment size (", segmentSize, ")"));
}
ptr = reinterpret_cast<char*>(boost::apply_visitor(SegmentAllocate{fullSize}, fSegments.at(fSegmentId)));
assert(reinterpret_cast<uintptr_t>(ptr) % 2 == 0);
uint16_t offset = 0;
offset = alignment - ((reinterpret_cast<uintptr_t>(ptr) + sizeof(uint16_t)) % alignment);
std::memcpy(ptr, &offset, sizeof(offset));
ptr = boost::apply_visitor(SegmentAllocate{fullSize}, fSegments.at(fSegmentId));
ShmHeader::Construct(ptr, alignment);
} catch (boost::interprocess::bad_alloc& ba) {
// LOG(warn) << "Shared memory full...";
if (ThrowingOnBadAlloc()) {
throw MessageBadAlloc(tools::ToString("shmem: could not create a message of size ", size, ", alignment: ", (alignment != 0) ? std::to_string(alignment) : "default", ", free memory: ", boost::apply_visitor(SegmentFreeMemory(), fSegments.at(fSegmentId))));
}
// rateLimiter.maybe_sleep();
std::this_thread::sleep_for(std::chrono::milliseconds(50));
if (Interrupted()) {
return ShmPtr(ptr);
throw MessageBadAlloc(tools::ToString("shmem: could not create a message of size ", size, ", alignment: ", (alignment != 0) ? std::to_string(alignment) : "default", ", free memory: ", boost::apply_visitor(SegmentFreeMemory(), fSegments.at(fSegmentId))));
} else {
continue;
}
@ -684,18 +723,20 @@ class Manager
(*fMsgDebug).emplace(fSegmentId, fShmVoidAlloc);
}
(*fMsgDebug).at(fSegmentId).emplace(
static_cast<size_t>(GetHandleFromAddress(ShmPtr(ptr).UserPtr(), fSegmentId)),
static_cast<size_t>(GetHandleFromAddress(ShmHeader::UserPtr(ptr), fSegmentId)),
MsgDebug(getpid(), size, std::chrono::system_clock::now().time_since_epoch().count())
);
#endif
}
return ShmPtr(ptr);
return ptr;
}
void Deallocate(boost::interprocess::managed_shared_memory::handle_t handle, uint16_t segmentId)
{
boost::apply_visitor(SegmentDeallocate(GetAddressFromHandle(handle, segmentId)), fSegments.at(segmentId));
char* ptr = GetAddressFromHandle(handle, segmentId);
ShmHeader::Destruct(ptr);
boost::apply_visitor(SegmentDeallocate(ptr), fSegments.at(segmentId));
#ifdef FAIRMQ_DEBUG_MODE
boost::interprocess::scoped_lock<boost::interprocess::named_mutex> lock(fShmMtx);
DecrementShmMsgCounter(segmentId);

View File

@ -38,7 +38,7 @@ class Message final : public fair::mq::Message
: fair::mq::Message(factory)
, fManager(manager)
, fQueued(false)
, fMeta{0, 0, 0, fManager.GetSegmentId(), -1}
, fMeta{0, 0, -1, -1, 0, fManager.GetSegmentId()}
, fRegionPtr(nullptr)
, fLocalPtr(nullptr)
{
@ -49,7 +49,7 @@ class Message final : public fair::mq::Message
: fair::mq::Message(factory)
, fManager(manager)
, fQueued(false)
, fMeta{0, 0, 0, fManager.GetSegmentId(), -1}
, fMeta{0, 0, -1, -1, 0, fManager.GetSegmentId()}
, fAlignment(alignment.alignment)
, fRegionPtr(nullptr)
, fLocalPtr(nullptr)
@ -61,7 +61,7 @@ class Message final : public fair::mq::Message
: fair::mq::Message(factory)
, fManager(manager)
, fQueued(false)
, fMeta{0, 0, 0, fManager.GetSegmentId(), -1}
, fMeta{0, 0, -1, -1, 0, fManager.GetSegmentId()}
, fRegionPtr(nullptr)
, fLocalPtr(nullptr)
{
@ -73,7 +73,7 @@ class Message final : public fair::mq::Message
: fair::mq::Message(factory)
, fManager(manager)
, fQueued(false)
, fMeta{0, 0, 0, fManager.GetSegmentId(), -1}
, fMeta{0, 0, -1, -1, 0, fManager.GetSegmentId()}
, fAlignment(alignment.alignment)
, fRegionPtr(nullptr)
, fLocalPtr(nullptr)
@ -86,7 +86,7 @@ class Message final : public fair::mq::Message
: fair::mq::Message(factory)
, fManager(manager)
, fQueued(false)
, fMeta{0, 0, 0, fManager.GetSegmentId(), -1}
, fMeta{0, 0, -1, -1, 0, fManager.GetSegmentId()}
, fRegionPtr(nullptr)
, fLocalPtr(nullptr)
{
@ -105,7 +105,7 @@ class Message final : public fair::mq::Message
: fair::mq::Message(factory)
, fManager(manager)
, fQueued(false)
, fMeta{size, reinterpret_cast<size_t>(hint), static_cast<UnmanagedRegion*>(region.get())->fRegionId, fManager.GetSegmentId(), -1}
, fMeta{size, reinterpret_cast<size_t>(hint), -1, -1, static_cast<UnmanagedRegion*>(region.get())->fRegionId, fManager.GetSegmentId()}
, fRegionPtr(nullptr)
, fLocalPtr(static_cast<char*>(data))
{
@ -187,8 +187,7 @@ class Message final : public fair::mq::Message
if (fMeta.fRegionId == 0) {
if (fMeta.fSize > 0) {
fManager.GetSegment(fMeta.fSegmentId);
ShmPtr shmPtr(reinterpret_cast<char*>(fManager.GetAddressFromHandle(fMeta.fHandle, fMeta.fSegmentId)));
fLocalPtr = shmPtr.UserPtr();
fLocalPtr = ShmHeader::UserPtr(fManager.GetAddressFromHandle(fMeta.fHandle, fMeta.fSegmentId));
} else {
fLocalPtr = nullptr;
}
@ -218,8 +217,8 @@ class Message final : public fair::mq::Message
} else if (newSize <= fMeta.fSize) {
try {
try {
ShmPtr shmPtr(fManager.ShrinkInPlace(newSize, static_cast<char*>(fManager.GetAddressFromHandle(fMeta.fHandle, fMeta.fSegmentId)), fMeta.fSegmentId));
fLocalPtr = shmPtr.UserPtr();
char* ptr = fManager.ShrinkInPlace(newSize, fManager.GetAddressFromHandle(fMeta.fHandle, fMeta.fSegmentId), fMeta.fSegmentId);
fLocalPtr = ShmHeader::UserPtr(ptr);
fMeta.fSize = newSize;
return true;
} catch (boost::interprocess::bad_alloc& e) {
@ -227,17 +226,12 @@ class Message final : public fair::mq::Message
// unused size >= 1000000 bytes: reallocate fully
// unused size < 1000000 bytes: simply reset the size and keep the rest of the buffer until message destruction
if (fMeta.fSize - newSize >= 1000000) {
ShmPtr shmPtr = fManager.Allocate(newSize, fAlignment);
if (shmPtr.RealPtr()) {
char* userPtr = shmPtr.UserPtr();
std::memcpy(userPtr, fLocalPtr, newSize);
fManager.Deallocate(fMeta.fHandle, fMeta.fSegmentId);
fLocalPtr = userPtr;
fMeta.fHandle = fManager.GetHandleFromAddress(shmPtr.RealPtr(), fMeta.fSegmentId);
} else {
LOG(debug) << "could not set used size: " << e.what();
return false;
}
char* ptr = fManager.Allocate(newSize, fAlignment);
char* userPtr = ShmHeader::UserPtr(ptr);
std::memcpy(userPtr, fLocalPtr, newSize);
fManager.Deallocate(fMeta.fHandle, fMeta.fSegmentId);
fLocalPtr = userPtr;
fMeta.fHandle = fManager.GetHandleFromAddress(ptr, fMeta.fSegmentId);
}
fMeta.fSize = newSize;
return true;
@ -254,33 +248,65 @@ class Message final : public fair::mq::Message
Transport GetType() const override { return fair::mq::Transport::SHM; }
void Copy(const fair::mq::Message& msg) override
uint16_t GetRefCount() const
{
if (fMeta.fHandle < 0) {
boost::interprocess::managed_shared_memory::handle_t otherHandle = static_cast<const Message&>(msg).fMeta.fHandle;
if (otherHandle) {
if (InitializeChunk(msg.GetSize())) {
std::memcpy(GetData(), msg.GetData(), msg.GetSize());
}
return 1;
}
if (fMeta.fRegionId == 0) { // managed segment
fManager.GetSegment(fMeta.fSegmentId);
return ShmHeader::RefCount(fManager.GetAddressFromHandle(fMeta.fHandle, fMeta.fSegmentId));
} else { // unmanaged region
if (fMeta.fShared < 0) { // UR msg is not yet shared
return 1;
} else {
LOG(error) << "copy fail: source message not initialized!";
fManager.GetSegment(fMeta.fSegmentId);
return ShmHeader::RefCount(fManager.GetAddressFromHandle(fMeta.fShared, fMeta.fSegmentId));
}
} else {
LOG(error) << "copy fail: target message already initialized!";
}
}
~Message() override
void Copy(const fair::mq::Message& other) override
{
try {
const Message& otherMsg = static_cast<const Message&>(other);
if (otherMsg.fMeta.fHandle < 0) {
// if the other message is not initialized, close this one too and return
CloseMessage();
} catch(SharedMemoryError& sme) {
LOG(error) << "error closing message: " << sme.what();
} catch(boost::interprocess::lock_exception& le) {
LOG(error) << "error closing message: " << le.what();
return;
}
if (fMeta.fHandle >= 0) {
// if this msg is already initialized, close it first
CloseMessage();
}
if (otherMsg.fMeta.fRegionId == 0) { // managed segment
fMeta = otherMsg.fMeta;
fManager.GetSegment(fMeta.fSegmentId);
ShmHeader::IncrementRefCount(fManager.GetAddressFromHandle(fMeta.fHandle, fMeta.fSegmentId));
} else { // unmanaged region
if (otherMsg.fMeta.fShared < 0) { // if UR msg is not yet shared
// TODO: minimize the size to 0 and don't create extra space for user buffer alignment
char* ptr = fManager.Allocate(2, 0);
// point the fShared in the unmanaged region message to the refCount holder
otherMsg.fMeta.fShared = fManager.GetHandleFromAddress(ptr, fMeta.fSegmentId);
// the message needs to be able to locate in which segment the refCount is stored
otherMsg.fMeta.fSegmentId = fMeta.fSegmentId;
// point this message to the same content as the unmanaged region message
fMeta = otherMsg.fMeta;
// increment the refCount
ShmHeader::IncrementRefCount(ptr);
} else { // if the UR msg is already shared
fMeta = otherMsg.fMeta;
fManager.GetSegment(fMeta.fSegmentId);
ShmHeader::IncrementRefCount(fManager.GetAddressFromHandle(fMeta.fShared, fMeta.fSegmentId));
}
}
}
~Message() override { CloseMessage(); }
private:
Manager& fManager;
bool fQueued;
@ -291,44 +317,70 @@ class Message final : public fair::mq::Message
char* InitializeChunk(const size_t size, size_t alignment = 0)
{
ShmPtr shmPtr = fManager.Allocate(size, alignment);
if (shmPtr.RealPtr()) {
fMeta.fHandle = fManager.GetHandleFromAddress(shmPtr.RealPtr(), fMeta.fSegmentId);
fMeta.fSize = size;
fLocalPtr = shmPtr.UserPtr();
if (size == 0) {
fMeta.fSize = 0;
return fLocalPtr;
}
char* ptr = fManager.Allocate(size, alignment);
fMeta.fHandle = fManager.GetHandleFromAddress(ptr, fMeta.fSegmentId);
fMeta.fSize = size;
fLocalPtr = ShmHeader::UserPtr(ptr);
return fLocalPtr;
}
void Deallocate()
{
if (fMeta.fHandle >= 0 && !fQueued) {
if (fMeta.fRegionId == 0) {
if (fMeta.fRegionId == 0) { // managed segment
fManager.GetSegment(fMeta.fSegmentId);
fManager.Deallocate(fMeta.fHandle, fMeta.fSegmentId);
fMeta.fHandle = -1;
} else {
if (!fRegionPtr) {
fRegionPtr = fManager.GetRegion(fMeta.fRegionId);
uint16_t refCount = ShmHeader::DecrementRefCount(fManager.GetAddressFromHandle(fMeta.fHandle, fMeta.fSegmentId));
if (refCount == 1) {
fManager.Deallocate(fMeta.fHandle, fMeta.fSegmentId);
}
if (fRegionPtr) {
fRegionPtr->ReleaseBlock({fMeta.fHandle, fMeta.fSize, fMeta.fHint});
} else { // unmanaged region
if (fMeta.fShared >= 0) {
// make sure segment is initialized in this transport
fManager.GetSegment(fMeta.fSegmentId);
// release unmanaged region block if ref count is one
uint16_t refCount = ShmHeader::DecrementRefCount(fManager.GetAddressFromHandle(fMeta.fShared, fMeta.fSegmentId));
if (refCount == 1) {
fManager.Deallocate(fMeta.fShared, fMeta.fSegmentId);
ReleaseUnmanagedRegionBlock();
}
} else {
LOG(warn) << "region ack queue for id " << fMeta.fRegionId << " no longer exist. Not sending ack";
ReleaseUnmanagedRegionBlock();
}
}
}
fMeta.fHandle = -1;
fLocalPtr = nullptr;
fMeta.fSize = 0;
}
void ReleaseUnmanagedRegionBlock()
{
if (!fRegionPtr) {
fRegionPtr = fManager.GetRegion(fMeta.fRegionId);
}
if (fRegionPtr) {
fRegionPtr->ReleaseBlock({fMeta.fHandle, fMeta.fSize, fMeta.fHint});
} else {
LOG(warn) << "region ack queue for id " << fMeta.fRegionId << " no longer exist. Not sending ack";
}
}
void CloseMessage()
{
Deallocate();
fAlignment = 0;
fManager.DecrementMsgCounter();
try {
Deallocate();
fAlignment = 0;
fManager.DecrementMsgCounter();
} catch(SharedMemoryError& sme) {
LOG(error) << "error closing message: " << sme.what();
} catch(boost::interprocess::lock_exception& le) {
LOG(error) << "error closing message: " << le.what();
}
}
};

View File

@ -89,7 +89,7 @@ add_testsuite(Message
${CMAKE_CURRENT_BINARY_DIR}/runner.cxx
message/_message.cxx
LINKS FairMQ
LINKS FairMQ PicoSHA2
INCLUDES ${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/message
${CMAKE_CURRENT_BINARY_DIR}

View File

@ -6,19 +6,23 @@
* copied verbatim in the file "LICENSE" *
********************************************************************************/
#include <array>
#include <cassert>
#include <cstdint>
#include <fairlogger/Logger.h>
#include <fairmq/Channel.h>
#include <fairmq/ProgOptions.h>
#include <fairmq/TransportFactory.h>
#include <fairmq/tools/Semaphore.h>
#include <fairmq/tools/Strings.h>
#include <fairmq/tools/Unique.h>
#include <fairmq/TransportFactory.h>
#include <fairmq/shmem/Message.h>
#include <gtest/gtest.h>
#include <array>
#include <cassert>
#include <cstdint>
#include <memory>
#include <string>
#include <string_view>
#include <string>
#include <utility>
namespace
@ -190,7 +194,6 @@ auto EmptyMessage(string const& transport, string const& _address) -> void
push.Bind(address);
pull.Connect(address);
{
auto outMsg(push.NewMessage());
ASSERT_EQ(outMsg->GetData(), nullptr);
@ -227,6 +230,129 @@ auto EmptyMessage(string const& transport, string const& _address) -> void
}
}
// The "zero copy" property of the Copy() method is an implementation detail and is not guaranteed.
// Currently it holds true for the shmem (across devices) and for zeromq (within same device) transports.
auto ZeroCopy() -> void
{
ProgOptions config;
config.SetProperty<string>("session", tools::Uuid());
auto factory(TransportFactory::CreateTransportFactory("shmem", tools::Uuid(), &config));
unique_ptr<string> str(make_unique<string>("asdf"));
const size_t size = 2;
MessagePtr original(factory->CreateMessage(size));
memcpy(original->GetData(), "AB", size);
{
MessagePtr copy(factory->CreateMessage());
copy->Copy(*original);
EXPECT_EQ(original->GetSize(), copy->GetSize());
EXPECT_EQ(original->GetData(), copy->GetData());
EXPECT_EQ(static_cast<const shmem::Message&>(*original).GetRefCount(), 2);
EXPECT_EQ(static_cast<const shmem::Message&>(*copy).GetRefCount(), 2);
// buffer must be still intact
ASSERT_EQ(AsStringView(*original)[0], 'A');
ASSERT_EQ(AsStringView(*original)[1], 'B');
ASSERT_EQ(AsStringView(*copy)[0], 'A');
ASSERT_EQ(AsStringView(*copy)[1], 'B');
}
EXPECT_EQ(static_cast<const shmem::Message&>(*original).GetRefCount(), 1);
}
// The "zero copy" property of the Copy() method is an implementation detail and is not guaranteed.
// Currently it holds true for the shmem (across devices) and for zeromq (within same device) transports.
auto ZeroCopyFromUnmanaged(string const& address) -> void
{
ProgOptions config1;
ProgOptions config2;
string session(tools::Uuid());
config1.SetProperty<string>("session", session);
config2.SetProperty<string>("session", session);
// ref counts should be accessible accross different segments
config2.SetProperty<uint16_t>("shm-segment-id", 2);
auto factory1(TransportFactory::CreateTransportFactory("shmem", tools::Uuid(), &config1));
auto factory2(TransportFactory::CreateTransportFactory("shmem", tools::Uuid(), &config2));
const size_t msgSize{100};
const size_t regionSize{1000000};
tools::Semaphore blocker;
auto region = factory1->CreateUnmanagedRegion(regionSize, [&blocker](void*, size_t, void*) {
blocker.Signal();
});
{
FairMQChannel push("Push", "push", factory1);
FairMQChannel pull("Pull", "pull", factory2);
push.Bind(address);
pull.Connect(address);
const size_t offset = 100;
auto msg1(push.NewMessage(region, static_cast<char*>(region->GetData()), msgSize, nullptr));
auto msg2(push.NewMessage(region, static_cast<char*>(region->GetData()) + offset, msgSize, nullptr));
const size_t contentSize = 2;
memcpy(msg1->GetData(), "AB", contentSize);
memcpy(msg2->GetData(), "CD", contentSize);
EXPECT_EQ(static_cast<const shmem::Message&>(*msg1).GetRefCount(), 1);
{
auto copyFromOriginal(push.NewMessage());
copyFromOriginal->Copy(*msg1);
EXPECT_EQ(static_cast<const shmem::Message&>(*msg1).GetRefCount(), 2);
EXPECT_EQ(static_cast<const shmem::Message&>(*msg1).GetRefCount(), static_cast<const shmem::Message&>(*copyFromOriginal).GetRefCount());
{
auto copyFromCopy(push.NewMessage());
copyFromCopy->Copy(*copyFromOriginal);
EXPECT_EQ(static_cast<const shmem::Message&>(*msg1).GetRefCount(), 3);
EXPECT_EQ(static_cast<const shmem::Message&>(*msg1).GetRefCount(), static_cast<const shmem::Message&>(*copyFromCopy).GetRefCount());
EXPECT_EQ(msg1->GetSize(), copyFromOriginal->GetSize());
EXPECT_EQ(msg1->GetData(), copyFromOriginal->GetData());
EXPECT_EQ(msg1->GetSize(), copyFromCopy->GetSize());
EXPECT_EQ(msg1->GetData(), copyFromCopy->GetData());
EXPECT_EQ(copyFromOriginal->GetSize(), copyFromCopy->GetSize());
EXPECT_EQ(copyFromOriginal->GetData(), copyFromCopy->GetData());
// messing with the ref count should not have affected the user buffer
ASSERT_EQ(AsStringView(*msg1)[0], 'A');
ASSERT_EQ(AsStringView(*msg1)[1], 'B');
push.Send(copyFromCopy);
push.Send(msg2);
auto incomingCopiedMsg(pull.NewMessage());
auto incomingOriginalMsg(pull.NewMessage());
pull.Receive(incomingCopiedMsg);
pull.Receive(incomingOriginalMsg);
EXPECT_EQ(static_cast<const shmem::Message&>(*incomingCopiedMsg).GetRefCount(), 3);
EXPECT_EQ(static_cast<const shmem::Message&>(*incomingOriginalMsg).GetRefCount(), 1);
ASSERT_EQ(AsStringView(*incomingCopiedMsg)[0], 'A');
ASSERT_EQ(AsStringView(*incomingCopiedMsg)[1], 'B');
{
// copying on a different segment should work
auto copyFromIncoming(pull.NewMessage());
copyFromIncoming->Copy(*incomingOriginalMsg);
EXPECT_EQ(static_cast<const shmem::Message&>(*copyFromIncoming).GetRefCount(), 2);
ASSERT_EQ(AsStringView(*incomingOriginalMsg)[0], 'C');
ASSERT_EQ(AsStringView(*incomingOriginalMsg)[1], 'D');
}
EXPECT_EQ(static_cast<const shmem::Message&>(*incomingOriginalMsg).GetRefCount(), 1);
}
EXPECT_EQ(static_cast<const shmem::Message&>(*msg1).GetRefCount(), 2);
}
EXPECT_EQ(static_cast<const shmem::Message&>(*msg1).GetRefCount(), 1);
}
blocker.Wait();
blocker.Wait();
}
TEST(Resize, zeromq) // NOLINT
{
RunPushPullWithMsgResize("zeromq", "ipc://test_message_resize");
@ -267,4 +393,14 @@ TEST(EmptyMessage, shmem) // NOLINT
EmptyMessage("shmem", "ipc://test_empty_message");
}
TEST(ZeroCopy, shmem) // NOLINT
{
ZeroCopy();
}
TEST(ZeroCopyFromUnmanaged, shmem) // NOLINT
{
ZeroCopyFromUnmanaged("ipc://test_zerocopy_unmanaged");
}
} // namespace

View File

@ -199,7 +199,6 @@ void RegionCallbacks(const string& transport, const string& _address)
});
ptr2 = region2->GetData();
{
FairMQMessagePtr msg1out(push.NewMessage(region1, ptr1, size1, intPtr1.get()));
FairMQMessagePtr msg2out(push.NewMessage(region2, ptr2, size2, intPtr2.get()));