diff --git a/fairmq/CMakeLists.txt b/fairmq/CMakeLists.txt index d159ae12..73563b4b 100644 --- a/fairmq/CMakeLists.txt +++ b/fairmq/CMakeLists.txt @@ -83,6 +83,7 @@ set(FAIRMQ_HEADER_FILES devices/FairMQProxy.h devices/FairMQSink.h devices/FairMQSplitter.h + ofi/Poller.h ofi/Socket.h ofi/TransportFactory.h options/FairMQParser.h @@ -148,6 +149,7 @@ set(FAIRMQ_SOURCE_FILES devices/FairMQProxy.cxx # devices/FairMQSink.cxx devices/FairMQSplitter.cxx + ofi/Poller.cxx ofi/Socket.cxx ofi/TransportFactory.cxx options/FairMQParser.cxx diff --git a/fairmq/FairMQPoller.h b/fairmq/FairMQPoller.h index a78f1e94..1494f068 100644 --- a/fairmq/FairMQPoller.h +++ b/fairmq/FairMQPoller.h @@ -33,6 +33,8 @@ namespace mq using PollerPtr = std::unique_ptr; +struct PollerError : std::runtime_error { using std::runtime_error::runtime_error; }; + } /* namespace mq */ } /* namespace fair */ diff --git a/fairmq/ofi/Poller.cxx b/fairmq/ofi/Poller.cxx new file mode 100644 index 00000000..9af0f96b --- /dev/null +++ b/fairmq/ofi/Poller.cxx @@ -0,0 +1,179 @@ +/******************************************************************************** + * Copyright (C) 2018 GSI Helmholtzzentrum fuer Schwerionenforschung GmbH * + * * + * This software is distributed under the terms of the * + * GNU Lesser General Public Licence (LGPL) version 3, * + * copied verbatim in the file "LICENSE" * + ********************************************************************************/ + +#include +#include +#include + +#include + +namespace fair +{ +namespace mq +{ +namespace ofi +{ + +using namespace std; + +Poller::Poller(const vector& channels) +{ + fNumItems = channels.size(); + fItems = new zmq_pollitem_t[fNumItems]; + + for (int i = 0; i < fNumItems; ++i) { + fItems[i].socket = channels.at(i).GetSocket().GetSocket(); + fItems[i].fd = 0; + fItems[i].revents = 0; + + int type = 0; + size_t size = sizeof(type); + zmq_getsockopt(channels.at(i).GetSocket().GetSocket(), ZMQ_TYPE, &type, &size); + + SetItemEvents(fItems[i], type); + } +} + +Poller::Poller(const vector& channels) +{ + fNumItems = channels.size(); + fItems = new zmq_pollitem_t[fNumItems]; + + for (int i = 0; i < fNumItems; ++i) { + fItems[i].socket = channels.at(i)->GetSocket().GetSocket(); + fItems[i].fd = 0; + fItems[i].revents = 0; + + int type = 0; + size_t size = sizeof(type); + zmq_getsockopt(channels.at(i)->GetSocket().GetSocket(), ZMQ_TYPE, &type, &size); + + SetItemEvents(fItems[i], type); + } +} + +Poller::Poller(const unordered_map>& channelsMap, const vector& channelList) +{ + int offset = 0; + + try { + // calculate offsets and the total size of the poll item set + for (string channel : channelList) { + fOffsetMap[channel] = offset; + offset += channelsMap.at(channel).size(); + fNumItems += channelsMap.at(channel).size(); + } + + fItems = new zmq_pollitem_t[fNumItems]; + + int index = 0; + for (string channel : channelList) { + for (unsigned int i = 0; i < channelsMap.at(channel).size(); ++i) { + index = fOffsetMap[channel] + i; + + fItems[index].socket = channelsMap.at(channel).at(i).GetSocket().GetSocket(); + fItems[index].fd = 0; + fItems[index].revents = 0; + + int type = 0; + size_t size = sizeof(type); + zmq_getsockopt(channelsMap.at(channel).at(i).GetSocket().GetSocket(), ZMQ_TYPE, &type, &size); + + SetItemEvents(fItems[index], type); + } + } + } + catch (const std::out_of_range& oor) { + throw PollerError{tools::ToString("At least one of the provided channel keys for poller initialization is invalid. ", + "Out of range error: ", oor.what())}; + } +} + +Poller::Poller(const FairMQSocket& cmdSocket, const FairMQSocket& dataSocket) + : fNumItems{2} +{ + fItems = new zmq_pollitem_t[fNumItems]; + + fItems[0].socket = cmdSocket.GetSocket(); + fItems[0].fd = 0; + fItems[0].events = ZMQ_POLLIN; + fItems[0].revents = 0; + + fItems[1].socket = dataSocket.GetSocket(); + fItems[1].fd = 0; + fItems[1].revents = 0; + + int type = 0; + size_t size = sizeof(type); + zmq_getsockopt(dataSocket.GetSocket(), ZMQ_TYPE, &type, &size); + + SetItemEvents(fItems[1], type); +} + +auto Poller::SetItemEvents(zmq_pollitem_t& item, const int type) -> void +{ + if (type == ZMQ_PAIR) { + item.events = ZMQ_POLLIN|ZMQ_POLLOUT; + } else { + throw PollerError{"Invalid poller configuration."}; + } +} + +auto Poller::Poll(const int timeout) -> void +{ + if (zmq_poll(fItems, fNumItems, timeout) < 0) { + if (errno == ETERM) { + LOG(debug) << "polling exited, reason: " << zmq_strerror(errno); + } else { + throw PollerError{tools::ToString("Polling failed, reason: ", zmq_strerror(errno))}; + } + } +} + +auto Poller::CheckInput(const int index) -> bool +{ + return fItems[index].revents & ZMQ_POLLIN; +} + +auto Poller::CheckOutput(const int index) -> bool +{ + return fItems[index].revents & ZMQ_POLLOUT; +} + +auto Poller::CheckInput(const string channelKey, const int index) -> bool +{ + try { + return fItems[fOffsetMap.at(channelKey) + index].revents & ZMQ_POLLIN; + } catch (const std::out_of_range& oor) { + throw PollerError{tools::ToString( + "Invalid channel key: '", channelKey, "', ", + "Out of range error: ", oor.what() + )}; + } +} + +auto Poller::CheckOutput(const string channelKey, const int index) -> bool +{ + try { + return fItems[fOffsetMap.at(channelKey) + index].revents & ZMQ_POLLOUT; + } catch (const std::out_of_range& oor) { + throw PollerError{tools::ToString( + "Invalid channel key: '", channelKey, "', ", + "Out of range error: ", oor.what() + )}; + } +} + +Poller::~Poller() +{ + delete[] fItems; +} + +} /* namespace ofi */ +} /* namespace mq */ +} /* namespace fair */ diff --git a/fairmq/ofi/Poller.h b/fairmq/ofi/Poller.h new file mode 100644 index 00000000..9409953b --- /dev/null +++ b/fairmq/ofi/Poller.h @@ -0,0 +1,72 @@ +/******************************************************************************** + * Copyright (C) 2018 GSI Helmholtzzentrum fuer Schwerionenforschung GmbH * + * * + * This software is distributed under the terms of the * + * GNU Lesser General Public Licence (LGPL) version 3, * + * copied verbatim in the file "LICENSE" * + ********************************************************************************/ + +#ifndef FAIR_MQ_OFI_POLLER_H +#define FAIR_MQ_OFI_POLLER_H + +#include +#include +#include + +#include +#include + +#include + +namespace fair +{ +namespace mq +{ +namespace ofi +{ + +class TransportFactory; + +/** + * @class Poller Poller.h + * @brief + * + * @todo TODO insert long description + */ +class Poller : public FairMQPoller +{ + friend class FairMQChannel; + friend class TransportFactory; + + public: + Poller(const std::vector& channels); + Poller(const std::vector& channels); + Poller(const std::unordered_map>& channelsMap, const std::vector& channelList); + + Poller(const Poller&) = delete; + Poller operator=(const Poller&) = delete; + + auto SetItemEvents(zmq_pollitem_t& item, const int type) -> void; + + auto Poll(const int timeout) -> void override; + auto CheckInput(const int index) -> bool override; + auto CheckOutput(const int index) -> bool override; + auto CheckInput(const std::string channelKey, const int index) -> bool override; + auto CheckOutput(const std::string channelKey, const int index) -> bool override; + + ~Poller() override; + + private: + Poller(const FairMQSocket& cmdSocket, const FairMQSocket& dataSocket); + + zmq_pollitem_t* fItems; + int fNumItems; + + std::unordered_map fOffsetMap; +}; /* class Poller */ + +} /* namespace ofi */ +} /* namespace mq */ +} /* namespace fair */ + +#endif /* FAIR_MQ_OFI_POLLER_H */ diff --git a/fairmq/ofi/TransportFactory.cxx b/fairmq/ofi/TransportFactory.cxx index cd86d9e3..31468a63 100644 --- a/fairmq/ofi/TransportFactory.cxx +++ b/fairmq/ofi/TransportFactory.cxx @@ -6,8 +6,9 @@ * copied verbatim in the file "LICENSE" * ********************************************************************************/ -#include +#include #include +#include #include #include // OFI libfabric @@ -63,22 +64,22 @@ TransportFactory::TransportFactory(const string& id, const FairMQProgOptions* co auto TransportFactory::CreateMessage() const -> MessagePtr { - throw runtime_error{"Not yet implemented."}; + throw runtime_error{"Not yet implemented Msg1."}; } auto TransportFactory::CreateMessage(const size_t size) const -> MessagePtr { - throw runtime_error{"Not yet implemented."}; + throw runtime_error{"Not yet implemented Msg2."}; } auto TransportFactory::CreateMessage(void* data, const size_t size, fairmq_free_fn* ffn, void* hint) const -> MessagePtr { - throw runtime_error{"Not yet implemented."}; + throw runtime_error{"Not yet implemented Msg3."}; } auto TransportFactory::CreateMessage(UnmanagedRegionPtr& region, void* data, const size_t size, void* hint) const -> MessagePtr { - throw runtime_error{"Not yet implemented."}; + throw runtime_error{"Not yet implemented Msg4."}; } auto TransportFactory::CreateSocket(const string& type, const string& name) const -> SocketPtr @@ -89,27 +90,27 @@ auto TransportFactory::CreateSocket(const string& type, const string& name) cons auto TransportFactory::CreatePoller(const vector& channels) const -> PollerPtr { - throw runtime_error{"Not yet implemented."}; + return unique_ptr(new Poller(channels)); } auto TransportFactory::CreatePoller(const vector& channels) const -> PollerPtr { - throw runtime_error{"Not yet implemented."}; + return unique_ptr(new Poller(channels)); } auto TransportFactory::CreatePoller(const unordered_map>& channelsMap, const vector& channelList) const -> PollerPtr { - throw runtime_error{"Not yet implemented."}; + return unique_ptr(new Poller(channelsMap, channelList)); } auto TransportFactory::CreatePoller(const FairMQSocket& cmdSocket, const FairMQSocket& dataSocket) const -> PollerPtr { - throw runtime_error{"Not yet implemented."}; + return unique_ptr(new Poller(cmdSocket, dataSocket)); } auto TransportFactory::CreateUnmanagedRegion(const size_t size, FairMQRegionCallback callback) const -> UnmanagedRegionPtr { - throw runtime_error{"Not yet implemented."}; + throw runtime_error{"Not yet implemented UMR."}; } auto TransportFactory::GetType() const -> Transport