mirror of
https://github.com/FairRootGroup/FairMQ.git
synced 2025-10-16 18:11:49 +00:00
zmq: implement alignment
This commit is contained in:
@@ -39,7 +39,9 @@ class FairMQMessage
|
||||
FairMQMessage(FairMQTransportFactory* factory) : fTransport(factory) {}
|
||||
|
||||
virtual void Rebuild() = 0;
|
||||
virtual void Rebuild(fair::mq::Alignment alignment) = 0;
|
||||
virtual void Rebuild(const size_t size) = 0;
|
||||
virtual void Rebuild(const size_t size, fair::mq::Alignment alignment) = 0;
|
||||
virtual void Rebuild(void* data, const size_t size, fairmq_free_fn* ffn, void* hint = nullptr) = 0;
|
||||
|
||||
virtual void* GetData() const = 0;
|
||||
|
@@ -110,6 +110,12 @@ auto Message::Rebuild() -> void
|
||||
fHint = nullptr;
|
||||
}
|
||||
|
||||
auto Message::Rebuild(Alignment /* alignment */) -> void
|
||||
{
|
||||
// TODO: implement alignment
|
||||
Rebuild();
|
||||
}
|
||||
|
||||
auto Message::Rebuild(const size_t size) -> void
|
||||
{
|
||||
if (fFreeFunction) {
|
||||
@@ -131,6 +137,12 @@ auto Message::Rebuild(const size_t size) -> void
|
||||
fHint = nullptr;
|
||||
}
|
||||
|
||||
auto Message::Rebuild(const size_t size, Alignment /* alignment */) -> void
|
||||
{
|
||||
// TODO: implement alignment
|
||||
Rebuild(size);
|
||||
}
|
||||
|
||||
auto Message::Rebuild(void* /*data*/, const size_t size, fairmq_free_fn* ffn, void* hint) -> void
|
||||
{
|
||||
if (fFreeFunction) {
|
||||
|
@@ -52,7 +52,9 @@ class Message final : public fair::mq::Message
|
||||
Message operator=(const Message&) = delete;
|
||||
|
||||
auto Rebuild() -> void override;
|
||||
auto Rebuild(Alignment alignment) -> void override;
|
||||
auto Rebuild(const size_t size) -> void override;
|
||||
auto Rebuild(const size_t size, Alignment alignment) -> void override;
|
||||
auto Rebuild(void* data, const size_t size, fairmq_free_fn* ffn, void* hint = nullptr) -> void override;
|
||||
|
||||
auto GetData() const -> void* override;
|
||||
|
@@ -50,11 +50,12 @@ class Message final : public fair::mq::Message
|
||||
fManager.IncrementMsgCounter();
|
||||
}
|
||||
|
||||
Message(Manager& manager, Alignment /* alignment */, FairMQTransportFactory* factory = nullptr)
|
||||
Message(Manager& manager, Alignment alignment, FairMQTransportFactory* factory = nullptr)
|
||||
: fair::mq::Message(factory)
|
||||
, fManager(manager)
|
||||
, fQueued(false)
|
||||
, fMeta{0, 0, 0, fManager.GetSegmentId(), -1}
|
||||
, fAlignment(alignment.alignment)
|
||||
, fRegionPtr(nullptr)
|
||||
, fLocalPtr(nullptr)
|
||||
{
|
||||
@@ -78,10 +79,11 @@ class Message final : public fair::mq::Message
|
||||
, fManager(manager)
|
||||
, fQueued(false)
|
||||
, fMeta{0, 0, 0, fManager.GetSegmentId(), -1}
|
||||
, fAlignment(alignment.alignment)
|
||||
, fRegionPtr(nullptr)
|
||||
, fLocalPtr(nullptr)
|
||||
{
|
||||
InitializeChunk(size, static_cast<size_t>(alignment));
|
||||
InitializeChunk(size, fAlignment);
|
||||
fManager.IncrementMsgCounter();
|
||||
}
|
||||
|
||||
@@ -142,6 +144,13 @@ class Message final : public fair::mq::Message
|
||||
fQueued = false;
|
||||
}
|
||||
|
||||
void Rebuild(Alignment alignment) override
|
||||
{
|
||||
CloseMessage();
|
||||
fQueued = false;
|
||||
fAlignment = alignment.alignment;
|
||||
}
|
||||
|
||||
void Rebuild(const size_t size) override
|
||||
{
|
||||
CloseMessage();
|
||||
@@ -149,6 +158,14 @@ class Message final : public fair::mq::Message
|
||||
InitializeChunk(size);
|
||||
}
|
||||
|
||||
void Rebuild(const size_t size, Alignment alignment) override
|
||||
{
|
||||
CloseMessage();
|
||||
fQueued = false;
|
||||
fAlignment = alignment.alignment;
|
||||
InitializeChunk(size, fAlignment);
|
||||
}
|
||||
|
||||
void Rebuild(void* data, const size_t size, fairmq_free_fn* ffn, void* hint = nullptr) override
|
||||
{
|
||||
CloseMessage();
|
||||
@@ -242,6 +259,7 @@ class Message final : public fair::mq::Message
|
||||
Manager& fManager;
|
||||
bool fQueued;
|
||||
MetaHeader fMeta;
|
||||
size_t fAlignment; // TODO: put this to debug mode
|
||||
mutable Region* fRegionPtr;
|
||||
mutable char* fLocalPtr;
|
||||
|
||||
@@ -276,8 +294,9 @@ class Message final : public fair::mq::Message
|
||||
}
|
||||
fLocalPtr = nullptr;
|
||||
fMeta.fSize = 0;
|
||||
fAlignment = 0;
|
||||
|
||||
fManager.DecrementMsgCounter();
|
||||
fManager.DecrementMsgCounter(); // TODO: put this to debug mode
|
||||
}
|
||||
};
|
||||
|
||||
|
@@ -18,8 +18,10 @@
|
||||
#include <zmq.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdlib> // malloc
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
#include <new> // bad_alloc
|
||||
#include <string>
|
||||
|
||||
namespace fair
|
||||
@@ -38,14 +40,17 @@ class Message final : public fair::mq::Message
|
||||
public:
|
||||
Message(FairMQTransportFactory* factory = nullptr)
|
||||
: fair::mq::Message(factory)
|
||||
, fAlignment(0)
|
||||
, fMsg(tools::make_unique<zmq_msg_t>())
|
||||
{
|
||||
if (zmq_msg_init(fMsg.get()) != 0) {
|
||||
LOG(error) << "failed initializing message, reason: " << zmq_strerror(errno);
|
||||
}
|
||||
}
|
||||
Message(Alignment /* alignment */, FairMQTransportFactory* factory = nullptr)
|
||||
|
||||
Message(Alignment alignment, FairMQTransportFactory* factory = nullptr)
|
||||
: fair::mq::Message(factory)
|
||||
, fAlignment(alignment.alignment)
|
||||
, fMsg(tools::make_unique<zmq_msg_t>())
|
||||
{
|
||||
if (zmq_msg_init(fMsg.get()) != 0) {
|
||||
@@ -55,6 +60,7 @@ class Message final : public fair::mq::Message
|
||||
|
||||
Message(const size_t size, FairMQTransportFactory* factory = nullptr)
|
||||
: fair::mq::Message(factory)
|
||||
, fAlignment(0)
|
||||
, fMsg(tools::make_unique<zmq_msg_t>())
|
||||
{
|
||||
if (zmq_msg_init_size(fMsg.get(), size) != 0) {
|
||||
@@ -62,17 +68,40 @@ class Message final : public fair::mq::Message
|
||||
}
|
||||
}
|
||||
|
||||
Message(const size_t size, Alignment /* alignment */, FairMQTransportFactory* factory = nullptr)
|
||||
static std::pair<void*, void*> AllocateAligned(size_t size, size_t alignment)
|
||||
{
|
||||
char* fullBufferPtr = static_cast<char*>(malloc(size + alignment));
|
||||
if (!fullBufferPtr) {
|
||||
LOG(error) << "failed to allocate buffer with provided size (" << size << ") and alignment (" << alignment << ").";
|
||||
throw std::bad_alloc();
|
||||
}
|
||||
|
||||
size_t offset = alignment - (reinterpret_cast<uintptr_t>(fullBufferPtr) % alignment);
|
||||
char* alignedPartPtr = fullBufferPtr + offset;
|
||||
|
||||
return {static_cast<void*>(fullBufferPtr), static_cast<void*>(alignedPartPtr)};
|
||||
}
|
||||
|
||||
Message(const size_t size, Alignment alignment, FairMQTransportFactory* factory = nullptr)
|
||||
: fair::mq::Message(factory)
|
||||
, fAlignment(alignment.alignment)
|
||||
, fMsg(tools::make_unique<zmq_msg_t>())
|
||||
{
|
||||
if (zmq_msg_init_size(fMsg.get(), size) != 0) {
|
||||
LOG(error) << "failed initializing message with size, reason: " << zmq_strerror(errno);
|
||||
if (fAlignment != 0) {
|
||||
auto ptrs = AllocateAligned(size, fAlignment);
|
||||
if (zmq_msg_init_data(fMsg.get(), ptrs.second, size, [](void* /* data */, void* hint) { free(hint); }, ptrs.first) != 0) {
|
||||
LOG(error) << "failed initializing message with size, reason: " << zmq_strerror(errno);
|
||||
}
|
||||
} else {
|
||||
if (zmq_msg_init_size(fMsg.get(), size) != 0) {
|
||||
LOG(error) << "failed initializing message with size, reason: " << zmq_strerror(errno);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Message(void* data, const size_t size, fairmq_free_fn* ffn, void* hint = nullptr, FairMQTransportFactory* factory = nullptr)
|
||||
: fair::mq::Message(factory)
|
||||
, fAlignment(0)
|
||||
, fMsg(tools::make_unique<zmq_msg_t>())
|
||||
{
|
||||
if (zmq_msg_init_data(fMsg.get(), data, size, ffn, hint) != 0) {
|
||||
@@ -82,6 +111,7 @@ class Message final : public fair::mq::Message
|
||||
|
||||
Message(UnmanagedRegionPtr& region, void* data, const size_t size, void* hint = 0, FairMQTransportFactory* factory = nullptr)
|
||||
: fair::mq::Message(factory)
|
||||
, fAlignment(0)
|
||||
, fMsg(tools::make_unique<zmq_msg_t>())
|
||||
{
|
||||
// FIXME: make this zero-copy:
|
||||
@@ -116,6 +146,16 @@ class Message final : public fair::mq::Message
|
||||
}
|
||||
}
|
||||
|
||||
void Rebuild(Alignment alignment) override
|
||||
{
|
||||
CloseMessage();
|
||||
fAlignment = alignment.alignment;
|
||||
fMsg = tools::make_unique<zmq_msg_t>();
|
||||
if (zmq_msg_init(fMsg.get()) != 0) {
|
||||
LOG(error) << "failed initializing message, reason: " << zmq_strerror(errno);
|
||||
}
|
||||
}
|
||||
|
||||
void Rebuild(const size_t size) override
|
||||
{
|
||||
CloseMessage();
|
||||
@@ -125,6 +165,24 @@ class Message final : public fair::mq::Message
|
||||
}
|
||||
}
|
||||
|
||||
void Rebuild(const size_t size, Alignment alignment) override
|
||||
{
|
||||
CloseMessage();
|
||||
fAlignment = alignment.alignment;
|
||||
fMsg = tools::make_unique<zmq_msg_t>();
|
||||
|
||||
if (fAlignment != 0) {
|
||||
auto ptrs = AllocateAligned(size, fAlignment);
|
||||
if (zmq_msg_init_data(fMsg.get(), ptrs.second, size, [](void* /* data */, void* hint) { free(hint); }, ptrs.first) != 0) {
|
||||
LOG(error) << "failed initializing message with size, reason: " << zmq_strerror(errno);
|
||||
}
|
||||
} else {
|
||||
if (zmq_msg_init_size(fMsg.get(), size) != 0) {
|
||||
LOG(error) << "failed initializing message with size, reason: " << zmq_strerror(errno);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Rebuild(void* data, const size_t size, fairmq_free_fn* ffn, void* hint = nullptr) override
|
||||
{
|
||||
CloseMessage();
|
||||
@@ -174,6 +232,23 @@ class Message final : public fair::mq::Message
|
||||
}
|
||||
}
|
||||
|
||||
void Realign()
|
||||
{
|
||||
// if alignment is provided
|
||||
if (fAlignment != 0) {
|
||||
void* data = GetData();
|
||||
size_t size = GetSize();
|
||||
// if buffer is valid && not already aligned with the given alignment
|
||||
if (data != nullptr && reinterpret_cast<uintptr_t>(GetData()) % fAlignment) {
|
||||
// create new aligned buffer
|
||||
auto ptrs = AllocateAligned(size, fAlignment);
|
||||
std::memcpy(ptrs.second, zmq_msg_data(fMsg.get()), size);
|
||||
// rebuild the message with the new buffer
|
||||
Rebuild(ptrs.second, size, [](void* /* buf */, void* hint) { free(hint); }, ptrs.first);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Transport GetType() const override { return Transport::ZMQ; }
|
||||
|
||||
void Copy(const fair::mq::Message& msg) override
|
||||
@@ -189,6 +264,7 @@ class Message final : public fair::mq::Message
|
||||
~Message() override { CloseMessage(); }
|
||||
|
||||
private:
|
||||
size_t fAlignment;
|
||||
std::unique_ptr<zmq_msg_t> fMsg;
|
||||
|
||||
zmq_msg_t* GetMessage() const { return fMsg.get(); }
|
||||
@@ -200,6 +276,7 @@ class Message final : public fair::mq::Message
|
||||
}
|
||||
// reset the message object to allow reuse in Rebuild
|
||||
fMsg.reset(nullptr);
|
||||
fAlignment = 0;
|
||||
}
|
||||
};
|
||||
|
||||
|
@@ -173,6 +173,7 @@ class Socket final : public fair::mq::Socket
|
||||
while (true) {
|
||||
int nbytes = zmq_msg_recv(static_cast<Message*>(msg.get())->GetMessage(), fSocket, flags);
|
||||
if (nbytes >= 0) {
|
||||
static_cast<Message*>(msg.get())->Realign();
|
||||
int64_t actualBytes = zmq_msg_size(static_cast<Message*>(msg.get())->GetMessage());
|
||||
fBytesRx += actualBytes;
|
||||
++fMessagesRx;
|
||||
@@ -261,6 +262,7 @@ class Socket final : public fair::mq::Socket
|
||||
|
||||
int nbytes = zmq_msg_recv(static_cast<Message*>(part.get())->GetMessage(), fSocket, flags);
|
||||
if (nbytes >= 0) {
|
||||
static_cast<Message*>(part.get())->Realign();
|
||||
msgVec.push_back(move(part));
|
||||
totalSize += nbytes;
|
||||
} else if (zmq_errno() == EAGAIN || zmq_errno() == EINTR) {
|
||||
|
Reference in New Issue
Block a user