diff --git a/fairmq/CMakeLists.txt b/fairmq/CMakeLists.txt index 2b1d3dc8..25d45883 100644 --- a/fairmq/CMakeLists.txt +++ b/fairmq/CMakeLists.txt @@ -111,6 +111,7 @@ EndIf(NANOMSG_FOUND) # manual install (globbing add not recommended) Set(FAIRMQHEADERS FairMQParts.h + FairMQTransports.h FairMQConfigPlugin.h FairMQControlPlugin.h runFairMQDevice.h diff --git a/fairmq/FairMQChannel.cxx b/fairmq/FairMQChannel.cxx index 58df6a2c..b4bedd2b 100644 --- a/fairmq/FairMQChannel.cxx +++ b/fairmq/FairMQChannel.cxx @@ -13,11 +13,11 @@ */ #include +#include // std::move #include // join/split #include "FairMQChannel.h" -#include "FairMQLogger.h" using namespace std; @@ -30,18 +30,21 @@ FairMQChannel::FairMQChannel() , fType("unspecified") , fMethod("unspecified") , fAddress("unspecified") + , fTransport("default") , fSndBufSize(1000) , fRcvBufSize(1000) , fSndKernelSize(0) , fRcvKernelSize(0) , fRateLogging(1) - , fChannelName("") + , fName("") , fIsValid(false) , fPoller(nullptr) - , fCmdSocket(nullptr) + , fChannelCmdSocket(nullptr) + , fTransportType(FairMQ::Transport::DEFAULT) , fTransportFactory(nullptr) , fNoBlockFlag(0) , fSndMoreFlag(0) + , fMultipart(false) { } @@ -50,18 +53,21 @@ FairMQChannel::FairMQChannel(const string& type, const string& method, const str , fType(type) , fMethod(method) , fAddress(address) + , fTransport("default") , fSndBufSize(1000) , fRcvBufSize(1000) , fSndKernelSize(0) , fRcvKernelSize(0) , fRateLogging(1) - , fChannelName("") + , fName("") , fIsValid(false) , fPoller(nullptr) - , fCmdSocket(nullptr) + , fChannelCmdSocket(nullptr) + , fTransportType(FairMQ::Transport::DEFAULT) , fTransportFactory(nullptr) , fNoBlockFlag(0) , fSndMoreFlag(0) + , fMultipart(false) { } @@ -70,18 +76,21 @@ FairMQChannel::FairMQChannel(const FairMQChannel& chan) , fType(chan.fType) , fMethod(chan.fMethod) , fAddress(chan.fAddress) + , fTransport(chan.fTransport) , fSndBufSize(chan.fSndBufSize) , fRcvBufSize(chan.fRcvBufSize) , fSndKernelSize(chan.fSndKernelSize) , fRcvKernelSize(chan.fRcvKernelSize) , fRateLogging(chan.fRateLogging) - , fChannelName(chan.fChannelName) + , fName(chan.fName) , fIsValid(false) , fPoller(nullptr) - , fCmdSocket(nullptr) + , fChannelCmdSocket(nullptr) + , fTransportType(FairMQ::Transport::DEFAULT) , fTransportFactory(nullptr) , fNoBlockFlag(chan.fNoBlockFlag) , fSndMoreFlag(chan.fSndMoreFlag) + , fMultipart(chan.fMultipart) {} FairMQChannel& FairMQChannel::operator=(const FairMQChannel& chan) @@ -89,16 +98,18 @@ FairMQChannel& FairMQChannel::operator=(const FairMQChannel& chan) fType = chan.fType; fMethod = chan.fMethod; fAddress = chan.fAddress; + fTransport = chan.fTransport; fSndBufSize = chan.fSndBufSize; fRcvBufSize = chan.fRcvBufSize; fSndKernelSize = chan.fSndKernelSize; fRcvKernelSize = chan.fRcvKernelSize; fRateLogging = chan.fRateLogging; fSocket = nullptr; - fChannelName = chan.fChannelName; + fName = chan.fName; fIsValid = false; fPoller = nullptr; - fCmdSocket = nullptr; + fChannelCmdSocket = nullptr; + fTransportType = FairMQ::Transport::DEFAULT; fTransportFactory = nullptr; fNoBlockFlag = chan.fNoBlockFlag; fSndMoreFlag = chan.fSndMoreFlag; @@ -108,13 +119,13 @@ FairMQChannel& FairMQChannel::operator=(const FairMQChannel& chan) string FairMQChannel::GetChannelName() const { - return fChannelName; + return fName; } string FairMQChannel::GetChannelPrefix() const { - string prefix = fChannelName; - return prefix.erase(fChannelName.rfind("[")); + string prefix = fName; + return prefix.erase(fName.rfind("[")); } string FairMQChannel::GetType() const @@ -159,6 +170,20 @@ string FairMQChannel::GetAddress() const } } +string FairMQChannel::GetTransport() const +{ + try + { + unique_lock lock(fChannelMutex); + return fTransport; + } + catch (exception& e) + { + LOG(ERROR) << "Exception caught in FairMQChannel::GetTransport: " << e.what(); + exit(EXIT_FAILURE); + } +} + int FairMQChannel::GetSndBufSize() const { try @@ -274,6 +299,21 @@ void FairMQChannel::UpdateAddress(const string& address) } } +void FairMQChannel::UpdateTransport(const string& transport) +{ + try + { + unique_lock lock(fChannelMutex); + fIsValid = false; + fTransport = transport; + } + catch (exception& e) + { + LOG(ERROR) << "Exception caught in FairMQChannel::UpdateTransport: " << e.what(); + exit(EXIT_FAILURE); + } +} + void FairMQChannel::UpdateSndBufSize(const int sndBufSize) { try @@ -370,7 +410,7 @@ bool FairMQChannel::ValidateChannel() unique_lock lock(fChannelMutex); stringstream ss; - ss << "Validating channel \"" << fChannelName << "\"... "; + ss << "Validating channel \"" << fName << "\"... "; if (fIsValid) { @@ -461,6 +501,17 @@ bool FairMQChannel::ValidateChannel() } } + // validate channel transport + // const string channelTransportNames[] = { "default", "zeromq", "nanomsg", "shmem" }; + // const set channelTransports(channelTransportNames, channelTransportNames + sizeof(channelTransportNames) / sizeof(string)); + if (FairMQ::TransportTypes.find(fTransport) == FairMQ::TransportTypes.end()) + { + ss << "INVALID"; + LOG(DEBUG) << ss.str(); + LOG(ERROR) << "Invalid channel transport: \"" << fTransport << "\""; + exit(EXIT_FAILURE); + } + // validate socket buffer size for sending if (fSndBufSize < 0) { @@ -518,19 +569,23 @@ bool FairMQChannel::ValidateChannel() } } -bool FairMQChannel::InitCommandInterface(shared_ptr factory, int numIoThreads) +void FairMQChannel::InitTransport(shared_ptr factory) { fTransportFactory = factory; + fTransportType = factory->GetType(); +} - fCmdSocket = fTransportFactory->CreateSocket("sub", "device-commands", numIoThreads, "internal"); - if (fCmdSocket) +bool FairMQChannel::InitCommandInterface(int numIoThreads) +{ + fChannelCmdSocket = fTransportFactory->CreateSocket("sub", "device-commands", numIoThreads, "internal"); + if (fChannelCmdSocket) { - fCmdSocket->Connect("inproc://commands"); + fChannelCmdSocket->Connect("inproc://commands"); - fNoBlockFlag = fCmdSocket->NOBLOCK; - fSndMoreFlag = fCmdSocket->SNDMORE; + fNoBlockFlag = fChannelCmdSocket->NOBLOCK; + fSndMoreFlag = fChannelCmdSocket->SNDMORE; - fPoller = fTransportFactory->CreatePoller(*fCmdSocket, *fSocket); + fPoller = fTransportFactory->CreatePoller(*fChannelCmdSocket, *fSocket); return true; } @@ -547,7 +602,19 @@ void FairMQChannel::ResetChannel() // TODO: implement channel resetting } -int FairMQChannel::Send(const unique_ptr& msg, int sndTimeoutInMs) const +int FairMQChannel::Send(unique_ptr& msg) const +{ + CheckCompatibility(msg); + return fSocket->Send(msg); +} + +int FairMQChannel::Receive(unique_ptr& msg) const +{ + CheckCompatibility(msg); + return fSocket->Receive(msg); +} + +int FairMQChannel::Send(unique_ptr& msg, int sndTimeoutInMs) const { fPoller->Poll(sndTimeoutInMs); @@ -562,13 +629,13 @@ int FairMQChannel::Send(const unique_ptr& msg, int sndTimeoutInMs if (fPoller->CheckOutput(1)) { - return fSocket->Send(msg.get(), 0); + return Send(msg); } return -2; } -int FairMQChannel::Receive(const unique_ptr& msg, int rcvTimeoutInMs) const +int FairMQChannel::Receive(unique_ptr& msg, int rcvTimeoutInMs) const { fPoller->Poll(rcvTimeoutInMs); @@ -583,13 +650,37 @@ int FairMQChannel::Receive(const unique_ptr& msg, int rcvTimeoutI if (fPoller->CheckInput(1)) { - return fSocket->Receive(msg.get(), 0); + return Receive(msg); } return -2; } -int64_t FairMQChannel::Send(const vector>& msgVec, int sndTimeoutInMs) const +int FairMQChannel::SendAsync(unique_ptr& msg) const +{ + CheckCompatibility(msg); + return fSocket->Send(msg, fNoBlockFlag); +} + +int FairMQChannel::ReceiveAsync(unique_ptr& msg) const +{ + CheckCompatibility(msg); + return fSocket->Receive(msg, fNoBlockFlag); +} + +int64_t FairMQChannel::Send(vector>& msgVec) const +{ + CheckCompatibility(msgVec); + return fSocket->Send(msgVec); +} + +int64_t FairMQChannel::Receive(vector>& msgVec) const +{ + CheckCompatibility(msgVec); + return fSocket->Receive(msgVec); +} + +int64_t FairMQChannel::Send(vector>& msgVec, int sndTimeoutInMs) const { fPoller->Poll(sndTimeoutInMs); @@ -604,7 +695,7 @@ int64_t FairMQChannel::Send(const vector>& msgVec, int if (fPoller->CheckOutput(1)) { - return fSocket->Send(msgVec); + return Send(msgVec); } return -2; @@ -625,155 +716,33 @@ int64_t FairMQChannel::Receive(vector>& msgVec, int rc if (fPoller->CheckInput(1)) { - return fSocket->Receive(msgVec); + return Receive(msgVec); } return -2; } -int FairMQChannel::Send(FairMQMessage* msg, const string& flag, int sndTimeoutInMs) const +int64_t FairMQChannel::SendAsync(vector>& msgVec) const { - if (flag == "") - { - fPoller->Poll(sndTimeoutInMs); - - if (fPoller->CheckInput(0)) - { - HandleUnblock(); - if (fInterrupted) - { - return -2; - - } - } - - if (fPoller->CheckOutput(1)) - { - return fSocket->Send(msg, flag); - } - - return -2; - } - else - { - return fSocket->Send(msg, flag); - } + CheckCompatibility(msgVec); + return fSocket->Send(msgVec, fNoBlockFlag); } -int FairMQChannel::Send(FairMQMessage* msg, const int flags, int sndTimeoutInMs) const +/// Receives a vector of messages in non-blocking mode. +/// +/// @param msgVec message vector reference +/// @return Number of bytes that have been received. If queue is empty, returns -2. +/// In case of errors, returns -1. +int64_t FairMQChannel::ReceiveAsync(vector>& msgVec) const { - if (flags == 0) - { - fPoller->Poll(sndTimeoutInMs); - - if (fPoller->CheckInput(0)) - { - HandleUnblock(); - if (fInterrupted) - { - return -2; - - } - } - - if (fPoller->CheckOutput(1)) - { - return fSocket->Send(msg, flags); - } - - return -2; - } - else - { - return fSocket->Send(msg, flags); - } -} - -int FairMQChannel::Receive(FairMQMessage* msg, const string& flag, int rcvTimeoutInMs) const -{ - if (flag == "") - { - fPoller->Poll(rcvTimeoutInMs); - - if (fPoller->CheckInput(0)) - { - HandleUnblock(); - if (fInterrupted) - { - return -2; - - } - } - - if (fPoller->CheckInput(1)) - { - return fSocket->Receive(msg, flag); - } - - return -2; - } - else - { - return fSocket->Receive(msg, flag); - } -} - -int FairMQChannel::Receive(FairMQMessage* msg, const int flags, int rcvTimeoutInMs) const -{ - if (flags == 0) - { - fPoller->Poll(rcvTimeoutInMs); - - if (fPoller->CheckInput(0)) - { - HandleUnblock(); - if (fInterrupted) - { - return -2; - - } - } - - if (fPoller->CheckInput(1)) - { - return fSocket->Receive(msg, flags); - } - - return -2; - } - else - { - return fSocket->Receive(msg, flags); - } -} - -bool FairMQChannel::ExpectsAnotherPart() const -{ - int64_t more = 0; - size_t more_size = sizeof more; - - if (fSocket) - { - fSocket->GetOption("rcv-more", &more, &more_size); - if (more) - { - return true; - } - else - { - return false; - } - } - else - { - return false; - } + CheckCompatibility(msgVec); + return fSocket->Receive(msgVec, fNoBlockFlag); } inline bool FairMQChannel::HandleUnblock() const { FairMQMessagePtr cmd(fTransportFactory->CreateMessage()); - if (fCmdSocket->Receive(cmd.get(), 0) >= 0) + if (fChannelCmdSocket->Receive(cmd) >= 0) { // LOG(DEBUG) << "unblocked"; } @@ -788,3 +757,51 @@ void FairMQChannel::Tokenize(vector& output, const string& input, const { boost::algorithm::split(output, input, boost::algorithm::is_any_of(delimiters)); } + +FairMQTransportFactory* FairMQChannel::Transport() +{ + return fTransportFactory.get(); +} + +bool FairMQChannel::CheckCompatibility(unique_ptr& msg) const +{ + if (fTransportType == msg->GetType()) + { + return true; + } + else + { + // LOG(WARN) << "Channel type does not match message type. Copying..."; + FairMQMessagePtr msgCopy(fTransportFactory->CreateMessage(msg->GetSize())); + memcpy(msgCopy->GetData(), msg->GetData(), msg->GetSize()); + msg = move(msgCopy); + return false; + } +} + +bool FairMQChannel::CheckCompatibility(vector>& msgVec) const +{ + if (msgVec.size() > 0) + { + if (fTransportType == msgVec.at(0)->GetType()) + { + return true; + } + else + { + // LOG(WARN) << "Channel type does not match message type. Copying..."; + vector> tempVec; + for (unsigned int i = 0; i < msgVec.size(); ++i) + { + tempVec.push_back(fTransportFactory->CreateMessage(msgVec[i]->GetSize())); + memcpy(tempVec[i]->GetData(), msgVec[i]->GetData(), msgVec[i]->GetSize()); + } + msgVec = move(tempVec); + return false; + } + } + else + { + return true; + } +} diff --git a/fairmq/FairMQChannel.h b/fairmq/FairMQChannel.h index 78fd7f45..a6c1f02a 100644 --- a/fairmq/FairMQChannel.h +++ b/fairmq/FairMQChannel.h @@ -17,12 +17,15 @@ #include #include // unique_ptr +#include #include #include #include "FairMQTransportFactory.h" #include "FairMQSocket.h" #include "FairMQPoller.h" +#include "FairMQTransports.h" +#include "FairMQLogger.h" class FairMQPoller; class FairMQTransportFactory; @@ -67,9 +70,13 @@ class FairMQChannel std::string GetMethod() const; /// Get socket address (e.g. "tcp://127.0.0.1:5555" or "ipc://abc") - /// @return Returns socket type (e.g. "tcp://127.0.0.1:5555" or "ipc://abc") + /// @return Returns socket address (e.g. "tcp://127.0.0.1:5555" or "ipc://abc") std::string GetAddress() const; + /// Get channel transport ("default", "zeromq", "nanomsg" or "shmem") + /// @return Returns channel transport (e.g. "default", "zeromq", "nanomsg" or "shmem") + std::string GetTransport() const; + /// Get socket send buffer size (in number of messages) /// @return Returns socket send buffer size (in number of messages) int GetSndBufSize() const; @@ -99,9 +106,13 @@ class FairMQChannel void UpdateMethod(const std::string& method); /// Set socket address - /// @param address Socket address (e.g. "tcp://127.0.0.1:5555" or "ipc://abc") + /// @param Socket address (e.g. "tcp://127.0.0.1:5555" or "ipc://abc") void UpdateAddress(const std::string& address); + /// Set channel transport + /// @param transport transport string ("default", "zeromq", "nanomsg" or "shmem") + void UpdateTransport(const std::string& transport); + /// Set socket send buffer size /// @param sndBufSize Socket send buffer size (in number of messages) void UpdateSndBufSize(const int sndBufSize); @@ -135,6 +146,9 @@ class FairMQChannel std::unique_ptr fSocket; + int Send(std::unique_ptr& msg) const; + int Receive(std::unique_ptr& msg) const; + /// Sends a message to the socket queue. /// @details Send method attempts to send a message by /// putting it in the output queue. If the queue is full or queueing is not possible @@ -143,7 +157,16 @@ class FairMQChannel /// @param msg Constant reference of unique_ptr to a FairMQMessage /// @return Number of bytes that have been queued. -2 If queueing was not possible or timed out. /// In case of errors, returns -1. - int Send(const std::unique_ptr& msg, int sndTimeoutInMs = -1) const; + int Send(std::unique_ptr& msg, int sndTimeoutInMs) const; + + /// Receives a message from the socket queue. + /// @details Receive method attempts to receive a message from the input queue. + /// If the queue is empty the method blocks. + /// + /// @param msg Constant reference of unique_ptr to a FairMQMessage + /// @return Number of bytes that have been received. -2 If reading from the queue was not possible or timed out. + /// In case of errors, returns -1. + int Receive(std::unique_ptr& msg, int rcvTimeoutInMs) const; /// Sends a message in non-blocking mode. /// @details SendAsync method attempts to send a message without blocking by @@ -153,41 +176,31 @@ class FairMQChannel /// @return Number of bytes that have been queued. If queueing failed due to /// full queue or no connected peers (when binding), returns -2. /// In case of errors, returns -1. - inline int SendAsync(const std::unique_ptr& msg) const - { - return fSocket->Send(msg.get(), fNoBlockFlag); - } + int SendAsync(std::unique_ptr& msg) const; - /// Queues the current message as a part of a multi-part message - /// @details SendPart method queues the provided message as a part of a multi-part message. - /// The actual transfer over the network is initiated once final part has been queued with the Send() or SendAsync() methods. + /// Receives a message in non-blocking mode. /// /// @param msg Constant reference of unique_ptr to a FairMQMessage - /// @return Number of bytes that have been queued. -2 If queueing was not possible. + /// @return Number of bytes that have been received. If queue is empty, returns -2. /// In case of errors, returns -1. - inline int SendPart(const std::unique_ptr& msg) const - { - return fSocket->Send(msg.get(), fSndMoreFlag); - } + int ReceiveAsync(std::unique_ptr& msg) const; - /// Queues the current message as a part of a multi-part message without blocking - /// @details SendPart method queues the provided message as a part of a multi-part message without blocking. - /// The actual transfer over the network is initiated once final part has been queued with the Send() or SendAsync() methods. - /// - /// @param msg Constant reference of unique_ptr to a FairMQMessage - /// @return Number of bytes that have been queued. -2 If queueing was not possible. - /// In case of errors, returns -1. - inline int SendPartAsync(const std::unique_ptr& msg) const - { - return fSocket->Send(msg.get(), fSndMoreFlag|fNoBlockFlag); - } + int64_t Send(std::vector>& msgVec) const; + int64_t Receive(std::vector>& msgVec) const; /// Send a vector of messages /// /// @param msgVec message vector reference /// @return Number of bytes that have been queued. -2 If queueing was not possible or timed out. /// In case of errors, returns -1. - int64_t Send(const std::vector>& msgVec, int sndTimeoutInMs = -1) const; + int64_t Send(std::vector>& msgVec, int sndTimeoutInMs) const; + + /// Receive a vector of messages + /// + /// @param msgVec message vector reference + /// @return Number of bytes that have been received. -2 If reading from the queue was not possible or timed out. + /// In case of errors, returns -1. + int64_t Receive(std::vector>& msgVec, int rcvTimeoutInMs) const; /// Sends a vector of message in non-blocking mode. /// @details SendAsync method attempts to send a vector of messages without blocking by @@ -196,82 +209,49 @@ class FairMQChannel /// @param msgVec message vector reference /// @return Number of bytes that have been queued. If queueing failed due to /// full queue or no connected peers (when binding), returns -2. In case of errors, returns -1. - inline int64_t SendAsync(const std::vector>& msgVec) const - { - return fSocket->Send(msgVec, fNoBlockFlag); - } - - /// Receives a message from the socket queue. - /// @details Receive method attempts to receive a message from the input queue. - /// If the queue is empty the method blocks. - /// - /// @param msg Constant reference of unique_ptr to a FairMQMessage - /// @return Number of bytes that have been received. -2 If reading from the queue was not possible or timed out. - /// In case of errors, returns -1. - int Receive(const std::unique_ptr& msg, int rcvTimeoutInMs = -1) const; - - /// Receives a message in non-blocking mode. - /// - /// @param msg Constant reference of unique_ptr to a FairMQMessage - /// @return Number of bytes that have been received. If queue is empty, returns -2. - /// In case of errors, returns -1. - inline int ReceiveAsync(const std::unique_ptr& msg) const - { - return fSocket->Receive(msg.get(), fNoBlockFlag); - } - - /// Receive a vector of messages - /// - /// @param msgVec message vector reference - /// @return Number of bytes that have been received. -2 If reading from the queue was not possible or timed out. - /// In case of errors, returns -1. - int64_t Receive(std::vector>& msgVec, int rcvTimeoutInMs = -1) const; + int64_t SendAsync(std::vector>& msgVec) const; /// Receives a vector of messages in non-blocking mode. /// /// @param msgVec message vector reference /// @return Number of bytes that have been received. If queue is empty, returns -2. /// In case of errors, returns -1. - inline int64_t ReceiveAsync(std::vector>& msgVec) const - { - return fSocket->Receive(msgVec, fNoBlockFlag); - } - - /// Checks if the socket is expecting to receive another part of a multipart message. - /// @return Return true if the socket expects another part of a multipart message and false otherwise. - bool ExpectsAnotherPart() const; - - // DEPRECATED socket method wrappers with raw pointers and flag checks - int Send(FairMQMessage* msg, const std::string& flag = "", int sndTimeoutInMs = -1) const; - int Send(FairMQMessage* msg, const int flags, int sndTimeoutInMs = -1) const; - int Receive(FairMQMessage* msg, const std::string& flag = "", int rcvTimeoutInMs = -1) const; - int Receive(FairMQMessage* msg, const int flags, int rcvTimeoutInMs = -1) const; + int64_t ReceiveAsync(std::vector>& msgVec) const; // TODO: this might go to some base utility library static void Tokenize(std::vector& output, const std::string& input, const std::string delimiters = ","); + FairMQTransportFactory* Transport(); + private: std::string fType; std::string fMethod; std::string fAddress; + std::string fTransport; int fSndBufSize; int fRcvBufSize; int fSndKernelSize; int fRcvKernelSize; int fRateLogging; - std::string fChannelName; + std::string fName; std::atomic fIsValid; - FairMQPollerPtr fPoller; - FairMQSocketPtr fCmdSocket; + FairMQPollerPtr fPoller; + FairMQSocketPtr fChannelCmdSocket; + + FairMQ::Transport fTransportType; std::shared_ptr fTransportFactory; int fNoBlockFlag; int fSndMoreFlag; - bool InitCommandInterface(std::shared_ptr factory, int numIoThreads); + bool CheckCompatibility(std::unique_ptr& msg) const; + bool CheckCompatibility(std::vector>& msgVec) const; + + void InitTransport(std::shared_ptr factory); + bool InitCommandInterface(int numIoThreads); bool HandleUnblock() const; @@ -282,6 +262,7 @@ class FairMQChannel static std::mutex fChannelMutex; static std::atomic fInterrupted; + bool fMultipart; }; #endif /* FAIRMQCHANNEL_H_ */ diff --git a/fairmq/FairMQDevice.cxx b/fairmq/FairMQDevice.cxx index 5ae23f98..b9e44bc3 100644 --- a/fairmq/FairMQDevice.cxx +++ b/fairmq/FairMQDevice.cxx @@ -13,7 +13,6 @@ */ #include -#include // std::sort() #include // catching system signals #include #include @@ -42,8 +41,8 @@ using namespace std; -// std::function and a wrapper to catch the signals -std::function sigHandler; +// function and a wrapper to catch the signals +function sigHandler; static void CallSignalHandler(int signal) { sigHandler(signal); @@ -54,13 +53,15 @@ FairMQDevice::FairMQDevice() , fConfig(nullptr) , fId() , fNetworkInterface() + , fDefaultTransport() , fMaxInitializationAttempts(120) , fNumIoThreads(1) , fPortRangeMin(22000) , fPortRangeMax(32000) , fLogIntervalInMs(1000) - , fCmdSocket(nullptr) , fTransportFactory(nullptr) + , fTransports() + , fDeviceCmdSockets() , fInitialValidationFinished(false) , fInitialValidationCondition() , fInitialValidationMutex() @@ -70,6 +71,10 @@ FairMQDevice::FairMQDevice() , fDataCallbacks(false) , fMsgInputs() , fMultipartInputs() + , fMultitransportInputs() + , fInputChannelKeys() + , fMultitransportMutex() + , fMultitransportProceed(false) { } @@ -78,8 +83,8 @@ void FairMQDevice::CatchSignals() if (!fCatchingSignals) { sigHandler = bind1st(mem_fun(&FairMQDevice::SignalHandler), this); - std::signal(SIGINT, CallSignalHandler); - std::signal(SIGTERM, CallSignalHandler); + signal(SIGINT, CallSignalHandler); + signal(SIGTERM, CallSignalHandler); fCatchingSignals = true; } } @@ -109,12 +114,12 @@ void FairMQDevice::SignalHandler(int signal) else { LOG(WARN) << "Repeated termination or bad initialization? Aborting."; - std::abort(); + abort(); // exit(EXIT_FAILURE); } } -void FairMQDevice::ConnectChannels(list& chans) +void FairMQDevice::AttachChannels(list& chans) { auto itr = chans.begin(); @@ -124,38 +129,12 @@ void FairMQDevice::ConnectChannels(list& chans) { if (AttachChannel(**itr)) { - (*itr)->InitCommandInterface(fTransportFactory, fNumIoThreads); + (*itr)->InitCommandInterface(fNumIoThreads); chans.erase(itr++); } else { - LOG(ERROR) << "failed to connect channel " << (*itr)->fChannelName; - ++itr; - } - } - else - { - ++itr; - } - } -} - -void FairMQDevice::BindChannels(list& chans) -{ - auto itr = chans.begin(); - - while (itr != chans.end()) - { - if ((*itr)->ValidateChannel()) - { - if (AttachChannel(**itr)) - { - (*itr)->InitCommandInterface(fTransportFactory, fNumIoThreads); - chans.erase(itr++); - } - else - { - LOG(ERROR) << "failed to bind channel " << (*itr)->fChannelName; + LOG(ERROR) << "failed to attach channel " << (*itr)->fName << " (" << (*itr)->fMethod << ")"; ++itr; } } @@ -174,10 +153,17 @@ void FairMQDevice::InitWrapper() exit(EXIT_FAILURE); } - if (!fCmdSocket) + if (fDeviceCmdSockets.empty()) { - fCmdSocket = fTransportFactory->CreateSocket("pub", "device-commands", fNumIoThreads, fId); - fCmdSocket->Bind("inproc://commands"); + auto p = fDeviceCmdSockets.emplace(fTransportFactory->GetType(), fTransportFactory->CreateSocket("pub", "device-commands", fNumIoThreads, fId)); + if (p.second) + { + p.first->second->Bind("inproc://commands"); + } + else + { + exit(EXIT_FAILURE); + } FairMQMessagePtr msg(fTransportFactory->CreateMessage()); msg->SetDeviceId(fId); @@ -192,12 +178,13 @@ void FairMQDevice::InitWrapper() { for (auto vi = (mi->second).begin(); vi != (mi->second).end(); ++vi) { + // set channel name: name + vector index + stringstream ss; + ss << mi->first << "[" << vi - (mi->second).begin() << "]"; + vi->fName = ss.str(); + if (vi->fMethod == "bind") { - // set channel name: name + vector index - stringstream ss; - ss << mi->first << "[" << vi - (mi->second).begin() << "]"; - vi->fChannelName = ss.str(); // if binding address is not specified, set it up to try getting it from the configured network interface if (vi->fAddress == "unspecified" || vi->fAddress == "") { @@ -208,19 +195,11 @@ void FairMQDevice::InitWrapper() } else if (vi->fMethod == "connect") { - // set channel name: name + vector index - stringstream ss; - ss << mi->first << "[" << vi - (mi->second).begin() << "]"; - vi->fChannelName = ss.str(); // fill the uninitialized list uninitializedConnectingChannels.push_back(&(*vi)); } else if (vi->fAddress.find_first_of("@+>") != string::npos) { - // set channel name: name + vector index - stringstream ss; - ss << mi->first << "[" << vi - (mi->second).begin() << "]"; - vi->fChannelName = ss.str(); // fill the uninitialized list uninitializedConnectingChannels.push_back(&(*vi)); } @@ -234,7 +213,7 @@ void FairMQDevice::InitWrapper() // Bind channels. Here one run is enough, because bind settings should be available locally // If necessary this could be handled in the same way as the connecting channels - BindChannels(uninitializedBindingChannels); + AttachChannels(uninitializedBindingChannels); // notify parent thread about completion of first validation. { lock_guard lock(fInitialValidationMutex); @@ -246,7 +225,7 @@ void FairMQDevice::InitWrapper() int numAttempts = 0; while (!uninitializedConnectingChannels.empty()) { - ConnectChannels(uninitializedConnectingChannels); + AttachChannels(uninitializedConnectingChannels); if (++numAttempts > fMaxInitializationAttempts) { LOG(ERROR) << "could not connect all channels after " << fMaxInitializationAttempts << " attempts"; @@ -256,7 +235,7 @@ void FairMQDevice::InitWrapper() if (numAttempts != 0) { - std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + this_thread::sleep_for(chrono::milliseconds(1000)); } } @@ -275,44 +254,31 @@ void FairMQDevice::Init() { } -bool FairMQDevice::BindChannel(FairMQChannel& ch) -{ - LOG(DEBUG) << "Initializing channel " << ch.fChannelName << " (" << ch.fType << ")"; - // initialize the socket - ch.fSocket = fTransportFactory->CreateSocket(ch.fType, ch.fChannelName, fNumIoThreads, fId); - // set high water marks - ch.fSocket->SetOption("snd-hwm", &(ch.fSndBufSize), sizeof(ch.fSndBufSize)); - ch.fSocket->SetOption("rcv-hwm", &(ch.fRcvBufSize), sizeof(ch.fRcvBufSize)); - - LOG(DEBUG) << "Binding channel " << ch.fChannelName << " on " << ch.fAddress; - - return BindEndpoint(*ch.fSocket, ch.fAddress); -} - -bool FairMQDevice::ConnectChannel(FairMQChannel& ch) -{ - LOG(DEBUG) << "Initializing channel " << ch.fChannelName << " (" << ch.fType << ")"; - // initialize the socket - ch.fSocket = fTransportFactory->CreateSocket(ch.fType, ch.fChannelName, fNumIoThreads, fId); - // set high water marks - ch.fSocket->SetOption("snd-hwm", &(ch.fSndBufSize), sizeof(ch.fSndBufSize)); - ch.fSocket->SetOption("rcv-hwm", &(ch.fRcvBufSize), sizeof(ch.fRcvBufSize)); - // connect - LOG(DEBUG) << "Connecting channel " << ch.fChannelName << " to " << ch.fAddress; - ConnectEndpoint(*ch.fSocket, ch.fAddress); - return true; -} - bool FairMQDevice::AttachChannel(FairMQChannel& ch) { - std::vector endpoints; + if (!ch.fTransportFactory) + { + if (ch.fTransport == "default" || ch.fTransport == fDefaultTransport) + { + LOG(DEBUG) << ch.fName << ": using default transport"; + ch.InitTransport(fTransportFactory); + } + else + { + LOG(DEBUG) << ch.fName << ": channel transport (" << fDefaultTransport << ") overriden to " << ch.fTransport; + ch.InitTransport(AddTransport(ch.fTransport)); + } + ch.fTransportType = ch.fTransportFactory->GetType(); + } + + vector endpoints; FairMQChannel::Tokenize(endpoints, ch.fAddress); for (auto& endpoint : endpoints) { //(re-)init socket if (!ch.fSocket) { - ch.fSocket = fTransportFactory->CreateSocket(ch.fType, ch.fChannelName, fNumIoThreads, fId); + ch.fSocket = ch.fTransportFactory->CreateSocket(ch.fType, ch.fName, fNumIoThreads, fId); } // set high water marks @@ -332,7 +298,7 @@ bool FairMQDevice::AttachChannel(FairMQChannel& ch) // attach bool bind = (ch.fMethod == "bind"); bool connectionModifier = false; - std::string address = endpoint; + string address = endpoint; // check if the default fMethod is overridden by a modifier if (endpoint[0] == '+' || endpoint[0] == '>') @@ -368,7 +334,7 @@ bool FairMQDevice::AttachChannel(FairMQChannel& ch) } endpoint += address; - LOG(DEBUG) << "Attached channel " << ch.fChannelName << " to " << endpoint << (bind ? " (bind) " : " (connect) "); + LOG(DEBUG) << "Attached channel " << ch.fName << " to " << endpoint << (bind ? " (bind) " : " (connect) "); // after the book keeping is done, exit in case of errors if (!rc) @@ -383,34 +349,32 @@ bool FairMQDevice::AttachChannel(FairMQChannel& ch) return true; } -bool FairMQDevice::ConnectEndpoint(FairMQSocket& socket, std::string& endpoint) +bool FairMQDevice::ConnectEndpoint(FairMQSocket& socket, string& endpoint) { socket.Connect(endpoint); return true; } -bool FairMQDevice::BindEndpoint(FairMQSocket& socket, std::string& endpoint) +bool FairMQDevice::BindEndpoint(FairMQSocket& socket, string& endpoint) { // number of attempts when choosing a random port int maxAttempts = 1000; int numAttempts = 0; // initialize random generator - std::default_random_engine generator(std::chrono::system_clock::now().time_since_epoch().count()); - std::uniform_int_distribution randomPort(fPortRangeMin, fPortRangeMax); + default_random_engine generator(chrono::system_clock::now().time_since_epoch().count()); + uniform_int_distribution randomPort(fPortRangeMin, fPortRangeMax); // try to bind to the saved port. In case of failure, try random one. while (!socket.Bind(endpoint)) { - LOG(DEBUG) << "Could not bind to configured (TCP) port, trying random port in range " - << fPortRangeMin << "-" << fPortRangeMax; + LOG(DEBUG) << "Could not bind to configured (TCP) port, trying random port in range " << fPortRangeMin << "-" << fPortRangeMax; ++numAttempts; if (numAttempts > maxAttempts) { - LOG(ERROR) << "could not bind to any (TCP) port in the given range after " - << maxAttempts << " attempts"; + LOG(ERROR) << "could not bind to any (TCP) port in the given range after " << maxAttempts << " attempts"; return false; } @@ -453,7 +417,7 @@ void FairMQDevice::SortChannel(const string& name, const bool reindex) // set channel name: name + vector index stringstream ss; ss << name << "[" << vi - fChannels.at(name).begin() << "]"; - vi->fChannelName = ss.str(); + vi->fName = ss.str(); } } } @@ -469,7 +433,7 @@ void FairMQDevice::PrintChannel(const string& name) { for (auto vi = fChannels[name].begin(); vi != fChannels[name].end(); ++vi) { - LOG(INFO) << vi->fChannelName << ": " + LOG(INFO) << vi->fName << ": " << vi->fType << " | " << vi->fMethod << " | " << vi->fAddress << " | " @@ -488,104 +452,53 @@ void FairMQDevice::OnData(const string& channelName, InputMsgCallback callback) { fDataCallbacks = true; fMsgInputs.insert(make_pair(channelName, callback)); + + if (find(fInputChannelKeys.begin(), fInputChannelKeys.end(), channelName) == fInputChannelKeys.end()) + { + fInputChannelKeys.push_back(channelName); + } } void FairMQDevice::OnData(const string& channelName, InputMultipartCallback callback) { fDataCallbacks = true; fMultipartInputs.insert(make_pair(channelName, callback)); + + if (find(fInputChannelKeys.begin(), fInputChannelKeys.end(), channelName) == fInputChannelKeys.end()) + { + fInputChannelKeys.push_back(channelName); + } } void FairMQDevice::RunWrapper() { LOG(INFO) << "DEVICE: Running..."; - std::thread rateLogger(&FairMQDevice::LogSocketRates, this); + // start the rate logger thread + thread rateLogger(&FairMQDevice::LogSocketRates, this); + // notify channels to resume transfers FairMQChannel::fInterrupted = false; - fCmdSocket->Resume(); + for (auto& kv : fDeviceCmdSockets) + { + kv.second->Resume(); + } try { PreRun(); + // process either data callbacks or ConditionalRun/Run if (fDataCallbacks) { - bool exitingRunningCallback = false; - - vector inputChannelKeys; - for (const auto& i: fMsgInputs) + // if only one input channel, do lightweight handling without additional polling. + if (fInputChannelKeys.size() == 1 && fChannels.at(fInputChannelKeys.at(0)).size() == 1) { - inputChannelKeys.push_back(i.first); + HandleSingleChannelInput(); } - for (const auto& i: fMultipartInputs) + else // otherwise do full handling with polling { - inputChannelKeys.push_back(i.first); - } - - FairMQPollerPtr poller(fTransportFactory->CreatePoller(fChannels, inputChannelKeys)); - - while (CheckCurrentState(RUNNING) && !exitingRunningCallback) - { - poller->Poll(200); - - for (const auto& mi : fMsgInputs) - { - for (unsigned int i = 0; i < fChannels.at(mi.first).size(); ++i) - { - if (poller->CheckInput(mi.first, i)) - { - unique_ptr msg(NewMessage()); - - if (Receive(msg, mi.first, i) >= 0) - { - if (mi.second(msg, i) == false) - { - exitingRunningCallback = true; - break; - } - } - else - { - exitingRunningCallback = true; - break; - } - } - } - if (exitingRunningCallback) - { - break; - } - } - - for (const auto& mi : fMultipartInputs) - { - for (unsigned int i = 0; i < fChannels.at(mi.first).size(); ++i) - { - if (poller->CheckInput(mi.first, i)) - { - FairMQParts parts; - - if (Receive(parts, mi.first, i) >= 0) - { - if (mi.second(parts, i) == false) - { - exitingRunningCallback = true; - break; - } - } - else - { - exitingRunningCallback = true; - break; - } - } - } - if (exitingRunningCallback) - { - break; - } - } + HandleMultipleChannelInput(); } } else @@ -601,24 +514,213 @@ void FairMQDevice::RunWrapper() } catch (const out_of_range& oor) { - LOG(ERROR) << "Out of Range error: " << oor.what(); - LOG(ERROR) << "Incorrect channel name in the Run() or the configuration?"; - ChangeState(ERROR); + LOG(ERROR) << "out of range: " << oor.what(); + LOG(ERROR) << "incorrect/incomplete channel configuration?"; } + // if Run() exited and the state is still RUNNING, transition to READY. if (CheckCurrentState(RUNNING)) { ChangeState(internal_READY); } + rateLogger.join(); +} + +void FairMQDevice::HandleSingleChannelInput() +{ + bool proceed = true; + + if (fMsgInputs.size() > 0) + { + while (CheckCurrentState(RUNNING) && proceed) + { + proceed = HandleMsgInput(fInputChannelKeys.at(0), fMsgInputs.begin()->second, 0); + } + } + else if (fMultipartInputs.size() > 0) + { + while (CheckCurrentState(RUNNING) && proceed) + { + proceed = HandleMultipartInput(fInputChannelKeys.at(0), fMultipartInputs.begin()->second, 0); + } + } +} + +void FairMQDevice::HandleMultipleChannelInput() +{ + // check if more than one transport is used + fMultitransportInputs.clear(); + for (const auto& k : fInputChannelKeys) + { + FairMQ::Transport t = fChannels.at(k).at(0).fTransportType; + if (fMultitransportInputs.find(t) == fMultitransportInputs.end()) + { + fMultitransportInputs.insert(pair>(t, vector())); + fMultitransportInputs.at(t).push_back(k); + } + else + { + fMultitransportInputs.at(t).push_back(k); + } + } + + for (const auto& mi : fMsgInputs) + { + for (unsigned int i = 0; i < fChannels.at(mi.first).size(); ++i) + { + fChannels.at(mi.first).at(i).fMultipart = false; + } + } + + for (const auto& mi : fMultipartInputs) + { + for (unsigned int i = 0; i < fChannels.at(mi.first).size(); ++i) + { + fChannels.at(mi.first).at(i).fMultipart = true; + } + } + + // if more than one transport is used, handle poll of each in a separate thread + if (fMultitransportInputs.size() > 1) + { + HandleMultipleTransportInput(); + } + else // otherwise poll directly + { + bool proceed = true; + + FairMQPollerPtr poller(fChannels.at(fInputChannelKeys.at(0)).at(0).fTransportFactory->CreatePoller(fChannels, fInputChannelKeys)); + + while (CheckCurrentState(RUNNING) && proceed) + { + poller->Poll(200); + + // check which inputs are ready and call their data handlers if they are. + for (const auto& ch : fInputChannelKeys) + { + for (unsigned int i = 0; i < fChannels.at(ch).size(); ++i) + { + if (poller->CheckInput(ch, i)) + { + if (fChannels.at(ch).at(i).fMultipart) + { + proceed = HandleMultipartInput(ch, fMultipartInputs.at(ch), i); + } + else + { + proceed = HandleMsgInput(ch, fMsgInputs.at(ch), i); + } + + if (!proceed) + { + break; + } + } + } + if (!proceed) + { + break; + } + } + } + } +} + +void FairMQDevice::HandleMultipleTransportInput() +{ + vector threads; + + fMultitransportProceed = true; + + for (const auto& i : fMultitransportInputs) + { + threads.push_back(thread(&FairMQDevice::PollForTransport, this, fTransports.at(i.first).get(), i.second)); + } + + for (thread& t : threads) + { + t.join(); + } +} + +void FairMQDevice::PollForTransport(const FairMQTransportFactory* factory, const vector& channelKeys) +{ try { - rateLogger.join(); + FairMQPollerPtr poller(factory->CreatePoller(fChannels, channelKeys)); + + while (CheckCurrentState(RUNNING) && fMultitransportProceed) + { + poller->Poll(500); + + for (const auto& ch : channelKeys) + { + for (unsigned int i = 0; i < fChannels.at(ch).size(); ++i) + { + if (poller->CheckInput(ch, i)) + { + lock_guard lock(fMultitransportMutex); + + if (!fMultitransportProceed) + { + break; + } + + if (fChannels.at(ch).at(i).fMultipart) + { + fMultitransportProceed = HandleMultipartInput(ch, fMultipartInputs.at(ch), i); + } + else + { + fMultitransportProceed = HandleMsgInput(ch, fMsgInputs.at(ch), i); + } + + if (!fMultitransportProceed) + { + break; + } + } + } + if (!fMultitransportProceed) + { + break; + } + } + } } - catch (exception& e) + catch (std::exception& e) { - LOG(ERROR) << "Exception cought during Run(): " << e.what(); - exit(EXIT_FAILURE); + LOG(ERROR) << "FairMQDevice::PollForTransport() failed: " << e.what() << ", going to ERROR state."; + ChangeState(ERROR_FOUND); + } +} + +bool FairMQDevice::HandleMsgInput(const string& chName, const InputMsgCallback& callback, int i) const +{ + unique_ptr input(fChannels.at(chName).at(i).fTransportFactory->CreateMessage()); + + if (Receive(input, chName, i) >= 0) + { + return callback(input, 0); + } + else + { + return false; + } +} + +bool FairMQDevice::HandleMultipartInput(const string& chName, const InputMultipartCallback& callback, int i) const +{ + FairMQParts input; + + if (Receive(input, chName, i) >= 0) + { + return callback(input, 0); + } + else + { + return false; } } @@ -643,7 +745,7 @@ void FairMQDevice::Pause() { while (CheckCurrentState(PAUSED)) { - std::this_thread::sleep_for(std::chrono::milliseconds(500)); + this_thread::sleep_for(chrono::milliseconds(500)); LOG(DEBUG) << "paused..."; } LOG(DEBUG) << "Unpausing"; @@ -759,38 +861,96 @@ int FairMQDevice::GetProperty(const int key, const int default_ /*= 0*/) } } +// DEPRECATED, use the string version void FairMQDevice::SetTransport(FairMQTransportFactory* factory) { - fTransportFactory = shared_ptr(factory); + if (fTransports.empty()) + { + fTransportFactory = shared_ptr(factory); + pair> t(fTransportFactory->GetType(), fTransportFactory); + fTransports.insert(t); + } + else + { + LOG(ERROR) << "Transports container is not empty when setting transport. Setting twice?"; + ChangeState(ERROR_FOUND); + } +} + +shared_ptr FairMQDevice::AddTransport(const string& transport) +{ + unordered_map>::const_iterator i = fTransports.find(FairMQ::TransportTypes.at(transport)); + + if (i == fTransports.end()) + { + shared_ptr tr; + + if (transport == "zeromq") + { + tr = make_shared(); + } + else if (transport == "shmem") + { + tr = make_shared(); + } +#ifdef NANOMSG_FOUND + else if (transport == "nanomsg") + { + tr = make_shared(); + } +#endif + else + { + LOG(ERROR) << "Unavailable transport requested: " + << "\"" << transport << "\"" + << ". Available are: " + << "\"zeromq\"" + << "\"shmem\"" +#ifdef NANOMSG_FOUND + << ", \"nanomsg\"" +#endif + << ". Exiting."; + exit(EXIT_FAILURE); + } + + LOG(DEBUG) << "Adding '" << transport << "' transport to the device."; + + pair> trPair(FairMQ::TransportTypes.at(transport), tr); + fTransports.insert(trPair); + + auto p = fDeviceCmdSockets.emplace(tr->GetType(), tr->CreateSocket("pub", "device-commands", fNumIoThreads, fId)); + if (p.second) + { + p.first->second->Bind("inproc://commands"); + } + else + { + exit(EXIT_FAILURE); + } + + FairMQMessagePtr msg(tr->CreateMessage()); + msg->SetDeviceId(fId); + + return move(tr); + } + else + { + LOG(DEBUG) << "Reusing existing '" << transport << "' transport."; + return i->second; + } } void FairMQDevice::SetTransport(const string& transport) { - if (transport == "zeromq") + if (fTransports.empty()) { - fTransportFactory = make_shared(); + LOG(DEBUG) << "Requesting '" << transport << "' as default transport for the device"; + fTransportFactory = AddTransport(transport); } - else if (transport == "shmem") - { - fTransportFactory = make_shared(); - } -#ifdef NANOMSG_FOUND - else if (transport == "nanomsg") - { - fTransportFactory = make_shared(); - } -#endif else { - LOG(ERROR) << "Unavailable transport implementation requested: " - << "\"" << transport << "\"" - << ". Available are: " - << "\"zeromq\"" -#ifdef NANOMSG_FOUND - << ", \"nanomsg\"" -#endif - << ". Exiting."; - exit(EXIT_FAILURE); + LOG(ERROR) << "Transports container is not empty when setting transport. Setting default twice?"; + ChangeState(ERROR_FOUND); } } @@ -799,7 +959,8 @@ void FairMQDevice::SetConfig(FairMQProgOptions& config) LOG(DEBUG) << "PID: " << getpid(); fConfig = &config; fChannels = config.GetFairMQMap(); - SetTransport(config.GetValue("transport")); + fDefaultTransport = config.GetValue("transport"); + SetTransport(fDefaultTransport); fId = config.GetValue("id"); fNetworkInterface = config.GetValue("network-interface"); fNumIoThreads = config.GetValue("io-threads"); @@ -812,7 +973,6 @@ void FairMQDevice::LogSocketRates() timestamp_t msSinceLastLog; - int numFilteredSockets = 0; vector filteredSockets; vector filteredChannelNames; vector logIntervals; @@ -832,83 +992,87 @@ void FairMQDevice::LogSocketRates() stringstream ss; ss << mi.first << "[" << vi - (mi.second).begin() << "]"; filteredChannelNames.push_back(ss.str()); - ++numFilteredSockets; } } } - vector bytesIn(numFilteredSockets); - vector msgIn(numFilteredSockets); - vector bytesOut(numFilteredSockets); - vector msgOut(numFilteredSockets); + unsigned int numFilteredSockets = filteredSockets.size(); - vector bytesInNew(numFilteredSockets); - vector msgInNew(numFilteredSockets); - vector bytesOutNew(numFilteredSockets); - vector msgOutNew(numFilteredSockets); - - vector mbPerSecIn(numFilteredSockets); - vector msgPerSecIn(numFilteredSockets); - vector mbPerSecOut(numFilteredSockets); - vector msgPerSecOut(numFilteredSockets); - - int i = 0; - for (const auto& vi : filteredSockets) + if (numFilteredSockets > 0) { - bytesIn.at(i) = vi->GetBytesRx(); - bytesOut.at(i) = vi->GetBytesTx(); - msgIn.at(i) = vi->GetMessagesRx(); - msgOut.at(i) = vi->GetMessagesTx(); - ++i; - } + vector bytesIn(numFilteredSockets); + vector msgIn(numFilteredSockets); + vector bytesOut(numFilteredSockets); + vector msgOut(numFilteredSockets); - t0 = get_timestamp(); + vector bytesInNew(numFilteredSockets); + vector msgInNew(numFilteredSockets); + vector bytesOutNew(numFilteredSockets); + vector msgOutNew(numFilteredSockets); - while (CheckCurrentState(RUNNING)) - { - t1 = get_timestamp(); - - msSinceLastLog = (t1 - t0) / 1000.0L; - - i = 0; + vector mbPerSecIn(numFilteredSockets); + vector msgPerSecIn(numFilteredSockets); + vector mbPerSecOut(numFilteredSockets); + vector msgPerSecOut(numFilteredSockets); + int i = 0; for (const auto& vi : filteredSockets) { - intervalCounters.at(i)++; - - if (intervalCounters.at(i) == logIntervals.at(i)) - { - intervalCounters.at(i) = 0; - - bytesInNew.at(i) = vi->GetBytesRx(); - mbPerSecIn.at(i) = (static_cast(bytesInNew.at(i) - bytesIn.at(i)) / (1024. * 1024.)) / static_cast(msSinceLastLog) * 1000.; - bytesIn.at(i) = bytesInNew.at(i); - - msgInNew.at(i) = vi->GetMessagesRx(); - msgPerSecIn.at(i) = static_cast(msgInNew.at(i) - msgIn.at(i)) / static_cast(msSinceLastLog) * 1000.; - msgIn.at(i) = msgInNew.at(i); - - bytesOutNew.at(i) = vi->GetBytesTx(); - mbPerSecOut.at(i) = (static_cast(bytesOutNew.at(i) - bytesOut.at(i)) / (1024. * 1024.)) / static_cast(msSinceLastLog) * 1000.; - bytesOut.at(i) = bytesOutNew.at(i); - - msgOutNew.at(i) = vi->GetMessagesTx(); - msgPerSecOut.at(i) = static_cast(msgOutNew.at(i) - msgOut.at(i)) / static_cast(msSinceLastLog) * 1000.; - msgOut.at(i) = msgOutNew.at(i); - - LOG(DEBUG) << filteredChannelNames.at(i) << ": " - << "in: " << msgPerSecIn.at(i) << " msg (" << mbPerSecIn.at(i) << " MB), " - << "out: " << msgPerSecOut.at(i) << " msg (" << mbPerSecOut.at(i) << " MB)"; - } - + bytesIn.at(i) = vi->GetBytesRx(); + bytesOut.at(i) = vi->GetBytesTx(); + msgIn.at(i) = vi->GetMessagesRx(); + msgOut.at(i) = vi->GetMessagesTx(); ++i; } - t0 = t1; - std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + t0 = get_timestamp(); + + while (CheckCurrentState(RUNNING)) + { + t1 = get_timestamp(); + + msSinceLastLog = (t1 - t0) / 1000.0L; + + i = 0; + + for (const auto& vi : filteredSockets) + { + intervalCounters.at(i)++; + + if (intervalCounters.at(i) == logIntervals.at(i)) + { + intervalCounters.at(i) = 0; + + bytesInNew.at(i) = vi->GetBytesRx(); + mbPerSecIn.at(i) = (static_cast(bytesInNew.at(i) - bytesIn.at(i)) / (1000. * 1000.)) / static_cast(msSinceLastLog) * 1000.; + bytesIn.at(i) = bytesInNew.at(i); + + msgInNew.at(i) = vi->GetMessagesRx(); + msgPerSecIn.at(i) = static_cast(msgInNew.at(i) - msgIn.at(i)) / static_cast(msSinceLastLog) * 1000.; + msgIn.at(i) = msgInNew.at(i); + + bytesOutNew.at(i) = vi->GetBytesTx(); + mbPerSecOut.at(i) = (static_cast(bytesOutNew.at(i) - bytesOut.at(i)) / (1000. * 1000.)) / static_cast(msSinceLastLog) * 1000.; + bytesOut.at(i) = bytesOutNew.at(i); + + msgOutNew.at(i) = vi->GetMessagesTx(); + msgPerSecOut.at(i) = static_cast(msgOutNew.at(i) - msgOut.at(i)) / static_cast(msSinceLastLog) * 1000.; + msgOut.at(i) = msgOutNew.at(i); + + LOG(DEBUG) << filteredChannelNames.at(i) << ": " + << "in: " << msgPerSecIn.at(i) << " msg (" << mbPerSecIn.at(i) << " MB), " + << "out: " << msgPerSecOut.at(i) << " msg (" << mbPerSecOut.at(i) << " MB)"; + } + + ++i; + } + + t0 = t1; + this_thread::sleep_for(chrono::milliseconds(1000)); + } } - LOG(DEBUG) << "FairMQDevice::LogSocketRates() stopping"; + // LOG(DEBUG) << "FairMQDevice::LogSocketRates() stopping"; } void FairMQDevice::InteractiveStateLoop() @@ -1011,9 +1175,12 @@ void FairMQDevice::InteractiveStateLoop() void FairMQDevice::Unblock() { FairMQChannel::fInterrupted = true; - fCmdSocket->Interrupt(); - FairMQMessagePtr cmd(fTransportFactory->CreateMessage()); - fCmdSocket->Send(cmd.get(), 0); + for (auto& kv : fDeviceCmdSockets) + { + kv.second->Interrupt(); + FairMQMessagePtr cmd(fTransports.at(kv.first)->CreateMessage()); + kv.second->Send(cmd); + } } void FairMQDevice::ResetTaskWrapper() @@ -1047,8 +1214,8 @@ void FairMQDevice::Reset() vi.fPoller = nullptr; - vi.fCmdSocket->Close(); - vi.fCmdSocket = nullptr; + vi.fChannelCmdSocket->Close(); + vi.fChannelCmdSocket = nullptr; } } } @@ -1061,10 +1228,14 @@ bool FairMQDevice::Terminated() void FairMQDevice::Terminate() { // Termination signal has to be sent only once to any socket. - if (fCmdSocket) + for (auto& kv : fDeviceCmdSockets) { - fCmdSocket->Terminate(); + kv.second->Terminate(); } + // if (!fDeviceCmdSockets.empty()) + // { + // fDeviceCmdSockets[0]->Terminate(); + // } } void FairMQDevice::Shutdown() @@ -1081,18 +1252,23 @@ void FairMQDevice::Shutdown() { vi.fSocket->Close(); } - if (vi.fCmdSocket) + if (vi.fChannelCmdSocket) { - vi.fCmdSocket->Close(); + vi.fChannelCmdSocket->Close(); } } } - if (fCmdSocket) + for (auto& s : fDeviceCmdSockets) { - fCmdSocket->Close(); + s.second->Close(); } + // if (!fDeviceCmdSockets.empty()) + // { + // fDeviceCmdSockets[0]->Close(); + // } + LOG(DEBUG) << "Closed all sockets!"; } diff --git a/fairmq/FairMQDevice.h b/fairmq/FairMQDevice.h index 7401e24d..640ac76a 100644 --- a/fairmq/FairMQDevice.h +++ b/fairmq/FairMQDevice.h @@ -17,6 +17,7 @@ #include #include // unique_ptr +#include // std::sort() #include #include #include @@ -30,6 +31,7 @@ #include "FairMQConfigurable.h" #include "FairMQStateMachine.h" #include "FairMQTransportFactory.h" +#include "FairMQTransports.h" #include "FairMQSocket.h" #include "FairMQChannel.h" @@ -96,83 +98,103 @@ class FairMQDevice : public FairMQStateMachine, public FairMQConfigurable Deserializer().Deserialize(msg, std::forward(data), std::forward(args)...); } + inline int Send(FairMQMessagePtr& msg, const std::string& chan, const int i = 0) const + { + return fChannels.at(chan).at(i).Send(msg); + } + + inline int Receive(FairMQMessagePtr& msg, const std::string& chan, const int i = 0) const + { + return fChannels.at(chan).at(i).Receive(msg); + } + /// Shorthand method to send `msg` on `chan` at index `i` /// @param msg message reference /// @param chan channel name /// @param i channel index /// @return Number of bytes that have been queued. -2 If queueing was not possible or timed out. /// In case of errors, returns -1. - inline int Send(const FairMQMessagePtr& msg, const std::string& chan, const int i = 0, int sndTimeoutInMs = -1) const + inline int Send(FairMQMessagePtr& msg, const std::string& chan, const int i, int sndTimeoutInMs) const { return fChannels.at(chan).at(i).Send(msg, sndTimeoutInMs); } - /// Shorthand method to send `msg` on `chan` at index `i` without blocking - /// @param msg message reference - /// @param chan channel name - /// @param i channel index - /// @return Number of bytes that have been queued. -2 If queueing was not possible or timed out. - /// In case of errors, returns -1. - inline int SendAsync(const FairMQMessagePtr& msg, const std::string& chan, const int i = 0) const - { - return fChannels.at(chan).at(i).SendAsync(msg); - } - - /// Shorthand method to send FairMQParts on `chan` at index `i` - /// @param parts parts reference - /// @param chan channel name - /// @param i channel index - /// @return Number of bytes that have been queued. -2 If queueing was not possible or timed out. - /// In case of errors, returns -1. - inline int64_t Send(const FairMQParts& parts, const std::string& chan, const int i = 0, int sndTimeoutInMs = -1) const - { - return fChannels.at(chan).at(i).Send(parts.fParts, sndTimeoutInMs); - } - - /// Shorthand method to send FairMQParts on `chan` at index `i` without blocking - /// @param parts parts reference - /// @param chan channel name - /// @param i channel index - /// @return Number of bytes that have been queued. -2 If queueing was not possible or timed out. - /// In case of errors, returns -1. - inline int64_t SendAsync(const FairMQParts& parts, const std::string& chan, const int i = 0) const - { - return fChannels.at(chan).at(i).SendAsync(parts.fParts); - } - /// Shorthand method to receive `msg` on `chan` at index `i` /// @param msg message reference /// @param chan channel name /// @param i channel index /// @return Number of bytes that have been received. -2 If reading from the queue was not possible or timed out. /// In case of errors, returns -1. - inline int Receive(const FairMQMessagePtr& msg, const std::string& chan, const int i = 0, int rcvTimeoutInMs = -1) const + inline int Receive(FairMQMessagePtr& msg, const std::string& chan, const int i, int rcvTimeoutInMs) const { return fChannels.at(chan).at(i).Receive(msg, rcvTimeoutInMs); } + /// Shorthand method to send `msg` on `chan` at index `i` without blocking + /// @param msg message reference + /// @param chan channel name + /// @param i channel index + /// @return Number of bytes that have been queued. -2 If queueing was not possible or timed out. + /// In case of errors, returns -1. + inline int SendAsync(FairMQMessagePtr& msg, const std::string& chan, const int i = 0) const + { + return fChannels.at(chan).at(i).SendAsync(msg); + } + /// Shorthand method to receive `msg` on `chan` at index `i` without blocking /// @param msg message reference /// @param chan channel name /// @param i channel index /// @return Number of bytes that have been received. -2 If reading from the queue was not possible or timed out. /// In case of errors, returns -1. - inline int ReceiveAsync(const FairMQMessagePtr& msg, const std::string& chan, const int i = 0) const + inline int ReceiveAsync(FairMQMessagePtr& msg, const std::string& chan, const int i = 0) const { return fChannels.at(chan).at(i).ReceiveAsync(msg); } + inline int64_t Send(FairMQParts& parts, const std::string& chan, const int i = 0) const + { + return fChannels.at(chan).at(i).Send(parts.fParts); + } + + inline int64_t Receive(FairMQParts& parts, const std::string& chan, const int i = 0) const + { + return fChannels.at(chan).at(i).Receive(parts.fParts); + } + + /// Shorthand method to send FairMQParts on `chan` at index `i` + /// @param parts parts reference + /// @param chan channel name + /// @param i channel index + /// @return Number of bytes that have been queued. -2 If queueing was not possible or timed out. + /// In case of errors, returns -1. + inline int64_t Send(FairMQParts& parts, const std::string& chan, const int i, int sndTimeoutInMs) const + { + return fChannels.at(chan).at(i).Send(parts.fParts, sndTimeoutInMs); + } + /// Shorthand method to receive FairMQParts on `chan` at index `i` /// @param parts parts reference /// @param chan channel name /// @param i channel index /// @return Number of bytes that have been received. -2 If reading from the queue was not possible or timed out. /// In case of errors, returns -1. - inline int64_t Receive(FairMQParts& parts, const std::string& chan, const int i = 0, int rcvTimeoutInMs = -1) const + inline int64_t Receive(FairMQParts& parts, const std::string& chan, const int i, int rcvTimeoutInMs) const { return fChannels.at(chan).at(i).Receive(parts.fParts, rcvTimeoutInMs); } + /// Shorthand method to send FairMQParts on `chan` at index `i` without blocking + /// @param parts parts reference + /// @param chan channel name + /// @param i channel index + /// @return Number of bytes that have been queued. -2 If queueing was not possible or timed out. + /// In case of errors, returns -1. + inline int64_t SendAsync(FairMQParts& parts, const std::string& chan, const int i = 0) const + { + return fChannels.at(chan).at(i).SendAsync(parts.fParts); + } + /// Shorthand method to receive FairMQParts on `chan` at index `i` without blocking /// @param parts parts reference /// @param chan channel name @@ -199,25 +221,25 @@ class FairMQDevice : public FairMQStateMachine, public FairMQConfigurable return fTransportFactory->CreateMessage(size); } - template - static void FairMQSimpleMsgCleanup(void* /*data*/, void* hint) - { - delete static_cast(hint); - } - - static void FairMQNoCleanup(void* /*data*/, void* /*hint*/) - { - } - /// @brief Create new FairMQMessage with user provided buffer and size /// @param data pointer to user provided buffer /// @param size size of the user provided buffer - /// @param ffn optional callback, called when the message is transfered (and can be deleted) - /// @param hint optional helper pointer that can be used in the callback + /// @param ffn callback, called when the message is transfered (and can be deleted) + /// @param obj optional helper pointer that can be used in the callback /// @return pointer to FairMQMessage - inline FairMQMessagePtr NewMessage(void* data, int size, fairmq_free_fn* ffn, void* hint = nullptr) const + inline FairMQMessagePtr NewMessage(void* data, int size, fairmq_free_fn* ffn, void* obj = nullptr) const + { + return fTransportFactory->CreateMessage(data, size, ffn, obj); + } + + template + inline FairMQMessagePtr NewMessageFor(const std::string& channel, int index, Args&&... args) const + { + return fChannels.at(channel).at(index).fTransportFactory->CreateMessage(std::forward(args)...); + } + + static void FairMQNoCleanup(void* /*data*/, void* /*obj*/) { - return fTransportFactory->CreateMessage(data, size, ffn, hint); } template @@ -231,6 +253,23 @@ class FairMQDevice : public FairMQStateMachine, public FairMQConfigurable return fTransportFactory->CreateMessage(const_cast(str.c_str()), str.length(), FairMQNoCleanup, nullptr); } + template + inline FairMQMessagePtr NewStaticMessageFor(const std::string& channel, int index, const T& data) const + { + return fChannels.at(channel).at(index).fTransportFactory->CreateMessage(data, sizeof(T), FairMQNoCleanup, nullptr); + } + + inline FairMQMessagePtr NewStaticMessageFor(const std::string& channel, int index, const std::string& str) const + { + return fChannels.at(channel).at(index).fTransportFactory->CreateMessage(const_cast(str.c_str()), str.length(), FairMQNoCleanup, nullptr); + } + + template + static void FairMQSimpleMsgCleanup(void* /*data*/, void* obj) + { + delete static_cast(obj); + } + template inline FairMQMessagePtr NewSimpleMessage(const T& data) const { @@ -295,8 +334,11 @@ class FairMQDevice : public FairMQStateMachine, public FairMQConfigurable /// Configures the device with a transport factory (DEPRECATED) /// @param factory Pointer to the transport factory object void SetTransport(FairMQTransportFactory* factory); - /// Configures the device with a transport factory - /// @param transport Transport string ("zeromq"/"nanomsg") + /// Adds a transport to the device if it doesn't exist + /// @param transport Transport string ("zeromq"/"nanomsg"/"shmem") + std::shared_ptr AddTransport(const std::string& transport); + /// Sets the default transport for the device + /// @param transport Transport string ("zeromq"/"nanomsg"/"shmem") void SetTransport(const std::string& transport = "zeromq"); void SetConfig(FairMQProgOptions& config); @@ -317,6 +359,11 @@ class FairMQDevice : public FairMQStateMachine, public FairMQConfigurable { return (static_cast(this)->*memberFunction)(msg, index); })); + + if (find(fInputChannelKeys.begin(), fInputChannelKeys.end(), channelName) == fInputChannelKeys.end()) + { + fInputChannelKeys.push_back(channelName); + } } void OnData(const std::string& channelName, InputMsgCallback); @@ -329,6 +376,11 @@ class FairMQDevice : public FairMQStateMachine, public FairMQConfigurable { return (static_cast(this)->*memberFunction)(parts, index); })); + + if (find(fInputChannelKeys.begin(), fInputChannelKeys.end(), channelName) == fInputChannelKeys.end()) + { + fInputChannelKeys.push_back(channelName); + } } void OnData(const std::string& channelName, InputMultipartCallback); @@ -338,6 +390,7 @@ class FairMQDevice : public FairMQStateMachine, public FairMQConfigurable protected: std::string fId; ///< Device ID std::string fNetworkInterface; ///< Network interface to use for dynamic binding + std::string fDefaultTransport; ///< Default transport for the device int fMaxInitializationAttempts; ///< Timeout for the initialization @@ -348,9 +401,9 @@ class FairMQDevice : public FairMQStateMachine, public FairMQConfigurable int fLogIntervalInMs; ///< Interval for logging the socket transfer rates - FairMQSocketPtr fCmdSocket; ///< Socket used for the internal unblocking mechanism - std::shared_ptr fTransportFactory; ///< Transport factory + std::unordered_map> fTransports; ///< Container for transports + std::unordered_map fDeviceCmdSockets; ///< Sockets used for the internal unblocking mechanism /// Additional user initialization (can be overloaded in child classes). Prefer to use InitTask(). virtual void Init(); @@ -403,14 +456,8 @@ class FairMQDevice : public FairMQStateMachine, public FairMQConfigurable /// Unblocks blocking channel send/receive calls void Unblock(); - /// Binds channel in the list - void BindChannels(std::list& chans); - /// Connects channel in the list - void ConnectChannels(std::list& chans); - /// Binds a single channel (used in InitWrapper) - bool BindChannel(FairMQChannel& ch); - /// Connects a single channel (used in InitWrapper) - bool ConnectChannel(FairMQChannel& ch); + /// Attach (bind/connect) channels in the list + void AttachChannels(std::list& chans); /// Sets up and connects/binds a socket to an endpoint /// return a string with the actual endpoint if it happens @@ -422,6 +469,14 @@ class FairMQDevice : public FairMQStateMachine, public FairMQConfigurable /// to override default: prepend "@" to bind, "+" or ">" to connect endpoint. bool AttachChannel(FairMQChannel& ch); + void HandleSingleChannelInput(); + void HandleMultipleChannelInput(); + void HandleMultipleTransportInput(); + void PollForTransport(const FairMQTransportFactory* factory, const std::vector& channelKeys); + + bool HandleMsgInput(const std::string& chName, const InputMsgCallback& callback, int i) const; + bool HandleMultipartInput(const std::string& chName, const InputMultipartCallback& callback, int i) const; + /// Signal handler void SignalHandler(int signal); bool fCatchingSignals; @@ -432,6 +487,10 @@ class FairMQDevice : public FairMQStateMachine, public FairMQConfigurable bool fDataCallbacks; std::unordered_map fMsgInputs; std::unordered_map fMultipartInputs; + std::unordered_map> fMultitransportInputs; + std::vector fInputChannelKeys; + std::mutex fMultitransportMutex; + std::atomic fMultitransportProceed; }; #endif /* FAIRMQDEVICE_H_ */ diff --git a/fairmq/FairMQMessage.h b/fairmq/FairMQMessage.h index 71d56616..d34bc16e 100644 --- a/fairmq/FairMQMessage.h +++ b/fairmq/FairMQMessage.h @@ -18,6 +18,8 @@ #include // for size_t #include // unique_ptr +#include "FairMQTransports.h" + using fairmq_free_fn = void(void* data, void* hint); class FairMQMessage @@ -34,6 +36,8 @@ class FairMQMessage virtual void SetDeviceId(const std::string& deviceId) = 0; + virtual FairMQ::Transport GetType() const = 0; + virtual void Copy(const std::unique_ptr& msg) = 0; virtual ~FairMQMessage() {}; diff --git a/fairmq/FairMQSocket.h b/fairmq/FairMQSocket.h index c77241c2..36280323 100644 --- a/fairmq/FairMQSocket.h +++ b/fairmq/FairMQSocket.h @@ -40,12 +40,10 @@ class FairMQSocket virtual void Connect(const std::string& address) = 0; virtual bool Attach(const std::string& address, bool serverish = false); - virtual int Send(FairMQMessage* msg, const std::string& flag = "") = 0; - virtual int Send(FairMQMessage* msg, const int flags = 0) = 0; - virtual int64_t Send(const std::vector>& msgVec, const int flags = 0) = 0; + virtual int Send(FairMQMessagePtr& msg, const int flags = 0) = 0; + virtual int Receive(FairMQMessagePtr& msg, const int flags = 0) = 0; - virtual int Receive(FairMQMessage* msg, const std::string& flag = "") = 0; - virtual int Receive(FairMQMessage* msg, const int flags = 0) = 0; + virtual int64_t Send(std::vector>& msgVec, const int flags = 0) = 0; virtual int64_t Receive(std::vector>& msgVec, const int flags = 0) = 0; virtual void* GetSocket() const = 0; diff --git a/fairmq/FairMQStateMachine.h b/fairmq/FairMQStateMachine.h index a2a99a06..c2159e9f 100644 --- a/fairmq/FairMQStateMachine.h +++ b/fairmq/FairMQStateMachine.h @@ -99,7 +99,7 @@ struct FairMQFSM_ : public msmf::state_machine_def } template - void on_exit(Event const&, FSM& fsm) + void on_exit(Event const&, FSM& /*fsm*/) { LOG(STATE) << "Exiting FairMQ state machine"; } diff --git a/fairmq/FairMQTransportFactory.h b/fairmq/FairMQTransportFactory.h index c0957ab4..47069ff5 100644 --- a/fairmq/FairMQTransportFactory.h +++ b/fairmq/FairMQTransportFactory.h @@ -21,10 +21,10 @@ #include #include "FairMQMessage.h" -#include "FairMQChannel.h" #include "FairMQSocket.h" #include "FairMQPoller.h" #include "FairMQLogger.h" +#include "FairMQTransports.h" class FairMQChannel; @@ -41,6 +41,8 @@ class FairMQTransportFactory virtual FairMQPollerPtr CreatePoller(const std::unordered_map>& channelsMap, const std::vector& channelList) const = 0; virtual FairMQPollerPtr CreatePoller(const FairMQSocket& cmdSocket, const FairMQSocket& dataSocket) const = 0; + virtual FairMQ::Transport GetType() const = 0; + virtual ~FairMQTransportFactory() {}; }; diff --git a/fairmq/FairMQTransports.h b/fairmq/FairMQTransports.h new file mode 100644 index 00000000..2cfc89ab --- /dev/null +++ b/fairmq/FairMQTransports.h @@ -0,0 +1,47 @@ +/******************************************************************************** + * Copyright (C) 2014 GSI Helmholtzzentrum fuer Schwerionenforschung GmbH * + * * + * This software is distributed under the terms of the * + * GNU Lesser General Public Licence version 3 (LGPL) version 3, * + * copied verbatim in the file "LICENSE" * + ********************************************************************************/ +#ifndef FAIRMQTRANSPORTS_H_ +#define FAIRMQTRANSPORTS_H_ + +#include +#include + +namespace FairMQ +{ + +enum class Transport +{ + DEFAULT, + ZMQ, + NN, + SHM +}; + +static std::unordered_map TransportTypes { + { "default", Transport::DEFAULT }, + { "zeromq", Transport::ZMQ }, + { "nanomsg", Transport::NN }, + { "shmem", Transport::SHM } +}; + +} + +namespace std +{ + +template <> +struct hash +{ + size_t operator()(const FairMQ::Transport& v) const + { + return hash()(static_cast(v)); + } +}; + +} +#endif /* FAIRMQTRANSPORTS_H_ */ diff --git a/fairmq/devices/FairMQBenchmarkSampler.cxx b/fairmq/devices/FairMQBenchmarkSampler.cxx index 71593780..e077014b 100644 --- a/fairmq/devices/FairMQBenchmarkSampler.cxx +++ b/fairmq/devices/FairMQBenchmarkSampler.cxx @@ -51,10 +51,10 @@ void FairMQBenchmarkSampler::Run() uint64_t numSentMsgs = 0; - FairMQMessagePtr baseMsg(fTransportFactory->CreateMessage(fMsgSize)); - // store the channel reference to avoid traversing the map on every loop iteration - const FairMQChannel& dataOutChannel = fChannels.at(fOutChannelName).at(0); + FairMQChannel& dataOutChannel = fChannels.at(fOutChannelName).at(0); + + FairMQMessagePtr baseMsg(dataOutChannel.Transport()->CreateMessage(fMsgSize)); LOG(INFO) << "Starting the benchmark with message size of " << fMsgSize << " and number of messages " << fNumMsgs << "."; auto tStart = chrono::high_resolution_clock::now(); @@ -63,7 +63,7 @@ void FairMQBenchmarkSampler::Run() { if (fSameMessage) { - FairMQMessagePtr msg(fTransportFactory->CreateMessage()); + FairMQMessagePtr msg(dataOutChannel.Transport()->CreateMessage()); msg->Copy(baseMsg); if (dataOutChannel.Send(msg) >= 0) @@ -80,7 +80,7 @@ void FairMQBenchmarkSampler::Run() } else { - FairMQMessagePtr msg(fTransportFactory->CreateMessage(fMsgSize)); + FairMQMessagePtr msg(dataOutChannel.Transport()->CreateMessage(fMsgSize)); if (dataOutChannel.Send(msg) >= 0) { diff --git a/fairmq/devices/FairMQSink.cxx b/fairmq/devices/FairMQSink.cxx index 5547220d..6bb39c33 100644 --- a/fairmq/devices/FairMQSink.cxx +++ b/fairmq/devices/FairMQSink.cxx @@ -36,14 +36,14 @@ void FairMQSink::Run() { uint64_t numReceivedMsgs = 0; // store the channel reference to avoid traversing the map on every loop iteration - const FairMQChannel& dataInChannel = fChannels.at(fInChannelName).at(0); + FairMQChannel& dataInChannel = fChannels.at(fInChannelName).at(0); LOG(INFO) << "Starting the benchmark and expecting to receive " << fNumMsgs << " messages."; auto tStart = chrono::high_resolution_clock::now(); while (CheckCurrentState(RUNNING)) { - FairMQMessagePtr msg(fTransportFactory->CreateMessage()); + FairMQMessagePtr msg(dataInChannel.Transport()->CreateMessage()); if (dataInChannel.Receive(msg) >= 0) { diff --git a/fairmq/nanomsg/FairMQMessageNN.cxx b/fairmq/nanomsg/FairMQMessageNN.cxx index 605500af..2286f28d 100644 --- a/fairmq/nanomsg/FairMQMessageNN.cxx +++ b/fairmq/nanomsg/FairMQMessageNN.cxx @@ -22,6 +22,8 @@ using namespace std; +static FairMQ::Transport gTransportType = FairMQ::Transport::NN; + string FairMQMessageNN::fDeviceID = string(); FairMQMessageNN::FairMQMessageNN() @@ -145,6 +147,11 @@ void FairMQMessageNN::SetDeviceId(const string& deviceId) fDeviceID = deviceId; } +FairMQ::Transport FairMQMessageNN::GetType() const +{ + return gTransportType; +} + void FairMQMessageNN::Copy(const unique_ptr& msg) { if (fMessage) diff --git a/fairmq/nanomsg/FairMQMessageNN.h b/fairmq/nanomsg/FairMQMessageNN.h index e649e18f..57d08b4b 100644 --- a/fairmq/nanomsg/FairMQMessageNN.h +++ b/fairmq/nanomsg/FairMQMessageNN.h @@ -41,6 +41,8 @@ class FairMQMessageNN : public FairMQMessage virtual void SetDeviceId(const std::string& deviceId); + virtual FairMQ::Transport GetType() const; + virtual void Copy(const std::unique_ptr& msg); virtual ~FairMQMessageNN(); diff --git a/fairmq/nanomsg/FairMQPollerNN.cxx b/fairmq/nanomsg/FairMQPollerNN.cxx index 198be34c..1a2cd252 100644 --- a/fairmq/nanomsg/FairMQPollerNN.cxx +++ b/fairmq/nanomsg/FairMQPollerNN.cxx @@ -53,7 +53,7 @@ FairMQPollerNN::FairMQPollerNN(const vector& channels) } else { - LOG(ERROR) << "invalid poller configuration, exiting."; + LOG(ERROR) << "nanomsg: invalid poller configuration, exiting."; exit(EXIT_FAILURE); } } @@ -104,7 +104,7 @@ FairMQPollerNN::FairMQPollerNN(const unordered_map } else { - LOG(ERROR) << "invalid poller configuration, exiting."; + LOG(ERROR) << "nanomsg: invalid poller configuration, exiting."; exit(EXIT_FAILURE); } } @@ -112,8 +112,8 @@ FairMQPollerNN::FairMQPollerNN(const unordered_map } catch (const std::out_of_range& oor) { - LOG(ERROR) << "At least one of the provided channel keys for poller initialization is invalid"; - LOG(ERROR) << "Out of Range error: " << oor.what() << '\n'; + LOG(ERROR) << "nanomsg: at least one of the provided channel keys for poller initialization is invalid"; + LOG(ERROR) << "nanomsg: out of range error: " << oor.what() << '\n'; exit(EXIT_FAILURE); } } @@ -150,7 +150,7 @@ FairMQPollerNN::FairMQPollerNN(const FairMQSocket& cmdSocket, const FairMQSocket } else { - LOG(ERROR) << "invalid poller configuration, exiting."; + LOG(ERROR) << "nanomsg: invalid poller configuration, exiting."; exit(EXIT_FAILURE); } } @@ -161,11 +161,12 @@ void FairMQPollerNN::Poll(const int timeout) { if (errno == ETERM) { - LOG(DEBUG) << "polling exited, reason: " << nn_strerror(errno); + LOG(DEBUG) << "nanomsg: polling exited, reason: " << nn_strerror(errno); } else { - LOG(ERROR) << "polling failed, reason: " << nn_strerror(errno); + LOG(ERROR) << "nanomsg: polling failed, reason: " << nn_strerror(errno); + throw std::runtime_error("nanomsg: polling failed"); } } } @@ -203,8 +204,8 @@ bool FairMQPollerNN::CheckInput(const string channelKey, const int index) } catch (const std::out_of_range& oor) { - LOG(ERROR) << "Invalid channel key: \"" << channelKey << "\""; - LOG(ERROR) << "Out of Range error: " << oor.what() << '\n'; + LOG(ERROR) << "nanomsg: invalid channel key: \"" << channelKey << "\""; + LOG(ERROR) << "nanomsg: out of range error: " << oor.what() << '\n'; exit(EXIT_FAILURE); } } @@ -222,8 +223,8 @@ bool FairMQPollerNN::CheckOutput(const string channelKey, const int index) } catch (const std::out_of_range& oor) { - LOG(ERROR) << "Invalid channel key: \"" << channelKey << "\""; - LOG(ERROR) << "Out of Range error: " << oor.what() << '\n'; + LOG(ERROR) << "nanomsg: invalid channel key: \"" << channelKey << "\""; + LOG(ERROR) << "nanomsg: out of range error: " << oor.what() << '\n'; exit(EXIT_FAILURE); } } diff --git a/fairmq/nanomsg/FairMQSocketNN.cxx b/fairmq/nanomsg/FairMQSocketNN.cxx index bbd36bbd..c63d659d 100644 --- a/fairmq/nanomsg/FairMQSocketNN.cxx +++ b/fairmq/nanomsg/FairMQSocketNN.cxx @@ -29,6 +29,8 @@ using namespace std; +atomic FairMQSocketNN::fInterrupted(false); + FairMQSocketNN::FairMQSocketNN(const string& type, const string& name, const int numIoThreads, const string& id /*= ""*/) : FairMQSocket(0, 0, NN_DONTWAIT) , fSocket(-1) @@ -71,6 +73,18 @@ FairMQSocketNN::FairMQSocketNN(const string& type, const string& name, const int } } + int sndTimeout = 700; + if (nn_setsockopt(fSocket, NN_SOL_SOCKET, NN_SNDTIMEO, &sndTimeout, sizeof(sndTimeout)) != 0) + { + LOG(ERROR) << "Failed setting NN_SNDTIMEO socket option, reason: " << nn_strerror(errno); + } + + int rcvTimeout = 700; + if (nn_setsockopt(fSocket, NN_SOL_SOCKET, NN_RCVTIMEO, &rcvTimeout, sizeof(rcvTimeout)) != 0) + { + LOG(ERROR) << "Failed setting NN_RCVTIMEO socket option, reason: " << nn_strerror(errno); + } + #ifdef NN_RCVMAXSIZE int rcvSize = -1; nn_setsockopt(fSocket, NN_SOL_SOCKET, NN_RCVMAXSIZE, &rcvSize, sizeof(rcvSize)); @@ -108,105 +122,155 @@ void FairMQSocketNN::Connect(const string& address) } } -int FairMQSocketNN::Send(FairMQMessage* msg, const string& flag) +int FairMQSocketNN::Send(FairMQMessagePtr& msg, const int flags) { - return Send(msg, GetConstant(flag)); + int nbytes = -1; + + while (true) + { + void* ptr = msg->GetMessage(); + nbytes = nn_send(fSocket, &ptr, NN_MSG, flags); + if (nbytes >= 0) + { + fBytesTx += nbytes; + ++fMessagesTx; + static_cast(msg.get())->fReceiving = false; + + return nbytes; + } + else if (nn_errno() == ETIMEDOUT) + { + if (!fInterrupted && ((flags & NN_DONTWAIT) == 0)) + { + continue; + } + else + { + return -2; + } + } + else if (nn_errno() == EAGAIN) + { + return -2; + } + else if (nn_errno() == ETERM) + { + LOG(INFO) << "terminating socket " << fId; + return -1; + } + else + { + LOG(ERROR) << "Failed sending on socket " << fId << ", reason: " << nn_strerror(errno); + return nbytes; + } + } } -int FairMQSocketNN::Send(FairMQMessage* msg, const int flags) +int FairMQSocketNN::Receive(FairMQMessagePtr& msg, const int flags) { - void* ptr = msg->GetMessage(); - int nbytes = nn_send(fSocket, &ptr, NN_MSG, flags); - if (nbytes >= 0) + int nbytes = -1; + + while (true) { - fBytesTx += nbytes; - ++fMessagesTx; - static_cast(msg)->fReceiving = false; - return nbytes; + void* ptr = NULL; + nbytes = nn_recv(fSocket, &ptr, NN_MSG, flags); + if (nbytes >= 0) + { + fBytesRx += nbytes; + ++fMessagesRx; + msg->SetMessage(ptr, nbytes); + static_cast(msg.get())->fReceiving = true; + return nbytes; + } + else if (nn_errno() == ETIMEDOUT) + { + if (!fInterrupted && ((flags & NN_DONTWAIT) == 0)) + { + continue; + } + else + { + return -2; + } + } + else if (nn_errno() == EAGAIN) + { + return -2; + } + else if (nn_errno() == ETERM) + { + LOG(INFO) << "terminating socket " << fId; + return -1; + } + else + { + LOG(ERROR) << "Failed receiving on socket " << fId << ", reason: " << nn_strerror(errno); + return nbytes; + } } - if (nn_errno() == EAGAIN) - { - return -2; - } - if (nn_errno() == ETERM) - { - LOG(INFO) << "terminating socket " << fId; - return -1; - } - LOG(ERROR) << "Failed sending on socket " << fId << ", reason: " << nn_strerror(errno); - return nbytes; } -int64_t FairMQSocketNN::Send(const vector>& msgVec, const int flags) +int64_t FairMQSocketNN::Send(vector>& msgVec, const int flags) { #ifdef MSGPACK_FOUND + const unsigned int vecSize = msgVec.size(); + // create msgpack simple buffer msgpack::sbuffer sbuf; // create msgpack packer msgpack::packer packer(&sbuf); // pack all parts into a single msgpack simple buffer - for (unsigned int i = 0; i < msgVec.size(); ++i) + for (unsigned int i = 0; i < vecSize; ++i) { static_cast(msgVec[i].get())->fReceiving = false; packer.pack_bin(msgVec[i]->GetSize()); packer.pack_bin_body(static_cast(msgVec[i]->GetData()), msgVec[i]->GetSize()); } - int64_t nbytes = nn_send(fSocket, sbuf.data(), sbuf.size(), flags); - if (nbytes >= 0) + int64_t nbytes = -1; + + while (true) { - fBytesTx += nbytes; - ++fMessagesTx; - return nbytes; + nbytes = nn_send(fSocket, sbuf.data(), sbuf.size(), flags); + if (nbytes >= 0) + { + fBytesTx += nbytes; + ++fMessagesTx; + return nbytes; + } + else if (nn_errno() == ETIMEDOUT) + { + if (!fInterrupted && ((flags & NN_DONTWAIT) == 0)) + { + continue; + } + else + { + return -2; + } + } + else if (nn_errno() == EAGAIN) + { + return -2; + } + else if (nn_errno() == ETERM) + { + LOG(INFO) << "terminating socket " << fId; + return -1; + } + else + { + LOG(ERROR) << "Failed sending on socket " << fId << ", reason: " << nn_strerror(errno); + return nbytes; + } } - if (nn_errno() == EAGAIN) - { - return -2; - } - if (nn_errno() == ETERM) - { - LOG(INFO) << "terminating socket " << fId; - return -1; - } - LOG(ERROR) << "Failed sending on socket " << fId << ", reason: " << nn_strerror(errno); - return nbytes; #else /*MSGPACK_FOUND*/ - LOG(ERROR) << "Cannot send message from vector of size " << msgVec.size() << " and flags " << flags << " with nanomsg multipart because MessagePack is not available."; + LOG(ERROR) << "Cannot send message from vector of size " << vecSize << " and flags " << flags << " with nanomsg multipart because MessagePack is not available."; exit(EXIT_FAILURE); #endif /*MSGPACK_FOUND*/ } -int FairMQSocketNN::Receive(FairMQMessage* msg, const string& flag) -{ - return Receive(msg, GetConstant(flag)); -} - -int FairMQSocketNN::Receive(FairMQMessage* msg, const int flags) -{ - void* ptr = NULL; - int nbytes = nn_recv(fSocket, &ptr, NN_MSG, flags); - if (nbytes >= 0) - { - fBytesRx += nbytes; - ++fMessagesRx; - msg->SetMessage(ptr, nbytes); - static_cast(msg)->fReceiving = true; - return nbytes; - } - if (nn_errno() == EAGAIN) - { - return -2; - } - if (nn_errno() == ETERM) - { - LOG(INFO) << "terminating socket " << fId; - return -1; - } - LOG(ERROR) << "Failed receiving on socket " << fId << ", reason: " << nn_strerror(errno); - return nbytes; -} - int64_t FairMQSocketNN::Receive(vector>& msgVec, const int flags) { #ifdef MSGPACK_FOUND @@ -217,51 +281,68 @@ int64_t FairMQSocketNN::Receive(vector>& msgVec, const msgVec.clear(); } - // pointer to point to received message buffer - char* ptr = NULL; - // receive the message into a buffer allocated by nanomsg and let ptr point to it - int nbytes = nn_recv(fSocket, &ptr, NN_MSG, flags); - if (nbytes >= 0) // if no errors or non-blocking timeouts + while (true) { - // store statistics on how many bytes received - fBytesRx += nbytes; - // store statistics on how many messages received (count messages instead of parts) - ++fMessagesRx; - - // offset to be used by msgpack to handle separate chunks - size_t offset = 0; - while (offset != static_cast(nbytes)) // continue until all parts have been read + // pointer to point to received message buffer + char* ptr = NULL; + // receive the message into a buffer allocated by nanomsg and let ptr point to it + int nbytes = nn_recv(fSocket, &ptr, NN_MSG, flags); + if (nbytes >= 0) // if no errors or non-blocking timeouts { - // vector of chars to hold blob (unlike char*/void* this type can be converted to by msgpack) - std::vector buf; + // store statistics on how many bytes received + fBytesRx += nbytes; + // store statistics on how many messages received (count messages instead of parts) + ++fMessagesRx; - // unpack and convert chunk - msgpack::unpacked result; - unpack(result, ptr, nbytes, offset); - msgpack::object object(result.get()); - object.convert(buf); - // get the single message size - size_t size = buf.size() * sizeof(char); - unique_ptr part(new FairMQMessageNN(size)); - static_cast(part.get())->fReceiving = true; - memcpy(part->GetData(), buf.data(), size); - msgVec.push_back(move(part)); + // offset to be used by msgpack to handle separate chunks + size_t offset = 0; + while (offset != static_cast(nbytes)) // continue until all parts have been read + { + // vector of chars to hold blob (unlike char*/void* this type can be converted to by msgpack) + std::vector buf; + + // unpack and convert chunk + msgpack::unpacked result; + unpack(result, ptr, nbytes, offset); + msgpack::object object(result.get()); + object.convert(buf); + // get the single message size + size_t size = buf.size() * sizeof(char); + unique_ptr part(new FairMQMessageNN(size)); + static_cast(part.get())->fReceiving = true; + memcpy(part->GetData(), buf.data(), size); + msgVec.push_back(move(part)); + } + + nn_freemsg(ptr); + return nbytes; + } + else if (nn_errno() == ETIMEDOUT) + { + if (!fInterrupted && ((flags & NN_DONTWAIT) == 0)) + { + continue; + } + else + { + return -2; + } + } + else if (nn_errno() == EAGAIN) + { + return -2; + } + else if (nn_errno() == ETERM) + { + LOG(INFO) << "terminating socket " << fId; + return -1; + } + else + { + LOG(ERROR) << "Failed receiving on socket " << fId << ", reason: " << nn_strerror(errno); + return nbytes; } - - nn_freemsg(ptr); - return nbytes; } - if (nn_errno() == EAGAIN) - { - return -2; - } - if (nn_errno() == ETERM) - { - LOG(INFO) << "terminating socket " << fId; - return -1; - } - LOG(ERROR) << "Failed receiving on socket " << fId << ", reason: " << nn_strerror(errno); - return nbytes; #else /*MSGPACK_FOUND*/ LOG(ERROR) << "Cannot receive message into vector of size " << msgVec.size() << " and flags " << flags << " with nanomsg multipart because MessagePack is not available."; exit(EXIT_FAILURE); @@ -280,10 +361,12 @@ void FairMQSocketNN::Terminate() void FairMQSocketNN::Interrupt() { + fInterrupted = true; } void FairMQSocketNN::Resume() { + fInterrupted = false; } void* FairMQSocketNN::GetSocket() const diff --git a/fairmq/nanomsg/FairMQSocketNN.h b/fairmq/nanomsg/FairMQSocketNN.h index 181028d4..136f02f1 100644 --- a/fairmq/nanomsg/FairMQSocketNN.h +++ b/fairmq/nanomsg/FairMQSocketNN.h @@ -19,6 +19,7 @@ #include #include "FairMQSocket.h" +#include "FairMQMessage.h" class FairMQSocketNN : public FairMQSocket { @@ -32,12 +33,10 @@ class FairMQSocketNN : public FairMQSocket virtual bool Bind(const std::string& address); virtual void Connect(const std::string& address); - virtual int Send(FairMQMessage* msg, const std::string& flag = ""); - virtual int Send(FairMQMessage* msg, const int flags = 0); - virtual int64_t Send(const std::vector>& msgVec, const int flags = 0); + virtual int Send(FairMQMessagePtr& msg, const int flags = 0); + virtual int Receive(FairMQMessagePtr& msg, const int flags = 0); - virtual int Receive(FairMQMessage* msg, const std::string& flag = ""); - virtual int Receive(FairMQMessage* msg, const int flags = 0); + virtual int64_t Send(std::vector>& msgVec, const int flags = 0); virtual int64_t Receive(std::vector>& msgVec, const int flags = 0); virtual void* GetSocket() const; @@ -72,6 +71,7 @@ class FairMQSocketNN : public FairMQSocket std::atomic fBytesRx; std::atomic fMessagesTx; std::atomic fMessagesRx; + static std::atomic fInterrupted; }; #endif /* FAIRMQSOCKETNN_H_ */ diff --git a/fairmq/nanomsg/FairMQTransportFactoryNN.cxx b/fairmq/nanomsg/FairMQTransportFactoryNN.cxx index ce3ca56f..70164703 100644 --- a/fairmq/nanomsg/FairMQTransportFactoryNN.cxx +++ b/fairmq/nanomsg/FairMQTransportFactoryNN.cxx @@ -16,9 +16,11 @@ using namespace std; +static FairMQ::Transport gTransportType = FairMQ::Transport::NN; + FairMQTransportFactoryNN::FairMQTransportFactoryNN() { - LOG(INFO) << "Using nanomsg library"; + LOG(DEBUG) << "Transport: Using nanomsg library"; } FairMQMessagePtr FairMQTransportFactoryNN::CreateMessage() const @@ -55,3 +57,8 @@ FairMQPollerPtr FairMQTransportFactoryNN::CreatePoller(const FairMQSocket& cmdSo { return unique_ptr(new FairMQPollerNN(cmdSocket, dataSocket)); } + +FairMQ::Transport FairMQTransportFactoryNN::GetType() const +{ + return gTransportType; +} diff --git a/fairmq/nanomsg/FairMQTransportFactoryNN.h b/fairmq/nanomsg/FairMQTransportFactoryNN.h index f1dcd5ca..df24b2b3 100644 --- a/fairmq/nanomsg/FairMQTransportFactoryNN.h +++ b/fairmq/nanomsg/FairMQTransportFactoryNN.h @@ -16,6 +16,7 @@ #define FAIRMQTRANSPORTFACTORYNN_H_ #include +#include #include "FairMQTransportFactory.h" #include "FairMQMessageNN.h" @@ -37,6 +38,8 @@ class FairMQTransportFactoryNN : public FairMQTransportFactory virtual FairMQPollerPtr CreatePoller(const std::unordered_map>& channelsMap, const std::vector& channelList) const; virtual FairMQPollerPtr CreatePoller(const FairMQSocket& cmdSocket, const FairMQSocket& dataSocket) const; + virtual FairMQ::Transport GetType() const; + virtual ~FairMQTransportFactoryNN() {}; }; diff --git a/fairmq/options/FairMQParser.cxx b/fairmq/options/FairMQParser.cxx index 4fc152ae..5d6720ac 100644 --- a/fairmq/options/FairMQParser.cxx +++ b/fairmq/options/FairMQParser.cxx @@ -228,6 +228,7 @@ void ChannelParser(const boost::property_tree::ptree& tree, FairMQMap& channelMa commonChannel.UpdateType(q.second.get("type", commonChannel.GetType())); commonChannel.UpdateMethod(q.second.get("method", commonChannel.GetMethod())); commonChannel.UpdateAddress(q.second.get("address", commonChannel.GetAddress())); + commonChannel.UpdateTransport(q.second.get("transport", commonChannel.GetTransport())); commonChannel.UpdateSndBufSize(q.second.get("sndBufSize", commonChannel.GetSndBufSize())); commonChannel.UpdateRcvBufSize(q.second.get("rcvBufSize", commonChannel.GetRcvBufSize())); commonChannel.UpdateSndKernelSize(q.second.get("sndKernelSize", commonChannel.GetSndKernelSize())); @@ -246,6 +247,7 @@ void ChannelParser(const boost::property_tree::ptree& tree, FairMQMap& channelMa LOG(DEBUG) << "\ttype = " << commonChannel.GetType(); LOG(DEBUG) << "\tmethod = " << commonChannel.GetMethod(); LOG(DEBUG) << "\taddress = " << commonChannel.GetAddress(); + LOG(DEBUG) << "\ttransport = " << commonChannel.GetTransport(); LOG(DEBUG) << "\tsndBufSize = " << commonChannel.GetSndBufSize(); LOG(DEBUG) << "\trcvBufSize = " << commonChannel.GetRcvBufSize(); LOG(DEBUG) << "\tsndKernelSize = " << commonChannel.GetSndKernelSize(); @@ -289,6 +291,7 @@ void ChannelParser(const boost::property_tree::ptree& tree, FairMQMap& channelMa commonChannel.UpdateType(p.second.get("type", commonChannel.GetType())); commonChannel.UpdateMethod(p.second.get("method", commonChannel.GetMethod())); commonChannel.UpdateAddress(p.second.get("address", commonChannel.GetAddress())); + commonChannel.UpdateTransport(p.second.get("transport", commonChannel.GetTransport())); commonChannel.UpdateSndBufSize(p.second.get("sndBufSize", commonChannel.GetSndBufSize())); commonChannel.UpdateRcvBufSize(p.second.get("rcvBufSize", commonChannel.GetRcvBufSize())); commonChannel.UpdateSndKernelSize(p.second.get("sndKernelSize", commonChannel.GetSndKernelSize())); @@ -308,6 +311,7 @@ void ChannelParser(const boost::property_tree::ptree& tree, FairMQMap& channelMa LOG(DEBUG) << "\ttype = " << commonChannel.GetType(); LOG(DEBUG) << "\tmethod = " << commonChannel.GetMethod(); LOG(DEBUG) << "\taddress = " << commonChannel.GetAddress(); + LOG(DEBUG) << "\ttransport = " << commonChannel.GetTransport(); LOG(DEBUG) << "\tsndBufSize = " << commonChannel.GetSndBufSize(); LOG(DEBUG) << "\trcvBufSize = " << commonChannel.GetRcvBufSize(); LOG(DEBUG) << "\tsndKernelSize = " << commonChannel.GetSndKernelSize(); @@ -348,6 +352,7 @@ void SocketParser(const boost::property_tree::ptree& tree, vector channel.UpdateType(q.second.get("type", channel.GetType())); channel.UpdateMethod(q.second.get("method", channel.GetMethod())); channel.UpdateAddress(q.second.get("address", channel.GetAddress())); + channel.UpdateTransport(q.second.get("transport", channel.GetTransport())); channel.UpdateSndBufSize(q.second.get("sndBufSize", channel.GetSndBufSize())); channel.UpdateRcvBufSize(q.second.get("rcvBufSize", channel.GetRcvBufSize())); channel.UpdateSndKernelSize(q.second.get("sndKernelSize", channel.GetSndKernelSize())); @@ -358,6 +363,7 @@ void SocketParser(const boost::property_tree::ptree& tree, vector LOG(DEBUG) << "\ttype = " << channel.GetType(); LOG(DEBUG) << "\tmethod = " << channel.GetMethod(); LOG(DEBUG) << "\taddress = " << channel.GetAddress(); + LOG(DEBUG) << "\ttransport = " << channel.GetTransport(); LOG(DEBUG) << "\tsndBufSize = " << channel.GetSndBufSize(); LOG(DEBUG) << "\trcvBufSize = " << channel.GetRcvBufSize(); LOG(DEBUG) << "\tsndKernelSize = " << channel.GetSndKernelSize(); @@ -378,6 +384,7 @@ void SocketParser(const boost::property_tree::ptree& tree, vector channel.UpdateType(p.second.get("type", channel.GetType())); channel.UpdateMethod(p.second.get("method", channel.GetMethod())); channel.UpdateAddress(p.second.get("address", channel.GetAddress())); + channel.UpdateTransport(p.second.get("transport", channel.GetTransport())); channel.UpdateSndBufSize(p.second.get("sndBufSize", channel.GetSndBufSize())); channel.UpdateRcvBufSize(p.second.get("rcvBufSize", channel.GetRcvBufSize())); channel.UpdateSndKernelSize(p.second.get("sndKernelSize", channel.GetSndKernelSize())); @@ -388,6 +395,7 @@ void SocketParser(const boost::property_tree::ptree& tree, vector LOG(DEBUG) << "\ttype = " << channel.GetType(); LOG(DEBUG) << "\tmethod = " << channel.GetMethod(); LOG(DEBUG) << "\taddress = " << channel.GetAddress(); + LOG(DEBUG) << "\ttransport = " << channel.GetTransport(); LOG(DEBUG) << "\tsndBufSize = " << channel.GetSndBufSize(); LOG(DEBUG) << "\trcvBufSize = " << channel.GetRcvBufSize(); LOG(DEBUG) << "\tsndKernelSize = " << channel.GetSndKernelSize(); @@ -414,6 +422,7 @@ void SocketParser(const boost::property_tree::ptree& tree, vector LOG(DEBUG) << "\ttype = " << channel.GetType(); LOG(DEBUG) << "\tmethod = " << channel.GetMethod(); LOG(DEBUG) << "\taddress = " << channel.GetAddress(); + LOG(DEBUG) << "\ttransport = " << channel.GetTransport(); LOG(DEBUG) << "\tsndBufSize = " << channel.GetSndBufSize(); LOG(DEBUG) << "\trcvBufSize = " << channel.GetRcvBufSize(); LOG(DEBUG) << "\tsndKernelSize = " << channel.GetSndKernelSize(); diff --git a/fairmq/options/FairMQProgOptions.cxx b/fairmq/options/FairMQProgOptions.cxx index b228a076..91303348 100644 --- a/fairmq/options/FairMQProgOptions.cxx +++ b/fairmq/options/FairMQProgOptions.cxx @@ -233,6 +233,7 @@ void FairMQProgOptions::UpdateMQValues() string typeKey = p.first + "." + to_string(index) + ".type"; string methodKey = p.first + "." + to_string(index) + ".method"; string addressKey = p.first + "." + to_string(index) + ".address"; + string transportKey = p.first + "." + to_string(index) + ".transport"; string sndBufSizeKey = p.first + "." + to_string(index) + ".sndBufSize"; string rcvBufSizeKey = p.first + "." + to_string(index) + ".rcvBufSize"; string sndKernelSizeKey = p.first + "." + to_string(index) + ".sndKernelSize"; @@ -242,6 +243,7 @@ void FairMQProgOptions::UpdateMQValues() fMQKeyMap[typeKey] = make_tuple(p.first, index, "type"); fMQKeyMap[methodKey] = make_tuple(p.first, index, "method"); fMQKeyMap[addressKey] = make_tuple(p.first, index, "address"); + fMQKeyMap[transportKey] = make_tuple(p.first, index, "transport"); fMQKeyMap[sndBufSizeKey] = make_tuple(p.first, index, "sndBufSize"); fMQKeyMap[rcvBufSizeKey] = make_tuple(p.first, index, "rcvBufSize"); fMQKeyMap[sndKernelSizeKey] = make_tuple(p.first, index, "sndKernelSize"); @@ -251,6 +253,7 @@ void FairMQProgOptions::UpdateMQValues() UpdateVarMap(typeKey, channel.GetType()); UpdateVarMap(methodKey, channel.GetMethod()); UpdateVarMap(addressKey, channel.GetAddress()); + UpdateVarMap(transportKey, channel.GetTransport()); //UpdateVarMap(sndBufSizeKey, to_string(channel.GetSndBufSize()));// string API @@ -384,6 +387,12 @@ int FairMQProgOptions::UpdateChannelMap(const string& channelName, int index, co fFairMQMap.at(channelName).at(index).UpdateAddress(val); return 0; } + + if (member == "transport") + { + fFairMQMap.at(channelName).at(index).UpdateTransport(val); + return 0; + } else { //if we get there it means something is wrong diff --git a/fairmq/plugins/control/FairMQDDSControlPlugin.cxx b/fairmq/plugins/control/FairMQDDSControlPlugin.cxx index 9179eb50..ffbc7019 100644 --- a/fairmq/plugins/control/FairMQDDSControlPlugin.cxx +++ b/fairmq/plugins/control/FairMQDDSControlPlugin.cxx @@ -35,8 +35,8 @@ class FairMQControlPluginDDS return fInstance; } - static void ResetInstance() - { + static void ResetInstance() + { try { delete fInstance; @@ -47,7 +47,7 @@ class FairMQControlPluginDDS LOG(ERROR) << "Error: " << e.what() << endl; return; } - } + } void Init(FairMQDevice& device) { diff --git a/fairmq/run/startMQBenchmark.sh.in b/fairmq/run/startMQBenchmark.sh.in index af1a6729..f765007e 100755 --- a/fairmq/run/startMQBenchmark.sh.in +++ b/fairmq/run/startMQBenchmark.sh.in @@ -4,6 +4,10 @@ numMsgs="0" msgSize="1000000" transport="zeromq" sameMsg="true" +affinity="false" +affinitySamp="" +affinitySink="" + if [[ $1 =~ ^[0-9]+$ ]]; then msgSize=$1 @@ -21,17 +25,40 @@ if [[ $4 =~ ^[a-z]+$ ]]; then sameMsg=$4 fi -echo "Starting benchmark with message size of $msgSize bytes ($numMsgs messages) and $transport transport." -echo "Using $transport transport." +if [[ $5 =~ ^[a-z]+$ ]]; then + affinity=$5 +fi + + +echo "Starting benchmark with following settings:" + +echo "" +echo "message size: $msgSize bytes" if [ $numMsgs = 0 ]; then - echo "Unlimited number of messages." + echo "number of messages: unlimited" else - echo "Number of messages: $numMsgs." + echo "number of messages: $numMsgs" +fi + +echo "transport: $transport" + +if [ $sameMsg = "true" ]; then + echo "resend same message: yes, using Copy() method to resend the same message" +else + echo "resend same message: no, allocating each message separately" +fi + +if [ $affinity = "true" ]; then + affinitySamp="taskset -c 0" + affinitySink="taskset -c 1" + echo "affinity: assigning sampler to core 0, sink to core 1" +else + echo "" fi echo "" -echo "Usage: startBenchmark [message size=1000000] [number of messages=0] [transport=zeromq/nanomsg/shmem] [resend same message=true]" +echo "Usage: startBenchmark [message size=1000000] [number of messages=0] [transport=zeromq/nanomsg/shmem] [resend same message=true] [affinity=false]" SAMPLER="bsampler" SAMPLER+=" --id bsampler1" @@ -43,7 +70,7 @@ SAMPLER+=" --same-msg $sameMsg" # SAMPLER+=" --msg-rate 1000" SAMPLER+=" --num-msgs $numMsgs" SAMPLER+=" --mq-config @CMAKE_BINARY_DIR@/bin/config/benchmark.json" -xterm -geometry 80x23+0+0 -hold -e @CMAKE_BINARY_DIR@/bin/$SAMPLER & +xterm -geometry 90x23+0+0 -hold -e $affinitySamp @CMAKE_BINARY_DIR@/bin/$SAMPLER & SINK="sink" SINK+=" --id sink1" @@ -52,4 +79,4 @@ SINK+=" --id sink1" SINK+=" --transport $transport" SINK+=" --num-msgs $numMsgs" SINK+=" --mq-config @CMAKE_BINARY_DIR@/bin/config/benchmark.json" -xterm -geometry 80x23+500+0 -hold -e @CMAKE_BINARY_DIR@/bin/$SINK & +xterm -geometry 90x23+550+0 -hold -e $affinitySink @CMAKE_BINARY_DIR@/bin/$SINK & diff --git a/fairmq/shmem/FairMQContextSHM.cxx b/fairmq/shmem/FairMQContextSHM.cxx index 07e3f2b1..28e21dcb 100644 --- a/fairmq/shmem/FairMQContextSHM.cxx +++ b/fairmq/shmem/FairMQContextSHM.cxx @@ -7,6 +7,8 @@ ********************************************************************************/ #include +#include + #include #include "FairMQLogger.h" @@ -46,11 +48,11 @@ FairMQContextSHM::~FairMQContextSHM() if (boost::interprocess::shared_memory_object::remove("FairMQSharedMemory")) { - LOG(INFO) << "Successfully removed shared memory after the device has stopped."; + printf("Successfully removed shared memory after the device has stopped.\n"); } else { - LOG(INFO) << "Did not remove shared memory after the device stopped. Still in use?"; + printf("Did not remove shared memory after the device stopped. Already removed?\n"); } } diff --git a/fairmq/shmem/FairMQMessageSHM.cxx b/fairmq/shmem/FairMQMessageSHM.cxx index 33cb5b54..776a08d9 100644 --- a/fairmq/shmem/FairMQMessageSHM.cxx +++ b/fairmq/shmem/FairMQMessageSHM.cxx @@ -14,45 +14,60 @@ using namespace std; using namespace FairMQ::shmem; -uint64_t FairMQMessageSHM::fMessageID = 0; -string FairMQMessageSHM::fDeviceID = string(); +static FairMQ::Transport gTransportType = FairMQ::Transport::SHM; + +// uint64_t FairMQMessageSHM::fMessageID = 0; +// string FairMQMessageSHM::fDeviceID = string(); atomic FairMQMessageSHM::fInterrupted(false); FairMQMessageSHM::FairMQMessageSHM() : fMessage() - , fOwner(nullptr) - , fReceiving(false) + // , fOwner(nullptr) + // , fReceiving(false) , fQueued(false) + , fMetaCreated(false) + , fHandle() + , fChunkSize(0) + , fLocalPtr(nullptr) { if (zmq_msg_init(&fMessage) != 0) { LOG(ERROR) << "failed initializing message, reason: " << zmq_strerror(errno); } + fMetaCreated = true; } -void FairMQMessageSHM::StringDeleter(void* /*data*/, void* str) -{ - delete static_cast(str); -} +// void FairMQMessageSHM::StringDeleter(void* /*data*/, void* str) +// { +// delete static_cast(str); +// } FairMQMessageSHM::FairMQMessageSHM(const size_t size) : fMessage() - , fOwner(nullptr) - , fReceiving(false) + // , fOwner(nullptr) + // , fReceiving(false) , fQueued(false) + , fMetaCreated(false) + , fHandle() + , fChunkSize(0) + , fLocalPtr(nullptr) { InitializeChunk(size); } FairMQMessageSHM::FairMQMessageSHM(void* data, const size_t size, fairmq_free_fn* ffn, void* hint) : fMessage() - , fOwner(nullptr) - , fReceiving(false) + // , fOwner(nullptr) + // , fReceiving(false) , fQueued(false) + , fMetaCreated(false) + , fHandle() + , fChunkSize(0) + , fLocalPtr(nullptr) { if (InitializeChunk(size)) { - memcpy(fOwner->fPtr->GetData(), data, size); + memcpy(fLocalPtr, data, size); if (ffn) { ffn(data, hint); @@ -66,66 +81,76 @@ FairMQMessageSHM::FairMQMessageSHM(void* data, const size_t size, fairmq_free_fn bool FairMQMessageSHM::InitializeChunk(const size_t size) { - string chunkID = fDeviceID + "c" + to_string(fMessageID); - string* ownerID = new string(fDeviceID + "o" + to_string(fMessageID)); + // string chunkID = fDeviceID + "c" + to_string(fMessageID); + // string* ownerID = new string(fDeviceID + "o" + to_string(fMessageID)); - bool success = false; - - while (!success) + while (!fHandle) { try { - fOwner = Manager::Instance().Segment()->construct(ownerID->c_str())( - make_managed_shared_ptr(Manager::Instance().Segment()->construct(chunkID.c_str())(size), - *(Manager::Instance().Segment()))); - success = true; + fLocalPtr = Manager::Instance().Segment()->allocate(size); + + // fOwner = Manager::Instance().Segment()->construct(ownerID->c_str())( + // make_managed_shared_ptr(Manager::Instance().Segment()->construct(chunkID.c_str())(size), + // *(Manager::Instance().Segment()))); } catch (bipc::bad_alloc& ba) { - LOG(WARN) << "Shared memory full..."; + // LOG(WARN) << "Shared memory full..."; this_thread::sleep_for(chrono::milliseconds(50)); if (fInterrupted) { - break; + return false; } else { continue; } } + fHandle = Manager::Instance().Segment()->get_handle_from_address(fLocalPtr); } - if (success) + fChunkSize = size; + + if (zmq_msg_init_size(&fMessage, sizeof(MetaHeader)) != 0) { - if (zmq_msg_init_data(&fMessage, const_cast(ownerID->c_str()), ownerID->length(), StringDeleter, ownerID) != 0) - { - LOG(ERROR) << "failed initializing meta message, reason: " << zmq_strerror(errno); - } - - ++fMessageID; + LOG(ERROR) << "failed initializing meta message, reason: " << zmq_strerror(errno); + return false; } + MetaHeader* metaPtr = new(zmq_msg_data(&fMessage)) MetaHeader(); + metaPtr->fSize = size; + metaPtr->fHandle = fHandle; - return success; + // if (zmq_msg_init_data(&fMessage, const_cast(ownerID->c_str()), ownerID->length(), StringDeleter, ownerID) != 0) + // { + // LOG(ERROR) << "failed initializing meta message, reason: " << zmq_strerror(errno); + // } + fMetaCreated = true; + + // ++fMessageID; + + return true; } void FairMQMessageSHM::Rebuild() { CloseMessage(); - fReceiving = false; + // fReceiving = false; fQueued = false; if (zmq_msg_init(&fMessage) != 0) { LOG(ERROR) << "failed initializing message, reason: " << zmq_strerror(errno); } + fMetaCreated = true; } void FairMQMessageSHM::Rebuild(const size_t size) { CloseMessage(); - fReceiving = false; + // fReceiving = false; fQueued = false; InitializeChunk(size); @@ -135,12 +160,12 @@ void FairMQMessageSHM::Rebuild(void* data, const size_t size, fairmq_free_fn* ff { CloseMessage(); - fReceiving = false; + // fReceiving = false; fQueued = false; if (InitializeChunk(size)) { - memcpy(fOwner->fPtr->GetData(), data, size); + memcpy(fLocalPtr, data, size); if (ffn) { ffn(data, hint); @@ -159,27 +184,42 @@ void* FairMQMessageSHM::GetMessage() void* FairMQMessageSHM::GetData() { - if (fOwner) + if (fLocalPtr) { - return fOwner->fPtr->GetData(); + return fLocalPtr; + } + else if (fHandle) + { + return Manager::Instance().Segment()->get_address_from_handle(fHandle); } else { - LOG(ERROR) << "Trying to get data of an empty shared memory message"; - exit(EXIT_FAILURE); + // LOG(ERROR) << "Trying to get data of an empty shared memory message"; + return nullptr; } + + // if (fOwner) + // { + // return fOwner->fPtr->GetData(); + // } + // else + // { + // LOG(ERROR) << "Trying to get data of an empty shared memory message"; + // exit(EXIT_FAILURE); + // } } size_t FairMQMessageSHM::GetSize() { - if (fOwner) - { - return fOwner->fPtr->GetSize(); - } - else - { - return 0; - } + return fChunkSize; + // if (fOwner) + // { + // return fOwner->fPtr->GetSize(); + // } + // else + // { + // return 0; + // } } void FairMQMessageSHM::SetMessage(void*, const size_t) @@ -187,21 +227,26 @@ void FairMQMessageSHM::SetMessage(void*, const size_t) // dummy method to comply with the interface. functionality not allowed in zeromq. } -void FairMQMessageSHM::SetDeviceId(const string& deviceId) +void FairMQMessageSHM::SetDeviceId(const string& /*deviceId*/) { - fDeviceID = deviceId; + // fDeviceID = deviceId; +} + +FairMQ::Transport FairMQMessageSHM::GetType() const +{ + return gTransportType; } void FairMQMessageSHM::Copy(const unique_ptr& msg) { - if (!fOwner) + if (!fHandle) { - FairMQ::shmem::ShPtrOwner* otherOwner = static_cast(msg.get())->fOwner; - if (otherOwner) + bipc::managed_shared_memory::handle_t otherHandle = static_cast(msg.get())->fHandle; + if (otherHandle) { - if (InitializeChunk(otherOwner->fPtr->GetSize())) + if (InitializeChunk(msg->GetSize())) { - memcpy(fOwner->fPtr->GetData(), otherOwner->fPtr->GetData(), otherOwner->fPtr->GetSize()); + memcpy(GetData(), msg->GetData(), msg->GetSize()); } } else @@ -266,31 +311,35 @@ void FairMQMessageSHM::Copy(const unique_ptr& msg) void FairMQMessageSHM::CloseMessage() { - if (fReceiving) - { - if (fOwner) + // if (fReceiving) + // { + // if (fOwner) + // { + // Manager::Instance().Segment()->destroy_ptr(fOwner); + // fOwner = nullptr; + // } + // else + // { + // LOG(ERROR) << "No shared pointer owner when closing a received message"; + // } + // } + // else + // { + if (fHandle && !fQueued) { - Manager::Instance().Segment()->destroy_ptr(fOwner); - fOwner = nullptr; + // LOG(WARN) << "Destroying unsent message"; + // Manager::Instance().Segment()->destroy_ptr(fHandle); + Manager::Instance().Segment()->deallocate(Manager::Instance().Segment()->get_address_from_handle(fHandle)); + fHandle = 0; } - else - { - LOG(ERROR) << "No shared pointer owner when closing a received message"; - } - } - else - { - if (fOwner && !fQueued) - { - LOG(WARN) << "Destroying unsent message"; - Manager::Instance().Segment()->destroy_ptr(fOwner); - fOwner = nullptr; - } - } + // } - if (zmq_msg_close(&fMessage) != 0) + if (fMetaCreated) { - LOG(ERROR) << "failed closing message, reason: " << zmq_strerror(errno); + if (zmq_msg_close(&fMessage) != 0) + { + LOG(ERROR) << "failed closing message, reason: " << zmq_strerror(errno); + } } } diff --git a/fairmq/shmem/FairMQMessageSHM.h b/fairmq/shmem/FairMQMessageSHM.h index fb6a0ffa..dfc55e14 100644 --- a/fairmq/shmem/FairMQMessageSHM.h +++ b/fairmq/shmem/FairMQMessageSHM.h @@ -42,22 +42,28 @@ class FairMQMessageSHM : public FairMQMessage virtual void SetDeviceId(const std::string& deviceId); + virtual FairMQ::Transport GetType() const; + virtual void Copy(const std::unique_ptr& msg); void CloseMessage(); virtual ~FairMQMessageSHM(); - static void StringDeleter(void* data, void* str); + // static void StringDeleter(void* data, void* str); private: zmq_msg_t fMessage; - FairMQ::shmem::ShPtrOwner* fOwner; - static uint64_t fMessageID; - static std::string fDeviceID; - bool fReceiving; + // FairMQ::shmem::ShPtrOwner* fOwner; + // static uint64_t fMessageID; + // static std::string fDeviceID; + // bool fReceiving; bool fQueued; + bool fMetaCreated; static std::atomic fInterrupted; + bipc::managed_shared_memory::handle_t fHandle; + size_t fChunkSize; + void* fLocalPtr; }; #endif /* FAIRMQMESSAGESHM_H_ */ diff --git a/fairmq/shmem/FairMQPollerSHM.cxx b/fairmq/shmem/FairMQPollerSHM.cxx index 926d9881..6d4eedf3 100644 --- a/fairmq/shmem/FairMQPollerSHM.cxx +++ b/fairmq/shmem/FairMQPollerSHM.cxx @@ -51,7 +51,7 @@ FairMQPollerSHM::FairMQPollerSHM(const vector& channels) } else { - LOG(ERROR) << "invalid poller configuration, exiting."; + LOG(ERROR) << "shmem: invalid poller configuration, exiting."; exit(EXIT_FAILURE); } } @@ -105,7 +105,7 @@ FairMQPollerSHM::FairMQPollerSHM(const unordered_mapallocate(size); - fHandle = Manager::Instance().Segment()->get_handle_from_address(ptr); - } - - ~Chunk() - { - Manager::Instance().Segment()->deallocate(Manager::Instance().Segment()->get_address_from_handle(fHandle)); - } - - // bipc::managed_shared_memory::handle_t GetHandle() const - // { - // return fHandle; - // } - - void* GetData() const - { - return Manager::Instance().Segment()->get_address_from_handle(fHandle); - } - - size_t GetSize() const - { - return fSize; - } - - private: + uint64_t fSize; bipc::managed_shared_memory::handle_t fHandle; - size_t fSize; }; -typedef bipc::managed_shared_ptr::type ShPtrType; +// class Chunk +// { +// public: +// Chunk() +// : fHandle() +// , fSize(0) +// { +// } -struct ShPtrOwner -{ - ShPtrOwner(const ShPtrType& other) - : fPtr(other) - {} +// Chunk(const size_t size) +// : fHandle() +// , fSize(size) +// { +// void* ptr = Manager::Instance().Segment()->allocate(size); +// fHandle = Manager::Instance().Segment()->get_handle_from_address(ptr); +// } - ShPtrOwner(const ShPtrOwner& other) - : fPtr(other.fPtr) - {} +// ~Chunk() +// { +// Manager::Instance().Segment()->deallocate(Manager::Instance().Segment()->get_address_from_handle(fHandle)); +// } - ShPtrType fPtr; -}; +// bipc::managed_shared_memory::handle_t GetHandle() const +// { +// return fHandle; +// } + +// void* GetData() const +// { +// return Manager::Instance().Segment()->get_address_from_handle(fHandle); +// } + +// size_t GetSize() const +// { +// return fSize; +// } + +// private: +// bipc::managed_shared_memory::handle_t fHandle; +// size_t fSize; +// }; + +// typedef bipc::managed_shared_ptr::type ShPtrType; + +// struct ShPtrOwner +// { +// ShPtrOwner(const ShPtrType& other) +// : fPtr(other) +// {} + +// ShPtrOwner(const ShPtrOwner& other) +// : fPtr(other.fPtr) +// {} + +// ShPtrType fPtr; +// }; } // namespace shmem diff --git a/fairmq/shmem/FairMQSocketSHM.cxx b/fairmq/shmem/FairMQSocketSHM.cxx index 639c62c9..ad07f71d 100644 --- a/fairmq/shmem/FairMQSocketSHM.cxx +++ b/fairmq/shmem/FairMQSocketSHM.cxx @@ -19,6 +19,7 @@ using namespace FairMQ::shmem; // Context to hold the ZeroMQ sockets unique_ptr FairMQSocketSHM::fContext; // = unique_ptr(new FairMQContextSHM(1)); bool FairMQSocketSHM::fContextInitialized = false; +atomic FairMQSocketSHM::fInterrupted(false); FairMQSocketSHM::FairMQSocketSHM(const string& type, const string& name, const int numIoThreads, const string& id /*= ""*/) : FairMQSocket(ZMQ_SNDMORE, ZMQ_RCVMORE, ZMQ_DONTWAIT) @@ -57,22 +58,22 @@ FairMQSocketSHM::FairMQSocketSHM(const string& type, const string& name, const i // Tell socket to try and send/receive outstanding messages for milliseconds before terminating. // Default value for ZeroMQ is -1, which is to wait forever. - int linger = 500; + int linger = 1000; if (zmq_setsockopt(fSocket, ZMQ_LINGER, &linger, sizeof(linger)) != 0) { LOG(ERROR) << "Failed setting ZMQ_LINGER socket option, reason: " << zmq_strerror(errno); } - int kernelSndSize = 10000; - if (zmq_setsockopt(fSocket, ZMQ_SNDBUF, &kernelSndSize, sizeof(kernelSndSize)) != 0) + int sndTimeout = 700; + if (zmq_setsockopt(fSocket, ZMQ_SNDTIMEO, &sndTimeout, sizeof(sndTimeout)) != 0) { - LOG(ERROR) << "Failed setting ZMQ_SNDBUF socket option, reason: " << zmq_strerror(errno); + LOG(ERROR) << "Failed setting ZMQ_SNDTIMEO socket option, reason: " << zmq_strerror(errno); } - int kernelRcvSize = 10000; - if (zmq_setsockopt(fSocket, ZMQ_RCVBUF, &kernelRcvSize, sizeof(kernelRcvSize)) != 0) + int rcvTimeout = 700; + if (zmq_setsockopt(fSocket, ZMQ_RCVTIMEO, &rcvTimeout, sizeof(rcvTimeout)) != 0) { - LOG(ERROR) << "Failed setting ZMQ_RCVBUF socket option, reason: " << zmq_strerror(errno); + LOG(ERROR) << "Failed setting ZMQ_RCVTIMEO socket option, reason: " << zmq_strerror(errno); } if (type == "sub") @@ -119,105 +120,183 @@ void FairMQSocketSHM::Connect(const string& address) } } -int FairMQSocketSHM::Send(FairMQMessage* msg, const string& flag) +int FairMQSocketSHM::Send(FairMQMessagePtr& msg, const int flags) { - return Send(msg, GetConstant(flag)); -} - -int FairMQSocketSHM::Send(FairMQMessage* msg, const int flags) -{ - int nbytes = zmq_msg_send(static_cast(msg->GetMessage()), fSocket, flags); - if (nbytes >= 0) + int nbytes = -1; + while (true && !fInterrupted) { - static_cast(msg)->fReceiving = false; - static_cast(msg)->fQueued = true; - size_t size = msg->GetSize(); - - fBytesTx += size; - ++fMessagesTx; - - return size; - } - if (zmq_errno() == EAGAIN) - { - return -2; - } - if (zmq_errno() == ETERM) - { - LOG(INFO) << "terminating socket " << fId; - return -1; - } - LOG(ERROR) << "Failed sending on socket " << fId << ", reason: " << zmq_strerror(errno); - return nbytes; -} - -int64_t FairMQSocketSHM::Send(const vector& msgVec, const int flags) -{ - // Sending vector typicaly handles more then one part - if (msgVec.size() > 1) - { - int64_t totalSize = 0; - - for (unsigned int i = 0; i < msgVec.size() - 1; ++i) + nbytes = zmq_msg_send(static_cast(msg->GetMessage()), fSocket, flags); + if (nbytes == 0) { - int nbytes = zmq_msg_send(static_cast(msgVec[i]->GetMessage()), fSocket, ZMQ_SNDMORE|flags); - if (nbytes >= 0) - { - static_cast(msgVec[i].get())->fReceiving = false; - static_cast(msgVec[i].get())->fQueued = true; - size_t size = msgVec[i]->GetSize(); + return nbytes; + } + else if (nbytes > 0) + { + // static_cast(msg.get())->fReceiving = false; + static_cast(msg.get())->fQueued = true; - totalSize += size; - fBytesTx += size; + size_t size = msg->GetSize(); + fBytesTx += size; + ++fMessagesTx; + + return size; + } + else if (zmq_errno() == EAGAIN) + { + if (!fInterrupted && ((flags & ZMQ_DONTWAIT) == 0)) + { + continue; } else { - if (zmq_errno() == EAGAIN) - { - return -2; - } - if (zmq_errno() == ETERM) - { - LOG(INFO) << "terminating socket " << fId; - return -1; - } - LOG(ERROR) << "Failed sending on socket " << fId << ", reason: " << zmq_strerror(errno); - return nbytes; + return -2; } } - - int nbytes = zmq_msg_send(static_cast(msgVec.back()->GetMessage()), fSocket, flags); - if (nbytes >= 0) + else if (zmq_errno() == ETERM) { - static_cast(msgVec.back().get())->fReceiving = false; - static_cast(msgVec.back().get())->fQueued = true; - size_t size = msgVec.back()->GetSize(); - - totalSize += size; - fBytesTx += size; + LOG(INFO) << "terminating socket " << fId; + return -1; } else { - if (zmq_errno() == EAGAIN) - { - return -2; - } - if (zmq_errno() == ETERM) - { - LOG(INFO) << "terminating socket " << fId; - return -1; - } LOG(ERROR) << "Failed sending on socket " << fId << ", reason: " << zmq_strerror(errno); return nbytes; } + } - // store statistics on how many messages have been sent (handle all parts as a single message) - ++fMessagesTx; - return totalSize; - } // If there's only one part, send it as a regular message - else if (msgVec.size() == 1) + return -1; +} + +int FairMQSocketSHM::Receive(FairMQMessagePtr& msg, const int flags) +{ + int nbytes = -1; + zmq_msg_t* msgPtr = static_cast(msg->GetMessage()); + while (true) { - return Send(msgVec.back().get(), flags); + nbytes = zmq_msg_recv(msgPtr, fSocket, flags); + if (nbytes == 0) + { + ++fMessagesRx; + + return nbytes; + } + else if (nbytes > 0) + { + // string ownerID(static_cast(zmq_msg_data(msgPtr)), zmq_msg_size(msgPtr)); + // ShPtrOwner* owner = Manager::Instance().Segment()->find(ownerID.c_str()).first; + MetaHeader* hdr = static_cast(zmq_msg_data(msgPtr)); + size_t size = 0; + if (hdr->fHandle) + { + static_cast(msg.get())->fHandle = hdr->fHandle; + static_cast(msg.get())->fChunkSize = hdr->fSize; + // static_cast(msg.get())->fOwner = owner; + // static_cast(msg.get())->fReceiving = true; + size = msg->GetSize(); + + fBytesRx += size; + ++fMessagesRx; + + return size; + } + else + { + LOG(ERROR) << "Received meta data, but could not find corresponding chunk"; + return -1; + } + } + else if (zmq_errno() == EAGAIN) + { + if (!fInterrupted && ((flags & ZMQ_DONTWAIT) == 0)) + { + continue; + } + else + { + return -2; + } + } + else if (zmq_errno() == ETERM) + { + LOG(INFO) << "terminating socket " << fId; + return -1; + } + else + { + LOG(ERROR) << "Failed receiving on socket " << fId << ", reason: " << zmq_strerror(errno); + return nbytes; + } + } +} + +int64_t FairMQSocketSHM::Send(vector& msgVec, const int flags) +{ + const unsigned int vecSize = msgVec.size(); + + // Sending vector typicaly handles more then one part + if (vecSize > 1) + { + int64_t totalSize = 0; + int nbytes = -1; + bool repeat = false; + + while (true && !fInterrupted) + { + for (unsigned int i = 0; i < vecSize; ++i) + { + nbytes = zmq_msg_send(static_cast(msgVec[i]->GetMessage()), + fSocket, + (i < vecSize - 1) ? ZMQ_SNDMORE|flags : flags); + if (nbytes >= 0) + { + static_cast(msgVec[i].get())->fQueued = true; + // static_cast(msgVec[i].get())->fReceiving = false; + // static_cast(msgVec[i].get())->fQueued = true; + size_t size = msgVec[i]->GetSize(); + + totalSize += size; + } + else + { + // according to ZMQ docs, this can only occur for the first part + if (zmq_errno() == EAGAIN) + { + if (!fInterrupted && ((flags & ZMQ_DONTWAIT) == 0)) + { + repeat = true; + break; + } + else + { + return -2; + } + } + if (zmq_errno() == ETERM) + { + LOG(INFO) << "terminating socket " << fId; + return -1; + } + LOG(ERROR) << "Failed sending on socket " << fId << ", reason: " << zmq_strerror(errno); + return nbytes; + } + } + + if (repeat) + { + continue; + } + + // store statistics on how many messages have been sent (handle all parts as a single message) + ++fMessagesTx; + fBytesTx += totalSize; + return totalSize; + } + + return -1; + } // If there's only one part, send it as a regular message + else if (vecSize == 1) + { + return Send(msgVec.back(), flags); } else // if the vector is empty, something might be wrong { @@ -226,112 +305,91 @@ int64_t FairMQSocketSHM::Send(const vector& msgVec, const int } } -int FairMQSocketSHM::Receive(FairMQMessage* msg, const string& flag) -{ - return Receive(msg, GetConstant(flag)); -} - -int FairMQSocketSHM::Receive(FairMQMessage* msg, const int flags) -{ - zmq_msg_t* msgPtr = static_cast(msg->GetMessage()); - int nbytes = zmq_msg_recv(msgPtr, fSocket, flags); - if (nbytes == 0) - { - ++fMessagesRx; - return nbytes; - } - else if (nbytes > 0) - { - string ownerID(static_cast(zmq_msg_data(msgPtr)), zmq_msg_size(msgPtr)); - ShPtrOwner* owner = Manager::Instance().Segment()->find(ownerID.c_str()).first; - size_t size = 0; - if (owner) - { - static_cast(msg)->fOwner = owner; - static_cast(msg)->fReceiving = true; - size = msg->GetSize(); - - fBytesRx += size; - ++fMessagesRx; - - return size; - } - else - { - LOG(ERROR) << "Received meta data, but could not find corresponding chunk"; - return -1; - } - } - if (zmq_errno() == EAGAIN) - { - return -2; - } - if (zmq_errno() == ETERM) - { - LOG(INFO) << "terminating socket " << fId; - return -1; - } - LOG(ERROR) << "Failed receiving on socket " << fId << ", reason: " << zmq_strerror(errno); - return nbytes; -} - int64_t FairMQSocketSHM::Receive(vector& msgVec, const int flags) { - // Warn if the vector is filled before Receive() and empty it. - if (msgVec.size() > 0) - { - LOG(WARN) << "Message vector contains elements before Receive(), they will be deleted!"; - msgVec.clear(); - } - int64_t totalSize = 0; int64_t more = 0; + bool repeat = false; - do + while (true) { - FairMQMessagePtr part(new FairMQMessageSHM()); - zmq_msg_t* msgPtr = static_cast(part->GetMessage()); - - int nbytes = zmq_msg_recv(msgPtr, fSocket, flags); - if (nbytes == 0) + // Warn if the vector is filled before Receive() and empty it. + if (msgVec.size() > 0) { - msgVec.push_back(move(part)); + LOG(WARN) << "Message vector contains elements before Receive(), they will be deleted!"; + msgVec.clear(); } - else if (nbytes > 0) + + totalSize = 0; + more = 0; + repeat = false; + + do { - string ownerID(static_cast(zmq_msg_data(msgPtr)), zmq_msg_size(msgPtr)); - ShPtrOwner* owner = Manager::Instance().Segment()->find(ownerID.c_str()).first; - size_t size = 0; - if (owner) + FairMQMessagePtr part(new FairMQMessageSHM()); + zmq_msg_t* msgPtr = static_cast(part->GetMessage()); + + int nbytes = zmq_msg_recv(msgPtr, fSocket, flags); + if (nbytes == 0) { - static_cast(part.get())->fOwner = owner; - static_cast(part.get())->fReceiving = true; - size = part->GetSize(); - msgVec.push_back(move(part)); + } + else if (nbytes > 0) + { + // string ownerID(static_cast(zmq_msg_data(msgPtr)), zmq_msg_size(msgPtr)); + // ShPtrOwner* owner = Manager::Instance().Segment()->find(ownerID.c_str()).first; + MetaHeader* hdr = static_cast(zmq_msg_data(msgPtr)); + size_t size = 0; + if (hdr->fHandle) + { + static_cast(part.get())->fHandle = hdr->fHandle; + static_cast(part.get())->fChunkSize = hdr->fSize; + // static_cast(msg.get())->fOwner = owner; + // static_cast(msg.get())->fReceiving = true; + size = part->GetSize(); - fBytesRx += size; - totalSize += size; + msgVec.push_back(move(part)); + + totalSize += size; + } + else + { + LOG(ERROR) << "Received meta data, but could not find corresponding chunk"; + return -1; + } + } + else if (zmq_errno() == EAGAIN) + { + if (!fInterrupted && ((flags & ZMQ_DONTWAIT) == 0)) + { + repeat = true; + break; + } + else + { + return -2; + } } else { - LOG(ERROR) << "Received meta data, but could not find corresponding chunk"; - return -1; + return nbytes; } + + size_t more_size = sizeof(more); + zmq_getsockopt(fSocket, ZMQ_RCVMORE, &more, &more_size); } - else + while (more); + + if (repeat) { - return nbytes; + continue; } - size_t more_size = sizeof(more); - zmq_getsockopt(fSocket, ZMQ_RCVMORE, &more, &more_size); + // store statistics on how many messages have been received (handle all parts as a single message) + ++fMessagesRx; + fBytesRx += totalSize; + return totalSize; } - while (more); - - // store statistics on how many messages have been received (handle all parts as a single message) - ++fMessagesRx; - return totalSize; } void FairMQSocketSHM::Close() @@ -362,11 +420,13 @@ void FairMQSocketSHM::Terminate() void FairMQSocketSHM::Interrupt() { FairMQMessageSHM::fInterrupted = true; + fInterrupted = true; } void FairMQSocketSHM::Resume() { FairMQMessageSHM::fInterrupted = false; + fInterrupted = false; } void* FairMQSocketSHM::GetSocket() const diff --git a/fairmq/shmem/FairMQSocketSHM.h b/fairmq/shmem/FairMQSocketSHM.h index 4be7aeb8..971b5dd6 100644 --- a/fairmq/shmem/FairMQSocketSHM.h +++ b/fairmq/shmem/FairMQSocketSHM.h @@ -13,6 +13,7 @@ #include // unique_ptr #include "FairMQSocket.h" +#include "FairMQMessage.h" #include "FairMQContextSHM.h" #include "FairMQShmManager.h" @@ -28,12 +29,10 @@ class FairMQSocketSHM : public FairMQSocket virtual bool Bind(const std::string& address); virtual void Connect(const std::string& address); - virtual int Send(FairMQMessage* msg, const std::string& flag = ""); - virtual int Send(FairMQMessage* msg, const int flags = 0); - virtual int64_t Send(const std::vector>& msgVec, const int flags = 0); + virtual int Send(FairMQMessagePtr& msg, const int flags = 0); + virtual int Receive(FairMQMessagePtr& msg, const int flags = 0); - virtual int Receive(FairMQMessage* msg, const std::string& flag = ""); - virtual int Receive(FairMQMessage* msg, const int flags = 0); + virtual int64_t Send(std::vector>& msgVec, const int flags = 0); virtual int64_t Receive(std::vector>& msgVec, const int flags = 0); virtual void* GetSocket() const; @@ -71,6 +70,7 @@ class FairMQSocketSHM : public FairMQSocket static std::unique_ptr fContext; static bool fContextInitialized; + static std::atomic fInterrupted; }; #endif /* FAIRMQSOCKETSHM_H_ */ diff --git a/fairmq/shmem/FairMQTransportFactorySHM.cxx b/fairmq/shmem/FairMQTransportFactorySHM.cxx index 3e2a194f..8537cf76 100644 --- a/fairmq/shmem/FairMQTransportFactorySHM.cxx +++ b/fairmq/shmem/FairMQTransportFactorySHM.cxx @@ -12,11 +12,13 @@ using namespace std; +static FairMQ::Transport gTransportType = FairMQ::Transport::SHM; + FairMQTransportFactorySHM::FairMQTransportFactorySHM() { int major, minor, patch; zmq_version(&major, &minor, &patch); - LOG(DEBUG) << "Using ZeroMQ (" << major << "." << minor << "." << patch << ") & " + LOG(DEBUG) << "Transport: Using ZeroMQ (" << major << "." << minor << "." << patch << ") & " << "boost::interprocess (" << (BOOST_VERSION / 100000) << "." << (BOOST_VERSION / 100 % 1000) << "." << (BOOST_VERSION % 100) << ")"; } @@ -54,3 +56,9 @@ FairMQPollerPtr FairMQTransportFactorySHM::CreatePoller(const FairMQSocket& cmdS { return unique_ptr(new FairMQPollerSHM(cmdSocket, dataSocket)); } + +FairMQ::Transport FairMQTransportFactorySHM::GetType() const +{ + return gTransportType; +} + diff --git a/fairmq/shmem/FairMQTransportFactorySHM.h b/fairmq/shmem/FairMQTransportFactorySHM.h index e0759ce3..c4ed9dfe 100644 --- a/fairmq/shmem/FairMQTransportFactorySHM.h +++ b/fairmq/shmem/FairMQTransportFactorySHM.h @@ -9,6 +9,7 @@ #define FAIRMQTRANSPORTFACTORYSHM_H_ #include +#include #include "FairMQTransportFactory.h" #include "FairMQContextSHM.h" @@ -31,6 +32,8 @@ class FairMQTransportFactorySHM : public FairMQTransportFactory virtual FairMQPollerPtr CreatePoller(const std::unordered_map>& channelsMap, const std::vector& channelList) const; virtual FairMQPollerPtr CreatePoller(const FairMQSocket& cmdSocket, const FairMQSocket& dataSocket) const; + virtual FairMQ::Transport GetType() const; + virtual ~FairMQTransportFactorySHM() {}; }; diff --git a/fairmq/zeromq/FairMQMessageZMQ.cxx b/fairmq/zeromq/FairMQMessageZMQ.cxx index cebf713e..af160738 100644 --- a/fairmq/zeromq/FairMQMessageZMQ.cxx +++ b/fairmq/zeromq/FairMQMessageZMQ.cxx @@ -20,6 +20,8 @@ using namespace std; +static FairMQ::Transport gTransportType = FairMQ::Transport::ZMQ; + string FairMQMessageZMQ::fDeviceID = string(); FairMQMessageZMQ::FairMQMessageZMQ() @@ -101,6 +103,11 @@ void FairMQMessageZMQ::SetDeviceId(const string& deviceId) fDeviceID = deviceId; } +FairMQ::Transport FairMQMessageZMQ::GetType() const +{ + return gTransportType; +} + void FairMQMessageZMQ::Copy(const unique_ptr& msg) { // Shares the message buffer between msg and this fMessage. diff --git a/fairmq/zeromq/FairMQMessageZMQ.h b/fairmq/zeromq/FairMQMessageZMQ.h index 0674aaa8..75ecf3eb 100644 --- a/fairmq/zeromq/FairMQMessageZMQ.h +++ b/fairmq/zeromq/FairMQMessageZMQ.h @@ -41,6 +41,8 @@ class FairMQMessageZMQ : public FairMQMessage virtual void SetDeviceId(const std::string& deviceId); + virtual FairMQ::Transport GetType() const; + virtual void Copy(const std::unique_ptr& msg); void CloseMessage(); diff --git a/fairmq/zeromq/FairMQPollerZMQ.cxx b/fairmq/zeromq/FairMQPollerZMQ.cxx index de50dcbf..ce63723c 100644 --- a/fairmq/zeromq/FairMQPollerZMQ.cxx +++ b/fairmq/zeromq/FairMQPollerZMQ.cxx @@ -51,7 +51,7 @@ FairMQPollerZMQ::FairMQPollerZMQ(const vector& channels) } else { - LOG(ERROR) << "invalid poller configuration, exiting."; + LOG(ERROR) << "zeromq: invalid poller configuration, exiting."; exit(EXIT_FAILURE); } } @@ -105,7 +105,7 @@ FairMQPollerZMQ::FairMQPollerZMQ(const unordered_map FairMQSocketZMQ::fContext = std::unique_ptr(new FairMQContextZMQ(1)); +unique_ptr FairMQSocketZMQ::fContext = unique_ptr(new FairMQContextZMQ(1)); +atomic FairMQSocketZMQ::fInterrupted(false); FairMQSocketZMQ::FairMQSocketZMQ(const string& type, const string& name, const int numIoThreads, const string& id /*= ""*/) : FairMQSocket(ZMQ_SNDMORE, ZMQ_RCVMORE, ZMQ_DONTWAIT) @@ -56,12 +57,24 @@ FairMQSocketZMQ::FairMQSocketZMQ(const string& type, const string& name, const i // Tell socket to try and send/receive outstanding messages for milliseconds before terminating. // Default value for ZeroMQ is -1, which is to wait forever. - int linger = 500; + int linger = 1000; if (zmq_setsockopt(fSocket, ZMQ_LINGER, &linger, sizeof(linger)) != 0) { LOG(ERROR) << "Failed setting ZMQ_LINGER socket option, reason: " << zmq_strerror(errno); } + int sndTimeout = 700; + if (zmq_setsockopt(fSocket, ZMQ_SNDTIMEO, &sndTimeout, sizeof(sndTimeout)) != 0) + { + LOG(ERROR) << "Failed setting ZMQ_SNDTIMEO socket option, reason: " << zmq_strerror(errno); + } + + int rcvTimeout = 700; + if (zmq_setsockopt(fSocket, ZMQ_RCVTIMEO, &rcvTimeout, sizeof(rcvTimeout)) != 0) + { + LOG(ERROR) << "Failed setting ZMQ_RCVTIMEO socket option, reason: " << zmq_strerror(errno); + } + if (type == "sub") { if (zmq_setsockopt(fSocket, ZMQ_SUBSCRIBE, NULL, 0) != 0) @@ -106,91 +119,145 @@ void FairMQSocketZMQ::Connect(const string& address) } } -int FairMQSocketZMQ::Send(FairMQMessage* msg, const string& flag) +int FairMQSocketZMQ::Send(FairMQMessagePtr& msg, const int flags) { - return Send(msg, GetConstant(flag)); -} + int nbytes = -1; -int FairMQSocketZMQ::Send(FairMQMessage* msg, const int flags) -{ - int nbytes = zmq_msg_send(static_cast(msg->GetMessage()), fSocket, flags); - if (nbytes >= 0) + while (true) { - fBytesTx += nbytes; - ++fMessagesTx; - return nbytes; - } - if (zmq_errno() == EAGAIN) - { - return -2; - } - if (zmq_errno() == ETERM) - { - LOG(INFO) << "terminating socket " << fId; - return -1; - } - LOG(ERROR) << "Failed sending on socket " << fId << ", reason: " << zmq_strerror(errno); - return nbytes; -} - -int64_t FairMQSocketZMQ::Send(const vector>& msgVec, const int flags) -{ - // Sending vector typicaly handles more then one part - if (msgVec.size() > 1) - { - int64_t totalSize = 0; - - for (unsigned int i = 0; i < msgVec.size() - 1; ++i) + nbytes = zmq_msg_send(static_cast(msg->GetMessage()), fSocket, flags); + if (nbytes >= 0) { - int nbytes = zmq_msg_send(static_cast(msgVec[i]->GetMessage()), fSocket, ZMQ_SNDMORE|flags); - if (nbytes >= 0) + fBytesTx += nbytes; + ++fMessagesTx; + + return nbytes; + } + else if (zmq_errno() == EAGAIN) + { + if (!fInterrupted && ((flags & ZMQ_DONTWAIT) == 0)) { - totalSize += nbytes; + continue; } else { - if (zmq_errno() == EAGAIN) - { - return -2; - } - if (zmq_errno() == ETERM) - { - LOG(INFO) << "terminating socket " << fId; - return -1; - } - LOG(ERROR) << "Failed sending on socket " << fId << ", reason: " << zmq_strerror(errno); - return nbytes; + return -2; } } - - int nbytes = zmq_msg_send(static_cast(msgVec.back()->GetMessage()), fSocket, flags); - if (nbytes >= 0) + else if (zmq_errno() == ETERM) { - totalSize += nbytes; + LOG(INFO) << "terminating socket " << fId; + return -1; } else { - if (zmq_errno() == EAGAIN) - { - return -2; - } - if (zmq_errno() == ETERM) - { - LOG(INFO) << "terminating socket " << fId; - return -1; - } LOG(ERROR) << "Failed sending on socket " << fId << ", reason: " << zmq_strerror(errno); return nbytes; } + } +} - // store statistics on how many messages have been sent (handle all parts as a single message) - ++fMessagesTx; - fBytesTx += totalSize; - return totalSize; - } // If there's only one part, send it as a regular message - else if (msgVec.size() == 1) +int FairMQSocketZMQ::Receive(FairMQMessagePtr& msg, const int flags) +{ + int nbytes = -1; + while (true) { - return Send(msgVec.back().get(), flags); + nbytes = zmq_msg_recv(static_cast(msg->GetMessage()), fSocket, flags); + if (nbytes >= 0) + { + fBytesRx += nbytes; + ++fMessagesRx; + return nbytes; + } + else if (zmq_errno() == EAGAIN) + { + if (!fInterrupted && ((flags & ZMQ_DONTWAIT) == 0)) + { + continue; + } + else + { + return -2; + } + } + else if (zmq_errno() == ETERM) + { + LOG(INFO) << "terminating socket " << fId; + return -1; + } + else + { + LOG(ERROR) << "Failed receiving on socket " << fId << ", reason: " << zmq_strerror(errno); + return nbytes; + } + } +} + +int64_t FairMQSocketZMQ::Send(vector>& msgVec, const int flags) +{ + const unsigned int vecSize = msgVec.size(); + + // Sending vector typicaly handles more then one part + if (vecSize > 1) + { + int64_t totalSize = 0; + int nbytes = -1; + bool repeat = false; + + while (true) + { + totalSize = 0; + nbytes = -1; + repeat = false; + + for (unsigned int i = 0; i < vecSize; ++i) + { + nbytes = zmq_msg_send(static_cast(msgVec[i]->GetMessage()), + fSocket, + (i < vecSize - 1) ? ZMQ_SNDMORE|flags : flags); + if (nbytes >= 0) + { + totalSize += nbytes; + } + else + { + // according to ZMQ docs, this can only occur for the first part + if (zmq_errno() == EAGAIN) + { + if (!fInterrupted && ((flags & ZMQ_DONTWAIT) == 0)) + { + repeat = true; + break; + } + else + { + return -2; + } + } + if (zmq_errno() == ETERM) + { + LOG(INFO) << "terminating socket " << fId; + return -1; + } + LOG(ERROR) << "Failed sending on socket " << fId << ", reason: " << zmq_strerror(errno); + return nbytes; + } + } + + if (repeat) + { + continue; + } + + // store statistics on how many messages have been sent (handle all parts as a single message) + ++fMessagesTx; + fBytesTx += totalSize; + return totalSize; + } + } // If there's only one part, send it as a regular message + else if (vecSize == 1) + { + return Send(msgVec.back(), flags); } else // if the vector is empty, something might be wrong { @@ -199,69 +266,67 @@ int64_t FairMQSocketZMQ::Send(const vector>& msgVec, c } } -int FairMQSocketZMQ::Receive(FairMQMessage* msg, const string& flag) -{ - return Receive(msg, GetConstant(flag)); -} - -int FairMQSocketZMQ::Receive(FairMQMessage* msg, const int flags) -{ - int nbytes = zmq_msg_recv(static_cast(msg->GetMessage()), fSocket, flags); - if (nbytes >= 0) - { - fBytesRx += nbytes; - ++fMessagesRx; - return nbytes; - } - if (zmq_errno() == EAGAIN) - { - return -2; - } - if (zmq_errno() == ETERM) - { - LOG(INFO) << "terminating socket " << fId; - return -1; - } - LOG(ERROR) << "Failed receiving on socket " << fId << ", reason: " << zmq_strerror(errno); - return nbytes; -} - int64_t FairMQSocketZMQ::Receive(vector>& msgVec, const int flags) { - // Warn if the vector is filled before Receive() and empty it. - if (msgVec.size() > 0) - { - LOG(WARN) << "Message vector contains elements before Receive(), they will be deleted!"; - msgVec.clear(); - } - int64_t totalSize = 0; int64_t more = 0; + bool repeat = false; - do + while (true) { - unique_ptr part(new FairMQMessageZMQ()); - - int nbytes = zmq_msg_recv(static_cast(part->GetMessage()), fSocket, flags); - if (nbytes >= 0) + // Warn if the vector is filled before Receive() and empty it. + if (msgVec.size() > 0) { - msgVec.push_back(move(part)); - totalSize += nbytes; - } - else - { - return nbytes; + LOG(WARN) << "Message vector contains elements before Receive(), they will be deleted!"; + msgVec.clear(); } - size_t more_size = sizeof(more); - zmq_getsockopt(fSocket, ZMQ_RCVMORE, &more, &more_size); + totalSize = 0; + more = 0; + repeat = false; + + do + { + unique_ptr part(new FairMQMessageZMQ()); + + int nbytes = zmq_msg_recv(static_cast(part->GetMessage()), fSocket, flags); + if (nbytes >= 0) + { + msgVec.push_back(move(part)); + totalSize += nbytes; + } + else if (zmq_errno() == EAGAIN) + { + if (!fInterrupted && ((flags & ZMQ_DONTWAIT) == 0)) + { + repeat = true; + break; + } + else + { + return -2; + } + } + else + { + return nbytes; + } + + size_t more_size = sizeof(more); + zmq_getsockopt(fSocket, ZMQ_RCVMORE, &more, &more_size); + } + while (more); + + if (repeat) + { + continue; + } + + // store statistics on how many messages have been received (handle all parts as a single message) + ++fMessagesRx; + fBytesRx += totalSize; + return totalSize; } - while (more); - - // store statistics on how many messages have been received (handle all parts as a single message) - ++fMessagesRx; - fBytesRx += totalSize; - return totalSize; } void FairMQSocketZMQ::Close() @@ -291,10 +356,12 @@ void FairMQSocketZMQ::Terminate() void FairMQSocketZMQ::Interrupt() { + fInterrupted = true; } void FairMQSocketZMQ::Resume() { + fInterrupted = false; } void* FairMQSocketZMQ::GetSocket() const @@ -506,10 +573,6 @@ int FairMQSocketZMQ::GetConstant(const string& constant) if (constant == "linger") return ZMQ_LINGER; - if (constant == "no-block") - return ZMQ_DONTWAIT; - if (constant == "snd-more no-block") - return ZMQ_DONTWAIT|ZMQ_SNDMORE; return -1; } diff --git a/fairmq/zeromq/FairMQSocketZMQ.h b/fairmq/zeromq/FairMQSocketZMQ.h index d57e1b0c..6e6bf226 100644 --- a/fairmq/zeromq/FairMQSocketZMQ.h +++ b/fairmq/zeromq/FairMQSocketZMQ.h @@ -20,6 +20,7 @@ #include // unique_ptr #include "FairMQSocket.h" +#include "FairMQMessage.h" #include "FairMQContextZMQ.h" class FairMQSocketZMQ : public FairMQSocket @@ -34,12 +35,10 @@ class FairMQSocketZMQ : public FairMQSocket virtual bool Bind(const std::string& address); virtual void Connect(const std::string& address); - virtual int Send(FairMQMessage* msg, const std::string& flag = ""); - virtual int Send(FairMQMessage* msg, const int flags = 0); - virtual int64_t Send(const std::vector>& msgVec, const int flags = 0); + virtual int Send(FairMQMessagePtr& msg, const int flags = 0); + virtual int Receive(FairMQMessagePtr& msg, const int flags = 0); - virtual int Receive(FairMQMessage* msg, const std::string& flag = ""); - virtual int Receive(FairMQMessage* msg, const int flags = 0); + virtual int64_t Send(std::vector>& msgVec, const int flags = 0); virtual int64_t Receive(std::vector>& msgVec, const int flags = 0); virtual void* GetSocket() const; @@ -76,6 +75,7 @@ class FairMQSocketZMQ : public FairMQSocket std::atomic fMessagesRx; static std::unique_ptr fContext; + static std::atomic fInterrupted; }; #endif /* FAIRMQSOCKETZMQ_H_ */ diff --git a/fairmq/zeromq/FairMQTransportFactoryZMQ.cxx b/fairmq/zeromq/FairMQTransportFactoryZMQ.cxx index 223ca28d..a423bf78 100644 --- a/fairmq/zeromq/FairMQTransportFactoryZMQ.cxx +++ b/fairmq/zeromq/FairMQTransportFactoryZMQ.cxx @@ -18,11 +18,13 @@ using namespace std; +static FairMQ::Transport gTransportType = FairMQ::Transport::ZMQ; + FairMQTransportFactoryZMQ::FairMQTransportFactoryZMQ() { int major, minor, patch; zmq_version(&major, &minor, &patch); - LOG(DEBUG) << "Using ZeroMQ library, version: " << major << "." << minor << "." << patch; + LOG(DEBUG) << "Transport: Using ZeroMQ library, version: " << major << "." << minor << "." << patch; } FairMQMessagePtr FairMQTransportFactoryZMQ::CreateMessage() const @@ -59,3 +61,8 @@ FairMQPollerPtr FairMQTransportFactoryZMQ::CreatePoller(const FairMQSocket& cmdS { return unique_ptr(new FairMQPollerZMQ(cmdSocket, dataSocket)); } + +FairMQ::Transport FairMQTransportFactoryZMQ::GetType() const +{ + return gTransportType; +} diff --git a/fairmq/zeromq/FairMQTransportFactoryZMQ.h b/fairmq/zeromq/FairMQTransportFactoryZMQ.h index 98d03d1e..3e9f6ed9 100644 --- a/fairmq/zeromq/FairMQTransportFactoryZMQ.h +++ b/fairmq/zeromq/FairMQTransportFactoryZMQ.h @@ -16,6 +16,7 @@ #define FAIRMQTRANSPORTFACTORYZMQ_H_ #include +#include #include "FairMQTransportFactory.h" #include "FairMQContextZMQ.h" @@ -38,6 +39,8 @@ class FairMQTransportFactoryZMQ : public FairMQTransportFactory virtual FairMQPollerPtr CreatePoller(const std::unordered_map>& channelsMap, const std::vector& channelList) const; virtual FairMQPollerPtr CreatePoller(const FairMQSocket& cmdSocket, const FairMQSocket& dataSocket) const; + virtual FairMQ::Transport GetType() const; + virtual ~FairMQTransportFactoryZMQ() {}; };