FairMQ: Implement ofi address exchange

Control messages are encoded with protobuf.
This commit is contained in:
Dennis Klein 2018-03-05 23:58:31 +01:00 committed by Mohammad Al-Turany
parent df5d5d4086
commit 5b3a5b9709
7 changed files with 323 additions and 117 deletions

View File

@ -199,10 +199,28 @@ configure_file(${CMAKE_SOURCE_DIR}/fairmq/options/startConfigExample.sh.in
${CMAKE_BINARY_DIR}/bin/startConfigExample.sh) ${CMAKE_BINARY_DIR}/bin/startConfigExample.sh)
########################
# compile protobuffers #
########################
add_custom_target(mkofibuilddir COMMAND ${CMAKE_COMMAND} -E make_directory ${CMAKE_CURRENT_BINARY_DIR}/ofi)
add_custom_command(
OUTPUT
${CMAKE_CURRENT_BINARY_DIR}/ofi/Control.pb.h
${CMAKE_CURRENT_BINARY_DIR}/ofi/Control.pb.cc
COMMAND ${PROTOBUF_PROTOC_EXECUTABLE} -I=${CMAKE_CURRENT_SOURCE_DIR}/ofi --cpp_out=${CMAKE_CURRENT_BINARY_DIR}/ofi Control.proto
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
DEPENDS mkofibuilddir
)
set_source_files_properties(${CMAKE_CURRENT_BINARY_DIR}/ofi/Control.pb.h PROPERTIES GENERATED TRUE)
set_source_files_properties(${CMAKE_CURRENT_BINARY_DIR}/ofi/Control.pb.cc PROPERTIES GENERATED TRUE)
################################# #################################
# define libFairMQ build target # # define libFairMQ build target #
################################# #################################
add_library(FairMQ SHARED add_library(FairMQ SHARED
${CMAKE_CURRENT_BINARY_DIR}/ofi/Control.pb.h
${CMAKE_CURRENT_BINARY_DIR}/ofi/Control.pb.cc
${FAIRMQ_SOURCE_FILES} ${FAIRMQ_SOURCE_FILES}
${FAIRMQ_HEADER_FILES} # for IDE integration ${FAIRMQ_HEADER_FILES} # for IDE integration
) )
@ -216,6 +234,7 @@ target_include_directories(FairMQ
$<BUILD_INTERFACE:${CMAKE_SOURCE_DIR}/logger> $<BUILD_INTERFACE:${CMAKE_SOURCE_DIR}/logger>
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}> $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}>
$<BUILD_INTERFACE:${CMAKE_SOURCE_DIR}> $<BUILD_INTERFACE:${CMAKE_SOURCE_DIR}>
$<BUILD_INTERFACE:${CMAKE_BINARY_DIR}>
$<INSTALL_INTERFACE:include/fairmq> $<INSTALL_INTERFACE:include/fairmq>
$<INSTALL_INTERFACE:include> $<INSTALL_INTERFACE:include>
) )
@ -243,6 +262,7 @@ target_link_libraries(FairMQ
PRIVATE # only libFairMQ links against private dependencies PRIVATE # only libFairMQ links against private dependencies
ZeroMQ ZeroMQ
OFI::libfabric OFI::libfabric
protobuf::libprotobuf
Msgpack Msgpack
$<$<BOOL:${NANOMSG_FOUND}>:nanomsg> $<$<BOOL:${NANOMSG_FOUND}>:nanomsg>
) )

View File

