diff --git a/fairmq/shmem/Common.h b/fairmq/shmem/Common.h index b17c3882..21666d71 100644 --- a/fairmq/shmem/Common.h +++ b/fairmq/shmem/Common.h @@ -339,4 +339,13 @@ struct SegmentBufferShrink } // namespace fair::mq::shmem +namespace fair::mq { class TransportFactory; } + +namespace fair::mq::shmem { +// Resolve a MetaHeader (received over a side channel) to the local data pointer. +// The caller is responsible for ensuring the backing buffer remains alive for the +// duration of access; FairMQ provides no refcount protection for this path. +char* GetDataAddressFromHandle(fair::mq::TransportFactory& factory, const MetaHeader& meta); +} // namespace fair::mq::shmem + #endif /* FAIR_MQ_SHMEM_COMMON_H_ */ diff --git a/fairmq/shmem/Manager.cxx b/fairmq/shmem/Manager.cxx index efdbf094..083e6c0e 100644 --- a/fairmq/shmem/Manager.cxx +++ b/fairmq/shmem/Manager.cxx @@ -7,6 +7,7 @@ ********************************************************************************/ #include "Manager.h" +#include "TransportFactory.h" // Needed to compile-firewall the header because it // interferes with the header. So, let's factor @@ -51,4 +52,12 @@ bool Manager::SpawnShmMonitor(const std::string& id) return true; } +char* GetDataAddressFromHandle(fair::mq::TransportFactory& factory, const MetaHeader& meta) +{ + if (factory.GetType() != fair::mq::Transport::SHM) { + throw SharedMemoryError("GetDataAddressFromHandle called on a non-shmem transport"); + } + return static_cast(factory).GetDataAddressFromHandle(meta); +} + } // namespace fair::mq::shmem diff --git a/fairmq/shmem/Manager.h b/fairmq/shmem/Manager.h index 72de47b8..60cee0df 100644 --- a/fairmq/shmem/Manager.h +++ b/fairmq/shmem/Manager.h @@ -776,6 +776,30 @@ class Manager auto GetMetadataMsgSize() const noexcept { return fMetadataMsgSize; } + // Resolve a MetaHeader (received over a side channel) to the local data pointer. + // The caller is responsible for ensuring the backing buffer remains alive for the + // duration of access; FairMQ provides no refcount protection for this path. + char* GetDataAddressFromHandle(const MetaHeader& meta) + { + if (meta.fManaged) { + if (meta.fSize == 0) { + return nullptr; + } + GetSegment(meta.fSegmentId); + auto it = fSegments.find(meta.fSegmentId); + if (it == fSegments.end()) { + throw SharedMemoryError(tools::ToString("GetDataAddressFromHandle: cannot open segment with id ", meta.fSegmentId)); + } + return ShmHeader::UserPtr(GetAddressFromHandle(meta.fHandle, meta.fSegmentId)); + } else { + UnmanagedRegion* region = GetRegionFromCache(meta.fRegionId); + if (!region) { + throw SharedMemoryError(tools::ToString("GetDataAddressFromHandle: cannot get unmanaged region with id ", meta.fRegionId)); + } + return reinterpret_cast(region->GetData()) + meta.fHandle; + } + } + ~Manager() { fRegionsGen += 1; // signal TL cache invalidation diff --git a/fairmq/shmem/Message.h b/fairmq/shmem/Message.h index 9c75b249..c2fdc1e1 100644 --- a/fairmq/shmem/Message.h +++ b/fairmq/shmem/Message.h @@ -167,6 +167,11 @@ class Message final : public fair::mq::Message } } + MetaHeader GetMeta() const + { + return {fSize, fHint, fHandle, fShared, fRegionId, fSegmentId, fManaged}; + } + void* GetData() const override { if (!fLocalPtr) { diff --git a/fairmq/shmem/TransportFactory.h b/fairmq/shmem/TransportFactory.h index 60850245..9e78b041 100644 --- a/fairmq/shmem/TransportFactory.h +++ b/fairmq/shmem/TransportFactory.h @@ -201,6 +201,8 @@ class TransportFactory final : public fair::mq::TransportFactory void Resume() override { fManager->Resume(); } void Reset() override { fManager->Reset(); } + char* GetDataAddressFromHandle(const MetaHeader& meta) { return fManager->GetDataAddressFromHandle(meta); } + ~TransportFactory() override { LOG(debug) << "Destroying Shared Memory transport..."; diff --git a/test/message/_message.cxx b/test/message/_message.cxx index 1757df3d..53158985 100644 --- a/test/message/_message.cxx +++ b/test/message/_message.cxx @@ -453,6 +453,46 @@ TEST(EmptyMessage, shmem_expanded_metadata) // NOLINT EmptyMessage("shmem", "ipc://test_empty_message", true); } +// GetMeta() + GetDataAddressFromHandle() round-trip: simulates a consumer device that +// receives message metadata over a side channel and resolves it to a local pointer. +// Uses a second factory on the same session to exercise the open_only segment attach path. +auto SideChannel(bool expandedShmMetadata = false) -> void +{ + const string session = tools::Uuid(); + + ProgOptions producerConfig; + producerConfig.SetProperty("session", session); + producerConfig.SetProperty("shm-segment-size", 100000000); + producerConfig.SetProperty("shm-monitor", true); + if (expandedShmMetadata) { + producerConfig.SetProperty("shm-metadata-msg-size", 2048); + } + auto producerFactory(TransportFactory::CreateTransportFactory("shmem", tools::Uuid(), &producerConfig)); + + ProgOptions consumerConfig; + consumerConfig.SetProperty("session", session); + consumerConfig.SetProperty("shm-segment-size", 100000000); + consumerConfig.SetProperty("shm-monitor", true); + auto consumerFactory(TransportFactory::CreateTransportFactory("shmem", tools::Uuid(), &consumerConfig)); + + const size_t size = 2; + MessagePtr msg(producerFactory->CreateMessage(size)); + memcpy(msg->GetData(), "AB", size); + + // producer side: extract metadata to send over the side channel + const auto meta = static_cast(*msg).GetMeta(); + EXPECT_EQ(meta.fSize, size); + EXPECT_TRUE(meta.fManaged); + + // consumer side: resolve via a different factory — exercises open_only segment attach. + // The virtual address differs between factory mappings; only the content must match. + char* ptr = shmem::GetDataAddressFromHandle(*consumerFactory, meta); + + ASSERT_NE(ptr, nullptr); + EXPECT_EQ(ptr[0], 'A'); + EXPECT_EQ(ptr[1], 'B'); +} + TEST(ZeroCopy, shmem) // NOLINT { ZeroCopy(); @@ -463,6 +503,93 @@ TEST(ZeroCopy, shmem_expanded_metadata) // NOLINT ZeroCopy(true); } +// Uses a second factory on the same session to exercise the remote region lookup path. +auto SideChannelUnmanaged() -> void +{ + const string session = tools::Uuid(); + + ProgOptions producerConfig; + producerConfig.SetProperty("session", session); + producerConfig.SetProperty("shm-segment-size", 100000000); + producerConfig.SetProperty("shm-monitor", true); + auto producerFactory(TransportFactory::CreateTransportFactory("shmem", tools::Uuid(), &producerConfig)); + + ProgOptions consumerConfig; + consumerConfig.SetProperty("session", session); + consumerConfig.SetProperty("shm-segment-size", 100000000); + consumerConfig.SetProperty("shm-monitor", true); + auto consumerFactory(TransportFactory::CreateTransportFactory("shmem", tools::Uuid(), &consumerConfig)); + + const size_t regionSize = 1000000; + tools::Semaphore blocker; + auto region = producerFactory->CreateUnmanagedRegion(regionSize, [&blocker](void*, size_t, void*) { + blocker.Signal(); + }); + + const size_t size = 2; + auto msg(producerFactory->CreateMessage(region, static_cast(region->GetData()), size, nullptr)); + memcpy(msg->GetData(), "AB", size); + + const auto meta = static_cast(*msg).GetMeta(); + EXPECT_EQ(meta.fSize, size); + EXPECT_FALSE(meta.fManaged); + + // consumer resolves via its own factory — exercises remote region lookup. + // The virtual address differs between factory mappings; only the content must match. + char* ptr = shmem::GetDataAddressFromHandle(*consumerFactory, meta); + + ASSERT_NE(ptr, nullptr); + EXPECT_EQ(ptr[0], 'A'); + EXPECT_EQ(ptr[1], 'B'); + + msg.reset(); + blocker.Wait(); +} + +auto SideChannelErrors() -> void +{ + ProgOptions config; + config.SetProperty("session", tools::Uuid()); + config.SetProperty("shm-segment-size", 100000000); + config.SetProperty("shm-monitor", true); + auto shmFactory(TransportFactory::CreateTransportFactory("shmem", tools::Uuid(), &config)); + auto zmqFactory(TransportFactory::CreateTransportFactory("zeromq", tools::Uuid(), &config)); + + // non-shmem factory must throw + shmem::MetaHeader meta{}; + meta.fManaged = true; + EXPECT_THROW(shmem::GetDataAddressFromHandle(*zmqFactory, meta), shmem::SharedMemoryError); + + // bad segment id must throw + meta.fSize = 1; + meta.fSegmentId = 999; + EXPECT_THROW(shmem::GetDataAddressFromHandle(*shmFactory, meta), shmem::SharedMemoryError); + + // zero-size managed message must return nullptr (mirrors Message::GetData()) + meta.fSize = 0; + EXPECT_EQ(shmem::GetDataAddressFromHandle(*shmFactory, meta), nullptr); +} + +TEST(SideChannel, shmem) // NOLINT +{ + SideChannel(); +} + +TEST(SideChannel, shmem_expanded_metadata) // NOLINT +{ + SideChannel(true); +} + +TEST(SideChannel, shmem_unmanaged) // NOLINT +{ + SideChannelUnmanaged(); +} + +TEST(SideChannel, errors) // NOLINT +{ + SideChannelErrors(); +} + TEST(ZeroCopyFromUnmanaged, shmem) // NOLINT { ZeroCopyFromUnmanaged("ipc://test_zerocopy_unmanaged", false, 10000000);