Refactor shm::Message to contain sorted members of MetaHeader

Move the members of MetaHeader flat into shmem::Message and sort them by
size to reduce the size of the class.
This commit is contained in:
Alexey Rybalchenko 2023-10-17 11:36:16 +02:00
parent f092b94c96
commit 1b7532a520
2 changed files with 153 additions and 77 deletions

View File

@ -38,10 +38,17 @@ class Message final : public fair::mq::Message
Message(Manager& manager, fair::mq::TransportFactory* factory = nullptr)
: fair::mq::Message(factory)
, fManager(manager)
, fQueued(false)
, fMeta{0, 0, -1, -1, 0, fManager.GetSegmentId(), true}
, fRegionPtr(nullptr)
, fLocalPtr(nullptr)
, fSize(0)
, fHint(0)
, fHandle(-1)
, fShared(-1)
, fRegionId(0)
, fSegmentId(fManager.GetSegmentId())
, fAlignment(0)
, fManaged(true)
, fQueued(false)
{
fManager.IncrementMsgCounter();
}
@ -49,11 +56,17 @@ class Message final : public fair::mq::Message
Message(Manager& manager, Alignment alignment, fair::mq::TransportFactory* factory = nullptr)
: fair::mq::Message(factory)
, fManager(manager)
, fQueued(false)
, fMeta{0, 0, -1, -1, 0, fManager.GetSegmentId(), true}
, fAlignment(alignment.alignment)
, fRegionPtr(nullptr)
, fLocalPtr(nullptr)
, fSize(0)
, fHint(0)
, fHandle(-1)
, fShared(-1)
, fRegionId(0)
, fSegmentId(fManager.GetSegmentId())
, fAlignment(alignment.alignment)
, fManaged(true)
, fQueued(false)
{
fManager.IncrementMsgCounter();
}
@ -61,10 +74,17 @@ class Message final : public fair::mq::Message
Message(Manager& manager, const size_t size, fair::mq::TransportFactory* factory = nullptr)
: fair::mq::Message(factory)
, fManager(manager)
, fQueued(false)
, fMeta{0, 0, -1, -1, 0, fManager.GetSegmentId(), true}
, fRegionPtr(nullptr)
, fLocalPtr(nullptr)
, fSize(0)
, fHint(0)
, fHandle(-1)
, fShared(-1)
, fRegionId(0)
, fSegmentId(fManager.GetSegmentId())
, fAlignment(0)
, fManaged(true)
, fQueued(false)
{
InitializeChunk(size);
fManager.IncrementMsgCounter();
@ -73,11 +93,17 @@ class Message final : public fair::mq::Message
Message(Manager& manager, const size_t size, Alignment alignment, fair::mq::TransportFactory* factory = nullptr)
: fair::mq::Message(factory)
, fManager(manager)
, fQueued(false)
, fMeta{0, 0, -1, -1, 0, fManager.GetSegmentId(), true}
, fAlignment(alignment.alignment)
, fRegionPtr(nullptr)
, fLocalPtr(nullptr)
, fSize(0)
, fHint(0)
, fHandle(-1)
, fShared(-1)
, fRegionId(0)
, fSegmentId(fManager.GetSegmentId())
, fAlignment(alignment.alignment)
, fManaged(true)
, fQueued(false)
{
InitializeChunk(size, fAlignment);
fManager.IncrementMsgCounter();
@ -86,10 +112,17 @@ class Message final : public fair::mq::Message
Message(Manager& manager, void* data, const size_t size, fair::mq::FreeFn* ffn, void* hint = nullptr, fair::mq::TransportFactory* factory = nullptr)
: fair::mq::Message(factory)
, fManager(manager)
, fQueued(false)
, fMeta{0, 0, -1, -1, 0, fManager.GetSegmentId(), true}
, fRegionPtr(nullptr)
, fLocalPtr(nullptr)
, fSize(0)
, fHint(0)
, fHandle(-1)
, fShared(-1)
, fRegionId(0)
, fSegmentId(fManager.GetSegmentId())
, fAlignment(0)
, fManaged(true)
, fQueued(false)
{
if (InitializeChunk(size)) {
std::memcpy(fLocalPtr, data, size);
@ -105,10 +138,17 @@ class Message final : public fair::mq::Message
Message(Manager& manager, UnmanagedRegionPtr& region, void* data, const size_t size, void* hint = 0, fair::mq::TransportFactory* factory = nullptr)
: fair::mq::Message(factory)
, fManager(manager)
, fQueued(false)
, fMeta{size, reinterpret_cast<size_t>(hint), -1, -1, static_cast<UnmanagedRegionImpl*>(region.get())->fRegionId, fManager.GetSegmentId(), false}
, fRegionPtr(nullptr)
, fLocalPtr(static_cast<char*>(data))
, fSize(size)
, fHint(reinterpret_cast<size_t>(hint))
, fHandle(-1)
, fShared(-1)
, fRegionId(static_cast<UnmanagedRegionImpl*>(region.get())->fRegionId)
, fSegmentId(fManager.GetSegmentId())
, fAlignment(0)
, fManaged(false)
, fQueued(false)
{
if (region->GetType() != GetType()) {
LOG(error) << "region type (" << region->GetType() << ") does not match message type (" << GetType() << ")";
@ -117,7 +157,7 @@ class Message final : public fair::mq::Message
if (reinterpret_cast<const char*>(data) >= reinterpret_cast<const char*>(region->GetData()) &&
reinterpret_cast<const char*>(data) <= reinterpret_cast<const char*>(region->GetData()) + region->GetSize()) {
fMeta.fHandle = (boost::interprocess::managed_shared_memory::handle_t)(reinterpret_cast<const char*>(data) - reinterpret_cast<const char*>(region->GetData()));
fHandle = (boost::interprocess::managed_shared_memory::handle_t)(reinterpret_cast<const char*>(data) - reinterpret_cast<const char*>(region->GetData()));
} else {
LOG(error) << "trying to create region message with data from outside the region";
throw TransportError("trying to create region message with data from outside the region");
@ -128,10 +168,17 @@ class Message final : public fair::mq::Message
Message(Manager& manager, MetaHeader& hdr, fair::mq::TransportFactory* factory = nullptr)
: fair::mq::Message(factory)
, fManager(manager)
, fQueued(false)
, fMeta{hdr}
, fRegionPtr(nullptr)
, fLocalPtr(nullptr)
, fSize(hdr.fSize)
, fHint(hdr.fHint)
, fHandle(hdr.fHandle)
, fShared(hdr.fShared)
, fRegionId(hdr.fRegionId)
, fSegmentId(hdr.fSegmentId)
, fAlignment(0)
, fManaged(hdr.fManaged)
, fQueued(false)
{
fManager.IncrementMsgCounter();
}
@ -187,17 +234,17 @@ class Message final : public fair::mq::Message
void* GetData() const override
{
if (!fLocalPtr) {
if (fMeta.fManaged) {
if (fMeta.fSize > 0) {
fManager.GetSegment(fMeta.fSegmentId);
fLocalPtr = ShmHeader::UserPtr(fManager.GetAddressFromHandle(fMeta.fHandle, fMeta.fSegmentId));
if (fManaged) {
if (fSize > 0) {
fManager.GetSegment(fSegmentId);
fLocalPtr = ShmHeader::UserPtr(fManager.GetAddressFromHandle(fHandle, fSegmentId));
} else {
fLocalPtr = nullptr;
}
} else {
fRegionPtr = fManager.GetRegionFromCache(fMeta.fRegionId);
fRegionPtr = fManager.GetRegionFromCache(fRegionId);
if (fRegionPtr) {
fLocalPtr = reinterpret_cast<char*>(fRegionPtr->GetData()) + fMeta.fHandle;
fLocalPtr = reinterpret_cast<char*>(fRegionPtr->GetData()) + fHandle;
} else {
// LOG(warn) << "could not get pointer from a region message";
fLocalPtr = nullptr;
@ -208,37 +255,37 @@ class Message final : public fair::mq::Message
return static_cast<void*>(fLocalPtr);
}
size_t GetSize() const override { return fMeta.fSize; }
size_t GetSize() const override { return fSize; }
bool SetUsedSize(size_t newSize) override
{
if (newSize == fMeta.fSize) {
if (newSize == fSize) {
return true;
} else if (newSize == 0) {
Deallocate();
return true;
} else if (newSize <= fMeta.fSize) {
} else if (newSize <= fSize) {
try {
try {
char* oldPtr = fManager.GetAddressFromHandle(fMeta.fHandle, fMeta.fSegmentId);
char* oldPtr = fManager.GetAddressFromHandle(fHandle, fSegmentId);
uint16_t userOffset = ShmHeader::UserOffset(oldPtr);
char* ptr = fManager.ShrinkInPlace(userOffset + newSize, oldPtr, fMeta.fSegmentId);
char* ptr = fManager.ShrinkInPlace(userOffset + newSize, oldPtr, fSegmentId);
fLocalPtr = ShmHeader::UserPtr(ptr);
fMeta.fSize = newSize;
fSize = newSize;
return true;
} catch (boost::interprocess::bad_alloc& e) {
// if shrinking fails (can happen due to boost alignment requirements):
// unused size >= 1000000 bytes: reallocate fully
// unused size < 1000000 bytes: simply reset the size and keep the rest of the buffer until message destruction
if (fMeta.fSize - newSize >= 1000000) {
if (fSize - newSize >= 1000000) {
char* ptr = fManager.Allocate(newSize, fAlignment);
char* userPtr = ShmHeader::UserPtr(ptr);
std::memcpy(userPtr, fLocalPtr, newSize);
fManager.Deallocate(fMeta.fHandle, fMeta.fSegmentId);
fManager.Deallocate(fHandle, fSegmentId);
fLocalPtr = userPtr;
fMeta.fHandle = fManager.GetHandleFromAddress(ptr, fMeta.fSegmentId);
fHandle = fManager.GetHandleFromAddress(ptr, fSegmentId);
}
fMeta.fSize = newSize;
fSize = newSize;
return true;
}
} catch (boost::interprocess::interprocess_exception& e) {
@ -255,100 +302,123 @@ class Message final : public fair::mq::Message
uint16_t GetRefCount() const
{
if (fMeta.fHandle < 0) {
if (fHandle < 0) {
return 1;
}
if (fMeta.fManaged) { // managed segment
fManager.GetSegment(fMeta.fSegmentId);
return ShmHeader::RefCount(fManager.GetAddressFromHandle(fMeta.fHandle, fMeta.fSegmentId));
if (fManaged) { // managed segment
fManager.GetSegment(fSegmentId);
return ShmHeader::RefCount(fManager.GetAddressFromHandle(fHandle, fSegmentId));
}
if (fMeta.fShared < 0) { // UR msg is not yet shared
if (fShared < 0) { // UR msg is not yet shared
return 1;
}
fRegionPtr = fManager.GetRegionFromCache(fMeta.fRegionId);
fRegionPtr = fManager.GetRegionFromCache(fRegionId);
if (!fRegionPtr) {
throw TransportError(tools::ToString("Cannot get unmanaged region with id ", fMeta.fRegionId));
throw TransportError(tools::ToString("Cannot get unmanaged region with id ", fRegionId));
}
return fRegionPtr->GetRefCountAddressFromHandle(fMeta.fShared)->Get();
return fRegionPtr->GetRefCountAddressFromHandle(fShared)->Get();
}
void Copy(const fair::mq::Message& other) override
{
const Message& otherMsg = static_cast<const Message&>(other);
// if the other message is not initialized, close this one too and return
if (otherMsg.fMeta.fHandle < 0) {
if (otherMsg.fHandle < 0) {
CloseMessage();
return;
}
// if this msg is already initialized, close it first
if (fMeta.fHandle >= 0) {
if (fHandle >= 0) {
CloseMessage();
}
// increment ref count
if (otherMsg.fMeta.fManaged) { // msg in managed segment
fManager.GetSegment(otherMsg.fMeta.fSegmentId);
ShmHeader::IncrementRefCount(fManager.GetAddressFromHandle(otherMsg.fMeta.fHandle, otherMsg.fMeta.fSegmentId));
if (otherMsg.fManaged) { // msg in managed segment
fManager.GetSegment(otherMsg.fSegmentId);
ShmHeader::IncrementRefCount(fManager.GetAddressFromHandle(otherMsg.fHandle, otherMsg.fSegmentId));
} else { // msg in unmanaged region
fRegionPtr = fManager.GetRegionFromCache(otherMsg.fMeta.fRegionId);
fRegionPtr = fManager.GetRegionFromCache(otherMsg.fRegionId);
if (!fRegionPtr) {
throw TransportError(tools::ToString("Cannot get unmanaged region with id ", otherMsg.fMeta.fRegionId));
throw TransportError(tools::ToString("Cannot get unmanaged region with id ", otherMsg.fRegionId));
}
if (otherMsg.fMeta.fShared < 0) {
if (otherMsg.fShared < 0) {
// UR msg not yet shared, create the reference counting object with count 2
otherMsg.fMeta.fShared = fRegionPtr->HandleFromAddress(&(fRegionPtr->MakeRefCount(2)));
otherMsg.fShared = fRegionPtr->HandleFromAddress(&(fRegionPtr->MakeRefCount(2)));
} else {
fRegionPtr->GetRefCountAddressFromHandle(otherMsg.fMeta.fShared)->Increment();
fRegionPtr->GetRefCountAddressFromHandle(otherMsg.fShared)->Increment();
}
}
// copy meta data
fMeta = otherMsg.fMeta;
fSize = otherMsg.fSize;
fHint = otherMsg.fHint;
fHandle = otherMsg.fHandle;
fShared = otherMsg.fShared;
fRegionId = otherMsg.fRegionId;
fSegmentId = otherMsg.fSegmentId;
fManaged = otherMsg.fManaged;
}
~Message() override { CloseMessage(); }
private:
Manager& fManager;
bool fQueued;
MetaHeader fMeta;
size_t fAlignment;
mutable UnmanagedRegion* fRegionPtr;
mutable char* fLocalPtr;
size_t fSize; // size of the shm buffer
size_t fHint; // user-defined value, given by the user on message creation and returned to the user on "buffer no longer needed"-callbacks
boost::interprocess::managed_shared_memory::handle_t fHandle; // handle to shm buffer, convertible to shm buffer ptr
mutable boost::interprocess::managed_shared_memory::handle_t fShared; // handle to the buffer storing the ref count for shared buffers
uint16_t fRegionId; // id of the unmanaged region
mutable uint16_t fSegmentId; // id of the managed segment
size_t fAlignment;
bool fManaged; // true = managed segment, false = unmanaged region
bool fQueued;
void SetMeta(const MetaHeader& meta)
{
fSize = meta.fSize;
fHint = meta.fHint;
fHandle = meta.fHandle;
fShared = meta.fShared;
fRegionId = meta.fRegionId;
fSegmentId = meta.fSegmentId;
fManaged = meta.fManaged;
}
char* InitializeChunk(const size_t size, size_t alignment = 0)
{
if (size == 0) {
fMeta.fSize = 0;
fSize = 0;
return fLocalPtr;
}
char* ptr = fManager.Allocate(size, alignment);
fMeta.fHandle = fManager.GetHandleFromAddress(ptr, fMeta.fSegmentId);
fMeta.fSize = size;
fHandle = fManager.GetHandleFromAddress(ptr, fSegmentId);
fSize = size;
fLocalPtr = ShmHeader::UserPtr(ptr);
return fLocalPtr;
}
void Deallocate()
{
if (fMeta.fHandle >= 0 && !fQueued) {
if (fMeta.fManaged) { // managed segment
fManager.GetSegment(fMeta.fSegmentId);
uint16_t refCount = ShmHeader::DecrementRefCount(fManager.GetAddressFromHandle(fMeta.fHandle, fMeta.fSegmentId));
if (fHandle >= 0 && !fQueued) {
if (fManaged) { // managed segment
fManager.GetSegment(fSegmentId);
uint16_t refCount = ShmHeader::DecrementRefCount(fManager.GetAddressFromHandle(fHandle, fSegmentId));
if (refCount == 1) {
fManager.Deallocate(fMeta.fHandle, fMeta.fSegmentId);
fManager.Deallocate(fHandle, fSegmentId);
}
} else { // unmanaged region
if (fMeta.fShared >= 0) {
fRegionPtr = fManager.GetRegionFromCache(fMeta.fRegionId);
if (fShared >= 0) {
fRegionPtr = fManager.GetRegionFromCache(fRegionId);
if (!fRegionPtr) {
throw TransportError(tools::ToString("Cannot get unmanaged region with id ", fMeta.fRegionId));
throw TransportError(tools::ToString("Cannot get unmanaged region with id ", fRegionId));
}
uint16_t refCount = fRegionPtr->GetRefCountAddressFromHandle(fMeta.fShared)->Decrement();
uint16_t refCount = fRegionPtr->GetRefCountAddressFromHandle(fShared)->Decrement();
if (refCount == 1) {
fRegionPtr->RemoveRefCount(*(fRegionPtr->GetRefCountAddressFromHandle(fMeta.fShared)));
fRegionPtr->RemoveRefCount(*(fRegionPtr->GetRefCountAddressFromHandle(fShared)));
ReleaseUnmanagedRegionBlock();
}
} else {
@ -356,21 +426,21 @@ class Message final : public fair::mq::Message
}
}
}
fMeta.fHandle = -1;
fHandle = -1;
fLocalPtr = nullptr;
fMeta.fSize = 0;
fSize = 0;
}
void ReleaseUnmanagedRegionBlock()
{
if (!fRegionPtr) {
fRegionPtr = fManager.GetRegionFromCache(fMeta.fRegionId);
fRegionPtr = fManager.GetRegionFromCache(fRegionId);
}
if (fRegionPtr) {
fRegionPtr->ReleaseBlock({fMeta.fHandle, fMeta.fSize, fMeta.fHint});
fRegionPtr->ReleaseBlock({fHandle, fSize, fHint});
} else {
LOG(warn) << "region ack queue for id " << fMeta.fRegionId << " no longer exist. Not sending ack";
LOG(warn) << "region ack queue for id " << fRegionId << " no longer exist. Not sending ack";
}
}

View File

@ -129,9 +129,11 @@ class Socket final : public fair::mq::Socket
}
int elapsed = 0;
MetaHeader meta{ shmMsg->fSize, shmMsg->fHint, shmMsg->fHandle, shmMsg->fShared, shmMsg->fRegionId, shmMsg->fSegmentId, shmMsg->fManaged };
// meta msg format: | MetaHeader | padded to fMetadataMsgSize |
zmq::ZMsg zmqMsg(std::max(fMetadataMsgSize, sizeof(MetaHeader)));
std::memcpy(zmqMsg.Data(), &(shmMsg->fMeta), sizeof(MetaHeader));
std::memcpy(zmqMsg.Data(), &meta, sizeof(MetaHeader));
while (true) {
int nbytes = zmq_msg_send(zmqMsg.Msg(), fSocket, flags);
@ -167,7 +169,8 @@ class Socket final : public fair::mq::Socket
while (true) {
Message* shmMsg = static_cast<Message*>(msg.get());
int nbytes = zmq_recv(fSocket, &(shmMsg->fMeta), sizeof(MetaHeader), flags);
MetaHeader meta;
int nbytes = zmq_recv(fSocket, &meta, sizeof(MetaHeader), flags);
if (nbytes > 0) {
// check for number of received messages. must be 1
if (static_cast<std::size_t>(nbytes) < sizeof(MetaHeader)) {
@ -177,6 +180,8 @@ class Socket final : public fair::mq::Socket
"Expected minimum size of ", sizeof(MetaHeader), " bytes, received ", nbytes));
}
shmMsg->SetMeta(meta);
size_t size = shmMsg->GetSize();
fBytesRx += size;
++fMessagesRx;
@ -218,7 +223,8 @@ class Socket final : public fair::mq::Socket
}
assertm(dynamic_cast<shmem::Message*>(msgPtr), "given mq::Message is a shmem::Message"); // NOLINT
auto shmMsg = static_cast<shmem::Message*>(msgPtr); // NOLINT(cppcoreguidelines-pro-type-static-cast-downcast)
std::memcpy(metas++, &(shmMsg->fMeta), sizeof(MetaHeader));
MetaHeader meta{ shmMsg->fSize, shmMsg->fHint, shmMsg->fHandle, shmMsg->fShared, shmMsg->fRegionId, shmMsg->fSegmentId, shmMsg->fManaged };
std::memcpy(metas++, &meta, sizeof(MetaHeader));
}
while (true) {
@ -230,7 +236,7 @@ class Socket final : public fair::mq::Socket
for (auto& msg : msgVec) {
Message* shmMsg = static_cast<Message*>(msg.get());
shmMsg->fQueued = true;
totalSize += shmMsg->fMeta.fSize;
totalSize += shmMsg->fSize;
}
// store statistics on how many messages have been sent