Do not share ofi context across sockets

This commit is contained in:
Dennis Klein 2018-11-26 02:07:48 +01:00 committed by Dennis Klein
parent b31ab1cc48
commit a08a34acd5
4 changed files with 64 additions and 52 deletions

View File

@ -34,10 +34,7 @@ using namespace std;
Context::Context(FairMQTransportFactory& sendFactory, Context::Context(FairMQTransportFactory& sendFactory,
FairMQTransportFactory& receiveFactory, FairMQTransportFactory& receiveFactory,
int numberIoThreads) int numberIoThreads)
: fOfiInfo(nullptr) : fIoWork(fIoContext)
, fOfiFabric(nullptr)
, fOfiDomain(nullptr)
, fIoWork(fIoContext)
, fReceiveFactory(receiveFactory) , fReceiveFactory(receiveFactory)
, fSendFactory(sendFactory) , fSendFactory(sendFactory)
{ {
@ -69,27 +66,6 @@ auto Context::GetAsiofiVersion() const -> string
return ASIOFI_VERSION; return ASIOFI_VERSION;
} }
auto Context::InitOfi(Address addr) -> void
{
if (!fOfiInfo) {
assert(!fOfiFabric);
assert(!fOfiDomain);
asiofi::hints hints;
if (addr.Protocol == "tcp") {
hints.set_provider("sockets");
} else if (addr.Protocol == "verbs") {
hints.set_provider("verbs");
}
fOfiInfo = tools::make_unique<asiofi::info>(addr.Ip.c_str(), std::to_string(addr.Port).c_str(), 0, hints);
// LOG(debug) << "OFI transport: " << *fOfiInfo;
fOfiFabric = tools::make_unique<asiofi::fabric>(*fOfiInfo);
fOfiDomain = tools::make_unique<asiofi::domain>(*fOfiFabric);
}
}
auto Context::ConvertAddress(std::string address) -> Address auto Context::ConvertAddress(std::string address) -> Address
{ {
string protocol, ip; string protocol, ip;

View File

@ -31,6 +31,22 @@ namespace mq
namespace ofi namespace ofi
{ {
enum class ConnectionType : bool { Bind, Connect };
struct Address {
std::string Protocol;
std::string Ip;
unsigned int Port;
friend auto operator<<(std::ostream& os, const Address& a) -> std::ostream&
{
return os << a.Protocol << "://" << a.Ip << ":" << a.Port;
}
friend auto operator==(const Address& lhs, const Address& rhs) -> bool
{
return (lhs.Protocol == rhs.Protocol) && (lhs.Ip == rhs.Ip) && (lhs.Port == rhs.Port);
}
};
/** /**
* @class Context Context.h <fairmq/ofi/Context.h> * @class Context Context.h <fairmq/ofi/Context.h>
* @brief Transport-wide context * @brief Transport-wide context
@ -47,16 +63,6 @@ class Context
auto GetAsiofiVersion() const -> std::string; auto GetAsiofiVersion() const -> std::string;
auto GetIoContext() -> boost::asio::io_context& { return fIoContext; } auto GetIoContext() -> boost::asio::io_context& { return fIoContext; }
struct Address {
std::string Protocol;
std::string Ip;
unsigned int Port;
friend auto operator<<(std::ostream& os, const Address& a) -> std::ostream& { return os << a.Protocol << "://" << a.Ip << ":" << a.Port; }
};
auto InitOfi(Address address) -> void;
auto GetOfiInfo() const -> const asiofi::info& { return *fOfiInfo; }
auto GetOfiFabric() const -> const asiofi::fabric& { return *fOfiFabric; }
auto GetOfiDomain() const -> const asiofi::domain& { return *fOfiDomain; }
static auto ConvertAddress(std::string address) -> Address; static auto ConvertAddress(std::string address) -> Address;
static auto ConvertAddress(Address address) -> sockaddr_in; static auto ConvertAddress(Address address) -> sockaddr_in;
static auto ConvertAddress(sockaddr_in address) -> Address; static auto ConvertAddress(sockaddr_in address) -> Address;
@ -67,9 +73,6 @@ class Context
auto MakeSendMessage(size_t size) -> MessagePtr; auto MakeSendMessage(size_t size) -> MessagePtr;
private: private:
std::unique_ptr<asiofi::info> fOfiInfo;
std::unique_ptr<asiofi::fabric> fOfiFabric;
std::unique_ptr<asiofi::domain> fOfiDomain;
boost::asio::io_context fIoContext; boost::asio::io_context fIoContext;
boost::asio::io_context::work fIoWork; boost::asio::io_context::work fIoWork;
std::vector<std::thread> fThreadPool; std::vector<std::thread> fThreadPool;

View File

@ -38,6 +38,9 @@ using namespace std;
Socket::Socket(Context& context, const string& type, const string& name, const string& id /*= ""*/) Socket::Socket(Context& context, const string& type, const string& name, const string& id /*= ""*/)
: fContext(context) : fContext(context)
, fOfiInfo(nullptr)
, fOfiFabric(nullptr)
, fOfiDomain(nullptr)
, fPassiveEndpoint(nullptr) , fPassiveEndpoint(nullptr)
, fDataEndpoint(nullptr) , fDataEndpoint(nullptr)
, fControlEndpoint(nullptr) , fControlEndpoint(nullptr)
@ -87,6 +90,32 @@ Socket::Socket(Context& context, const string& type, const string& name, const s
} }
} }
auto Socket::InitOfi(Address addr) -> void
{
if (!fOfiInfo) {
assert(!fOfiFabric);
assert(!fOfiDomain);
asiofi::hints hints;
if (addr.Protocol == "tcp") {
hints.set_provider("sockets");
} else if (addr.Protocol == "verbs") {
hints.set_provider("verbs");
}
if (fRemoteAddr == addr) {
fOfiInfo = tools::make_unique<asiofi::info>(addr.Ip.c_str(), std::to_string(addr.Port).c_str(), 0, hints);
} else {
fOfiInfo = tools::make_unique<asiofi::info>(addr.Ip.c_str(), std::to_string(addr.Port).c_str(), FI_SOURCE, hints);
}
// LOG(debug) << "OFI transport: " << *fOfiInfo;
fOfiFabric = tools::make_unique<asiofi::fabric>(*fOfiInfo);
fOfiDomain = tools::make_unique<asiofi::domain>(*fOfiFabric);
}
}
auto Socket::Bind(const string& addr) -> bool auto Socket::Bind(const string& addr) -> bool
try { try {
fLocalAddr = Context::VerifyAddress(addr); fLocalAddr = Context::VerifyAddress(addr);
@ -94,15 +123,13 @@ try {
fNeedOfiMemoryRegistration = true; fNeedOfiMemoryRegistration = true;
} }
fContext.InitOfi(fLocalAddr); InitOfi(fLocalAddr);
fPassiveEndpoint = tools::make_unique<asiofi::passive_endpoint>(fIoStrand.context(), fContext.GetOfiFabric()); fPassiveEndpoint = tools::make_unique<asiofi::passive_endpoint>(fIoStrand.context(), *fOfiFabric);
fPassiveEndpoint->set_local_address(Context::ConvertAddress(fLocalAddr)); fPassiveEndpoint->set_local_address(Context::ConvertAddress(fLocalAddr));
BindControlEndpoint(); BindControlEndpoint();
BindDataEndpoint();
return true; return true;
} }
// TODO catch the correct ofi error // TODO catch the correct ofi error
@ -126,10 +153,12 @@ auto Socket::BindControlEndpoint() -> void
LOG(debug) << "OFI transport (" << fId LOG(debug) << "OFI transport (" << fId
<< "): control band connection request received. Accepting ..."; << "): control band connection request received. Accepting ...";
fControlEndpoint = tools::make_unique<asiofi::connected_endpoint>( fControlEndpoint = tools::make_unique<asiofi::connected_endpoint>(
fIoStrand.context(), fContext.GetOfiDomain(), info); fIoStrand.context(), *fOfiDomain, info);
fControlEndpoint->enable(); fControlEndpoint->enable();
fControlEndpoint->accept([&]() { fControlEndpoint->accept([&]() {
LOG(debug) << "OFI transport (" << fId << "): control band connection accepted."; LOG(debug) << "OFI transport (" << fId << "): control band connection accepted.";
BindDataEndpoint();
}); });
}); });
@ -144,7 +173,7 @@ auto Socket::BindDataEndpoint() -> void
LOG(debug) << "OFI transport (" << fId LOG(debug) << "OFI transport (" << fId
<< "): data band connection request received. Accepting ..."; << "): data band connection request received. Accepting ...";
fDataEndpoint = tools::make_unique<asiofi::connected_endpoint>( fDataEndpoint = tools::make_unique<asiofi::connected_endpoint>(
fIoStrand.context(), fContext.GetOfiDomain(), info); fIoStrand.context(), *fOfiDomain, info);
fDataEndpoint->enable(); fDataEndpoint->enable();
fDataEndpoint->accept([&]() { fDataEndpoint->accept([&]() {
LOG(debug) << "OFI transport (" << fId << "): data band connection accepted."; LOG(debug) << "OFI transport (" << fId << "): data band connection accepted.";
@ -164,7 +193,7 @@ auto Socket::Connect(const string& address) -> void
fNeedOfiMemoryRegistration = true; fNeedOfiMemoryRegistration = true;
} }
fContext.InitOfi(fRemoteAddr); InitOfi(fRemoteAddr);
ConnectControlEndpoint(); ConnectControlEndpoint();
@ -183,7 +212,7 @@ auto Socket::ConnectControlEndpoint() -> void
bool completed(false); bool completed(false);
fControlEndpoint = fControlEndpoint =
tools::make_unique<asiofi::connected_endpoint>(fIoStrand.context(), fContext.GetOfiDomain()); tools::make_unique<asiofi::connected_endpoint>(fIoStrand.context(), *fOfiDomain);
fControlEndpoint->enable(); fControlEndpoint->enable();
fControlEndpoint->connect(Context::ConvertAddress(fRemoteAddr), [&]() { fControlEndpoint->connect(Context::ConvertAddress(fRemoteAddr), [&]() {
@ -213,7 +242,7 @@ auto Socket::ConnectDataEndpoint() -> void
bool completed(false); bool completed(false);
fDataEndpoint = fDataEndpoint =
tools::make_unique<asiofi::connected_endpoint>(fIoStrand.context(), fContext.GetOfiDomain()); tools::make_unique<asiofi::connected_endpoint>(fIoStrand.context(), *fOfiDomain);
fDataEndpoint->enable(); fDataEndpoint->enable();
fDataEndpoint->connect(Context::ConvertAddress(fRemoteAddr), [&]() { fDataEndpoint->connect(Context::ConvertAddress(fRemoteAddr), [&]() {
@ -353,7 +382,7 @@ auto Socket::OnSend(azmq::message& zmsg, size_t /*bytes_transferred*/) -> void
ctrl->size = size; ctrl->size = size;
auto ctrl_msg = boost::asio::mutable_buffer(ctrl.get(), sizeof(PostBuffer)); auto ctrl_msg = boost::asio::mutable_buffer(ctrl.get(), sizeof(PostBuffer));
if (fNeedOfiMemoryRegistration) { if (fNeedOfiMemoryRegistration) {
asiofi::memory_region mr(fContext.GetOfiDomain(), ctrl_msg, asiofi::mr::access::send); asiofi::memory_region mr(*fOfiDomain, ctrl_msg, asiofi::mr::access::send);
auto desc = mr.desc(); auto desc = mr.desc();
fControlEndpoint->send( fControlEndpoint->send(
ctrl_msg, desc, [&, ctrl2 = std::move(ctrl)](boost::asio::mutable_buffer) mutable { ctrl_msg, desc, [&, ctrl2 = std::move(ctrl)](boost::asio::mutable_buffer) mutable {
@ -371,7 +400,7 @@ auto Socket::OnSend(azmq::message& zmsg, size_t /*bytes_transferred*/) -> void
boost::asio::mutable_buffer buffer(msg->GetData(), size); boost::asio::mutable_buffer buffer(msg->GetData(), size);
if (fNeedOfiMemoryRegistration) { if (fNeedOfiMemoryRegistration) {
asiofi::memory_region mr(fContext.GetOfiDomain(), buffer, asiofi::mr::access::send); asiofi::memory_region mr(*fOfiDomain, buffer, asiofi::mr::access::send);
auto desc = mr.desc(); auto desc = mr.desc();
fDataEndpoint->send(buffer, fDataEndpoint->send(buffer,
@ -445,7 +474,7 @@ auto Socket::OnRecvControl(ofi::unique_ptr<PostBuffer> ctrl) -> void
boost::asio::mutable_buffer buffer(msg->GetData(), size); boost::asio::mutable_buffer buffer(msg->GetData(), size);
if (fNeedOfiMemoryRegistration) { if (fNeedOfiMemoryRegistration) {
asiofi::memory_region mr(fContext.GetOfiDomain(), buffer, asiofi::mr::access::recv); asiofi::memory_region mr(*fOfiDomain, buffer, asiofi::mr::access::recv);
auto desc = mr.desc(); auto desc = mr.desc();
fDataEndpoint->recv( fDataEndpoint->recv(

View File

@ -82,6 +82,9 @@ class Socket final : public fair::mq::Socket
private: private:
Context& fContext; Context& fContext;
std::unique_ptr<asiofi::info> fOfiInfo;
std::unique_ptr<asiofi::fabric> fOfiFabric;
std::unique_ptr<asiofi::domain> fOfiDomain;
std::unique_ptr<asiofi::passive_endpoint> fPassiveEndpoint; std::unique_ptr<asiofi::passive_endpoint> fPassiveEndpoint;
std::unique_ptr<asiofi::connected_endpoint> fDataEndpoint, fControlEndpoint; std::unique_ptr<asiofi::connected_endpoint> fDataEndpoint, fControlEndpoint;
std::string fId; std::string fId;
@ -89,8 +92,8 @@ class Socket final : public fair::mq::Socket
std::atomic<unsigned long> fBytesRx; std::atomic<unsigned long> fBytesRx;
std::atomic<unsigned long> fMessagesTx; std::atomic<unsigned long> fMessagesTx;
std::atomic<unsigned long> fMessagesRx; std::atomic<unsigned long> fMessagesRx;
Context::Address fRemoteAddr; Address fRemoteAddr;
Context::Address fLocalAddr; Address fLocalAddr;
boost::asio::io_service::strand fIoStrand; boost::asio::io_service::strand fIoStrand;
int fSndTimeout; int fSndTimeout;
int fRcvTimeout; int fRcvTimeout;
@ -111,6 +114,7 @@ class Socket final : public fair::mq::Socket
// auto WaitForControlPeer() -> void; // auto WaitForControlPeer() -> void;
// auto AnnounceDataAddress() -> void; // auto AnnounceDataAddress() -> void;
auto InitOfi(Address addr) -> void;
auto BindControlEndpoint() -> void; auto BindControlEndpoint() -> void;
auto BindDataEndpoint() -> void; auto BindDataEndpoint() -> void;
auto ConnectControlEndpoint() -> void; auto ConnectControlEndpoint() -> void;