diff --git a/fairmq/shmem/Message.h b/fairmq/shmem/Message.h index 46e25953..502034e9 100644 --- a/fairmq/shmem/Message.h +++ b/fairmq/shmem/Message.h @@ -38,10 +38,17 @@ class Message final : public fair::mq::Message Message(Manager& manager, fair::mq::TransportFactory* factory = nullptr) : fair::mq::Message(factory) , fManager(manager) - , fQueued(false) - , fMeta{0, 0, -1, -1, 0, fManager.GetSegmentId(), true} , fRegionPtr(nullptr) , fLocalPtr(nullptr) + , fSize(0) + , fHint(0) + , fHandle(-1) + , fShared(-1) + , fRegionId(0) + , fSegmentId(fManager.GetSegmentId()) + , fAlignment(0) + , fManaged(true) + , fQueued(false) { fManager.IncrementMsgCounter(); } @@ -49,11 +56,17 @@ class Message final : public fair::mq::Message Message(Manager& manager, Alignment alignment, fair::mq::TransportFactory* factory = nullptr) : fair::mq::Message(factory) , fManager(manager) - , fQueued(false) - , fMeta{0, 0, -1, -1, 0, fManager.GetSegmentId(), true} - , fAlignment(alignment.alignment) , fRegionPtr(nullptr) , fLocalPtr(nullptr) + , fSize(0) + , fHint(0) + , fHandle(-1) + , fShared(-1) + , fRegionId(0) + , fSegmentId(fManager.GetSegmentId()) + , fAlignment(alignment.alignment) + , fManaged(true) + , fQueued(false) { fManager.IncrementMsgCounter(); } @@ -61,10 +74,17 @@ class Message final : public fair::mq::Message Message(Manager& manager, const size_t size, fair::mq::TransportFactory* factory = nullptr) : fair::mq::Message(factory) , fManager(manager) - , fQueued(false) - , fMeta{0, 0, -1, -1, 0, fManager.GetSegmentId(), true} , fRegionPtr(nullptr) , fLocalPtr(nullptr) + , fSize(0) + , fHint(0) + , fHandle(-1) + , fShared(-1) + , fRegionId(0) + , fSegmentId(fManager.GetSegmentId()) + , fAlignment(0) + , fManaged(true) + , fQueued(false) { InitializeChunk(size); fManager.IncrementMsgCounter(); @@ -73,11 +93,17 @@ class Message final : public fair::mq::Message Message(Manager& manager, const size_t size, Alignment alignment, fair::mq::TransportFactory* factory = nullptr) : fair::mq::Message(factory) , fManager(manager) - , fQueued(false) - , fMeta{0, 0, -1, -1, 0, fManager.GetSegmentId(), true} - , fAlignment(alignment.alignment) , fRegionPtr(nullptr) , fLocalPtr(nullptr) + , fSize(0) + , fHint(0) + , fHandle(-1) + , fShared(-1) + , fRegionId(0) + , fSegmentId(fManager.GetSegmentId()) + , fAlignment(alignment.alignment) + , fManaged(true) + , fQueued(false) { InitializeChunk(size, fAlignment); fManager.IncrementMsgCounter(); @@ -86,10 +112,17 @@ class Message final : public fair::mq::Message Message(Manager& manager, void* data, const size_t size, fair::mq::FreeFn* ffn, void* hint = nullptr, fair::mq::TransportFactory* factory = nullptr) : fair::mq::Message(factory) , fManager(manager) - , fQueued(false) - , fMeta{0, 0, -1, -1, 0, fManager.GetSegmentId(), true} , fRegionPtr(nullptr) , fLocalPtr(nullptr) + , fSize(0) + , fHint(0) + , fHandle(-1) + , fShared(-1) + , fRegionId(0) + , fSegmentId(fManager.GetSegmentId()) + , fAlignment(0) + , fManaged(true) + , fQueued(false) { if (InitializeChunk(size)) { std::memcpy(fLocalPtr, data, size); @@ -105,10 +138,17 @@ class Message final : public fair::mq::Message Message(Manager& manager, UnmanagedRegionPtr& region, void* data, const size_t size, void* hint = 0, fair::mq::TransportFactory* factory = nullptr) : fair::mq::Message(factory) , fManager(manager) - , fQueued(false) - , fMeta{size, reinterpret_cast(hint), -1, -1, static_cast(region.get())->fRegionId, fManager.GetSegmentId(), false} , fRegionPtr(nullptr) , fLocalPtr(static_cast(data)) + , fSize(size) + , fHint(reinterpret_cast(hint)) + , fHandle(-1) + , fShared(-1) + , fRegionId(static_cast(region.get())->fRegionId) + , fSegmentId(fManager.GetSegmentId()) + , fAlignment(0) + , fManaged(false) + , fQueued(false) { if (region->GetType() != GetType()) { LOG(error) << "region type (" << region->GetType() << ") does not match message type (" << GetType() << ")"; @@ -117,7 +157,7 @@ class Message final : public fair::mq::Message if (reinterpret_cast(data) >= reinterpret_cast(region->GetData()) && reinterpret_cast(data) <= reinterpret_cast(region->GetData()) + region->GetSize()) { - fMeta.fHandle = (boost::interprocess::managed_shared_memory::handle_t)(reinterpret_cast(data) - reinterpret_cast(region->GetData())); + fHandle = (boost::interprocess::managed_shared_memory::handle_t)(reinterpret_cast(data) - reinterpret_cast(region->GetData())); } else { LOG(error) << "trying to create region message with data from outside the region"; throw TransportError("trying to create region message with data from outside the region"); @@ -128,10 +168,17 @@ class Message final : public fair::mq::Message Message(Manager& manager, MetaHeader& hdr, fair::mq::TransportFactory* factory = nullptr) : fair::mq::Message(factory) , fManager(manager) - , fQueued(false) - , fMeta{hdr} , fRegionPtr(nullptr) , fLocalPtr(nullptr) + , fSize(hdr.fSize) + , fHint(hdr.fHint) + , fHandle(hdr.fHandle) + , fShared(hdr.fShared) + , fRegionId(hdr.fRegionId) + , fSegmentId(hdr.fSegmentId) + , fAlignment(0) + , fManaged(hdr.fManaged) + , fQueued(false) { fManager.IncrementMsgCounter(); } @@ -187,17 +234,17 @@ class Message final : public fair::mq::Message void* GetData() const override { if (!fLocalPtr) { - if (fMeta.fManaged) { - if (fMeta.fSize > 0) { - fManager.GetSegment(fMeta.fSegmentId); - fLocalPtr = ShmHeader::UserPtr(fManager.GetAddressFromHandle(fMeta.fHandle, fMeta.fSegmentId)); + if (fManaged) { + if (fSize > 0) { + fManager.GetSegment(fSegmentId); + fLocalPtr = ShmHeader::UserPtr(fManager.GetAddressFromHandle(fHandle, fSegmentId)); } else { fLocalPtr = nullptr; } } else { - fRegionPtr = fManager.GetRegionFromCache(fMeta.fRegionId); + fRegionPtr = fManager.GetRegionFromCache(fRegionId); if (fRegionPtr) { - fLocalPtr = reinterpret_cast(fRegionPtr->GetData()) + fMeta.fHandle; + fLocalPtr = reinterpret_cast(fRegionPtr->GetData()) + fHandle; } else { // LOG(warn) << "could not get pointer from a region message"; fLocalPtr = nullptr; @@ -208,37 +255,37 @@ class Message final : public fair::mq::Message return static_cast(fLocalPtr); } - size_t GetSize() const override { return fMeta.fSize; } + size_t GetSize() const override { return fSize; } bool SetUsedSize(size_t newSize) override { - if (newSize == fMeta.fSize) { + if (newSize == fSize) { return true; } else if (newSize == 0) { Deallocate(); return true; - } else if (newSize <= fMeta.fSize) { + } else if (newSize <= fSize) { try { try { - char* oldPtr = fManager.GetAddressFromHandle(fMeta.fHandle, fMeta.fSegmentId); + char* oldPtr = fManager.GetAddressFromHandle(fHandle, fSegmentId); uint16_t userOffset = ShmHeader::UserOffset(oldPtr); - char* ptr = fManager.ShrinkInPlace(userOffset + newSize, oldPtr, fMeta.fSegmentId); + char* ptr = fManager.ShrinkInPlace(userOffset + newSize, oldPtr, fSegmentId); fLocalPtr = ShmHeader::UserPtr(ptr); - fMeta.fSize = newSize; + fSize = newSize; return true; } catch (boost::interprocess::bad_alloc& e) { // if shrinking fails (can happen due to boost alignment requirements): // unused size >= 1000000 bytes: reallocate fully // unused size < 1000000 bytes: simply reset the size and keep the rest of the buffer until message destruction - if (fMeta.fSize - newSize >= 1000000) { + if (fSize - newSize >= 1000000) { char* ptr = fManager.Allocate(newSize, fAlignment); char* userPtr = ShmHeader::UserPtr(ptr); std::memcpy(userPtr, fLocalPtr, newSize); - fManager.Deallocate(fMeta.fHandle, fMeta.fSegmentId); + fManager.Deallocate(fHandle, fSegmentId); fLocalPtr = userPtr; - fMeta.fHandle = fManager.GetHandleFromAddress(ptr, fMeta.fSegmentId); + fHandle = fManager.GetHandleFromAddress(ptr, fSegmentId); } - fMeta.fSize = newSize; + fSize = newSize; return true; } } catch (boost::interprocess::interprocess_exception& e) { @@ -255,100 +302,123 @@ class Message final : public fair::mq::Message uint16_t GetRefCount() const { - if (fMeta.fHandle < 0) { + if (fHandle < 0) { return 1; } - if (fMeta.fManaged) { // managed segment - fManager.GetSegment(fMeta.fSegmentId); - return ShmHeader::RefCount(fManager.GetAddressFromHandle(fMeta.fHandle, fMeta.fSegmentId)); + if (fManaged) { // managed segment + fManager.GetSegment(fSegmentId); + return ShmHeader::RefCount(fManager.GetAddressFromHandle(fHandle, fSegmentId)); } - if (fMeta.fShared < 0) { // UR msg is not yet shared + if (fShared < 0) { // UR msg is not yet shared return 1; } - fRegionPtr = fManager.GetRegionFromCache(fMeta.fRegionId); + fRegionPtr = fManager.GetRegionFromCache(fRegionId); if (!fRegionPtr) { - throw TransportError(tools::ToString("Cannot get unmanaged region with id ", fMeta.fRegionId)); + throw TransportError(tools::ToString("Cannot get unmanaged region with id ", fRegionId)); } - return fRegionPtr->GetRefCountAddressFromHandle(fMeta.fShared)->Get(); + return fRegionPtr->GetRefCountAddressFromHandle(fShared)->Get(); } void Copy(const fair::mq::Message& other) override { const Message& otherMsg = static_cast(other); // if the other message is not initialized, close this one too and return - if (otherMsg.fMeta.fHandle < 0) { + if (otherMsg.fHandle < 0) { CloseMessage(); return; } // if this msg is already initialized, close it first - if (fMeta.fHandle >= 0) { + if (fHandle >= 0) { CloseMessage(); } // increment ref count - if (otherMsg.fMeta.fManaged) { // msg in managed segment - fManager.GetSegment(otherMsg.fMeta.fSegmentId); - ShmHeader::IncrementRefCount(fManager.GetAddressFromHandle(otherMsg.fMeta.fHandle, otherMsg.fMeta.fSegmentId)); + if (otherMsg.fManaged) { // msg in managed segment + fManager.GetSegment(otherMsg.fSegmentId); + ShmHeader::IncrementRefCount(fManager.GetAddressFromHandle(otherMsg.fHandle, otherMsg.fSegmentId)); } else { // msg in unmanaged region - fRegionPtr = fManager.GetRegionFromCache(otherMsg.fMeta.fRegionId); + fRegionPtr = fManager.GetRegionFromCache(otherMsg.fRegionId); if (!fRegionPtr) { - throw TransportError(tools::ToString("Cannot get unmanaged region with id ", otherMsg.fMeta.fRegionId)); + throw TransportError(tools::ToString("Cannot get unmanaged region with id ", otherMsg.fRegionId)); } - if (otherMsg.fMeta.fShared < 0) { + if (otherMsg.fShared < 0) { // UR msg not yet shared, create the reference counting object with count 2 - otherMsg.fMeta.fShared = fRegionPtr->HandleFromAddress(&(fRegionPtr->MakeRefCount(2))); + otherMsg.fShared = fRegionPtr->HandleFromAddress(&(fRegionPtr->MakeRefCount(2))); } else { - fRegionPtr->GetRefCountAddressFromHandle(otherMsg.fMeta.fShared)->Increment(); + fRegionPtr->GetRefCountAddressFromHandle(otherMsg.fShared)->Increment(); } } // copy meta data - fMeta = otherMsg.fMeta; + fSize = otherMsg.fSize; + fHint = otherMsg.fHint; + fHandle = otherMsg.fHandle; + fShared = otherMsg.fShared; + fRegionId = otherMsg.fRegionId; + fSegmentId = otherMsg.fSegmentId; + fManaged = otherMsg.fManaged; } ~Message() override { CloseMessage(); } private: Manager& fManager; - bool fQueued; - MetaHeader fMeta; - size_t fAlignment; mutable UnmanagedRegion* fRegionPtr; mutable char* fLocalPtr; + size_t fSize; // size of the shm buffer + size_t fHint; // user-defined value, given by the user on message creation and returned to the user on "buffer no longer needed"-callbacks + boost::interprocess::managed_shared_memory::handle_t fHandle; // handle to shm buffer, convertible to shm buffer ptr + mutable boost::interprocess::managed_shared_memory::handle_t fShared; // handle to the buffer storing the ref count for shared buffers + uint16_t fRegionId; // id of the unmanaged region + mutable uint16_t fSegmentId; // id of the managed segment + size_t fAlignment; + bool fManaged; // true = managed segment, false = unmanaged region + bool fQueued; + + void SetMeta(const MetaHeader& meta) + { + fSize = meta.fSize; + fHint = meta.fHint; + fHandle = meta.fHandle; + fShared = meta.fShared; + fRegionId = meta.fRegionId; + fSegmentId = meta.fSegmentId; + fManaged = meta.fManaged; + } char* InitializeChunk(const size_t size, size_t alignment = 0) { if (size == 0) { - fMeta.fSize = 0; + fSize = 0; return fLocalPtr; } char* ptr = fManager.Allocate(size, alignment); - fMeta.fHandle = fManager.GetHandleFromAddress(ptr, fMeta.fSegmentId); - fMeta.fSize = size; + fHandle = fManager.GetHandleFromAddress(ptr, fSegmentId); + fSize = size; fLocalPtr = ShmHeader::UserPtr(ptr); return fLocalPtr; } void Deallocate() { - if (fMeta.fHandle >= 0 && !fQueued) { - if (fMeta.fManaged) { // managed segment - fManager.GetSegment(fMeta.fSegmentId); - uint16_t refCount = ShmHeader::DecrementRefCount(fManager.GetAddressFromHandle(fMeta.fHandle, fMeta.fSegmentId)); + if (fHandle >= 0 && !fQueued) { + if (fManaged) { // managed segment + fManager.GetSegment(fSegmentId); + uint16_t refCount = ShmHeader::DecrementRefCount(fManager.GetAddressFromHandle(fHandle, fSegmentId)); if (refCount == 1) { - fManager.Deallocate(fMeta.fHandle, fMeta.fSegmentId); + fManager.Deallocate(fHandle, fSegmentId); } } else { // unmanaged region - if (fMeta.fShared >= 0) { - fRegionPtr = fManager.GetRegionFromCache(fMeta.fRegionId); + if (fShared >= 0) { + fRegionPtr = fManager.GetRegionFromCache(fRegionId); if (!fRegionPtr) { - throw TransportError(tools::ToString("Cannot get unmanaged region with id ", fMeta.fRegionId)); + throw TransportError(tools::ToString("Cannot get unmanaged region with id ", fRegionId)); } - uint16_t refCount = fRegionPtr->GetRefCountAddressFromHandle(fMeta.fShared)->Decrement(); + uint16_t refCount = fRegionPtr->GetRefCountAddressFromHandle(fShared)->Decrement(); if (refCount == 1) { - fRegionPtr->RemoveRefCount(*(fRegionPtr->GetRefCountAddressFromHandle(fMeta.fShared))); + fRegionPtr->RemoveRefCount(*(fRegionPtr->GetRefCountAddressFromHandle(fShared))); ReleaseUnmanagedRegionBlock(); } } else { @@ -356,21 +426,21 @@ class Message final : public fair::mq::Message } } } - fMeta.fHandle = -1; + fHandle = -1; fLocalPtr = nullptr; - fMeta.fSize = 0; + fSize = 0; } void ReleaseUnmanagedRegionBlock() { if (!fRegionPtr) { - fRegionPtr = fManager.GetRegionFromCache(fMeta.fRegionId); + fRegionPtr = fManager.GetRegionFromCache(fRegionId); } if (fRegionPtr) { - fRegionPtr->ReleaseBlock({fMeta.fHandle, fMeta.fSize, fMeta.fHint}); + fRegionPtr->ReleaseBlock({fHandle, fSize, fHint}); } else { - LOG(warn) << "region ack queue for id " << fMeta.fRegionId << " no longer exist. Not sending ack"; + LOG(warn) << "region ack queue for id " << fRegionId << " no longer exist. Not sending ack"; } } diff --git a/fairmq/shmem/Socket.h b/fairmq/shmem/Socket.h index 83b99af3..5aa579e5 100644 --- a/fairmq/shmem/Socket.h +++ b/fairmq/shmem/Socket.h @@ -129,9 +129,11 @@ class Socket final : public fair::mq::Socket } int elapsed = 0; + MetaHeader meta{ shmMsg->fSize, shmMsg->fHint, shmMsg->fHandle, shmMsg->fShared, shmMsg->fRegionId, shmMsg->fSegmentId, shmMsg->fManaged }; + // meta msg format: | MetaHeader | padded to fMetadataMsgSize | zmq::ZMsg zmqMsg(std::max(fMetadataMsgSize, sizeof(MetaHeader))); - std::memcpy(zmqMsg.Data(), &(shmMsg->fMeta), sizeof(MetaHeader)); + std::memcpy(zmqMsg.Data(), &meta, sizeof(MetaHeader)); while (true) { int nbytes = zmq_msg_send(zmqMsg.Msg(), fSocket, flags); @@ -167,7 +169,8 @@ class Socket final : public fair::mq::Socket while (true) { Message* shmMsg = static_cast(msg.get()); - int nbytes = zmq_recv(fSocket, &(shmMsg->fMeta), sizeof(MetaHeader), flags); + MetaHeader meta; + int nbytes = zmq_recv(fSocket, &meta, sizeof(MetaHeader), flags); if (nbytes > 0) { // check for number of received messages. must be 1 if (static_cast(nbytes) < sizeof(MetaHeader)) { @@ -177,6 +180,8 @@ class Socket final : public fair::mq::Socket "Expected minimum size of ", sizeof(MetaHeader), " bytes, received ", nbytes)); } + shmMsg->SetMeta(meta); + size_t size = shmMsg->GetSize(); fBytesRx += size; ++fMessagesRx; @@ -218,7 +223,8 @@ class Socket final : public fair::mq::Socket } assertm(dynamic_cast(msgPtr), "given mq::Message is a shmem::Message"); // NOLINT auto shmMsg = static_cast(msgPtr); // NOLINT(cppcoreguidelines-pro-type-static-cast-downcast) - std::memcpy(metas++, &(shmMsg->fMeta), sizeof(MetaHeader)); + MetaHeader meta{ shmMsg->fSize, shmMsg->fHint, shmMsg->fHandle, shmMsg->fShared, shmMsg->fRegionId, shmMsg->fSegmentId, shmMsg->fManaged }; + std::memcpy(metas++, &meta, sizeof(MetaHeader)); } while (true) { @@ -230,7 +236,7 @@ class Socket final : public fair::mq::Socket for (auto& msg : msgVec) { Message* shmMsg = static_cast(msg.get()); shmMsg->fQueued = true; - totalSize += shmMsg->fMeta.fSize; + totalSize += shmMsg->fSize; } // store statistics on how many messages have been sent