From a08a34acd59bccff1a257227fb6fe78d2352f758 Mon Sep 17 00:00:00 2001 From: Dennis Klein Date: Mon, 26 Nov 2018 02:07:48 +0100 Subject: [PATCH] Do not share ofi context across sockets --- fairmq/ofi/Context.cxx | 26 +-------------------- fairmq/ofi/Context.h | 29 ++++++++++++----------- fairmq/ofi/Socket.cxx | 53 ++++++++++++++++++++++++++++++++---------- fairmq/ofi/Socket.h | 8 +++++-- 4 files changed, 64 insertions(+), 52 deletions(-) diff --git a/fairmq/ofi/Context.cxx b/fairmq/ofi/Context.cxx index 83909d23..964af84e 100644 --- a/fairmq/ofi/Context.cxx +++ b/fairmq/ofi/Context.cxx @@ -34,10 +34,7 @@ using namespace std; Context::Context(FairMQTransportFactory& sendFactory, FairMQTransportFactory& receiveFactory, int numberIoThreads) - : fOfiInfo(nullptr) - , fOfiFabric(nullptr) - , fOfiDomain(nullptr) - , fIoWork(fIoContext) + : fIoWork(fIoContext) , fReceiveFactory(receiveFactory) , fSendFactory(sendFactory) { @@ -69,27 +66,6 @@ auto Context::GetAsiofiVersion() const -> string return ASIOFI_VERSION; } -auto Context::InitOfi(Address addr) -> void -{ - if (!fOfiInfo) { - assert(!fOfiFabric); - assert(!fOfiDomain); - - asiofi::hints hints; - if (addr.Protocol == "tcp") { - hints.set_provider("sockets"); - } else if (addr.Protocol == "verbs") { - hints.set_provider("verbs"); - } - fOfiInfo = tools::make_unique(addr.Ip.c_str(), std::to_string(addr.Port).c_str(), 0, hints); - // LOG(debug) << "OFI transport: " << *fOfiInfo; - - fOfiFabric = tools::make_unique(*fOfiInfo); - - fOfiDomain = tools::make_unique(*fOfiFabric); - } -} - auto Context::ConvertAddress(std::string address) -> Address { string protocol, ip; diff --git a/fairmq/ofi/Context.h b/fairmq/ofi/Context.h index 3dcea79a..90a91d8c 100644 --- a/fairmq/ofi/Context.h +++ b/fairmq/ofi/Context.h @@ -31,6 +31,22 @@ namespace mq namespace ofi { +enum class ConnectionType : bool { Bind, Connect }; + +struct Address { + std::string Protocol; + std::string Ip; + unsigned int Port; + friend auto operator<<(std::ostream& os, const Address& a) -> std::ostream& + { + return os << a.Protocol << "://" << a.Ip << ":" << a.Port; + } + friend auto operator==(const Address& lhs, const Address& rhs) -> bool + { + return (lhs.Protocol == rhs.Protocol) && (lhs.Ip == rhs.Ip) && (lhs.Port == rhs.Port); + } +}; + /** * @class Context Context.h * @brief Transport-wide context @@ -47,16 +63,6 @@ class Context auto GetAsiofiVersion() const -> std::string; auto GetIoContext() -> boost::asio::io_context& { return fIoContext; } - struct Address { - std::string Protocol; - std::string Ip; - unsigned int Port; - friend auto operator<<(std::ostream& os, const Address& a) -> std::ostream& { return os << a.Protocol << "://" << a.Ip << ":" << a.Port; } - }; - auto InitOfi(Address address) -> void; - auto GetOfiInfo() const -> const asiofi::info& { return *fOfiInfo; } - auto GetOfiFabric() const -> const asiofi::fabric& { return *fOfiFabric; } - auto GetOfiDomain() const -> const asiofi::domain& { return *fOfiDomain; } static auto ConvertAddress(std::string address) -> Address; static auto ConvertAddress(Address address) -> sockaddr_in; static auto ConvertAddress(sockaddr_in address) -> Address; @@ -67,9 +73,6 @@ class Context auto MakeSendMessage(size_t size) -> MessagePtr; private: - std::unique_ptr fOfiInfo; - std::unique_ptr fOfiFabric; - std::unique_ptr fOfiDomain; boost::asio::io_context fIoContext; boost::asio::io_context::work fIoWork; std::vector fThreadPool; diff --git a/fairmq/ofi/Socket.cxx b/fairmq/ofi/Socket.cxx index 1cc37764..a28a9422 100644 --- a/fairmq/ofi/Socket.cxx +++ b/fairmq/ofi/Socket.cxx @@ -38,6 +38,9 @@ using namespace std; Socket::Socket(Context& context, const string& type, const string& name, const string& id /*= ""*/) : fContext(context) + , fOfiInfo(nullptr) + , fOfiFabric(nullptr) + , fOfiDomain(nullptr) , fPassiveEndpoint(nullptr) , fDataEndpoint(nullptr) , fControlEndpoint(nullptr) @@ -87,6 +90,32 @@ Socket::Socket(Context& context, const string& type, const string& name, const s } } +auto Socket::InitOfi(Address addr) -> void +{ + if (!fOfiInfo) { + assert(!fOfiFabric); + assert(!fOfiDomain); + + asiofi::hints hints; + if (addr.Protocol == "tcp") { + hints.set_provider("sockets"); + } else if (addr.Protocol == "verbs") { + hints.set_provider("verbs"); + } + if (fRemoteAddr == addr) { + fOfiInfo = tools::make_unique(addr.Ip.c_str(), std::to_string(addr.Port).c_str(), 0, hints); + } else { + fOfiInfo = tools::make_unique(addr.Ip.c_str(), std::to_string(addr.Port).c_str(), FI_SOURCE, hints); + } + + // LOG(debug) << "OFI transport: " << *fOfiInfo; + + fOfiFabric = tools::make_unique(*fOfiInfo); + + fOfiDomain = tools::make_unique(*fOfiFabric); + } +} + auto Socket::Bind(const string& addr) -> bool try { fLocalAddr = Context::VerifyAddress(addr); @@ -94,15 +123,13 @@ try { fNeedOfiMemoryRegistration = true; } - fContext.InitOfi(fLocalAddr); + InitOfi(fLocalAddr); - fPassiveEndpoint = tools::make_unique(fIoStrand.context(), fContext.GetOfiFabric()); + fPassiveEndpoint = tools::make_unique(fIoStrand.context(), *fOfiFabric); fPassiveEndpoint->set_local_address(Context::ConvertAddress(fLocalAddr)); BindControlEndpoint(); - BindDataEndpoint(); - return true; } // TODO catch the correct ofi error @@ -126,10 +153,12 @@ auto Socket::BindControlEndpoint() -> void LOG(debug) << "OFI transport (" << fId << "): control band connection request received. Accepting ..."; fControlEndpoint = tools::make_unique( - fIoStrand.context(), fContext.GetOfiDomain(), info); + fIoStrand.context(), *fOfiDomain, info); fControlEndpoint->enable(); fControlEndpoint->accept([&]() { LOG(debug) << "OFI transport (" << fId << "): control band connection accepted."; + + BindDataEndpoint(); }); }); @@ -144,7 +173,7 @@ auto Socket::BindDataEndpoint() -> void LOG(debug) << "OFI transport (" << fId << "): data band connection request received. Accepting ..."; fDataEndpoint = tools::make_unique( - fIoStrand.context(), fContext.GetOfiDomain(), info); + fIoStrand.context(), *fOfiDomain, info); fDataEndpoint->enable(); fDataEndpoint->accept([&]() { LOG(debug) << "OFI transport (" << fId << "): data band connection accepted."; @@ -164,7 +193,7 @@ auto Socket::Connect(const string& address) -> void fNeedOfiMemoryRegistration = true; } - fContext.InitOfi(fRemoteAddr); + InitOfi(fRemoteAddr); ConnectControlEndpoint(); @@ -183,7 +212,7 @@ auto Socket::ConnectControlEndpoint() -> void bool completed(false); fControlEndpoint = - tools::make_unique(fIoStrand.context(), fContext.GetOfiDomain()); + tools::make_unique(fIoStrand.context(), *fOfiDomain); fControlEndpoint->enable(); fControlEndpoint->connect(Context::ConvertAddress(fRemoteAddr), [&]() { @@ -213,7 +242,7 @@ auto Socket::ConnectDataEndpoint() -> void bool completed(false); fDataEndpoint = - tools::make_unique(fIoStrand.context(), fContext.GetOfiDomain()); + tools::make_unique(fIoStrand.context(), *fOfiDomain); fDataEndpoint->enable(); fDataEndpoint->connect(Context::ConvertAddress(fRemoteAddr), [&]() { @@ -353,7 +382,7 @@ auto Socket::OnSend(azmq::message& zmsg, size_t /*bytes_transferred*/) -> void ctrl->size = size; auto ctrl_msg = boost::asio::mutable_buffer(ctrl.get(), sizeof(PostBuffer)); if (fNeedOfiMemoryRegistration) { - asiofi::memory_region mr(fContext.GetOfiDomain(), ctrl_msg, asiofi::mr::access::send); + asiofi::memory_region mr(*fOfiDomain, ctrl_msg, asiofi::mr::access::send); auto desc = mr.desc(); fControlEndpoint->send( ctrl_msg, desc, [&, ctrl2 = std::move(ctrl)](boost::asio::mutable_buffer) mutable { @@ -371,7 +400,7 @@ auto Socket::OnSend(azmq::message& zmsg, size_t /*bytes_transferred*/) -> void boost::asio::mutable_buffer buffer(msg->GetData(), size); if (fNeedOfiMemoryRegistration) { - asiofi::memory_region mr(fContext.GetOfiDomain(), buffer, asiofi::mr::access::send); + asiofi::memory_region mr(*fOfiDomain, buffer, asiofi::mr::access::send); auto desc = mr.desc(); fDataEndpoint->send(buffer, @@ -445,7 +474,7 @@ auto Socket::OnRecvControl(ofi::unique_ptr ctrl) -> void boost::asio::mutable_buffer buffer(msg->GetData(), size); if (fNeedOfiMemoryRegistration) { - asiofi::memory_region mr(fContext.GetOfiDomain(), buffer, asiofi::mr::access::recv); + asiofi::memory_region mr(*fOfiDomain, buffer, asiofi::mr::access::recv); auto desc = mr.desc(); fDataEndpoint->recv( diff --git a/fairmq/ofi/Socket.h b/fairmq/ofi/Socket.h index 01f22280..549d5f50 100644 --- a/fairmq/ofi/Socket.h +++ b/fairmq/ofi/Socket.h @@ -82,6 +82,9 @@ class Socket final : public fair::mq::Socket private: Context& fContext; + std::unique_ptr fOfiInfo; + std::unique_ptr fOfiFabric; + std::unique_ptr fOfiDomain; std::unique_ptr fPassiveEndpoint; std::unique_ptr fDataEndpoint, fControlEndpoint; std::string fId; @@ -89,8 +92,8 @@ class Socket final : public fair::mq::Socket std::atomic fBytesRx; std::atomic fMessagesTx; std::atomic fMessagesRx; - Context::Address fRemoteAddr; - Context::Address fLocalAddr; + Address fRemoteAddr; + Address fLocalAddr; boost::asio::io_service::strand fIoStrand; int fSndTimeout; int fRcvTimeout; @@ -111,6 +114,7 @@ class Socket final : public fair::mq::Socket // auto WaitForControlPeer() -> void; // auto AnnounceDataAddress() -> void; + auto InitOfi(Address addr) -> void; auto BindControlEndpoint() -> void; auto BindDataEndpoint() -> void; auto ConnectControlEndpoint() -> void;