@ -12,6 +12,7 @@
#include <arpa/inet.h> #include <arpa/inet.h>
#include <cstring> #include <cstring>
#include <google/protobuf/stubs/common.h>
#include <memory> #include <memory>
#include <netinet/in.h> #include <netinet/in.h>
#include <rdma/fabric.h> #include <rdma/fabric.h>
@ -36,10 +37,15 @@ using namespace std;
Context::Context(int numberIoThreads) Context::Context(int numberIoThreads)
: fOfiDomain(nullptr) : fOfiDomain(nullptr)
, fOfiFabric(nullptr) , fOfiFabric(nullptr)
, fOfiInfo(nullptr)
, fOfiAddressVector(nullptr)
, fOfiEventQueue(nullptr)
, fZmqContext(zmq_ctx_new()) , fZmqContext(zmq_ctx_new())
{ {
if (!fZmqContext) if (!fZmqContext)
throw ContextError{tools::ToString("Failed creating zmq context, reason: ", zmq_strerror(errno))}; throw ContextError{tools::ToString("Failed creating zmq context, reason: ", zmq_strerror(errno))};
GOOGLE_PROTOBUF_VERIFY_VERSION;
} }
Context::~Context() Context::~Context()
@ -48,6 +54,12 @@ Context::~Context()
LOG(error) << "Failed closing zmq context, reason: " << zmq_strerror(errno); LOG(error) << "Failed closing zmq context, reason: " << zmq_strerror(errno);
} }
if (fOfiEventQueue) {
auto ret = fi_close(&fOfiEventQueue->fid);
if (ret != FI_SUCCESS)
LOG(error) << "Failed closing ofi event queue, reason: " << fi_strerror(ret);
}
if (fOfiAddressVector) { if (fOfiAddressVector) {
auto ret = fi_close(&fOfiAddressVector->fid); auto ret = fi_close(&fOfiAddressVector->fid);
if (ret != FI_SUCCESS) if (ret != FI_SUCCESS)
@ -67,19 +79,24 @@ Context::~Context()
} }
} }
auto Context::GetZmqVersion() const -> std::string auto Context::GetZmqVersion() const -> string
{ {
int major, minor, patch; int major, minor, patch;
zmq_version(&major, &minor, &patch); zmq_version(&major, &minor, &patch);
return tools::ToString(major, ".", minor, ".", patch); return tools::ToString(major, ".", minor, ".", patch);
} }
auto Context::GetOfiApiVersion() const -> std::string auto Context::GetOfiApiVersion() const -> string
{ {
auto ofi_version{fi_version()}; auto ofi_version{fi_version()};
return tools::ToString(FI_MAJOR(ofi_version), ".", FI_MINOR(ofi_version)); return tools::ToString(FI_MAJOR(ofi_version), ".", FI_MINOR(ofi_version));
} }
auto Context::GetPbVersion() const -> string
{
return google::protobuf::internal::VersionString(GOOGLE_PROTOBUF_VERSION);
}
auto Context::InitOfi(ConnectionType type, std::string addr) -> void auto Context::InitOfi(ConnectionType type, std::string addr) -> void
{ {
auto addr2 = ConvertAddress(addr); auto addr2 = ConvertAddress(addr);
@ -93,7 +110,7 @@ auto Context::InitOfi(ConnectionType type, std::string addr) -> void
// Prepare fi_getinfo query // Prepare fi_getinfo query
unique_ptr<fi_info, void(*)(fi_info*)> ofi_hints(fi_allocinfo(), fi_freeinfo); unique_ptr<fi_info, void(*)(fi_info*)> ofi_hints(fi_allocinfo(), fi_freeinfo);
ofi_hints->caps = FI_MSG | FI_SOURCE; ofi_hints->caps = FI_MSG | FI_RMA;
ofi_hints->mode = FI_ASYNC_IOV; ofi_hints->mode = FI_ASYNC_IOV;
ofi_hints->addr_format = FI_SOCKADDR_IN; ofi_hints->addr_format = FI_SOCKADDR_IN;
ofi_hints->fabric_attr->prov_name = strdup("sockets"); ofi_hints->fabric_attr->prov_name = strdup("sockets");
@ -105,17 +122,17 @@ auto Context::InitOfi(ConnectionType type, std::string addr) -> void
// ofi_hints->src_addr = sa; // ofi_hints->src_addr = sa;
// ofi_hints->src_addrlen = sizeof(sockaddr_in); // ofi_hints->src_addrlen = sizeof(sockaddr_in);
// } else { // } else {
ofi_hints->dest_addr = sa; // ofi_hints->dest_addr = sa;
ofi_hints->dest_addrlen = sizeof(sockaddr_in); // ofi_hints->dest_addrlen = sizeof(sockaddr_in);
// } // }
// Query fi_getinfo for fabric to use // Query fi_getinfo for fabric to use
auto res = fi_getinfo(FI_VERSION(1, 5), nullptr, nullptr, 0, ofi_hints.get(), &fOfiInfo); auto res = fi_getinfo(FI_VERSION(1, 5), strdup(addr2.Ip.c_str()), 0, 0, ofi_hints.get(), &fOfiInfo);
if (res != 0) throw ContextError{tools::ToString("Failed querying fi_getinfo, reason: ", fi_strerror(res))}; if (res != 0) throw ContextError{tools::ToString("Failed querying fi_getinfo, reason: ", fi_strerror(res))};
if (!fOfiInfo) throw ContextError{"Could not find any ofi compatible fabric."}; if (!fOfiInfo) throw ContextError{"Could not find any ofi compatible fabric."};
// for(auto cursor{ofi_info}; cursor->next != nullptr; cursor = cursor->next) { // for(auto cursor{ofi_info}; cursor->next != nullptr; cursor = cursor->next) {
LOG(debug) << fi_tostr(fOfiInfo, FI_TYPE_INFO); // LOG(debug) << fi_tostr(fOfiInfo, FI_TYPE_INFO);
// } // }
// //
} else { } else {
@ -123,6 +140,7 @@ auto Context::InitOfi(ConnectionType type, std::string addr) -> void
} }
OpenOfiFabric(); OpenOfiFabric();
OpenOfiEventQueue();
OpenOfiDomain(); OpenOfiDomain();
OpenOfiAddressVector(); OpenOfiAddressVector();
} }
@ -154,23 +172,39 @@ auto Context::OpenOfiDomain() -> void
} }
} }
auto Context::OpenOfiEventQueue() -> void
{
fi_eq_attr eqAttr = {100, 0, FI_WAIT_UNSPEC, 0, nullptr};
// size_t size; [> # entries for EQ <]
// uint64_t flags; [> operation flags <]
// enum fi_wait_obj wait_obj; [> requested wait object <]
// int signaling_vector; [> interrupt affinity <]
// struct fid_wait *wait_set; [> optional wait set <]
auto ret = fi_eq_open(fOfiFabric, &eqAttr, &fOfiEventQueue, nullptr);
if (ret != FI_SUCCESS)
throw ContextError{tools::ToString("Failed opening ofi event queue, reason: ", fi_strerror(ret))};
}
auto Context::OpenOfiAddressVector() -> void auto Context::OpenOfiAddressVector() -> void
{ {
if (!fOfiAddressVector) { if (!fOfiAddressVector) {
assert(fOfiDomain); assert(fOfiDomain);
fi_av_attr attr = {fOfiInfo->domain_attr->av_type, 0, 1000, 0, nullptr, nullptr, 0}; fi_av_attr attr = {fOfiInfo->domain_attr->av_type, 0, 1000, 0, nullptr, nullptr, 0};
// struct fi_av_attr { // enum fi_av_type type; [> type of AV <]
// enum fi_av_type type; [> type of AV <] // int rx_ctx_bits; [> address bits to identify rx ctx <]
// int rx_ctx_bits; [> address bits to identify rx ctx <] // size_t count; [> # entries for AV <]
// size_t count; [> # entries for AV <] // size_t ep_per_node; [> # endpoints per fabric address <]
// size_t ep_per_node; [> # endpoints per fabric address <] // const char *name; [> system name of AV <]
// const char *name; [> system name of AV <] // void *map_addr; [> base mmap address <]
// void *map_addr; [> base mmap address <] // uint64_t flags; [> operation flags <]
// uint64_t flags; [> operation flags <]
// };
auto ret = fi_av_open(fOfiDomain, &attr, &fOfiAddressVector, nullptr); auto ret = fi_av_open(fOfiDomain, &attr, &fOfiAddressVector, nullptr);
if (ret != FI_SUCCESS) if (ret != FI_SUCCESS)
throw ContextError{tools::ToString("Failed opening ofi address vector, reason: ", fi_strerror(ret))}; throw ContextError{tools::ToString("Failed opening ofi address vector, reason: ", fi_strerror(ret))};
assert(fOfiEventQueue);
ret = fi_av_bind(fOfiAddressVector, &fOfiEventQueue->fid, 0);
if (ret != FI_SUCCESS)
throw ContextError{tools::ToString("Failed binding ofi event queue to address vector, reason: ", fi_strerror(ret))};
} else { } else {
LOG(debug) << "Ofi address vector already opened. Skipping."; LOG(debug) << "Ofi address vector already opened. Skipping.";
} }
@ -185,6 +219,11 @@ auto Context::CreateOfiEndpoint() -> fid_ep*
if (ret != FI_SUCCESS) if (ret != FI_SUCCESS)
throw ContextError{tools::ToString("Failed creating ofi endpoint, reason: ", fi_strerror(ret))}; throw ContextError{tools::ToString("Failed creating ofi endpoint, reason: ", fi_strerror(ret))};
assert(fOfiEventQueue);
ret = fi_ep_bind(ep, &fOfiEventQueue->fid, 0);
if (ret != FI_SUCCESS)
throw ContextError{tools::ToString("Failed binding ofi address vector to ofi endpoint, reason: ", fi_strerror(ret))};
assert(fOfiAddressVector); assert(fOfiAddressVector);
ret = fi_ep_bind(ep, &fOfiAddressVector->fid, 0); ret = fi_ep_bind(ep, &fOfiAddressVector->fid, 0);
if (ret != FI_SUCCESS) if (ret != FI_SUCCESS)
@ -254,6 +293,21 @@ auto Context::ConvertAddress(Address address) -> sockaddr_in
return sa; return sa;
} }
auto Context::ConvertAddress(sockaddr_in address) -> Address
{
return {"tcp", inet_ntoa(address.sin_addr), ntohs(address.sin_port)};
}
auto Context::VerifyAddress(const std::string& address) -> Address
{
auto addr = ConvertAddress(address);
if (addr.Protocol != "tcp")
throw ContextError("Wrong protocol: Supported protocols are: tcp");
return addr;
}
} /* namespace ofi */ } /* namespace ofi */
} /* namespace mq */ } /* namespace mq */
} /* namespace fair */ } /* namespace fair */

