mirror of
https://github.com/FairRootGroup/FairMQ.git
synced 2025-10-15 09:31:45 +00:00
Implement shmem msg zero-copy
This commit is contained in:
committed by
Dennis Klein
parent
c57410b820
commit
bce380d871
@@ -89,7 +89,7 @@ add_testsuite(Message
|
||||
${CMAKE_CURRENT_BINARY_DIR}/runner.cxx
|
||||
message/_message.cxx
|
||||
|
||||
LINKS FairMQ
|
||||
LINKS FairMQ PicoSHA2
|
||||
INCLUDES ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/message
|
||||
${CMAKE_CURRENT_BINARY_DIR}
|
||||
|
@@ -6,19 +6,23 @@
|
||||
* copied verbatim in the file "LICENSE" *
|
||||
********************************************************************************/
|
||||
|
||||
#include <array>
|
||||
#include <cassert>
|
||||
#include <cstdint>
|
||||
#include <fairlogger/Logger.h>
|
||||
#include <fairmq/Channel.h>
|
||||
#include <fairmq/ProgOptions.h>
|
||||
#include <fairmq/TransportFactory.h>
|
||||
#include <fairmq/tools/Semaphore.h>
|
||||
#include <fairmq/tools/Strings.h>
|
||||
#include <fairmq/tools/Unique.h>
|
||||
#include <fairmq/TransportFactory.h>
|
||||
#include <fairmq/shmem/Message.h>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <array>
|
||||
#include <cassert>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
namespace
|
||||
@@ -190,7 +194,6 @@ auto EmptyMessage(string const& transport, string const& _address) -> void
|
||||
push.Bind(address);
|
||||
pull.Connect(address);
|
||||
|
||||
|
||||
{
|
||||
auto outMsg(push.NewMessage());
|
||||
ASSERT_EQ(outMsg->GetData(), nullptr);
|
||||
@@ -227,6 +230,129 @@ auto EmptyMessage(string const& transport, string const& _address) -> void
|
||||
}
|
||||
}
|
||||
|
||||
// The "zero copy" property of the Copy() method is an implementation detail and is not guaranteed.
|
||||
// Currently it holds true for the shmem (across devices) and for zeromq (within same device) transports.
|
||||
auto ZeroCopy() -> void
|
||||
{
|
||||
ProgOptions config;
|
||||
config.SetProperty<string>("session", tools::Uuid());
|
||||
auto factory(TransportFactory::CreateTransportFactory("shmem", tools::Uuid(), &config));
|
||||
|
||||
unique_ptr<string> str(make_unique<string>("asdf"));
|
||||
const size_t size = 2;
|
||||
MessagePtr original(factory->CreateMessage(size));
|
||||
memcpy(original->GetData(), "AB", size);
|
||||
{
|
||||
MessagePtr copy(factory->CreateMessage());
|
||||
copy->Copy(*original);
|
||||
EXPECT_EQ(original->GetSize(), copy->GetSize());
|
||||
EXPECT_EQ(original->GetData(), copy->GetData());
|
||||
EXPECT_EQ(static_cast<const shmem::Message&>(*original).GetRefCount(), 2);
|
||||
EXPECT_EQ(static_cast<const shmem::Message&>(*copy).GetRefCount(), 2);
|
||||
|
||||
// buffer must be still intact
|
||||
ASSERT_EQ(AsStringView(*original)[0], 'A');
|
||||
ASSERT_EQ(AsStringView(*original)[1], 'B');
|
||||
ASSERT_EQ(AsStringView(*copy)[0], 'A');
|
||||
ASSERT_EQ(AsStringView(*copy)[1], 'B');
|
||||
}
|
||||
EXPECT_EQ(static_cast<const shmem::Message&>(*original).GetRefCount(), 1);
|
||||
}
|
||||
|
||||
// The "zero copy" property of the Copy() method is an implementation detail and is not guaranteed.
|
||||
// Currently it holds true for the shmem (across devices) and for zeromq (within same device) transports.
|
||||
auto ZeroCopyFromUnmanaged(string const& address) -> void
|
||||
{
|
||||
ProgOptions config1;
|
||||
ProgOptions config2;
|
||||
string session(tools::Uuid());
|
||||
config1.SetProperty<string>("session", session);
|
||||
config2.SetProperty<string>("session", session);
|
||||
// ref counts should be accessible accross different segments
|
||||
config2.SetProperty<uint16_t>("shm-segment-id", 2);
|
||||
auto factory1(TransportFactory::CreateTransportFactory("shmem", tools::Uuid(), &config1));
|
||||
auto factory2(TransportFactory::CreateTransportFactory("shmem", tools::Uuid(), &config2));
|
||||
|
||||
const size_t msgSize{100};
|
||||
const size_t regionSize{1000000};
|
||||
tools::Semaphore blocker;
|
||||
|
||||
auto region = factory1->CreateUnmanagedRegion(regionSize, [&blocker](void*, size_t, void*) {
|
||||
blocker.Signal();
|
||||
});
|
||||
|
||||
{
|
||||
FairMQChannel push("Push", "push", factory1);
|
||||
FairMQChannel pull("Pull", "pull", factory2);
|
||||
|
||||
push.Bind(address);
|
||||
pull.Connect(address);
|
||||
|
||||
const size_t offset = 100;
|
||||
auto msg1(push.NewMessage(region, static_cast<char*>(region->GetData()), msgSize, nullptr));
|
||||
auto msg2(push.NewMessage(region, static_cast<char*>(region->GetData()) + offset, msgSize, nullptr));
|
||||
const size_t contentSize = 2;
|
||||
memcpy(msg1->GetData(), "AB", contentSize);
|
||||
memcpy(msg2->GetData(), "CD", contentSize);
|
||||
EXPECT_EQ(static_cast<const shmem::Message&>(*msg1).GetRefCount(), 1);
|
||||
|
||||
{
|
||||
auto copyFromOriginal(push.NewMessage());
|
||||
copyFromOriginal->Copy(*msg1);
|
||||
EXPECT_EQ(static_cast<const shmem::Message&>(*msg1).GetRefCount(), 2);
|
||||
EXPECT_EQ(static_cast<const shmem::Message&>(*msg1).GetRefCount(), static_cast<const shmem::Message&>(*copyFromOriginal).GetRefCount());
|
||||
{
|
||||
auto copyFromCopy(push.NewMessage());
|
||||
copyFromCopy->Copy(*copyFromOriginal);
|
||||
EXPECT_EQ(static_cast<const shmem::Message&>(*msg1).GetRefCount(), 3);
|
||||
EXPECT_EQ(static_cast<const shmem::Message&>(*msg1).GetRefCount(), static_cast<const shmem::Message&>(*copyFromCopy).GetRefCount());
|
||||
|
||||
EXPECT_EQ(msg1->GetSize(), copyFromOriginal->GetSize());
|
||||
EXPECT_EQ(msg1->GetData(), copyFromOriginal->GetData());
|
||||
EXPECT_EQ(msg1->GetSize(), copyFromCopy->GetSize());
|
||||
EXPECT_EQ(msg1->GetData(), copyFromCopy->GetData());
|
||||
EXPECT_EQ(copyFromOriginal->GetSize(), copyFromCopy->GetSize());
|
||||
EXPECT_EQ(copyFromOriginal->GetData(), copyFromCopy->GetData());
|
||||
|
||||
// messing with the ref count should not have affected the user buffer
|
||||
ASSERT_EQ(AsStringView(*msg1)[0], 'A');
|
||||
ASSERT_EQ(AsStringView(*msg1)[1], 'B');
|
||||
|
||||
push.Send(copyFromCopy);
|
||||
push.Send(msg2);
|
||||
|
||||
auto incomingCopiedMsg(pull.NewMessage());
|
||||
auto incomingOriginalMsg(pull.NewMessage());
|
||||
pull.Receive(incomingCopiedMsg);
|
||||
pull.Receive(incomingOriginalMsg);
|
||||
|
||||
EXPECT_EQ(static_cast<const shmem::Message&>(*incomingCopiedMsg).GetRefCount(), 3);
|
||||
EXPECT_EQ(static_cast<const shmem::Message&>(*incomingOriginalMsg).GetRefCount(), 1);
|
||||
|
||||
ASSERT_EQ(AsStringView(*incomingCopiedMsg)[0], 'A');
|
||||
ASSERT_EQ(AsStringView(*incomingCopiedMsg)[1], 'B');
|
||||
|
||||
{
|
||||
// copying on a different segment should work
|
||||
auto copyFromIncoming(pull.NewMessage());
|
||||
copyFromIncoming->Copy(*incomingOriginalMsg);
|
||||
EXPECT_EQ(static_cast<const shmem::Message&>(*copyFromIncoming).GetRefCount(), 2);
|
||||
|
||||
ASSERT_EQ(AsStringView(*incomingOriginalMsg)[0], 'C');
|
||||
ASSERT_EQ(AsStringView(*incomingOriginalMsg)[1], 'D');
|
||||
}
|
||||
|
||||
EXPECT_EQ(static_cast<const shmem::Message&>(*incomingOriginalMsg).GetRefCount(), 1);
|
||||
}
|
||||
EXPECT_EQ(static_cast<const shmem::Message&>(*msg1).GetRefCount(), 2);
|
||||
}
|
||||
EXPECT_EQ(static_cast<const shmem::Message&>(*msg1).GetRefCount(), 1);
|
||||
}
|
||||
|
||||
blocker.Wait();
|
||||
blocker.Wait();
|
||||
}
|
||||
|
||||
TEST(Resize, zeromq) // NOLINT
|
||||
{
|
||||
RunPushPullWithMsgResize("zeromq", "ipc://test_message_resize");
|
||||
@@ -267,4 +393,14 @@ TEST(EmptyMessage, shmem) // NOLINT
|
||||
EmptyMessage("shmem", "ipc://test_empty_message");
|
||||
}
|
||||
|
||||
TEST(ZeroCopy, shmem) // NOLINT
|
||||
{
|
||||
ZeroCopy();
|
||||
}
|
||||
|
||||
TEST(ZeroCopyFromUnmanaged, shmem) // NOLINT
|
||||
{
|
||||
ZeroCopyFromUnmanaged("ipc://test_zerocopy_unmanaged");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@@ -199,7 +199,6 @@ void RegionCallbacks(const string& transport, const string& _address)
|
||||
});
|
||||
ptr2 = region2->GetData();
|
||||
|
||||
|
||||
{
|
||||
FairMQMessagePtr msg1out(push.NewMessage(region1, ptr1, size1, intPtr1.get()));
|
||||
FairMQMessagePtr msg2out(push.NewMessage(region2, ptr2, size2, intPtr2.get()));
|
||||
|
Reference in New Issue
Block a user