diff --git a/fairmq/ofi/Context.h b/fairmq/ofi/Context.h index 98edb141..4ad042f3 100644 --- a/fairmq/ofi/Context.h +++ b/fairmq/ofi/Context.h @@ -9,6 +9,8 @@ #ifndef FAIR_MQ_OFI_CONTEXT_H #define FAIR_MQ_OFI_CONTEXT_H +#include + #include #include #include @@ -63,6 +65,9 @@ class Context static auto ConvertAddress(Address address) -> sockaddr_in; static auto ConvertAddress(sockaddr_in address) -> Address; static auto VerifyAddress(const std::string& address) -> Address; + auto GetDomain() const -> const asiofi::domain& { return *fOfiDomain; } + auto Interrupt() -> void { LOG(debug) << "OFI transport: Interrupted (NOOP - not implemented)."; } + auto Resume() -> void { LOG(debug) << "OFI transport: Resumed (NOOP - not implemented)."; } private: void* fZmqContext; diff --git a/fairmq/ofi/ControlMessages.h b/fairmq/ofi/ControlMessages.h index 32ff4d31..987c3400 100644 --- a/fairmq/ofi/ControlMessages.h +++ b/fairmq/ofi/ControlMessages.h @@ -56,6 +56,8 @@ auto MakeControlMessage(A* pmr, Args&& ... args) -> CtrlMsgPtr if (std::is_same::value) { raw_ptr->type = ControlMessageType::DataAddressAnnouncement; + } else if (std::is_same::value) { + raw_ptr->type = ControlMessageType::PostBuffer; } return {raw_ptr, [=](T* p) { pmr->deallocate(p, sizeof(T)); }}; diff --git a/fairmq/ofi/Message.cxx b/fairmq/ofi/Message.cxx index 3a470702..7209fb0a 100644 --- a/fairmq/ofi/Message.cxx +++ b/fairmq/ofi/Message.cxx @@ -10,6 +10,7 @@ #include #include +#include #include #include #include @@ -23,38 +24,48 @@ namespace ofi using namespace std; -Message::Message() +Message::Message(boost::container::pmr::memory_resource* pmr) : fInitialSize(0) , fSize(0) , fData(nullptr) , fFreeFunction(nullptr) , fHint(nullptr) + , fPmr(pmr) { } -Message::Message(const size_t size) +Message::Message(boost::container::pmr::memory_resource* pmr, const size_t size) : fInitialSize(size) , fSize(size) , fData(nullptr) , fFreeFunction(nullptr) , fHint(nullptr) + , fPmr(pmr) { if (size) { - fData = malloc(size); + fData = fPmr->allocate(size); assert(fData); } } -Message::Message(void* data, const size_t size, fairmq_free_fn* ffn, void* hint) +Message::Message(boost::container::pmr::memory_resource* pmr, + void* data, + const size_t size, + fairmq_free_fn* ffn, + void* hint) : fInitialSize(size) , fSize(size) , fData(data) , fFreeFunction(ffn) , fHint(hint) -{ -} + , fPmr(pmr) +{} -Message::Message(FairMQUnmanagedRegionPtr& /*region*/, void* /*data*/, const size_t /*size*/, void* /*hint*/) +Message::Message(boost::container::pmr::memory_resource* /*pmr*/, + FairMQUnmanagedRegionPtr& /*region*/, + void* /*data*/, + const size_t /*size*/, + void* /*hint*/) { throw MessageError{"Not yet implemented."}; } @@ -62,9 +73,9 @@ Message::Message(FairMQUnmanagedRegionPtr& /*region*/, void* /*data*/, const siz auto Message::Rebuild() -> void { if (fFreeFunction) { - fFreeFunction(fData, fHint); + fFreeFunction(fData, fHint); } else { - free(fData); + fPmr->deallocate(fData, fSize); } fData = nullptr; fInitialSize = 0; @@ -78,10 +89,10 @@ auto Message::Rebuild(const size_t size) -> void if (fFreeFunction) { fFreeFunction(fData, fHint); } else { - free(fData); + fPmr->deallocate(fData, fSize); } if (size) { - fData = malloc(size); + fData = fPmr->allocate(size); assert(fData); } else { fData = nullptr; @@ -97,10 +108,10 @@ auto Message::Rebuild(void* /*data*/, const size_t size, fairmq_free_fn* ffn, vo if (fFreeFunction) { fFreeFunction(fData, fHint); } else { - free(fData); + fPmr->deallocate(fData, fSize); } if (size) { - fData = malloc(size); + fData = fPmr->allocate(size); assert(fData); } else { fData = nullptr; @@ -143,7 +154,7 @@ Message::~Message() if (fFreeFunction) { fFreeFunction(fData, fHint); } else { - free(fData); + fPmr->deallocate(fData, fSize); } } diff --git a/fairmq/ofi/Message.h b/fairmq/ofi/Message.h index 7c933f4b..5a79926b 100644 --- a/fairmq/ofi/Message.h +++ b/fairmq/ofi/Message.h @@ -12,10 +12,10 @@ #include #include -#include - -#include // size_t +#include #include +#include // size_t +#include namespace fair { @@ -33,10 +33,18 @@ namespace ofi class Message final : public fair::mq::Message { public: - Message(); - Message(const size_t size); - Message(void* data, const size_t size, fairmq_free_fn* ffn, void* hint = nullptr); - Message(FairMQUnmanagedRegionPtr& region, void* data, const size_t size, void* hint = 0); + Message(boost::container::pmr::memory_resource* pmr); + Message(boost::container::pmr::memory_resource* pmr, const size_t size); + Message(boost::container::pmr::memory_resource* pmr, + void* data, + const size_t size, + fairmq_free_fn* ffn, + void* hint = nullptr); + Message(boost::container::pmr::memory_resource* pmr, + FairMQUnmanagedRegionPtr& region, + void* data, + const size_t size, + void* hint = 0); Message(const Message&) = delete; Message operator=(const Message&) = delete; @@ -62,6 +70,7 @@ class Message final : public fair::mq::Message void* fData; fairmq_free_fn* fFreeFunction; void* fHint; + boost::container::pmr::memory_resource* fPmr; }; /* class Message */ } /* namespace ofi */ diff --git a/fairmq/ofi/Socket.cxx b/fairmq/ofi/Socket.cxx index 1496e642..595c61d6 100644 --- a/fairmq/ofi/Socket.cxx +++ b/fairmq/ofi/Socket.cxx @@ -13,6 +13,8 @@ #include #include +#include +#include #include #include #include @@ -43,7 +45,6 @@ Socket::Socket(Context& context, const string& type, const string& name, const s , fMessagesTx(0) , fMessagesRx(0) , fContext(context) - , fWaitingForControlPeer(false) , fIoStrand(fContext.GetIoContext()) , fSndTimeout(100) , fRcvTimeout(100) @@ -230,6 +231,14 @@ auto Socket::SendControlMessage(CtrlMsgPtr ctrl) -> void std::memcpy(zmq_msg_data(msg), ctrl.get(), sizeof(DataAddressAnnouncement)); } break; + case ControlMessageType::PostBuffer: + { + auto ret = zmq_msg_init_size(msg, sizeof(PostBuffer)); + (void)ret; + assert(ret == 0); + std::memcpy(zmq_msg_data(msg), ctrl.get(), sizeof(PostBuffer)); + } + break; default: throw SocketError(tools::ToString("Cannot send control message of unknown type.")); } @@ -274,6 +283,13 @@ auto Socket::ReceiveControlMessage() -> CtrlMsgPtr // LOG(debug) << "Received control message: " << ctrl->DebugString(); return StaticUniquePtrUpcast(std::move(daa)); } + case ControlMessageType::PostBuffer: { + assert(msg_size == sizeof(PostBuffer)); + auto pb = MakeControlMessage(&fCtrlMemPool); + std::memcpy(pb.get(), msg_data, sizeof(PostBuffer)); + // LOG(debug) << "Received control message: " << ctrl->DebugString(); + return StaticUniquePtrUpcast(std::move(pb)); + } default: throw SocketError(tools::ToString("Received control message of unknown type.")); } @@ -327,43 +343,53 @@ auto Socket::TryReceive(MessagePtr& msg) -> int { return ReceiveImpl(msg, ZMQ_DO auto Socket::TrySend(std::vector& msgVec) -> int64_t { return SendImpl(msgVec, ZMQ_DONTWAIT, 0); } auto Socket::TryReceive(std::vector& msgVec) -> int64_t { return ReceiveImpl(msgVec, ZMQ_DONTWAIT, 0); } +#include +#include + auto Socket::SendImpl(FairMQMessagePtr& msg, const int /*flags*/, const int /*timeout*/) -> int try { auto size = msg->GetSize(); + LOG(debug) << "OFI transport (" << fId << "): ENTER SendImpl"; - this_thread::sleep_for(std::chrono::seconds(10)); // Create and send control message - // auto ctrl = tools::make_unique(); - // auto buf = tools::make_unique(); - // buf->set_size(size); - // ctrl->set_allocated_post_buffer(buf.release()); - // assert(ctrl->IsInitialized()); - // SendControlMessage(move(ctrl)); + auto pb = MakeControlMessage(&fCtrlMemPool); + pb->size = size; + SendControlMessage(StaticUniquePtrUpcast(std::move(pb))); + LOG(debug) << "OFI transport (" << fId << "): >>>>> SendImpl: Control message sent, size=" << size; if (size) { - // Receive and process control message - // auto ctrl2 = ReceiveControlMessage(); - // assert(ctrl2->has_post_buffer_acknowledgement()); - // assert(ctrl2->post_buffer_acknowledgement().size() == size); + boost::asio::mutable_buffer buffer(msg->GetData(), size); + asiofi::memory_region mr(fContext.GetDomain(), buffer, asiofi::mr::access::send); - // Send data - // fi_context ctx; - // auto ret = fi_send(fDataEndpoint, msg->GetData(), size, nullptr, fRemoteDataAddr, &ctx); - // if (ret < 0) - // throw SocketError(tools::ToString("Failed posting ofi send buffer, reason: ", fi_strerror(ret))); - } + std::mutex m; + std::condition_variable cv; + bool completed(false); - if (size) { - // fi_cq_err_entry cqEntry; - // auto ret = fi_cq_sread(fDataCompletionQueueTx, &cqEntry, 1, nullptr, -1); - // if (ret != 1) - // throw SocketError(tools::ToString("Failed reading ofi tx completion queue event, reason: ", fi_strerror(ret))); + fDataEndpoint->send( + buffer, + mr.desc(), + [&](boost::asio::mutable_buffer) { + { + std::unique_lock lk(m); + completed = true; + } + cv.notify_one(); + LOG(debug) << "OFI transport (" << fId << "): > SendImpl: Data buffer sent"; + } + ); + + { + std::unique_lock lk(m); + cv.wait(lk, [&](){ return completed; }); + } + LOG(debug) << "OFI transport (" << fId << "): >>>>> SendImpl: Data send buffer posted"; } msg.reset(nullptr); fBytesTx += size; fMessagesTx++; + LOG(debug) << "OFI transport (" << fId << "): LEAVE SendImpl"; return size; } catch (const SilentSocketError& e) @@ -376,52 +402,47 @@ catch (const std::exception& e) return -1; } -auto Socket::ReceiveImpl(FairMQMessagePtr& /*msg*/, const int /*flags*/, const int /*timeout*/) -> int +auto Socket::ReceiveImpl(FairMQMessagePtr& msg, const int /*flags*/, const int /*timeout*/) -> int try { - this_thread::sleep_for(std::chrono::seconds(10)); - if (fWaitingForControlPeer) { - WaitForControlPeer(); - // AnnounceDataAddress(); - // ProcessDataAddressAnnouncement(ReceiveControlMessage()); - } - + LOG(debug) << "OFI transport (" << fId << "): ENTER ReceiveImpl"; // Receive and process control message - // auto ctrl = ReceiveControlMessage(); - // assert(ctrl->has_post_buffer()); - // auto postBuffer = ctrl->post_buffer(); - // auto size = postBuffer.size(); + auto pb = StaticUniquePtrDowncast(ReceiveControlMessage()); + assert(pb.get()); + auto size = pb->size; + LOG(debug) << "OFI transport (" << fId << "): <<<<< ReceiveImpl: Control message received, size=" << size; // Receive data - // if (size) { - // fi_context ctx; - // msg->Rebuild(size); - // auto buf = msg->GetData(); - // auto size2 = msg->GetSize(); - // auto ret = fi_recv(fDataEndpoint, buf, size2, nullptr, fRemoteDataAddr, &ctx); - // if (ret < 0) - // throw SocketError(tools::ToString("Failed posting ofi receive buffer, reason: ", fi_strerror(ret))); + if (size) { + msg->Rebuild(size); + boost::asio::mutable_buffer buffer(msg->GetData(), size); + asiofi::memory_region mr(fContext.GetDomain(), buffer, asiofi::mr::access::recv); - // Create and send control message - // auto ctrl2 = tools::make_unique(); - // auto ack = tools::make_unique(); - // ack->set_size(msg->GetSize()); - // ctrl2->set_allocated_post_buffer_acknowledgement(ack.release()); - // assert(ctrl2->IsInitialized()); - // SendControlMessage(move(ctrl2)); + std::mutex m; + std::condition_variable cv; + bool completed(false); - // fi_cq_err_entry cqEntry; - // ret = fi_cq_sread(fDataCompletionQueueRx, &cqEntry, 1, nullptr, -1); - // if (ret != 1) - // throw SocketError(tools::ToString("Failed reading ofi rx completion queue event, reason: ", fi_strerror(ret))); - // assert(cqEntry.len == size2); - // assert(cqEntry.buf == buf); - // } + fDataEndpoint->recv(buffer, mr.desc(), [&](boost::asio::mutable_buffer) { + { + std::unique_lock lk(m); + completed = true; + } + cv.notify_one(); + } + ); + LOG(debug) << "OFI transport (" << fId << "): <<<<< ReceiveImpl: Data buffer posted"; - // fBytesRx += size; + { + std::unique_lock lk(m); + cv.wait(lk, [&](){ return completed; }); + } + LOG(debug) << "OFI transport (" << fId << "): <<<<< ReceiveImpl: Data received"; + } + + fBytesRx += size; fMessagesRx++; - // return size; - return 0; + LOG(debug) << "OFI transport (" << fId << "): EXIT ReceiveImpl"; + return size; } catch (const SilentSocketError& e) { diff --git a/fairmq/ofi/TransportFactory.cxx b/fairmq/ofi/TransportFactory.cxx index c164bf07..07331d24 100644 --- a/fairmq/ofi/TransportFactory.cxx +++ b/fairmq/ofi/TransportFactory.cxx @@ -36,22 +36,28 @@ catch (ContextError& e) auto TransportFactory::CreateMessage() const -> MessagePtr { - return MessagePtr{new Message()}; + return MessagePtr{new Message(&fMemoryResource)}; } auto TransportFactory::CreateMessage(const size_t size) const -> MessagePtr { - return MessagePtr{new Message(size)}; + return MessagePtr{new Message(&fMemoryResource, size)}; } -auto TransportFactory::CreateMessage(void* data, const size_t size, fairmq_free_fn* ffn, void* hint) const -> MessagePtr +auto TransportFactory::CreateMessage(void* data, + const size_t size, + fairmq_free_fn* ffn, + void* hint) const -> MessagePtr { - return MessagePtr{new Message(data, size, ffn, hint)}; + return MessagePtr{new Message(&fMemoryResource, data, size, ffn, hint)}; } -auto TransportFactory::CreateMessage(UnmanagedRegionPtr& region, void* data, const size_t size, void* hint) const -> MessagePtr +auto TransportFactory::CreateMessage(UnmanagedRegionPtr& region, + void* data, + const size_t size, + void* hint) const -> MessagePtr { - return MessagePtr{new Message(region, data, size, hint)}; + return MessagePtr{new Message(&fMemoryResource, region, data, size, hint)}; } auto TransportFactory::CreateSocket(const string& type, const string& name) -> SocketPtr diff --git a/fairmq/ofi/TransportFactory.h b/fairmq/ofi/TransportFactory.h index ef0f915c..b69eaab1 100644 --- a/fairmq/ofi/TransportFactory.h +++ b/fairmq/ofi/TransportFactory.h @@ -13,6 +13,8 @@ #include #include +#include + namespace fair { namespace mq @@ -48,12 +50,13 @@ class TransportFactory final : public FairMQTransportFactory auto GetType() const -> Transport override; - void Interrupt() override {} - void Resume() override {} + void Interrupt() override { fContext.Interrupt(); } + void Resume() override { fContext.Resume(); } void Reset() override {} private: mutable Context fContext; + asiofi::allocated_pool_resource fMemoryResource; }; /* class TransportFactory */ } /* namespace ofi */