View File

@ -11,6 +11,7 @@
#include <memory> #include <memory>
#include <netinet/in.h> #include <netinet/in.h>
#include <ostream>
#include <rdma/fabric.h> #include <rdma/fabric.h>
#include <string> #include <string>
#include <stdexcept> #include <stdexcept>
@ -37,21 +38,24 @@ class Context
Context(int numberIoThreads = 1); Context(int numberIoThreads = 1);
~Context(); ~Context();
/// Deferred Ofi initialization
auto InitOfi(ConnectionType type, std::string address) -> void; auto InitOfi(ConnectionType type, std::string address) -> void;
auto CreateOfiEndpoint() -> fid_ep*; auto CreateOfiEndpoint() -> fid_ep*;
auto CreateOfiCompletionQueue(Direction dir) -> fid_cq*; auto CreateOfiCompletionQueue(Direction dir) -> fid_cq*;
auto GetZmqVersion() const -> std::string; auto GetZmqVersion() const -> std::string;
auto GetOfiApiVersion() const -> std::string; auto GetOfiApiVersion() const -> std::string;
auto GetPbVersion() const -> std::string;
auto GetZmqContext() const -> void* { return fZmqContext; } auto GetZmqContext() const -> void* { return fZmqContext; }
auto InsertAddressVector(sockaddr_in address) -> fi_addr_t; auto InsertAddressVector(sockaddr_in address) -> fi_addr_t;
struct Address { struct Address {
std::string Protocol; std::string Protocol;
std::string Ip; std::string Ip;
unsigned int Port; unsigned int Port;
friend auto operator<<(std::ostream& os, const Address& a) -> std::ostream& { return os << a.Protocol << "://" << a.Ip << ":" << a.Port; }
}; };
static auto ConvertAddress(std::string address) -> Address; static auto ConvertAddress(std::string address) -> Address;
static auto ConvertAddress(Address address) -> sockaddr_in; static auto ConvertAddress(Address address) -> sockaddr_in;
static auto ConvertAddress(sockaddr_in address) -> Address;
static auto VerifyAddress(const std::string& address) -> Address;
private: private:
void* fZmqContext; void* fZmqContext;
@ -59,8 +63,10 @@ class Context
fid_fabric* fOfiFabric; fid_fabric* fOfiFabric;
fid_domain* fOfiDomain; fid_domain* fOfiDomain;
fid_av* fOfiAddressVector; fid_av* fOfiAddressVector;
fid_eq* fOfiEventQueue;
auto OpenOfiFabric() -> void; auto OpenOfiFabric() -> void;
auto OpenOfiEventQueue() -> void;
auto OpenOfiDomain() -> void; auto OpenOfiDomain() -> void;
auto OpenOfiAddressVector() -> void; auto OpenOfiAddressVector() -> void;
}; /* class Context */ }; /* class Context */

