Do not share ofi context across sockets

This commit is contained in:
Dennis Klein 2018-11-26 02:07:48 +01:00 committed by Dennis Klein
parent b31ab1cc48
commit a08a34acd5
4 changed files with 64 additions and 52 deletions

View File

@ -34,10 +34,7 @@ using namespace std;
Context::Context(FairMQTransportFactory& sendFactory,
FairMQTransportFactory& receiveFactory,
int numberIoThreads)
: fOfiInfo(nullptr)
, fOfiFabric(nullptr)
, fOfiDomain(nullptr)
, fIoWork(fIoContext)
: fIoWork(fIoContext)
, fReceiveFactory(receiveFactory)
, fSendFactory(sendFactory)
{
@ -69,27 +66,6 @@ auto Context::GetAsiofiVersion() const -> string
return ASIOFI_VERSION;
}
auto Context::InitOfi(Address addr) -> void
{
if (!fOfiInfo) {
assert(!fOfiFabric);
assert(!fOfiDomain);
asiofi::hints hints;
if (addr.Protocol == "tcp") {
hints.set_provider("sockets");
} else if (addr.Protocol == "verbs") {
hints.set_provider("verbs");
}
fOfiInfo = tools::make_unique<asiofi::info>(addr.Ip.c_str(), std::to_string(addr.Port).c_str(), 0, hints);
// LOG(debug) << "OFI transport: " << *fOfiInfo;
fOfiFabric = tools::make_unique<asiofi::fabric>(*fOfiInfo);
fOfiDomain = tools::make_unique<asiofi::domain>(*fOfiFabric);
}
}
auto Context::ConvertAddress(std::string address) -> Address
{
string protocol, ip;

View File

@ -31,6 +31,22 @@ namespace mq
namespace ofi
{
enum class ConnectionType : bool { Bind, Connect };
struct Address {
std::string Protocol;
std::string Ip;
unsigned int Port;
friend auto operator<<(std::ostream& os, const Address& a) -> std::ostream&
{
return os << a.Protocol << "://" << a.Ip << ":" << a.Port;
}
friend auto operator==(const Address& lhs, const Address& rhs) -> bool
{
return (lhs.Protocol == rhs.Protocol) && (lhs.Ip == rhs.Ip) && (lhs.Port == rhs.Port);
}
};
/**
* @class Context Context.h <fairmq/ofi/Context.h>
* @brief Transport-wide context
@ -47,16 +63,6 @@ class Context
auto GetAsiofiVersion() const -> std::string;
auto GetIoContext() -> boost::asio::io_context& { return fIoContext; }
struct Address {
std::string Protocol;
std::string Ip;
unsigned int Port;
friend auto operator<<(std::ostream& os, const Address& a) -> std::ostream& { return os << a.Protocol << "://" << a.Ip << ":" << a.Port; }
};
auto InitOfi(Address address) -> void;
auto GetOfiInfo() const -> const asiofi::info& { return *fOfiInfo; }
auto GetOfiFabric() const -> const asiofi::fabric& { return *fOfiFabric; }
auto GetOfiDomain() const -> const asiofi::domain& { return *fOfiDomain; }
static auto ConvertAddress(std::string address) -> Address;
static auto ConvertAddress(Address address) -> sockaddr_in;
static auto ConvertAddress(sockaddr_in address) -> Address;
@ -67,9 +73,6 @@ class Context
auto MakeSendMessage(size_t size) -> MessagePtr;
private:
std::unique_ptr<asiofi::info> fOfiInfo;
std::unique_ptr<asiofi::fabric> fOfiFabric;
std::unique_ptr<asiofi::domain> fOfiDomain;
boost::asio::io_context fIoContext;
boost::asio::io_context::work fIoWork;
std::vector<std::thread> fThreadPool;

View File

@ -38,6 +38,9 @@ using namespace std;
Socket::Socket(Context& context, const string& type, const string& name, const string& id /*= ""*/)
: fContext(context)
, fOfiInfo(nullptr)
, fOfiFabric(nullptr)
, fOfiDomain(nullptr)
, fPassiveEndpoint(nullptr)
, fDataEndpoint(nullptr)
, fControlEndpoint(nullptr)
@ -87,6 +90,32 @@ Socket::Socket(Context& context, const string& type, const string& name, const s
}
}
auto Socket::InitOfi(Address addr) -> void
{
if (!fOfiInfo) {
assert(!fOfiFabric);
assert(!fOfiDomain);
asiofi::hints hints;
if (addr.Protocol == "tcp") {
hints.set_provider("sockets");
} else if (addr.Protocol == "verbs") {
hints.set_provider("verbs");
}
if (fRemoteAddr == addr) {
fOfiInfo = tools::make_unique<asiofi::info>(addr.Ip.c_str(), std::to_string(addr.Port).c_str(), 0, hints);
} else {
fOfiInfo = tools::make_unique<asiofi::info>(addr.Ip.c_str(), std::to_string(addr.Port).c_str(), FI_SOURCE, hints);
}
// LOG(debug) << "OFI transport: " << *fOfiInfo;
fOfiFabric = tools::make_unique<asiofi::fabric>(*fOfiInfo);
fOfiDomain = tools::make_unique<asiofi::domain>(*fOfiFabric);
}
}
auto Socket::Bind(const string& addr) -> bool
try {
fLocalAddr = Context::VerifyAddress(addr);
@ -94,15 +123,13 @@ try {
fNeedOfiMemoryRegistration = true;
}
fContext.InitOfi(fLocalAddr);
InitOfi(fLocalAddr);
fPassiveEndpoint = tools::make_unique<asiofi::passive_endpoint>(fIoStrand.context(), fContext.GetOfiFabric());
fPassiveEndpoint = tools::make_unique<asiofi::passive_endpoint>(fIoStrand.context(), *fOfiFabric);
fPassiveEndpoint->set_local_address(Context::ConvertAddress(fLocalAddr));
BindControlEndpoint();
BindDataEndpoint();
return true;
}
// TODO catch the correct ofi error
@ -126,10 +153,12 @@ auto Socket::BindControlEndpoint() -> void
LOG(debug) << "OFI transport (" << fId
<< "): control band connection request received. Accepting ...";
fControlEndpoint = tools::make_unique<asiofi::connected_endpoint>(
fIoStrand.context(), fContext.GetOfiDomain(), info);
fIoStrand.context(), *fOfiDomain, info);
fControlEndpoint->enable();
fControlEndpoint->accept([&]() {
LOG(debug) << "OFI transport (" << fId << "): control band connection accepted.";
BindDataEndpoint();
});
});
@ -144,7 +173,7 @@ auto Socket::BindDataEndpoint() -> void
LOG(debug) << "OFI transport (" << fId
<< "): data band connection request received. Accepting ...";
fDataEndpoint = tools::make_unique<asiofi::connected_endpoint>(
fIoStrand.context(), fContext.GetOfiDomain(), info);
fIoStrand.context(), *fOfiDomain, info);
fDataEndpoint->enable();
fDataEndpoint->accept([&]() {
LOG(debug) << "OFI transport (" << fId << "): data band connection accepted.";
@ -164,7 +193,7 @@ auto Socket::Connect(const string& address) -> void
fNeedOfiMemoryRegistration = true;
}
fContext.InitOfi(fRemoteAddr);
InitOfi(fRemoteAddr);
ConnectControlEndpoint();
@ -183,7 +212,7 @@ auto Socket::ConnectControlEndpoint() -> void
bool completed(false);
fControlEndpoint =
tools::make_unique<asiofi::connected_endpoint>(fIoStrand.context(), fContext.GetOfiDomain());
tools::make_unique<asiofi::connected_endpoint>(fIoStrand.context(), *fOfiDomain);
fControlEndpoint->enable();
fControlEndpoint->connect(Context::ConvertAddress(fRemoteAddr), [&]() {
@ -213,7 +242,7 @@ auto Socket::ConnectDataEndpoint() -> void
bool completed(false);
fDataEndpoint =
tools::make_unique<asiofi::connected_endpoint>(fIoStrand.context(), fContext.GetOfiDomain());
tools::make_unique<asiofi::connected_endpoint>(fIoStrand.context(), *fOfiDomain);
fDataEndpoint->enable();
fDataEndpoint->connect(Context::ConvertAddress(fRemoteAddr), [&]() {
@ -353,7 +382,7 @@ auto Socket::OnSend(azmq::message& zmsg, size_t /*bytes_transferred*/) -> void
ctrl->size = size;
auto ctrl_msg = boost::asio::mutable_buffer(ctrl.get(), sizeof(PostBuffer));
if (fNeedOfiMemoryRegistration) {
asiofi::memory_region mr(fContext.GetOfiDomain(), ctrl_msg, asiofi::mr::access::send);
asiofi::memory_region mr(*fOfiDomain, ctrl_msg, asiofi::mr::access::send);
auto desc = mr.desc();
fControlEndpoint->send(
ctrl_msg, desc, [&, ctrl2 = std::move(ctrl)](boost::asio::mutable_buffer) mutable {
@ -371,7 +400,7 @@ auto Socket::OnSend(azmq::message& zmsg, size_t /*bytes_transferred*/) -> void
boost::asio::mutable_buffer buffer(msg->GetData(), size);
if (fNeedOfiMemoryRegistration) {
asiofi::memory_region mr(fContext.GetOfiDomain(), buffer, asiofi::mr::access::send);
asiofi::memory_region mr(*fOfiDomain, buffer, asiofi::mr::access::send);
auto desc = mr.desc();
fDataEndpoint->send(buffer,
@ -445,7 +474,7 @@ auto Socket::OnRecvControl(ofi::unique_ptr<PostBuffer> ctrl) -> void
boost::asio::mutable_buffer buffer(msg->GetData(), size);
if (fNeedOfiMemoryRegistration) {
asiofi::memory_region mr(fContext.GetOfiDomain(), buffer, asiofi::mr::access::recv);
asiofi::memory_region mr(*fOfiDomain, buffer, asiofi::mr::access::recv);
auto desc = mr.desc();
fDataEndpoint->recv(

View File

@ -82,6 +82,9 @@ class Socket final : public fair::mq::Socket
private:
Context& fContext;
std::unique_ptr<asiofi::info> fOfiInfo;
std::unique_ptr<asiofi::fabric> fOfiFabric;
std::unique_ptr<asiofi::domain> fOfiDomain;
std::unique_ptr<asiofi::passive_endpoint> fPassiveEndpoint;
std::unique_ptr<asiofi::connected_endpoint> fDataEndpoint, fControlEndpoint;
std::string fId;
@ -89,8 +92,8 @@ class Socket final : public fair::mq::Socket
std::atomic<unsigned long> fBytesRx;
std::atomic<unsigned long> fMessagesTx;
std::atomic<unsigned long> fMessagesRx;
Context::Address fRemoteAddr;
Context::Address fLocalAddr;
Address fRemoteAddr;
Address fLocalAddr;
boost::asio::io_service::strand fIoStrand;
int fSndTimeout;
int fRcvTimeout;
@ -111,6 +114,7 @@ class Socket final : public fair::mq::Socket
// auto WaitForControlPeer() -> void;
// auto AnnounceDataAddress() -> void;
auto InitOfi(Address addr) -> void;
auto BindControlEndpoint() -> void;
auto BindDataEndpoint() -> void;
auto ConnectControlEndpoint() -> void;