From 0b11ad9274474f9df79dfaffaccbe294ff8519d9 Mon Sep 17 00:00:00 2001 From: Alexey Rybalchenko Date: Wed, 7 Oct 2015 15:49:25 +0200 Subject: [PATCH] Fix the type mismatch in the multi-channel poller --- fairmq/FairMQTransportFactory.h | 3 +- fairmq/nanomsg/FairMQPollerNN.cxx | 48 ++++++++++++++++++-- fairmq/nanomsg/FairMQPollerNN.h | 2 +- fairmq/nanomsg/FairMQTransportFactoryNN.cxx | 2 +- fairmq/nanomsg/FairMQTransportFactoryNN.h | 2 +- fairmq/zeromq/FairMQPollerZMQ.cxx | 49 +++++++++++++++++++-- fairmq/zeromq/FairMQPollerZMQ.h | 2 +- fairmq/zeromq/FairMQTransportFactoryZMQ.cxx | 2 +- fairmq/zeromq/FairMQTransportFactoryZMQ.h | 2 +- 9 files changed, 99 insertions(+), 13 deletions(-) diff --git a/fairmq/FairMQTransportFactory.h b/fairmq/FairMQTransportFactory.h index 292d5df1..21a14e8d 100644 --- a/fairmq/FairMQTransportFactory.h +++ b/fairmq/FairMQTransportFactory.h @@ -17,6 +17,7 @@ #include #include +#include #include "FairMQMessage.h" #include "FairMQChannel.h" @@ -36,7 +37,7 @@ class FairMQTransportFactory virtual FairMQSocket* CreateSocket(const std::string& type, const std::string& name, int numIoThreads) = 0; virtual FairMQPoller* CreatePoller(const std::vector& channels) = 0; - virtual FairMQPoller* CreatePoller(std::map>& channelsMap, std::initializer_list channelList) = 0; + virtual FairMQPoller* CreatePoller(std::unordered_map>& channelsMap, std::initializer_list channelList) = 0; virtual FairMQPoller* CreatePoller(FairMQSocket& cmdSocket, FairMQSocket& dataSocket) = 0; virtual ~FairMQTransportFactory() {}; diff --git a/fairmq/nanomsg/FairMQPollerNN.cxx b/fairmq/nanomsg/FairMQPollerNN.cxx index 45918ccd..564ab214 100644 --- a/fairmq/nanomsg/FairMQPollerNN.cxx +++ b/fairmq/nanomsg/FairMQPollerNN.cxx @@ -34,11 +34,32 @@ FairMQPollerNN::FairMQPollerNN(const vector& channels) for (int i = 0; i < fNumItems; ++i) { items[i].fd = channels.at(i).fSocket->GetSocket(1); - items[i].events = NN_POLLIN; + + int type = 0; + size_t sz = sizeof(type); + nn_getsockopt(channels.at(i).fSocket->GetSocket(1), NN_SOL_SOCKET, NN_PROTOCOL, &type, &sz); + + if (type == NN_REQ || type == NN_REP || type == NN_PAIR) + { + items[i].events = NN_POLLIN|NN_POLLOUT; + } + else if (type == NN_PUSH || type == NN_PUB) + { + items[i].events = NN_POLLOUT; + } + else if (type == NN_PULL || type == NN_SUB) + { + items[i].events = NN_POLLIN; + } + else + { + LOG(ERROR) << "invalid poller configuration, exiting."; + exit(EXIT_FAILURE); + } } } -FairMQPollerNN::FairMQPollerNN(map>& channelsMap, initializer_list channelList) +FairMQPollerNN::FairMQPollerNN(unordered_map>& channelsMap, initializer_list channelList) : items() , fNumItems(0) , fOffsetMap() @@ -64,7 +85,28 @@ FairMQPollerNN::FairMQPollerNN(map>& channelsMap, { index = fOffsetMap[channel] + i; items[index].fd = channelsMap.at(channel).at(i).fSocket->GetSocket(1); - items[index].events = NN_POLLIN; + + int type = 0; + size_t sz = sizeof(type); + nn_getsockopt(channelsMap.at(channel).at(i).fSocket->GetSocket(1), NN_SOL_SOCKET, NN_PROTOCOL, &type, &sz); + + if (type == NN_REQ || type == NN_REP || type == NN_PAIR) + { + items[index].events = NN_POLLIN|NN_POLLOUT; + } + else if (type == NN_PUSH || type == NN_PUB) + { + items[index].events = NN_POLLOUT; + } + else if (type == NN_PULL || type == NN_SUB) + { + items[index].events = NN_POLLIN; + } + else + { + LOG(ERROR) << "invalid poller configuration, exiting."; + exit(EXIT_FAILURE); + } } } } diff --git a/fairmq/nanomsg/FairMQPollerNN.h b/fairmq/nanomsg/FairMQPollerNN.h index cf44762d..cb01382b 100644 --- a/fairmq/nanomsg/FairMQPollerNN.h +++ b/fairmq/nanomsg/FairMQPollerNN.h @@ -32,7 +32,7 @@ class FairMQPollerNN : public FairMQPoller public: FairMQPollerNN(const std::vector& channels); - FairMQPollerNN(std::map>& channelsMap, std::initializer_list channelList); + FairMQPollerNN(std::unordered_map>& channelsMap, std::initializer_list channelList); virtual void Poll(const int timeout); virtual bool CheckInput(const int index); diff --git a/fairmq/nanomsg/FairMQTransportFactoryNN.cxx b/fairmq/nanomsg/FairMQTransportFactoryNN.cxx index 1baa978d..8391f33c 100644 --- a/fairmq/nanomsg/FairMQTransportFactoryNN.cxx +++ b/fairmq/nanomsg/FairMQTransportFactoryNN.cxx @@ -46,7 +46,7 @@ FairMQPoller* FairMQTransportFactoryNN::CreatePoller(const vector return new FairMQPollerNN(channels); } -FairMQPoller* FairMQTransportFactoryNN::CreatePoller(std::map>& channelsMap, std::initializer_list channelList) +FairMQPoller* FairMQTransportFactoryNN::CreatePoller(std::unordered_map>& channelsMap, std::initializer_list channelList) { return new FairMQPollerNN(channelsMap, channelList); } diff --git a/fairmq/nanomsg/FairMQTransportFactoryNN.h b/fairmq/nanomsg/FairMQTransportFactoryNN.h index 04cb69c4..ee27845f 100644 --- a/fairmq/nanomsg/FairMQTransportFactoryNN.h +++ b/fairmq/nanomsg/FairMQTransportFactoryNN.h @@ -34,7 +34,7 @@ class FairMQTransportFactoryNN : public FairMQTransportFactory virtual FairMQSocket* CreateSocket(const std::string& type, const std::string& name, int numIoThreads); virtual FairMQPoller* CreatePoller(const std::vector& channels); - virtual FairMQPoller* CreatePoller(std::map>& channelsMap, std::initializer_list channelList); + virtual FairMQPoller* CreatePoller(std::unordered_map>& channelsMap, std::initializer_list channelList); virtual FairMQPoller* CreatePoller(FairMQSocket& cmdSocket, FairMQSocket& dataSocket); virtual ~FairMQTransportFactoryNN() {}; diff --git a/fairmq/zeromq/FairMQPollerZMQ.cxx b/fairmq/zeromq/FairMQPollerZMQ.cxx index c6859da3..1374d592 100644 --- a/fairmq/zeromq/FairMQPollerZMQ.cxx +++ b/fairmq/zeromq/FairMQPollerZMQ.cxx @@ -31,12 +31,33 @@ FairMQPollerZMQ::FairMQPollerZMQ(const vector& channels) { items[i].socket = channels.at(i).fSocket->GetSocket(); items[i].fd = 0; - items[i].events = ZMQ_POLLIN; items[i].revents = 0; + + int type = 0; + size_t size = sizeof(type); + zmq_getsockopt (channels.at(i).fSocket->GetSocket(), ZMQ_TYPE, &type, &size); + + if (type == ZMQ_REQ || type == ZMQ_REP || type == ZMQ_PAIR || type == ZMQ_DEALER || type == ZMQ_ROUTER) + { + items[i].events = ZMQ_POLLIN|ZMQ_POLLOUT; + } + else if (type == ZMQ_PUSH || type == ZMQ_PUB || type == ZMQ_XPUB) + { + items[i].events = ZMQ_POLLOUT; + } + else if (type == ZMQ_PULL || type == ZMQ_SUB || type == ZMQ_XSUB) + { + items[i].events = ZMQ_POLLIN; + } + else + { + LOG(ERROR) << "invalid poller configuration, exiting."; + exit(EXIT_FAILURE); + } } } -FairMQPollerZMQ::FairMQPollerZMQ(map>& channelsMap, initializer_list channelList) +FairMQPollerZMQ::FairMQPollerZMQ(unordered_map>& channelsMap, initializer_list channelList) : items() , fNumItems(0) , fOffsetMap() @@ -61,10 +82,32 @@ FairMQPollerZMQ::FairMQPollerZMQ(map>& channelsMap for (int i = 0; i < channelsMap.at(channel).size(); ++i) { index = fOffsetMap[channel] + i; + items[index].socket = channelsMap.at(channel).at(i).fSocket->GetSocket(); items[index].fd = 0; - items[index].events = ZMQ_POLLIN; items[index].revents = 0; + + int type = 0; + size_t size = sizeof(type); + zmq_getsockopt (channelsMap.at(channel).at(i).fSocket->GetSocket(), ZMQ_TYPE, &type, &size); + + if (type == ZMQ_REQ || type == ZMQ_REP || type == ZMQ_PAIR || type == ZMQ_DEALER || type == ZMQ_ROUTER) + { + items[index].events = ZMQ_POLLIN|ZMQ_POLLOUT; + } + else if (type == ZMQ_PUSH || type == ZMQ_PUB || type == ZMQ_XPUB) + { + items[index].events = ZMQ_POLLOUT; + } + else if (type == ZMQ_PULL || type == ZMQ_SUB || type == ZMQ_XSUB) + { + items[index].events = ZMQ_POLLIN; + } + else + { + LOG(ERROR) << "invalid poller configuration, exiting."; + exit(EXIT_FAILURE); + } } } } diff --git a/fairmq/zeromq/FairMQPollerZMQ.h b/fairmq/zeromq/FairMQPollerZMQ.h index 7a37128c..8a0d3c29 100644 --- a/fairmq/zeromq/FairMQPollerZMQ.h +++ b/fairmq/zeromq/FairMQPollerZMQ.h @@ -32,7 +32,7 @@ class FairMQPollerZMQ : public FairMQPoller public: FairMQPollerZMQ(const std::vector& channels); - FairMQPollerZMQ(std::map>& channelsMap, std::initializer_list channelList); + FairMQPollerZMQ(std::unordered_map>& channelsMap, std::initializer_list channelList); virtual void Poll(const int timeout); virtual bool CheckInput(const int index); diff --git a/fairmq/zeromq/FairMQTransportFactoryZMQ.cxx b/fairmq/zeromq/FairMQTransportFactoryZMQ.cxx index b774d23b..34c7514d 100644 --- a/fairmq/zeromq/FairMQTransportFactoryZMQ.cxx +++ b/fairmq/zeromq/FairMQTransportFactoryZMQ.cxx @@ -50,7 +50,7 @@ FairMQPoller* FairMQTransportFactoryZMQ::CreatePoller(const vector>& channelsMap, std::initializer_list channelList) +FairMQPoller* FairMQTransportFactoryZMQ::CreatePoller(std::unordered_map>& channelsMap, std::initializer_list channelList) { return new FairMQPollerZMQ(channelsMap, channelList); } diff --git a/fairmq/zeromq/FairMQTransportFactoryZMQ.h b/fairmq/zeromq/FairMQTransportFactoryZMQ.h index a2442aa5..220a6ff6 100644 --- a/fairmq/zeromq/FairMQTransportFactoryZMQ.h +++ b/fairmq/zeromq/FairMQTransportFactoryZMQ.h @@ -35,7 +35,7 @@ class FairMQTransportFactoryZMQ : public FairMQTransportFactory virtual FairMQSocket* CreateSocket(const std::string& type, const std::string& name, int numIoThreads); virtual FairMQPoller* CreatePoller(const std::vector& channels); - virtual FairMQPoller* CreatePoller(std::map>& channelsMap, std::initializer_list channelList); + virtual FairMQPoller* CreatePoller(std::unordered_map>& channelsMap, std::initializer_list channelList); virtual FairMQPoller* CreatePoller(FairMQSocket& cmdSocket, FairMQSocket& dataSocket); virtual ~FairMQTransportFactoryZMQ() {};