diff --git a/fairmq/ofi/Context.cxx b/fairmq/ofi/Context.cxx index e2ff045b..4f64b535 100644 --- a/fairmq/ofi/Context.cxx +++ b/fairmq/ofi/Context.cxx @@ -31,11 +31,12 @@ namespace ofi using namespace std; -Context::Context(int numberIoThreads) +Context::Context(FairMQTransportFactory& receiveFactory, int numberIoThreads) : fOfiInfo(nullptr) , fOfiFabric(nullptr) , fOfiDomain(nullptr) , fIoWork(fIoContext) + , fReceiveFactory(receiveFactory) { InitThreadPool(numberIoThreads); } @@ -176,6 +177,11 @@ auto Context::VerifyAddress(const std::string& address) -> Address return addr; } +auto Context::MakeReceiveMessage(size_t size) -> MessagePtr +{ + return fReceiveFactory.CreateMessage(size); +} + } /* namespace ofi */ } /* namespace mq */ } /* namespace fair */ diff --git a/fairmq/ofi/Context.h b/fairmq/ofi/Context.h index fcb3d7f7..a45e11af 100644 --- a/fairmq/ofi/Context.h +++ b/fairmq/ofi/Context.h @@ -10,6 +10,7 @@ #define FAIR_MQ_OFI_CONTEXT_H #include +#include #include #include @@ -44,7 +45,7 @@ enum class Direction : bool { Receive, Transmit }; class Context { public: - Context(int numberIoThreads = 1); + Context(FairMQTransportFactory& receiveFactory, int numberIoThreads = 1); ~Context(); // auto CreateOfiEndpoint() -> fid_ep*; @@ -66,6 +67,7 @@ class Context auto GetDomain() const -> const asiofi::domain& { return *fOfiDomain; } auto Interrupt() -> void { LOG(debug) << "OFI transport: Interrupted (NOOP - not implemented)."; } auto Resume() -> void { LOG(debug) << "OFI transport: Resumed (NOOP - not implemented)."; } + auto MakeReceiveMessage(size_t size) -> MessagePtr; private: std::unique_ptr fOfiInfo; @@ -74,6 +76,7 @@ class Context boost::asio::io_context fIoContext; boost::asio::io_context::work fIoWork; std::vector fThreadPool; + FairMQTransportFactory& fReceiveFactory; auto InitThreadPool(int numberIoThreads) -> void; auto InitOfi(ConnectionType type, Address address) -> void; diff --git a/fairmq/ofi/Socket.cxx b/fairmq/ofi/Socket.cxx index 60e37cc0..547dbfe3 100644 --- a/fairmq/ofi/Socket.cxx +++ b/fairmq/ofi/Socket.cxx @@ -49,8 +49,11 @@ Socket::Socket(Context& context, const string& type, const string& name, const s , fControlEndpoint(fIoStrand.context(), ZMQ_PAIR) , fSndTimeout(100) , fRcvTimeout(100) - , fQueue1(fIoStrand.context()) - , fQueue2(fIoStrand.context()) + , fSendQueueWrite(fIoStrand.context(), ZMQ_PUSH) + , fSendQueueRead(fIoStrand.context(), ZMQ_PULL) + , fRecvQueueWrite(fIoStrand.context(), ZMQ_PUSH) + , fRecvQueueRead(fIoStrand.context(), ZMQ_PULL) + , fSentCount(0) { if (type != "pair") { throw SocketError{tools::ToString("Socket type '", type, "' not implemented for ofi transport.")}; @@ -63,17 +66,28 @@ Socket::Socket(Context& context, const string& type, const string& name, const s // Setup internal queue auto hashed_id = std::hash()(fId); - auto queue_id = tools::ToString("inproc://QUEUE", hashed_id); - LOG(debug) << "OFI transport (" << fId << "): " << "Binding Q1: " << queue_id; - fQueue1.bind(queue_id); - LOG(debug) << "OFI transport (" << fId << "): " << "Connecting Q2: " << queue_id; - fQueue2.connect(queue_id); - azmq::socket::snd_hwm send_max(100); - azmq::socket::rcv_hwm recv_max(100); - fQueue1.set_option(send_max); - fQueue1.set_option(recv_max); - fQueue2.set_option(send_max); - fQueue2.set_option(recv_max); + auto queue_id = tools::ToString("inproc://TXQUEUE", hashed_id); + LOG(debug) << "OFI transport (" << fId << "): " << "Binding SQR: " << queue_id; + fSendQueueRead.bind(queue_id); + LOG(debug) << "OFI transport (" << fId << "): " << "Connecting SQW: " << queue_id; + fSendQueueWrite.connect(queue_id); + queue_id = tools::ToString("inproc://RXQUEUE", hashed_id); + LOG(debug) << "OFI transport (" << fId << "): " << "Binding RQR: " << queue_id; + fRecvQueueRead.bind(queue_id); + LOG(debug) << "OFI transport (" << fId << "): " << "Connecting RQW: " << queue_id; + fRecvQueueWrite.connect(queue_id); + + // TODO wire this up with config + azmq::socket::snd_hwm send_max(10); + azmq::socket::rcv_hwm recv_max(10); + fSendQueueRead.set_option(send_max); + fSendQueueRead.set_option(recv_max); + fSendQueueWrite.set_option(send_max); + fSendQueueWrite.set_option(recv_max); + fRecvQueueRead.set_option(send_max); + fRecvQueueRead.set_option(recv_max); + fSendQueueWrite.set_option(send_max); + fSendQueueWrite.set_option(recv_max); fControlEndpoint.set_option(send_max); fControlEndpoint.set_option(recv_max); } @@ -90,7 +104,8 @@ try { fLocalDataAddr = addr; BindDataEndpoint(); - boost::asio::post(fIoStrand, std::bind(&Socket::SendQueueReader, this)); + // boost::asio::post(fIoStrand, std::bind(&Socket::SendQueueReader, this)); + boost::asio::post(fIoStrand, std::bind(&Socket::RecvControlQueueReader, this)); return true; } @@ -116,6 +131,9 @@ auto Socket::Connect(const string& address) -> bool ReceiveDataAddressAnnouncement(); ConnectDataEndpoint(); + + // boost::asio::post(fIoStrand, std::bind(&Socket::SendQueueReader, this)); + boost::asio::post(fIoStrand, std::bind(&Socket::RecvControlQueueReader, this)); } auto Socket::BindControlEndpoint(Context::Address address) -> void @@ -225,11 +243,13 @@ auto Socket::AnnounceDataAddress() -> void auto Socket::Send(MessagePtr& msg, const int /*timeout*/) -> int { - LOG(debug) << "OFI transport (" << fId << "): ENTER Send: size=" << msg->GetSize(); + LOG(debug) << "OFI transport (" << fId << "): ENTER Send: data=" << msg->GetData() << ",size=" << msg->GetSize(); MessagePtr* msgptr(new std::unique_ptr(std::move(msg))); try { - auto res = fQueue1.send(boost::asio::const_buffer(msgptr, sizeof(MessagePtr)), 0); + ++fSentCount; + LOG(info) << fSentCount; + auto res = fSendQueueWrite.send(boost::asio::const_buffer(msgptr, sizeof(MessagePtr)), 0); LOG(debug) << "OFI transport (" << fId << "): LEAVE Send"; return res; @@ -244,16 +264,44 @@ auto Socket::Send(MessagePtr& msg, const int /*timeout*/) -> int } } -auto Socket::Receive(MessagePtr& msg, const int timeout) -> int { return 0; /*ReceiveImpl(msg, 0, timeout);*/ } +auto Socket::Receive(MessagePtr& msg, const int /*timeout*/) -> int +{ + LOG(debug) << "OFI transport (" << fId << "): ENTER Receive"; + + try { + azmq::message zmsg; + auto recv = fRecvQueueRead.receive(zmsg); + + size_t size(0); + if (recv > 0) { + msg = std::move(*(static_cast(zmsg.buffer().data()))); + size = msg->GetSize(); + } + + fBytesRx += size; + fMessagesRx++; + + LOG(debug) << "OFI transport (" << fId << "): LEAVE Receive"; + return size; + } catch (const std::exception& e) { + LOG(error) << e.what(); + return -1; + } catch (const boost::system::error_code& e) { + LOG(error) << e; + return -1; + } +} + auto Socket::Send(std::vector& msgVec, const int timeout) -> int64_t { return SendImpl(msgVec, 0, timeout); } auto Socket::Receive(std::vector& msgVec, const int timeout) -> int64_t { return ReceiveImpl(msgVec, 0, timeout); } auto Socket::SendQueueReader() -> void { - fQueue2.async_receive(boost::asio::bind_executor( + fSendQueueRead.async_receive(boost::asio::bind_executor( fIoStrand, [&](const boost::system::error_code& ec, azmq::message& zmsg, size_t bytes_transferred) { if (!ec) { + --fSentCount; OnSend(zmsg, bytes_transferred); } })); @@ -266,7 +314,7 @@ auto Socket::OnSend(azmq::message& zmsg, size_t bytes_transferred) -> void MessagePtr msg(std::move(*(static_cast(zmsg.buffer().data())))); auto size = msg->GetSize(); - LOG(debug) << "OFI transport (" << fId << "): >>>>> OnSend: size=" << size; + LOG(debug) << "OFI transport (" << fId << "): OnSend: data=" << msg->GetData() << ",size=" << msg->GetSize(); // Create and send control message auto pb = MakeControlMessage(); @@ -284,7 +332,9 @@ auto Socket::OnSend(azmq::message& zmsg, size_t bytes_transferred) -> void auto Socket::OnControlMessageSent(size_t bytes_transferred, MessagePtr msg) -> void { - LOG(debug) << "OFI transport (" << fId << "): ENTER OnControlMessageSent: bytes_transferred=" << bytes_transferred; + LOG(debug) << "OFI transport (" << fId + << "): ENTER OnControlMessageSent: bytes_transferred=" << bytes_transferred + << ",data=" << msg->GetData() << ",size=" << msg->GetSize(); assert(bytes_transferred == sizeof(PostBuffer)); auto size = msg->GetSize(); @@ -302,77 +352,92 @@ auto Socket::OnControlMessageSent(size_t bytes_transferred, MessagePtr msg) -> v // received, size_ack=" << size_ack; boost::asio::mutable_buffer buffer(msg->GetData(), size); - asiofi::memory_region mr(fContext.GetDomain(), buffer, asiofi::mr::access::send); + // asiofi::memory_region mr(fContext.GetDomain(), buffer, asiofi::mr::access::send); + // auto desc = mr.desc(); - fDataEndpoint->send(buffer, mr.desc(), [&, mr2 = std::move(mr)](boost::asio::mutable_buffer) { - LOG(debug) << "OFI transport (" << fId << "): >>>>> Data buffer sent"; - fBytesTx += size; - fMessagesTx++; - }); + fDataEndpoint->send( + buffer, + // desc, + [&, size, msg2 = std::move(msg)/*, mr2 = std::move(mr)*/](boost::asio::mutable_buffer) mutable { + LOG(debug) << "OFI transport (" << fId << "): >>>>> Data buffer sent"; + fBytesTx += size; + fMessagesTx++; + }); } + boost::asio::post(fIoStrand, std::bind(&Socket::SendQueueReader, this)); LOG(debug) << "OFI transport (" << fId << "): LEAVE OnControlMessageSent"; } -auto Socket::ReceiveImpl(FairMQMessagePtr& msg, const int /*flags*/, const int /*timeout*/) -> int -try { - LOG(debug) << "OFI transport (" << fId << "): ENTER ReceiveImpl"; - // Receive and process control message - azmq::message ctrl; - auto recv = fControlEndpoint.receive(ctrl); - assert(recv == sizeof(PostBuffer)); (void)recv; - auto pb(static_cast(ctrl.data())); +auto Socket::RecvControlQueueReader() -> void +{ + fControlEndpoint.async_receive(boost::asio::bind_executor( + fIoStrand, + [&](const boost::system::error_code& ec, azmq::message& zmsg, size_t bytes_transferred) { + if (!ec) { + OnRecvControl(zmsg, bytes_transferred); + } + })); +} + +auto Socket::OnRecvControl(azmq::message& zmsg, size_t bytes_transferred) -> void +{ + LOG(debug) << "OFI transport (" << fId + << "): ENTER OnRecvControl: bytes_transferred=" << bytes_transferred; + + assert(bytes_transferred == sizeof(PostBuffer)); + auto pb(static_cast(zmsg.data())); assert(pb->type == ControlMessageType::PostBuffer); auto size = pb->size; - LOG(debug) << "OFI transport (" << fId << "): <<<<< ReceiveImpl: Control message received, size=" << size; + LOG(debug) << "OFI transport (" << fId << "): OnRecvControl: PostBuffer.size=" << size; // Receive data if (size) { - msg->Rebuild(size); + auto msg = fContext.MakeReceiveMessage(size); boost::asio::mutable_buffer buffer(msg->GetData(), size); - asiofi::memory_region mr(fContext.GetDomain(), buffer, asiofi::mr::access::recv); + // asiofi::memory_region mr(fContext.GetDomain(), buffer, asiofi::mr::access::recv); + // auto msg33 = fContext.MakeReceiveMessage(size); + // boost::asio::mutable_buffer buffer33(msg33->GetData(), size); + // asiofi::memory_region mr33(fContext.GetDomain(), buffer33, asiofi::mr::access::recv); + // auto desc = mr.desc(); - std::mutex m; - std::condition_variable cv; - bool completed(false); - - fDataEndpoint->recv(buffer, mr.desc(), [&](boost::asio::mutable_buffer) { - { - std::unique_lock lk(m); - completed = true; - } - cv.notify_one(); - } - ); + fDataEndpoint->recv( + buffer, + // desc, + [&, msg2 = std::move(msg)/*, mr2 = std::move(mr)*/](boost::asio::mutable_buffer) mutable { + MessagePtr* msgptr(new std::unique_ptr(std::move(msg2))); + fRecvQueueWrite.async_send( + azmq::message(boost::asio::const_buffer(msgptr, sizeof(MessagePtr))), + [&](const boost::system::error_code& ec, size_t bytes_transferred2) { + if (!ec) { + LOG(debug) << "OFI transport (" << fId + << "): <<<<< Data buffer received, bytes_transferred2=" + << bytes_transferred2; + } + }); + }); // LOG(debug) << "OFI transport (" << fId << "): <<<<< ReceiveImpl: Data buffer posted"; - auto ack = MakeControlMessage(); - ack.size = size; - auto sent = fControlEndpoint.send(boost::asio::buffer(ack)); - assert(sent == sizeof(PostBuffer)); (void)sent; + // auto ack = MakeControlMessage(); + // ack.size = size; + // auto sent = fControlEndpoint.send(boost::asio::buffer(ack)); + // assert(sent == sizeof(PostBuffer)); (void)sent; // LOG(debug) << "OFI transport (" << fId << "): <<<<< ReceiveImpl: Control Ack sent"; - - { - std::unique_lock lk(m); - cv.wait(lk, [&](){ return completed; }); - } - // LOG(debug) << "OFI transport (" << fId << "): <<<<< ReceiveImpl: Data received"; + } else { + fRecvQueueWrite.async_send( + azmq::message(boost::asio::const_buffer(nullptr, 0)), + [&](const boost::system::error_code& ec, size_t bytes_transferred2) { + if (!ec) { + LOG(debug) << "OFI transport (" << fId + << "): <<<<< Data buffer received, bytes_transferred2=" + << bytes_transferred2; + } + }); } - fBytesRx += size; - fMessagesRx++; + boost::asio::post(fIoStrand, std::bind(&Socket::RecvControlQueueReader, this)); - // LOG(debug) << "OFI transport (" << fId << "): EXIT ReceiveImpl"; - return size; -} -catch (const SilentSocketError& e) -{ - return -2; -} -catch (const std::exception& e) -{ - LOG(error) << e.what(); - return -1; + LOG(debug) << "OFI transport (" << fId << "): LEAVE OnRecvControl"; } auto Socket::SendImpl(vector& /*msgVec*/, const int /*flags*/, const int /*timeout*/) -> int64_t @@ -548,14 +613,14 @@ auto Socket::ReceiveImpl(vector& /*msgVec*/, const int /*flags auto Socket::Close() -> void {} -auto Socket::SetOption(const string& option, const void* value, size_t valueSize) -> void +auto Socket::SetOption(const string& /*option*/, const void* /*value*/, size_t /*valueSize*/) -> void { // if (zmq_setsockopt(fControlSocket, GetConstant(option), value, valueSize) < 0) { // throw SocketError{tools::ToString("Failed setting socket option, reason: ", zmq_strerror(errno))}; // } } -auto Socket::GetOption(const string& option, void* value, size_t* valueSize) -> void +auto Socket::GetOption(const string& /*option*/, void* /*value*/, size_t* /*valueSize*/) -> void { // if (zmq_getsockopt(fControlSocket, GetConstant(option), value, valueSize) < 0) { // throw SocketError{tools::ToString("Failed getting socket option, reason: ", zmq_strerror(errno))}; diff --git a/fairmq/ofi/Socket.h b/fairmq/ofi/Socket.h index 3a3c4aeb..05947201 100644 --- a/fairmq/ofi/Socket.h +++ b/fairmq/ofi/Socket.h @@ -19,7 +19,6 @@ #include #include // unique_ptr #include -class FairMQTransportFactory; namespace fair { @@ -37,7 +36,7 @@ namespace ofi class Socket final : public fair::mq::Socket { public: - Socket(Context& factory, const std::string& type, const std::string& name, const std::string& id = "", FairMQTransportFactory* fac); + Socket(Context& context, const std::string& type, const std::string& name, const std::string& id = ""); Socket(const Socket&) = delete; Socket operator=(const Socket&) = delete; @@ -93,11 +92,16 @@ class Socket final : public fair::mq::Socket mutable azmq::socket fControlEndpoint; int fSndTimeout; int fRcvTimeout; - azmq::pair_socket fQueue1, fQueue2; + azmq::socket fSendQueueWrite, fSendQueueRead; + azmq::socket fRecvQueueWrite, fRecvQueueRead; + std::atomic fSentCount; auto SendQueueReader() -> void; auto OnSend(azmq::message& msg, size_t bytes_transferred) -> void; auto OnControlMessageSent(size_t bytes_transferred, MessagePtr msg) -> void; + auto RecvControlQueueReader() -> void; + auto OnRecvControl(azmq::message& msg, size_t bytes_transferred) -> void; + auto OnReceive() -> void; auto ReceiveImpl(MessagePtr& msg, const int flags, const int timeout) -> int; auto SendImpl(std::vector& msgVec, const int flags, const int timeout) -> int64_t; auto ReceiveImpl(std::vector& msgVec, const int flags, const int timeout) -> int64_t; diff --git a/fairmq/ofi/TransportFactory.cxx b/fairmq/ofi/TransportFactory.cxx index 694e706b..7f76c794 100644 --- a/fairmq/ofi/TransportFactory.cxx +++ b/fairmq/ofi/TransportFactory.cxx @@ -24,13 +24,12 @@ namespace ofi using namespace std; TransportFactory::TransportFactory(const string& id, const FairMQProgOptions* /*config*/) -try : FairMQTransportFactory{id} +try : FairMQTransportFactory(id) + , fContext(*this, 1) { LOG(debug) << "OFI transport: Using AZMQ & " << "asiofi (" << fContext.GetAsiofiVersion() << ")"; -} -catch (ContextError& e) -{ +} catch (ContextError& e) { throw TransportFactoryError{e.what()}; }