diff --git a/fairmq/ofi/Message.cxx b/fairmq/ofi/Message.cxx index 7209fb0a..8411a682 100644 --- a/fairmq/ofi/Message.cxx +++ b/fairmq/ofi/Message.cxx @@ -75,7 +75,9 @@ auto Message::Rebuild() -> void if (fFreeFunction) { fFreeFunction(fData, fHint); } else { - fPmr->deallocate(fData, fSize); + if (fData) { + fPmr->deallocate(fData, fSize); + } } fData = nullptr; fInitialSize = 0; @@ -89,7 +91,9 @@ auto Message::Rebuild(const size_t size) -> void if (fFreeFunction) { fFreeFunction(fData, fHint); } else { - fPmr->deallocate(fData, fSize); + if (fData) { + fPmr->deallocate(fData, fSize); + } } if (size) { fData = fPmr->allocate(size); @@ -108,7 +112,9 @@ auto Message::Rebuild(void* /*data*/, const size_t size, fairmq_free_fn* ffn, vo if (fFreeFunction) { fFreeFunction(fData, fHint); } else { - fPmr->deallocate(fData, fSize); + if (fData) { + fPmr->deallocate(fData, fSize); + } } if (size) { fData = fPmr->allocate(size); @@ -152,9 +158,11 @@ auto Message::Copy(const fair::mq::Message& /*msg*/) -> void Message::~Message() { if (fFreeFunction) { - fFreeFunction(fData, fHint); + fFreeFunction(fData, fHint); } else { - fPmr->deallocate(fData, fSize); + if (fData) { + fPmr->deallocate(fData, fSize); + } } } diff --git a/fairmq/ofi/Socket.cxx b/fairmq/ofi/Socket.cxx index 595c61d6..709dfe80 100644 --- a/fairmq/ofi/Socket.cxx +++ b/fairmq/ofi/Socket.cxx @@ -17,14 +17,14 @@ #include #include #include -#include -#include -#include #include #include #include #include +#include +#include + namespace fair { namespace mq @@ -98,8 +98,6 @@ try { fLocalDataAddr = addr; BindDataEndpoint(); - AnnounceDataAddress(); - return true; } catch (const SilentSocketError& e) @@ -143,17 +141,33 @@ auto Socket::BindDataEndpoint() -> void assert(!fPassiveDataEndpoint); assert(!fDataEndpoint); + std::mutex m; + std::condition_variable cv; + bool completed(false); + fPassiveDataEndpoint = fContext.MakeOfiPassiveEndpoint(fLocalDataAddr); fPassiveDataEndpoint->listen([&](fid_t /*handle*/, asiofi::info&& info) { LOG(debug) << "OFI transport (" << fId << "): data band connection request received. Accepting ..."; fDataEndpoint = fContext.MakeOfiConnectedEndpoint(info); fDataEndpoint->enable(); fDataEndpoint->accept([&]() { - LOG(debug) << "OFI transport (" << fId << "): data band connection accepted."; + { + std::unique_lock lk(m); + completed = true; + } + cv.notify_one(); }); }); LOG(debug) << "OFI transport (" << fId << "): data band bound to " << fLocalDataAddr; + + AnnounceDataAddress(); + + { + std::unique_lock lk(m); + cv.wait(lk, [&](){ return completed; }); + } + LOG(debug) << "OFI transport (" << fId << "): data band connection accepted."; } auto Socket::ConnectControlSocket(Context::Address address) -> void @@ -295,69 +309,31 @@ auto Socket::ReceiveControlMessage() -> CtrlMsgPtr } } -// auto Socket::WaitForControlPeer() -> void -// { - // assert(fWaitingForControlPeer); -// - // First frame in message contains event number and value - // zmq_msg_t msg; - // zmq_msg_init(&msg); - // if (zmq_msg_recv(&msg, fMonitorSocket, 0) == -1) - // throw SocketError(tools::ToString("Failed to get monitor event, reason: ", zmq_strerror(errno))); -// - // uint8_t* data = (uint8_t*) zmq_msg_data(&msg); - // uint16_t event = *(uint16_t*)(data); - // int value = *(uint32_t *)(data + 2); -// - // Second frame in message contains event address - // zmq_msg_init(&msg); - // if (zmq_msg_recv(&msg, fMonitorSocket, 0) == -1) - // throw SocketError(tools::ToString("Failed to get monitor event, reason: ", zmq_strerror(errno))); -// - // if (event == ZMQ_EVENT_ACCEPTED) { - // string localAddress = string(static_cast(zmq_msg_data(&msg)), zmq_msg_size(&msg)); - // sockaddr_in remoteAddr; - // socklen_t addrSize = sizeof(sockaddr_in); - // int ret = getpeername(value, (sockaddr*)&remoteAddr, &addrSize); - // if (ret != 0) - // throw SocketError(tools::ToString("Failed retrieving remote address, reason: ", strerror(errno))); - // string remoteIp(inet_ntoa(remoteAddr.sin_addr)); - // int remotePort = ntohs(remoteAddr.sin_port); - // LOG(debug) << "Accepted control peer connection from " << remoteIp << ":" << remotePort; - // } else if (event == ZMQ_EVENT_CONNECTED) { - // LOG(debug) << "Connected successfully to control peer"; - // } else { - // LOG(debug) << "Unknown monitor event received: " << event << ". Ignoring."; - // } -// - // fWaitingForControlPeer = false; -// } - auto Socket::Send(MessagePtr& msg, const int timeout) -> int { return SendImpl(msg, 0, timeout); } auto Socket::Receive(MessagePtr& msg, const int timeout) -> int { return ReceiveImpl(msg, 0, timeout); } auto Socket::Send(std::vector& msgVec, const int timeout) -> int64_t { return SendImpl(msgVec, 0, timeout); } auto Socket::Receive(std::vector& msgVec, const int timeout) -> int64_t { return ReceiveImpl(msgVec, 0, timeout); } -auto Socket::TrySend(MessagePtr& msg) -> int { return SendImpl(msg, ZMQ_DONTWAIT, 0); } -auto Socket::TryReceive(MessagePtr& msg) -> int { return ReceiveImpl(msg, ZMQ_DONTWAIT, 0); } -auto Socket::TrySend(std::vector& msgVec) -> int64_t { return SendImpl(msgVec, ZMQ_DONTWAIT, 0); } -auto Socket::TryReceive(std::vector& msgVec) -> int64_t { return ReceiveImpl(msgVec, ZMQ_DONTWAIT, 0); } - -#include -#include - auto Socket::SendImpl(FairMQMessagePtr& msg, const int /*flags*/, const int /*timeout*/) -> int try { auto size = msg->GetSize(); - LOG(debug) << "OFI transport (" << fId << "): ENTER SendImpl"; + // LOG(debug) << "OFI transport (" << fId << "): ENTER SendImpl"; // Create and send control message auto pb = MakeControlMessage(&fCtrlMemPool); pb->size = size; SendControlMessage(StaticUniquePtrUpcast(std::move(pb))); - LOG(debug) << "OFI transport (" << fId << "): >>>>> SendImpl: Control message sent, size=" << size; + // LOG(debug) << "OFI transport (" << fId << "): >>>>> SendImpl: Control message sent, size=" << size; + // LOG(debug) << "OFI transport (" << fId << "): >>>>> SendImpl: msg->GetData()=" << msg->GetData() << ",msg->GetSize()=" << msg->GetSize(); if (size) { + // Receive ack + auto ack = StaticUniquePtrDowncast(ReceiveControlMessage()); + assert(ack.get()); + auto size_ack = ack->size; + assert(size == size_ack); + // LOG(debug) << "OFI transport (" << fId << "): >>>>> SendImpl: Control ack received, size_ack=" << size_ack; + boost::asio::mutable_buffer buffer(msg->GetData(), size); asiofi::memory_region mr(fContext.GetDomain(), buffer, asiofi::mr::access::send); @@ -374,7 +350,7 @@ try { completed = true; } cv.notify_one(); - LOG(debug) << "OFI transport (" << fId << "): > SendImpl: Data buffer sent"; + // LOG(debug) << "OFI transport (" << fId << "): > SendImpl: Data buffer sent"; } ); @@ -382,14 +358,14 @@ try { std::unique_lock lk(m); cv.wait(lk, [&](){ return completed; }); } - LOG(debug) << "OFI transport (" << fId << "): >>>>> SendImpl: Data send buffer posted"; + // LOG(debug) << "OFI transport (" << fId << "): >>>>> SendImpl: Data send buffer posted"; } msg.reset(nullptr); fBytesTx += size; fMessagesTx++; - LOG(debug) << "OFI transport (" << fId << "): LEAVE SendImpl"; + // LOG(debug) << "OFI transport (" << fId << "): LEAVE SendImpl"; return size; } catch (const SilentSocketError& e) @@ -404,12 +380,12 @@ catch (const std::exception& e) auto Socket::ReceiveImpl(FairMQMessagePtr& msg, const int /*flags*/, const int /*timeout*/) -> int try { - LOG(debug) << "OFI transport (" << fId << "): ENTER ReceiveImpl"; + // LOG(debug) << "OFI transport (" << fId << "): ENTER ReceiveImpl"; // Receive and process control message auto pb = StaticUniquePtrDowncast(ReceiveControlMessage()); assert(pb.get()); auto size = pb->size; - LOG(debug) << "OFI transport (" << fId << "): <<<<< ReceiveImpl: Control message received, size=" << size; + // LOG(debug) << "OFI transport (" << fId << "): <<<<< ReceiveImpl: Control message received, size=" << size; // Receive data if (size) { @@ -429,19 +405,24 @@ try { cv.notify_one(); } ); - LOG(debug) << "OFI transport (" << fId << "): <<<<< ReceiveImpl: Data buffer posted"; + // LOG(debug) << "OFI transport (" << fId << "): <<<<< ReceiveImpl: Data buffer posted"; + + auto ack = MakeControlMessage(&fCtrlMemPool); + ack->size = size; + SendControlMessage(StaticUniquePtrUpcast(std::move(ack))); + // LOG(debug) << "OFI transport (" << fId << "): <<<<< ReceiveImpl: Control Ack sent"; { std::unique_lock lk(m); cv.wait(lk, [&](){ return completed; }); } - LOG(debug) << "OFI transport (" << fId << "): <<<<< ReceiveImpl: Data received"; + // LOG(debug) << "OFI transport (" << fId << "): <<<<< ReceiveImpl: Data received"; } fBytesRx += size; fMessagesRx++; - LOG(debug) << "OFI transport (" << fId << "): EXIT ReceiveImpl"; + // LOG(debug) << "OFI transport (" << fId << "): EXIT ReceiveImpl"; return size; } catch (const SilentSocketError& e) @@ -658,6 +639,14 @@ int Socket::GetLinger() const return value; } +void Socket::SetLinger(const int value) +{ + if (zmq_setsockopt(fControlSocket, ZMQ_LINGER, &value, sizeof(value)) < 0) { + throw SocketError(tools::ToString("failed setting ZMQ_LINGER, reason: ", zmq_strerror(errno))); + } +} + + void Socket::SetSndBufSize(const int value) { if (zmq_setsockopt(fControlSocket, ZMQ_SNDHWM, &value, sizeof(value)) < 0) { diff --git a/fairmq/ofi/Socket.h b/fairmq/ofi/Socket.h index affab53e..1f1f76d1 100644 --- a/fairmq/ofi/Socket.h +++ b/fairmq/ofi/Socket.h @@ -51,11 +51,6 @@ class Socket final : public fair::mq::Socket auto Send(std::vector& msgVec, int timeout = 0) -> int64_t override; auto Receive(std::vector& msgVec, int timeout = 0) -> int64_t override; - auto TrySend(MessagePtr& msg) -> int override; - auto TryReceive(MessagePtr& msg) -> int override; - auto TrySend(std::vector& msgVec) -> int64_t override; - auto TryReceive(std::vector& msgVec) -> int64_t override; - auto GetSocket() const -> void* { return fControlSocket; } void SetLinger(const int value) override; diff --git a/fairmq/ofi/TransportFactory.cxx b/fairmq/ofi/TransportFactory.cxx index 07331d24..8957f26f 100644 --- a/fairmq/ofi/TransportFactory.cxx +++ b/fairmq/ofi/TransportFactory.cxx @@ -34,12 +34,12 @@ catch (ContextError& e) throw TransportFactoryError{e.what()}; } -auto TransportFactory::CreateMessage() const -> MessagePtr +auto TransportFactory::CreateMessage() -> MessagePtr { return MessagePtr{new Message(&fMemoryResource)}; } -auto TransportFactory::CreateMessage(const size_t size) const -> MessagePtr +auto TransportFactory::CreateMessage(const size_t size) -> MessagePtr { return MessagePtr{new Message(&fMemoryResource, size)}; } @@ -47,7 +47,7 @@ auto TransportFactory::CreateMessage(const size_t size) const -> MessagePtr auto TransportFactory::CreateMessage(void* data, const size_t size, fairmq_free_fn* ffn, - void* hint) const -> MessagePtr + void* hint) -> MessagePtr { return MessagePtr{new Message(&fMemoryResource, data, size, ffn, hint)}; } @@ -55,7 +55,7 @@ auto TransportFactory::CreateMessage(void* data, auto TransportFactory::CreateMessage(UnmanagedRegionPtr& region, void* data, const size_t size, - void* hint) const -> MessagePtr + void* hint) -> MessagePtr { return MessagePtr{new Message(&fMemoryResource, region, data, size, hint)}; } @@ -71,7 +71,7 @@ auto TransportFactory::CreatePoller(const vector& channels) const // return PollerPtr{new Poller(channels)}; } -auto TransportFactory::CreatePoller(const vector& channels) const -> PollerPtr +auto TransportFactory::CreatePoller(const vector& channels) const -> PollerPtr { throw runtime_error{"Not yet implemented (Poller)."}; // return PollerPtr{new Poller(channels)}; diff --git a/fairmq/ofi/TransportFactory.h b/fairmq/ofi/TransportFactory.h index b69eaab1..3a3a8d45 100644 --- a/fairmq/ofi/TransportFactory.h +++ b/fairmq/ofi/TransportFactory.h @@ -35,15 +35,15 @@ class TransportFactory final : public FairMQTransportFactory TransportFactory(const TransportFactory&) = delete; TransportFactory operator=(const TransportFactory&) = delete; - auto CreateMessage() const -> MessagePtr override; - auto CreateMessage(const std::size_t size) const -> MessagePtr override; - auto CreateMessage(void* data, const std::size_t size, fairmq_free_fn* ffn, void* hint = nullptr) const -> MessagePtr override; - auto CreateMessage(UnmanagedRegionPtr& region, void* data, const std::size_t size, void* hint = nullptr) const -> MessagePtr override; + auto CreateMessage() -> MessagePtr override; + auto CreateMessage(const std::size_t size) -> MessagePtr override; + auto CreateMessage(void* data, const std::size_t size, fairmq_free_fn* ffn, void* hint = nullptr) -> MessagePtr override; + auto CreateMessage(UnmanagedRegionPtr& region, void* data, const std::size_t size, void* hint = nullptr) -> MessagePtr override; auto CreateSocket(const std::string& type, const std::string& name) -> SocketPtr override; auto CreatePoller(const std::vector& channels) const -> PollerPtr override; - auto CreatePoller(const std::vector& channels) const -> PollerPtr override; + auto CreatePoller(const std::vector& channels) const -> PollerPtr override; auto CreatePoller(const std::unordered_map>& channelsMap, const std::vector& channelList) const -> PollerPtr override; auto CreateUnmanagedRegion(const size_t size, FairMQRegionCallback callback = nullptr) const -> UnmanagedRegionPtr override;