15
fairmq/ofi/Control.proto Normal file
View File

@ -0,0 +1,15 @@
syntax = "proto3";
option optimize_for = SPEED;
package fair.mq.ofi;
message DataAddressAnnouncement {
uint32 ipv4 = 1; // in_addr_t from <netinet/in.h>
uint32 port = 2; // in_port_t from <netinet/in.h>
}
message ControlMessage {
oneof type {
DataAddressAnnouncement data_address_announcement = 1;
}
}

View File

@ -16,6 +16,7 @@
#include <netinet/in.h> #include <netinet/in.h>
#include <rdma/fabric.h> #include <rdma/fabric.h>
#include <rdma/fi_endpoint.h> #include <rdma/fi_endpoint.h>
#include <rdma/fi_cm.h>
#include <sstream> #include <sstream>
#include <string.h> #include <string.h>
#include <sys/socket.h> #include <sys/socket.h>
@ -35,34 +36,34 @@ Socket::Socket(Context& context, const string& type, const string& name, const s
, fDataCompletionQueueTx(nullptr) , fDataCompletionQueueTx(nullptr)
, fDataCompletionQueueRx(nullptr) , fDataCompletionQueueRx(nullptr)
, fId(id + "." + name + "." + type) , fId(id + "." + name + "." + type)
, fMetaSocket(nullptr) , fControlSocket(nullptr)
, fMonitorSocket(nullptr) , fMonitorSocket(nullptr)
, fSndTimeout(100) , fSndTimeout(100)
, fRcvTimeout(100) , fRcvTimeout(100)
, fContext(context) , fContext(context)
, fWaitingForRemoteConnect(false) , fWaitingForControlPeer(false)
{ {
if (type != "pair") { if (type != "pair") {
throw SocketError{tools::ToString("Socket type '", type, "' not implemented for ofi transport.")}; throw SocketError{tools::ToString("Socket type '", type, "' not implemented for ofi transport.")};
} else { } else {
fMetaSocket = zmq_socket(fContext.GetZmqContext(), ZMQ_PAIR); fControlSocket = zmq_socket(fContext.GetZmqContext(), ZMQ_PAIR);
if (fMetaSocket == nullptr) if (fControlSocket == nullptr)
throw SocketError{tools::ToString("Failed creating zmq meta socket ", fId, ", reason: ", zmq_strerror(errno))}; throw SocketError{tools::ToString("Failed creating zmq meta socket ", fId, ", reason: ", zmq_strerror(errno))};
if (zmq_setsockopt(fMetaSocket, ZMQ_IDENTITY, fId.c_str(), fId.length()) != 0) if (zmq_setsockopt(fControlSocket, ZMQ_IDENTITY, fId.c_str(), fId.length()) != 0)
throw SocketError{tools::ToString("Failed setting ZMQ_IDENTITY socket option, reason: ", zmq_strerror(errno))}; throw SocketError{tools::ToString("Failed setting ZMQ_IDENTITY socket option, reason: ", zmq_strerror(errno))};
// Tell socket to try and send/receive outstanding messages for <linger> milliseconds before terminating. // Tell socket to try and send/receive outstanding messages for <linger> milliseconds before terminating.
// Default value for ZeroMQ is -1, which is to wait forever. // Default value for ZeroMQ is -1, which is to wait forever.
int linger = 1000; int linger = 1000;
if (zmq_setsockopt(fMetaSocket, ZMQ_LINGER, &linger, sizeof(linger)) != 0) if (zmq_setsockopt(fControlSocket, ZMQ_LINGER, &linger, sizeof(linger)) != 0)
throw SocketError{tools::ToString("Failed setting ZMQ_LINGER socket option, reason: ", zmq_strerror(errno))}; throw SocketError{tools::ToString("Failed setting ZMQ_LINGER socket option, reason: ", zmq_strerror(errno))};
if (zmq_setsockopt(fMetaSocket, ZMQ_SNDTIMEO, &fSndTimeout, sizeof(fSndTimeout)) != 0) if (zmq_setsockopt(fControlSocket, ZMQ_SNDTIMEO, &fSndTimeout, sizeof(fSndTimeout)) != 0)
throw SocketError{tools::ToString("Failed setting ZMQ_SNDTIMEO socket option, reason: ", zmq_strerror(errno))}; throw SocketError{tools::ToString("Failed setting ZMQ_SNDTIMEO socket option, reason: ", zmq_strerror(errno))};
if (zmq_setsockopt(fMetaSocket, ZMQ_RCVTIMEO, &fRcvTimeout, sizeof(fRcvTimeout)) != 0) if (zmq_setsockopt(fControlSocket, ZMQ_RCVTIMEO, &fRcvTimeout, sizeof(fRcvTimeout)) != 0)
throw SocketError{tools::ToString("Failed setting ZMQ_RCVTIMEO socket option, reason: ", zmq_strerror(errno))}; throw SocketError{tools::ToString("Failed setting ZMQ_RCVTIMEO socket option, reason: ", zmq_strerror(errno))};
fMonitorSocket = zmq_socket(fContext.GetZmqContext(), ZMQ_PAIR); fMonitorSocket = zmq_socket(fContext.GetZmqContext(), ZMQ_PAIR);
@ -71,7 +72,7 @@ Socket::Socket(Context& context, const string& type, const string& name, const s
throw SocketError{tools::ToString("Failed creating zmq monitor socket ", fId, ", reason: ", zmq_strerror(errno))}; throw SocketError{tools::ToString("Failed creating zmq monitor socket ", fId, ", reason: ", zmq_strerror(errno))};
auto mon_addr = tools::ToString("inproc://", fId); auto mon_addr = tools::ToString("inproc://", fId);
if (zmq_socket_monitor(fMetaSocket, mon_addr.c_str(), ZMQ_EVENT_ACCEPTED) < 0) if (zmq_socket_monitor(fControlSocket, mon_addr.c_str(), ZMQ_EVENT_ACCEPTED | ZMQ_EVENT_CONNECTED) < 0)
throw SocketError{tools::ToString("Failed setting up monitor on meta socket, reason: ", zmq_strerror(errno))}; throw SocketError{tools::ToString("Failed setting up monitor on meta socket, reason: ", zmq_strerror(errno))};
if (zmq_connect(fMonitorSocket, mon_addr.c_str()) != 0) if (zmq_connect(fMonitorSocket, mon_addr.c_str()) != 0)
@ -80,50 +81,65 @@ Socket::Socket(Context& context, const string& type, const string& name, const s
} }
auto Socket::Bind(const string& address) -> bool auto Socket::Bind(const string& address) -> bool
{ try {
auto addr2 = fContext.ConvertAddress(address); auto addr = Context::VerifyAddress(address);
if (addr2.Protocol != "tcp") BindControlSocket(addr);
throw SocketError("Wrong protocol: Supported protocols are: tcp");
if (zmq_bind(fMetaSocket, address.c_str()) != 0) {
if (errno == EADDRINUSE) {
// do not print error in this case, this is handled by FairMQDevice
// in case no connection could be established after trying a number of random ports from a range.
return false;
}
LOG(error) << "Failed binding socket " << fId << ", reason: " << zmq_strerror(errno);
return false;
}
fContext.InitOfi(ConnectionType::Bind, address); fContext.InitOfi(ConnectionType::Bind, address);
InitDataEndpoint();
try { fWaitingForControlPeer = true;
InitDataEndpoint();
} catch (SocketError& e) {
LOG(error) << e.what();
return false;
}
fWaitingForRemoteConnect = true;
return true; return true;
} }
catch (const SilentSocketError& e)
{
// do not print error in this case, this is handled by FairMQDevice
// in case no connection could be established after trying a number of random ports from a range.
return false;
}
catch (const SocketError& e)
{
LOG(error) << e.what();
return false;
}
auto Socket::Connect(const string& address) -> void auto Socket::Connect(const string& address) -> void
{ {
auto addr2 = fContext.ConvertAddress(address); auto addr = Context::VerifyAddress(address);
if (addr2.Protocol != "tcp") ConnectControlSocket(addr);
throw SocketError("Wrong protocol: Supported protocols are: tcp");
if (zmq_connect(fMetaSocket, address.c_str()) != 0) {
throw SocketError(tools::ToString("Failed connecting socket ", fId, ", reason: ", zmq_strerror(errno)));
}
fContext.InitOfi(ConnectionType::Connect, address); fContext.InitOfi(ConnectionType::Connect, address);
InitDataEndpoint(); InitDataEndpoint();
fWaitingForControlPeer = true;
}
fRemoteAddr = fContext.InsertAddressVector(fContext.ConvertAddress(addr2)); auto Socket::BindControlSocket(Context::Address address) -> void
{
auto addr = tools::ToString("tcp://", address.Ip, ":", address.Port);
if (zmq_bind(fControlSocket, addr.c_str()) != 0) {
if (errno == EADDRINUSE) throw SilentSocketError("EADDRINUSE");
throw SocketError(tools::ToString("Failed binding control socket ", fId, ", reason: ", zmq_strerror(errno)));
}
}
auto Socket::ConnectControlSocket(Context::Address address) -> void
{
auto addr = tools::ToString("tcp://", address.Ip, ":", address.Port);
if (zmq_connect(fControlSocket, addr.c_str()) != 0)
throw SocketError(tools::ToString("Failed connecting control socket ", fId, ", reason: ", zmq_strerror(errno)));
}
auto Socket::ProcessDataAddressAnnouncement(std::unique_ptr<ControlMessage> ctrl) -> void
{
assert(ctrl->has_data_address_announcement());
auto daa = ctrl->data_address_announcement();
sockaddr_in remoteAddr;
remoteAddr.sin_family = AF_INET;
remoteAddr.sin_port = daa.port();
remoteAddr.sin_addr.s_addr = daa.ipv4();
LOG(debug) << Context::ConvertAddress(remoteAddr);
fRemoteDataAddr = fContext.InsertAddressVector(remoteAddr);
} }
auto Socket::InitDataEndpoint() -> void auto Socket::InitDataEndpoint() -> void
@ -149,13 +165,81 @@ auto Socket::InitDataEndpoint() -> void
ret = fi_enable(fDataEndpoint); ret = fi_enable(fDataEndpoint);
if (ret != FI_SUCCESS) if (ret != FI_SUCCESS)
throw SocketError(tools::ToString("Failed opening ofi fabric, reason: ", fi_strerror(ret))); throw SocketError(tools::ToString("Failed enabling ofi endpoint, reason: ", fi_strerror(ret)));
} }
} }
auto Socket::WaitForRemoteConnect() -> void void free_string(void* /*data*/, void* hint)
{ {
assert(fWaitingForRemoteConnect); delete static_cast<string*>(hint);
}
auto Socket::AnnounceDataAddress() -> void
try {
using namespace google::protobuf;
size_t addrlen = sizeof(sockaddr_in);
auto ret = fi_getname(&fDataEndpoint->fid, &fLocalDataAddr, &addrlen);
if (ret != FI_SUCCESS)
throw SocketError(tools::ToString("Failed retrieving native address from ofi endpoint, reason: ", fi_strerror(ret)));
assert(addrlen == sizeof(sockaddr_in));
LOG(debug) << "Address of local ofi endpoint in socket " << fId << ": " << Context::ConvertAddress(fLocalDataAddr);
// Create new control message
auto ctrl = tools::make_unique<ControlMessage>();
auto daa = tools::make_unique<DataAddressAnnouncement>();
// Fill data address announcement
daa->set_ipv4(fLocalDataAddr.sin_addr.s_addr);
daa->set_port(fLocalDataAddr.sin_port);
// Fill control message
ctrl->set_allocated_data_address_announcement(daa.release());
assert(ctrl->IsInitialized());
SendControlMessage(move(ctrl));
} catch (const SocketError& e) {
throw SocketError(tools::ToString("Failed to announce data address, reason: ", e.what()));
}
auto Socket::SendControlMessage(unique_ptr<ControlMessage> ctrl) -> void
{
assert(fControlSocket);
// Serialize
string* str = new string();
ctrl->SerializeToString(str);
zmq_msg_t msg;
auto ret = zmq_msg_init_data(&msg, const_cast<char*>(str->c_str()), str->length(), free_string, str);
assert(ret == 0);
// Send
if (zmq_msg_send(&msg, fControlSocket, 0) == -1)
throw SocketError(tools::ToString("Failed to send control message, reason: ", zmq_strerror(errno)));
}
auto Socket::ReceiveControlMessage() -> unique_ptr<ControlMessage>
{
assert(fControlSocket);
// Receive
zmq_msg_t msg;
auto ret = zmq_msg_init(&msg);
assert(ret == 0);
if (zmq_msg_recv(&msg, fControlSocket, 0) == -1)
throw SocketError(tools::ToString("Failed to receive control message, reason: ", zmq_strerror(errno)));
// Deserialize
auto ctrl = tools::make_unique<ControlMessage>();
ctrl->ParseFromArray(zmq_msg_data(&msg), zmq_msg_size(&msg));
return ctrl;
}
auto Socket::WaitForControlPeer() -> void
{
assert(fWaitingForControlPeer);
// First frame in message contains event number and value // First frame in message contains event number and value
zmq_msg_t msg; zmq_msg_t msg;
@ -172,21 +256,23 @@ auto Socket::WaitForRemoteConnect() -> void
if (zmq_msg_recv(&msg, fMonitorSocket, 0) == -1) if (zmq_msg_recv(&msg, fMonitorSocket, 0) == -1)
throw SocketError(tools::ToString("Failed to get monitor event, reason: ", zmq_strerror(errno))); throw SocketError(tools::ToString("Failed to get monitor event, reason: ", zmq_strerror(errno)));
string localAddress = string(static_cast<char*>(zmq_msg_data(&msg)), zmq_msg_size(&msg)); if (event == ZMQ_EVENT_ACCEPTED) {
// string localAddress = string(static_cast<char*>(zmq_msg_data(&msg)), zmq_msg_size(&msg));
sockaddr_in remoteAddr;
socklen_t addrSize = sizeof(sockaddr_in);
int ret = getpeername(value, (sockaddr*)&remoteAddr, &addrSize);
if (ret != 0)
throw SocketError(tools::ToString("Failed retrieving remote address, reason: ", strerror(errno)));
string remoteIp(inet_ntoa(remoteAddr.sin_addr));
int remotePort = ntohs(remoteAddr.sin_port);
LOG(debug) << "Accepted control peer connection from " << remoteIp << ":" << remotePort;
} else if (event == ZMQ_EVENT_CONNECTED) {
LOG(debug) << "Connected successfully to control peer";
} else {
LOG(debug) << "Unknown monitor event received: " << event << ". Ignoring.";
}
assert(event == ZMQ_EVENT_ACCEPTED); // we only subscribed for this event fWaitingForControlPeer = false;
sockaddr_in remoteAddr;
socklen_t addrSize = sizeof(sockaddr_in);
int ret = getpeername(value, (sockaddr*)&remoteAddr, &addrSize);
if (ret != 0)
throw SocketError(tools::ToString("Failed retrieving peer address, reason: ", strerror(errno)));
string remoteIp(inet_ntoa(remoteAddr.sin_addr));
int remotePort = ntohs(remoteAddr.sin_port);
LOG(debug) << "peer connected from " << remoteIp << ":" << remotePort << " at " << localAddress;
fRemoteAddr = fContext.InsertAddressVector(remoteAddr);
fWaitingForRemoteConnect = false;
} }
auto Socket::Send(MessagePtr& msg, const int timeout) -> int { return SendImpl(msg, 0, timeout); } auto Socket::Send(MessagePtr& msg, const int timeout) -> int { return SendImpl(msg, 0, timeout); }
@ -200,41 +286,51 @@ auto Socket::TrySend(std::vector<MessagePtr>& msgVec) -> int64_t { return SendIm
auto Socket::TryReceive(std::vector<MessagePtr>& msgVec) -> int64_t { return ReceiveImpl(msgVec, ZMQ_DONTWAIT, 0); } auto Socket::TryReceive(std::vector<MessagePtr>& msgVec) -> int64_t { return ReceiveImpl(msgVec, ZMQ_DONTWAIT, 0); }
auto Socket::SendImpl(FairMQMessagePtr& msg, const int flags, const int timeout) -> int auto Socket::SendImpl(FairMQMessagePtr& msg, const int flags, const int timeout) -> int
try {
if (fWaitingForControlPeer) {
WaitForControlPeer();
AnnounceDataAddress();
ProcessDataAddressAnnouncement(ReceiveControlMessage());
}
auto ret = zmq_send(fControlSocket, nullptr, 0, flags);
if (ret == EAGAIN) throw SilentSocketError("EAGAIN");
if (ret == -1) throw SocketError(tools::ToString("Failed sending control message on socket ", fId, ", reason: ", zmq_strerror(errno)));
return ret;
}
catch (const SilentSocketError& e)
{ {
if (fWaitingForRemoteConnect) { return -2;
try { }
WaitForRemoteConnect(); catch (const std::exception& e)
} catch (const std::exception& e) { {
LOG(error) << e.what(); LOG(error) << e.what();
return -1; return -1;
}
}
// void* metadata = malloc(sizeof(size_t));
auto ret = zmq_send(fMetaSocket, nullptr, 0, flags);
if (ret == EAGAIN) {
return -2;
} else if (ret < 0) {
LOG(error) << "Failed sending meta message on socket " << fId << ", reason: " << zmq_strerror(errno);
return -1;
} else {
// auto ret2 = fi_send(fDataEndpoint, msg->GetData(), msg->GetSize(), nullptr, fi_addr_t dest_addr, nullptr);
return ret;
}
} }
auto Socket::ReceiveImpl(FairMQMessagePtr& msg, const int flags, const int timeout) -> int auto Socket::ReceiveImpl(FairMQMessagePtr& msg, const int flags, const int timeout) -> int
{ try {
auto ret = zmq_recv(fMetaSocket, nullptr, 0, flags); if (fWaitingForControlPeer) {
if (ret == EAGAIN) { WaitForControlPeer();
return -2; AnnounceDataAddress();
} else if (ret < 0) { ProcessDataAddressAnnouncement(ReceiveControlMessage());
LOG(error) << "Failed receiving meta message on socket " << fId << ", reason: " << zmq_strerror(errno);
return -1;
} else {
return ret;
} }
auto ret = zmq_recv(fControlSocket, nullptr, 0, flags);
if (ret == EAGAIN) throw SilentSocketError("EAGAIN");
if (ret == -1) throw SocketError(tools::ToString("Failed sending control message on socket ", fId, ", reason: ", zmq_strerror(errno)));
return ret;
}
catch (const SilentSocketError& e)
{
return -2;
}
catch (const std::exception& e)
{
LOG(error) << e.what();
return -1;
} }
auto Socket::SendImpl(vector<FairMQMessagePtr>& msgVec, const int flags, const int timeout) -> int64_t auto Socket::SendImpl(vector<FairMQMessagePtr>& msgVec, const int flags, const int timeout) -> int64_t
@ -410,7 +506,7 @@ auto Socket::ReceiveImpl(vector<FairMQMessagePtr>& msgVec, const int flags, cons
auto Socket::Close() -> void auto Socket::Close() -> void
{ {
if (zmq_close(fMetaSocket) != 0) if (zmq_close(fControlSocket) != 0)
throw SocketError(tools::ToString("Failed closing zmq meta socket, reason: ", zmq_strerror(errno))); throw SocketError(tools::ToString("Failed closing zmq meta socket, reason: ", zmq_strerror(errno)));
if (zmq_close(fMonitorSocket) != 0) if (zmq_close(fMonitorSocket) != 0)
@ -437,14 +533,14 @@ auto Socket::Close() -> void
auto Socket::SetOption(const string& option, const void* value, size_t valueSize) -> void auto Socket::SetOption(const string& option, const void* value, size_t valueSize) -> void
{ {
if (zmq_setsockopt(fMetaSocket, GetConstant(option), value, valueSize) < 0) { if (zmq_setsockopt(fControlSocket, GetConstant(option), value, valueSize) < 0) {
throw SocketError{tools::ToString("Failed setting socket option, reason: ", zmq_strerror(errno))}; throw SocketError{tools::ToString("Failed setting socket option, reason: ", zmq_strerror(errno))};
} }
} }
auto Socket::GetOption(const string& option, void* value, size_t* valueSize) -> void auto Socket::GetOption(const string& option, void* value, size_t* valueSize) -> void
{ {
if (zmq_getsockopt(fMetaSocket, GetConstant(option), value, valueSize) < 0) { if (zmq_getsockopt(fControlSocket, GetConstant(option), value, valueSize) < 0) {
throw SocketError{tools::ToString("Failed getting socket option, reason: ", zmq_strerror(errno))}; throw SocketError{tools::ToString("Failed getting socket option, reason: ", zmq_strerror(errno))};
} }
} }

View File

@ -12,8 +12,10 @@
#include <FairMQSocket.h> #include <FairMQSocket.h>
#include <FairMQMessage.h> #include <FairMQMessage.h>
#include <fairmq/ofi/Context.h> #include <fairmq/ofi/Context.h>
#include <fairmq/ofi/Control.pb.h>
#include <memory> // unique_ptr #include <memory> // unique_ptr
#include <netinet/in.h>
#include <rdma/fabric.h> #include <rdma/fabric.h>
namespace fair namespace fair
@ -51,7 +53,7 @@ class Socket : public fair::mq::Socket
auto TrySend(std::vector<MessagePtr>& msgVec) -> int64_t override; auto TrySend(std::vector<MessagePtr>& msgVec) -> int64_t override;
auto TryReceive(std::vector<MessagePtr>& msgVec) -> int64_t override; auto TryReceive(std::vector<MessagePtr>& msgVec) -> int64_t override;
auto GetSocket() const -> void* override { return fMetaSocket; } auto GetSocket() const -> void* override { return fControlSocket; }
auto GetSocket(int nothing) const -> int override { return -1; } auto GetSocket(int nothing) const -> int override { return -1; }
auto Close() -> void override; auto Close() -> void override;
@ -74,7 +76,7 @@ class Socket : public fair::mq::Socket
~Socket() override; ~Socket() override;
private: private:
void* fMetaSocket; void* fControlSocket;
void* fMonitorSocket; void* fMonitorSocket;
fid_ep* fDataEndpoint; fid_ep* fDataEndpoint;
fid_cq* fDataCompletionQueueTx; fid_cq* fDataCompletionQueueTx;
@ -85,8 +87,9 @@ class Socket : public fair::mq::Socket
std::atomic<unsigned long> fMessagesTx; std::atomic<unsigned long> fMessagesTx;
std::atomic<unsigned long> fMessagesRx; std::atomic<unsigned long> fMessagesRx;
Context& fContext; Context& fContext;
fi_addr_t fRemoteAddr; fi_addr_t fRemoteDataAddr;
bool fWaitingForRemoteConnect; sockaddr_in fLocalDataAddr;
bool fWaitingForControlPeer;
int fSndTimeout; int fSndTimeout;
int fRcvTimeout; int fRcvTimeout;
@ -97,9 +100,20 @@ class Socket : public fair::mq::Socket
auto ReceiveImpl(std::vector<MessagePtr>& msgVec, const int flags, const int timeout) -> int64_t; auto ReceiveImpl(std::vector<MessagePtr>& msgVec, const int flags, const int timeout) -> int64_t;
auto InitDataEndpoint() -> void; auto InitDataEndpoint() -> void;
auto WaitForRemoteConnect() -> void; auto WaitForControlPeer() -> void;
auto AnnounceDataAddress() -> void;
auto SendControlMessage(std::unique_ptr<ControlMessage> ctrl) -> void;
auto ReceiveControlMessage() -> std::unique_ptr<ControlMessage>;
auto ProcessDataAddressAnnouncement(std::unique_ptr<ControlMessage> ctrl) -> void;
auto ConnectControlSocket(Context::Address address) -> void;
auto BindControlSocket(Context::Address address) -> void;
}; /* class Socket */ }; /* class Socket */
// helper function to clean up the object holding the data after it is transported.
void free_string(void* /*data*/, void* hint);
struct SilentSocketError : SocketError { using SocketError::SocketError; };
} /* namespace ofi */ } /* namespace ofi */
} /* namespace mq */ } /* namespace mq */
} /* namespace fair */ } /* namespace fair */

View File

@ -27,7 +27,8 @@ TransportFactory::TransportFactory(const string& id, const FairMQProgOptions* co
try : FairMQTransportFactory{id} try : FairMQTransportFactory{id}
{ {
LOG(debug) << "Transport: Using ZeroMQ (" << fContext.GetZmqVersion() << ") & " LOG(debug) << "Transport: Using ZeroMQ (" << fContext.GetZmqVersion() << ") & "
<< "OFI libfabric (API " << fContext.GetOfiApiVersion() << ")"; << "OFI libfabric (API " << fContext.GetOfiApiVersion() << ") & "
<< "Google Protobuf (" << fContext.GetPbVersion() << ")";
} }
catch (ContextError& e) { throw TransportFactoryError{e.what()}; } catch (ContextError& e) { throw TransportFactoryError{e.what()}; }