Implement shmem msg zero-copy

This commit is contained in:
Alexey Rybalchenko
2021-07-14 10:46:12 +02:00
committed by Dennis Klein
parent c57410b820
commit bce380d871
7 changed files with 333 additions and 100 deletions

View File

@@ -89,7 +89,7 @@ add_testsuite(Message
${CMAKE_CURRENT_BINARY_DIR}/runner.cxx
message/_message.cxx
LINKS FairMQ
LINKS FairMQ PicoSHA2
INCLUDES ${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_SOURCE_DIR}/message
${CMAKE_CURRENT_BINARY_DIR}

View File

@@ -6,19 +6,23 @@
* copied verbatim in the file "LICENSE" *
********************************************************************************/
#include <array>
#include <cassert>
#include <cstdint>
#include <fairlogger/Logger.h>
#include <fairmq/Channel.h>
#include <fairmq/ProgOptions.h>
#include <fairmq/TransportFactory.h>
#include <fairmq/tools/Semaphore.h>
#include <fairmq/tools/Strings.h>
#include <fairmq/tools/Unique.h>
#include <fairmq/TransportFactory.h>
#include <fairmq/shmem/Message.h>
#include <gtest/gtest.h>
#include <array>
#include <cassert>
#include <cstdint>
#include <memory>
#include <string>
#include <string_view>
#include <string>
#include <utility>
namespace
@@ -190,7 +194,6 @@ auto EmptyMessage(string const& transport, string const& _address) -> void
push.Bind(address);
pull.Connect(address);
{
auto outMsg(push.NewMessage());
ASSERT_EQ(outMsg->GetData(), nullptr);
@@ -227,6 +230,129 @@ auto EmptyMessage(string const& transport, string const& _address) -> void
}
}
// The "zero copy" property of the Copy() method is an implementation detail and is not guaranteed.
// Currently it holds true for the shmem (across devices) and for zeromq (within same device) transports.
auto ZeroCopy() -> void
{
ProgOptions config;
config.SetProperty<string>("session", tools::Uuid());
auto factory(TransportFactory::CreateTransportFactory("shmem", tools::Uuid(), &config));
unique_ptr<string> str(make_unique<string>("asdf"));
const size_t size = 2;
MessagePtr original(factory->CreateMessage(size));
memcpy(original->GetData(), "AB", size);
{
MessagePtr copy(factory->CreateMessage());
copy->Copy(*original);
EXPECT_EQ(original->GetSize(), copy->GetSize());
EXPECT_EQ(original->GetData(), copy->GetData());
EXPECT_EQ(static_cast<const shmem::Message&>(*original).GetRefCount(), 2);
EXPECT_EQ(static_cast<const shmem::Message&>(*copy).GetRefCount(), 2);
// buffer must be still intact
ASSERT_EQ(AsStringView(*original)[0], 'A');
ASSERT_EQ(AsStringView(*original)[1], 'B');
ASSERT_EQ(AsStringView(*copy)[0], 'A');
ASSERT_EQ(AsStringView(*copy)[1], 'B');
}
EXPECT_EQ(static_cast<const shmem::Message&>(*original).GetRefCount(), 1);
}
// The "zero copy" property of the Copy() method is an implementation detail and is not guaranteed.
// Currently it holds true for the shmem (across devices) and for zeromq (within same device) transports.
auto ZeroCopyFromUnmanaged(string const& address) -> void
{
ProgOptions config1;
ProgOptions config2;
string session(tools::Uuid());
config1.SetProperty<string>("session", session);
config2.SetProperty<string>("session", session);
// ref counts should be accessible accross different segments
config2.SetProperty<uint16_t>("shm-segment-id", 2);
auto factory1(TransportFactory::CreateTransportFactory("shmem", tools::Uuid(), &config1));
auto factory2(TransportFactory::CreateTransportFactory("shmem", tools::Uuid(), &config2));
const size_t msgSize{100};
const size_t regionSize{1000000};
tools::Semaphore blocker;
auto region = factory1->CreateUnmanagedRegion(regionSize, [&blocker](void*, size_t, void*) {
blocker.Signal();
});
{
FairMQChannel push("Push", "push", factory1);
FairMQChannel pull("Pull", "pull", factory2);
push.Bind(address);
pull.Connect(address);
const size_t offset = 100;
auto msg1(push.NewMessage(region, static_cast<char*>(region->GetData()), msgSize, nullptr));
auto msg2(push.NewMessage(region, static_cast<char*>(region->GetData()) + offset, msgSize, nullptr));
const size_t contentSize = 2;
memcpy(msg1->GetData(), "AB", contentSize);
memcpy(msg2->GetData(), "CD", contentSize);
EXPECT_EQ(static_cast<const shmem::Message&>(*msg1).GetRefCount(), 1);
{
auto copyFromOriginal(push.NewMessage());
copyFromOriginal->Copy(*msg1);
EXPECT_EQ(static_cast<const shmem::Message&>(*msg1).GetRefCount(), 2);
EXPECT_EQ(static_cast<const shmem::Message&>(*msg1).GetRefCount(), static_cast<const shmem::Message&>(*copyFromOriginal).GetRefCount());
{
auto copyFromCopy(push.NewMessage());
copyFromCopy->Copy(*copyFromOriginal);
EXPECT_EQ(static_cast<const shmem::Message&>(*msg1).GetRefCount(), 3);
EXPECT_EQ(static_cast<const shmem::Message&>(*msg1).GetRefCount(), static_cast<const shmem::Message&>(*copyFromCopy).GetRefCount());
EXPECT_EQ(msg1->GetSize(), copyFromOriginal->GetSize());
EXPECT_EQ(msg1->GetData(), copyFromOriginal->GetData());
EXPECT_EQ(msg1->GetSize(), copyFromCopy->GetSize());
EXPECT_EQ(msg1->GetData(), copyFromCopy->GetData());
EXPECT_EQ(copyFromOriginal->GetSize(), copyFromCopy->GetSize());
EXPECT_EQ(copyFromOriginal->GetData(), copyFromCopy->GetData());
// messing with the ref count should not have affected the user buffer
ASSERT_EQ(AsStringView(*msg1)[0], 'A');
ASSERT_EQ(AsStringView(*msg1)[1], 'B');
push.Send(copyFromCopy);
push.Send(msg2);
auto incomingCopiedMsg(pull.NewMessage());
auto incomingOriginalMsg(pull.NewMessage());
pull.Receive(incomingCopiedMsg);
pull.Receive(incomingOriginalMsg);
EXPECT_EQ(static_cast<const shmem::Message&>(*incomingCopiedMsg).GetRefCount(), 3);
EXPECT_EQ(static_cast<const shmem::Message&>(*incomingOriginalMsg).GetRefCount(), 1);
ASSERT_EQ(AsStringView(*incomingCopiedMsg)[0], 'A');
ASSERT_EQ(AsStringView(*incomingCopiedMsg)[1], 'B');
{
// copying on a different segment should work
auto copyFromIncoming(pull.NewMessage());
copyFromIncoming->Copy(*incomingOriginalMsg);
EXPECT_EQ(static_cast<const shmem::Message&>(*copyFromIncoming).GetRefCount(), 2);
ASSERT_EQ(AsStringView(*incomingOriginalMsg)[0], 'C');
ASSERT_EQ(AsStringView(*incomingOriginalMsg)[1], 'D');
}
EXPECT_EQ(static_cast<const shmem::Message&>(*incomingOriginalMsg).GetRefCount(), 1);
}
EXPECT_EQ(static_cast<const shmem::Message&>(*msg1).GetRefCount(), 2);
}
EXPECT_EQ(static_cast<const shmem::Message&>(*msg1).GetRefCount(), 1);
}
blocker.Wait();
blocker.Wait();
}
TEST(Resize, zeromq) // NOLINT
{
RunPushPullWithMsgResize("zeromq", "ipc://test_message_resize");
@@ -267,4 +393,14 @@ TEST(EmptyMessage, shmem) // NOLINT
EmptyMessage("shmem", "ipc://test_empty_message");
}
TEST(ZeroCopy, shmem) // NOLINT
{
ZeroCopy();
}
TEST(ZeroCopyFromUnmanaged, shmem) // NOLINT
{
ZeroCopyFromUnmanaged("ipc://test_zerocopy_unmanaged");
}
} // namespace

View File

@@ -199,7 +199,6 @@ void RegionCallbacks(const string& transport, const string& _address)
});
ptr2 = region2->GetData();
{
FairMQMessagePtr msg1out(push.NewMessage(region1, ptr1, size1, intPtr1.get()));
FairMQMessagePtr msg2out(push.NewMessage(region2, ptr2, size2, intPtr2.get()));