diff --git a/fairmq/FairMQMessage.h b/fairmq/FairMQMessage.h index a7b7be9b..1341c0f9 100644 --- a/fairmq/FairMQMessage.h +++ b/fairmq/FairMQMessage.h @@ -29,14 +29,11 @@ class FairMQMessage virtual void Rebuild(const size_t size) = 0; virtual void Rebuild(void* data, const size_t size, fairmq_free_fn* ffn, void* hint = nullptr) = 0; - virtual void* GetMessage() = 0; virtual void* GetData() = 0; virtual size_t GetSize() const = 0; virtual bool SetUsedSize(const size_t size) = 0; - virtual void SetMessage(void* data, size_t size) = 0; - virtual FairMQ::Transport GetType() const = 0; virtual void Copy(const std::unique_ptr& msg) = 0; diff --git a/fairmq/nanomsg/FairMQMessageNN.cxx b/fairmq/nanomsg/FairMQMessageNN.cxx index ef2effd9..77246ae4 100644 --- a/fairmq/nanomsg/FairMQMessageNN.cxx +++ b/fairmq/nanomsg/FairMQMessageNN.cxx @@ -94,13 +94,13 @@ FairMQMessageNN::FairMQMessageNN(FairMQUnmanagedRegionPtr& region, void* data, c void FairMQMessageNN::Rebuild() { - Clear(); + CloseMessage(); fReceiving = false; } void FairMQMessageNN::Rebuild(const size_t size) { - Clear(); + CloseMessage(); fMessage = nn_allocmsg(size, 0); if (!fMessage) { @@ -112,7 +112,7 @@ void FairMQMessageNN::Rebuild(const size_t size) void FairMQMessageNN::Rebuild(void* data, const size_t size, fairmq_free_fn* ffn, void* hint) { - Clear(); + CloseMessage(); fMessage = nn_allocmsg(size, 0); if (!fMessage) { @@ -173,7 +173,7 @@ FairMQ::Transport FairMQMessageNN::GetType() const return fTransportType; } -void FairMQMessageNN::Copy(const unique_ptr& msg) +void FairMQMessageNN::Copy(const FairMQMessagePtr& msg) { if (fMessage) { @@ -192,12 +192,12 @@ void FairMQMessageNN::Copy(const unique_ptr& msg) } else { - memcpy(fMessage, msg->GetMessage(), size); + memcpy(fMessage, static_cast(msg.get())->GetMessage(), size); fSize = size; } } -void FairMQMessageNN::Clear() +void FairMQMessageNN::CloseMessage() { if (nn_freemsg(fMessage) < 0) { @@ -214,15 +214,6 @@ FairMQMessageNN::~FairMQMessageNN() { if (fReceiving) { - int rc = nn_freemsg(fMessage); - if (rc < 0) - { - LOG(ERROR) << "failed freeing message, reason: " << nn_strerror(errno); - } - else - { - fMessage = nullptr; - fSize = 0; - } + CloseMessage(); } } diff --git a/fairmq/nanomsg/FairMQMessageNN.h b/fairmq/nanomsg/FairMQMessageNN.h index 6e7c3684..97737ae1 100644 --- a/fairmq/nanomsg/FairMQMessageNN.h +++ b/fairmq/nanomsg/FairMQMessageNN.h @@ -22,8 +22,12 @@ #include "FairMQMessage.h" #include "FairMQUnmanagedRegion.h" +class FairMQSocketNN; + class FairMQMessageNN : public FairMQMessage { + friend class FairMQSocketNN; + public: FairMQMessageNN(); FairMQMessageNN(const size_t size); @@ -37,22 +41,17 @@ class FairMQMessageNN : public FairMQMessage void Rebuild(const size_t size) override; void Rebuild(void* data, const size_t size, fairmq_free_fn* ffn, void* hint = nullptr) override; - void* GetMessage() override; void* GetData() override; size_t GetSize() const override; bool SetUsedSize(const size_t size) override; - void SetMessage(void* data, const size_t size) override; - FairMQ::Transport GetType() const override; - void Copy(const std::unique_ptr& msg) override; + void Copy(const FairMQMessagePtr& msg) override; ~FairMQMessageNN() override; - friend class FairMQSocketNN; - private: void* fMessage; size_t fSize; @@ -60,7 +59,9 @@ class FairMQMessageNN : public FairMQMessage FairMQUnmanagedRegion* fRegionPtr; static FairMQ::Transport fTransportType; - void Clear(); + void* GetMessage(); + void CloseMessage(); + void SetMessage(void* data, const size_t size); }; #endif /* FAIRMQMESSAGENN_H_ */ diff --git a/fairmq/nanomsg/FairMQSocketNN.cxx b/fairmq/nanomsg/FairMQSocketNN.cxx index 4403af7b..11fae63f 100644 --- a/fairmq/nanomsg/FairMQSocketNN.cxx +++ b/fairmq/nanomsg/FairMQSocketNN.cxx @@ -125,18 +125,20 @@ int FairMQSocketNN::Send(FairMQMessagePtr& msg, const int flags) { int nbytes = -1; + FairMQMessageNN* msgPtr = static_cast(msg.get()); + void* bufPtr = msgPtr->GetMessage(); + while (true) { - void* ptr = msg->GetMessage(); - if (static_cast(msg.get())->fRegionPtr == nullptr) + if (msgPtr->fRegionPtr == nullptr) { - nbytes = nn_send(fSocket, &ptr, NN_MSG, flags); + nbytes = nn_send(fSocket, &bufPtr, NN_MSG, flags); } else { - nbytes = nn_send(fSocket, ptr, msg->GetSize(), flags); + nbytes = nn_send(fSocket, bufPtr, msg->GetSize(), flags); // nn_send copies the data, safe to call region callback here - static_cast(static_cast(msg.get())->fRegionPtr)->fCallback(msg->GetMessage(), msg->GetSize()); + static_cast(msgPtr->fRegionPtr)->fCallback(bufPtr, msg->GetSize()); } if (nbytes >= 0) @@ -183,6 +185,8 @@ int FairMQSocketNN::Receive(FairMQMessagePtr& msg, const int flags) { int nbytes = -1; + FairMQMessageNN* msgPtr = static_cast(msg.get()); + while (true) { void* ptr = nullptr; @@ -191,8 +195,8 @@ int FairMQSocketNN::Receive(FairMQMessagePtr& msg, const int flags) { fBytesRx += nbytes; ++fMessagesRx; - msg->SetMessage(ptr, nbytes); - static_cast(msg.get())->fReceiving = true; + msgPtr->SetMessage(ptr, nbytes); + msgPtr->fReceiving = true; return nbytes; } #if NN_VERSION_CURRENT>2 // backwards-compatibility with nanomsg version<=0.6 @@ -227,7 +231,7 @@ int FairMQSocketNN::Receive(FairMQMessagePtr& msg, const int flags) } } -int64_t FairMQSocketNN::Send(vector>& msgVec, const int flags) +int64_t FairMQSocketNN::Send(vector& msgVec, const int flags) { const unsigned int vecSize = msgVec.size(); #ifdef MSGPACK_FOUND @@ -240,13 +244,15 @@ int64_t FairMQSocketNN::Send(vector>& msgVec, const in // pack all parts into a single msgpack simple buffer for (unsigned int i = 0; i < vecSize; ++i) { - static_cast(msgVec[i].get())->fReceiving = false; + FairMQMessageNN* partPtr = static_cast(msgVec[i].get()); + + partPtr->fReceiving = false; packer.pack_bin(msgVec[i]->GetSize()); packer.pack_bin_body(static_cast(msgVec[i]->GetData()), msgVec[i]->GetSize()); // call region callback - if (static_cast(msgVec[i].get())->fRegionPtr) + if (partPtr->fRegionPtr) { - static_cast(static_cast(msgVec[i].get())->fRegionPtr)->fCallback(msgVec[i]->GetMessage(), msgVec[i]->GetSize()); + static_cast(partPtr->fRegionPtr)->fCallback(partPtr->GetMessage(), msgVec[i]->GetSize()); } } @@ -297,7 +303,7 @@ int64_t FairMQSocketNN::Send(vector>& msgVec, const in #endif /*MSGPACK_FOUND*/ } -int64_t FairMQSocketNN::Receive(vector>& msgVec, const int flags) +int64_t FairMQSocketNN::Receive(vector& msgVec, const int flags) { #ifdef MSGPACK_FOUND // Warn if the vector is filled before Receive() and empty it. @@ -334,7 +340,7 @@ int64_t FairMQSocketNN::Receive(vector>& msgVec, const object.convert(buf); // get the single message size size_t size = buf.size() * sizeof(char); - unique_ptr part(new FairMQMessageNN(size)); + FairMQMessagePtr part(new FairMQMessageNN(size)); static_cast(part.get())->fReceiving = true; memcpy(part->GetData(), buf.data(), size); msgVec.push_back(move(part)); diff --git a/fairmq/shmem/FairMQMessageSHM.cxx b/fairmq/shmem/FairMQMessageSHM.cxx index af3bc2b3..c735588c 100644 --- a/fairmq/shmem/FairMQMessageSHM.cxx +++ b/fairmq/shmem/FairMQMessageSHM.cxx @@ -197,7 +197,7 @@ void FairMQMessageSHM::Rebuild(void* data, const size_t size, fairmq_free_fn* ff } } -void* FairMQMessageSHM::GetMessage() +zmq_msg_t* FairMQMessageSHM::GetMessage() { return &fMessage; } @@ -269,11 +269,6 @@ bool FairMQMessageSHM::SetUsedSize(const size_t size) } } -void FairMQMessageSHM::SetMessage(void*, const size_t) -{ - // dummy method to comply with the interface. functionality not allowed in zeromq. -} - FairMQ::Transport FairMQMessageSHM::GetType() const { return fTransportType; diff --git a/fairmq/shmem/FairMQMessageSHM.h b/fairmq/shmem/FairMQMessageSHM.h index 4259ca10..3da6301e 100644 --- a/fairmq/shmem/FairMQMessageSHM.h +++ b/fairmq/shmem/FairMQMessageSHM.h @@ -20,6 +20,8 @@ #include // size_t #include +class FairMQSocketSHM; + class FairMQMessageSHM : public FairMQMessage { friend class FairMQSocketSHM; @@ -33,25 +35,18 @@ class FairMQMessageSHM : public FairMQMessage FairMQMessageSHM(const FairMQMessageSHM&) = delete; FairMQMessageSHM operator=(const FairMQMessageSHM&) = delete; - bool InitializeChunk(const size_t size); - void Rebuild() override; void Rebuild(const size_t size) override; void Rebuild(void* data, const size_t size, fairmq_free_fn* ffn, void* hint = nullptr) override; - void* GetMessage() override; void* GetData() override; size_t GetSize() const override; bool SetUsedSize(const size_t size) override; - void SetMessage(void* data, const size_t size) override; - FairMQ::Transport GetType() const override; - void Copy(const std::unique_ptr& msg) override; - - void CloseMessage(); + void Copy(const FairMQMessagePtr& msg) override; ~FairMQMessageSHM() override; @@ -67,6 +62,10 @@ class FairMQMessageSHM : public FairMQMessage boost::interprocess::managed_shared_memory::handle_t fHandle; size_t fSize; char* fLocalPtr; + + bool InitializeChunk(const size_t size); + zmq_msg_t* GetMessage(); + void CloseMessage(); }; #endif /* FAIRMQMESSAGESHM_H_ */ diff --git a/fairmq/shmem/FairMQSocketSHM.cxx b/fairmq/shmem/FairMQSocketSHM.cxx index dc5062e9..a1a43885 100644 --- a/fairmq/shmem/FairMQSocketSHM.cxx +++ b/fairmq/shmem/FairMQSocketSHM.cxx @@ -114,7 +114,7 @@ int FairMQSocketSHM::Send(FairMQMessagePtr& msg, const int flags) int nbytes = -1; while (true && !fInterrupted) { - nbytes = zmq_msg_send(static_cast(msg->GetMessage()), fSocket, flags); + nbytes = zmq_msg_send(static_cast(msg.get())->GetMessage(), fSocket, flags); if (nbytes == 0) { return nbytes; @@ -158,7 +158,7 @@ int FairMQSocketSHM::Send(FairMQMessagePtr& msg, const int flags) int FairMQSocketSHM::Receive(FairMQMessagePtr& msg, const int flags) { int nbytes = -1; - zmq_msg_t* msgPtr = static_cast(msg->GetMessage()); + zmq_msg_t* msgPtr = static_cast(msg.get())->GetMessage(); while (true) { nbytes = zmq_msg_recv(msgPtr, fSocket, flags); @@ -221,7 +221,7 @@ int64_t FairMQSocketSHM::Send(vector& msgVec, const int flags) { for (unsigned int i = 0; i < vecSize; ++i) { - nbytes = zmq_msg_send(static_cast(msgVec[i]->GetMessage()), + nbytes = zmq_msg_send(static_cast(msgVec[i].get())->GetMessage(), fSocket, (i < vecSize - 1) ? ZMQ_SNDMORE|flags : flags); if (nbytes >= 0) @@ -302,7 +302,7 @@ int64_t FairMQSocketSHM::Receive(vector& msgVec, const int fla do { FairMQMessagePtr part(new FairMQMessageSHM(fManager)); - zmq_msg_t* msgPtr = static_cast(part->GetMessage()); + zmq_msg_t* msgPtr = static_cast(part.get())->GetMessage(); int nbytes = zmq_msg_recv(msgPtr, fSocket, flags); if (nbytes == 0) diff --git a/fairmq/zeromq/FairMQMessageZMQ.cxx b/fairmq/zeromq/FairMQMessageZMQ.cxx index 84a0c98f..90bfa4f9 100644 --- a/fairmq/zeromq/FairMQMessageZMQ.cxx +++ b/fairmq/zeromq/FairMQMessageZMQ.cxx @@ -112,7 +112,7 @@ void FairMQMessageZMQ::Rebuild(void* data, const size_t size, fairmq_free_fn* ff } } -void* FairMQMessageZMQ::GetMessage() +zmq_msg_t* FairMQMessageZMQ::GetMessage() { if (!fViewMsg) { @@ -190,11 +190,6 @@ void FairMQMessageZMQ::ApplyUsedSize() } } -void FairMQMessageZMQ::SetMessage(void*, const size_t) -{ - // dummy method to comply with the interface. functionality not allowed in zeromq. -} - FairMQ::Transport FairMQMessageZMQ::GetType() const { return fTransportType; @@ -202,18 +197,19 @@ FairMQ::Transport FairMQMessageZMQ::GetType() const void FairMQMessageZMQ::Copy(const FairMQMessagePtr& msg) { + FairMQMessageZMQ* msgPtr = static_cast(msg.get()); // Shares the message buffer between msg and this fMsg. - if (zmq_msg_copy(fMsg.get(), static_cast(msg->GetMessage())) != 0) + if (zmq_msg_copy(fMsg.get(), msgPtr->GetMessage()) != 0) { LOG(ERROR) << "failed copying message, reason: " << zmq_strerror(errno); return; } // if the target message has been resized, apply same to this message also - if (static_cast(msg.get())->fUsedSizeModified) + if (msgPtr->fUsedSizeModified) { fUsedSizeModified = true; - fUsedSize = static_cast(msg.get())->fUsedSize; + fUsedSize = msgPtr->fUsedSize; } } diff --git a/fairmq/zeromq/FairMQMessageZMQ.h b/fairmq/zeromq/FairMQMessageZMQ.h index ef9011d0..918a698b 100644 --- a/fairmq/zeromq/FairMQMessageZMQ.h +++ b/fairmq/zeromq/FairMQMessageZMQ.h @@ -24,8 +24,12 @@ #include "FairMQMessage.h" #include "FairMQUnmanagedRegion.h" +class FairMQSocketZMQ; + class FairMQMessageZMQ : public FairMQMessage { + friend class FairMQSocketZMQ; + public: FairMQMessageZMQ(); FairMQMessageZMQ(const size_t size); @@ -36,21 +40,15 @@ class FairMQMessageZMQ : public FairMQMessage void Rebuild(const size_t size) override; void Rebuild(void* data, const size_t size, fairmq_free_fn* ffn, void* hint = nullptr) override; - void* GetMessage() override; void* GetData() override; size_t GetSize() const override; bool SetUsedSize(const size_t size) override; void ApplyUsedSize(); - void SetMessage(void* data, const size_t size) override; - - FairMQ::Transport GetType() const override; - void Copy(const std::unique_ptr& msg) override; - - void CloseMessage(); + void Copy(const FairMQMessagePtr& msg) override; ~FairMQMessageZMQ() override; @@ -60,6 +58,9 @@ class FairMQMessageZMQ : public FairMQMessage std::unique_ptr fMsg; std::unique_ptr fViewMsg; // view on a subset of fMsg (treating it as user buffer) static FairMQ::Transport fTransportType; + + zmq_msg_t* GetMessage(); + void CloseMessage(); }; #endif /* FAIRMQMESSAGEZMQ_H_ */ diff --git a/fairmq/zeromq/FairMQSocketZMQ.cxx b/fairmq/zeromq/FairMQSocketZMQ.cxx index 213dc67b..308ad27f 100644 --- a/fairmq/zeromq/FairMQSocketZMQ.cxx +++ b/fairmq/zeromq/FairMQSocketZMQ.cxx @@ -119,7 +119,7 @@ int FairMQSocketZMQ::Send(FairMQMessagePtr& msg, const int flags) while (true) { - nbytes = zmq_msg_send(static_cast(msg->GetMessage()), fSocket, flags); + nbytes = zmq_msg_send(static_cast(msg.get())->GetMessage(), fSocket, flags); if (nbytes >= 0) { fBytesTx += nbytes; @@ -157,7 +157,7 @@ int FairMQSocketZMQ::Receive(FairMQMessagePtr& msg, const int flags) while (true) { - nbytes = zmq_msg_recv(static_cast(msg->GetMessage()), fSocket, flags); + nbytes = zmq_msg_recv(static_cast(msg.get())->GetMessage(), fSocket, flags); if (nbytes >= 0) { fBytesRx += nbytes; @@ -209,7 +209,7 @@ int64_t FairMQSocketZMQ::Send(vector& msgVec, const int flags) { static_cast(msgVec[i].get())->ApplyUsedSize(); - nbytes = zmq_msg_send(static_cast(msgVec[i]->GetMessage()), + nbytes = zmq_msg_send(static_cast(msgVec[i].get())->GetMessage(), fSocket, (i < vecSize - 1) ? ZMQ_SNDMORE|flags : flags); if (nbytes >= 0) @@ -279,7 +279,7 @@ int64_t FairMQSocketZMQ::Receive(vector& msgVec, const int fla { unique_ptr part(new FairMQMessageZMQ()); - int nbytes = zmq_msg_recv(static_cast(part->GetMessage()), fSocket, flags); + int nbytes = zmq_msg_recv(static_cast(part.get())->GetMessage(), fSocket, flags); if (nbytes >= 0) { msgVec.push_back(move(part));