diff --git a/fairmq/ofi/Context.cxx b/fairmq/ofi/Context.cxx index b9bb927f..186b54ed 100644 --- a/fairmq/ofi/Context.cxx +++ b/fairmq/ofi/Context.cxx @@ -11,6 +11,7 @@ #include #include +#include #include #include #include @@ -79,26 +80,16 @@ auto Context::GetOfiApiVersion() const -> std::string return tools::ToString(FI_MAJOR(ofi_version), ".", FI_MINOR(ofi_version)); } -auto Context::InitOfi(ConnectionType type, std::string address) -> void +auto Context::InitOfi(ConnectionType type, std::string addr) -> void { - // Parse address - string protocol, ip; - unsigned int port = 0; - regex address_regex("^([a-z]+)://([0-9]+\\.[0-9]+\\.[0-9]+\\.[0-9]+):([0-9]+).*"); - smatch address_result; - if (regex_match(address, address_result, address_regex)) { - protocol = address_result[1]; - ip = address_result[2]; - port = stoul(address_result[3]); - LOG(debug) << "Parsed '" << protocol << "', '" << ip << "', '" << port << "' fields from '" << address << "'"; - } - if (protocol != "tcp") throw ContextError{"Wrong protocol: Supplied address must be in format tcp://ip:port"}; + auto addr2 = ConvertAddress(addr); + if (addr2.Protocol != "tcp") + throw ContextError{"Wrong protocol: Supplied address must be in format tcp://ip:port"}; if (!fOfiInfo) { sockaddr_in* sa = static_cast(malloc(sizeof(sockaddr_in))); - inet_pton(AF_INET, ip.c_str(), &(sa->sin_addr)); - sa->sin_port = htons(port); - sa->sin_family = AF_INET; + auto sa2 = ConvertAddress(addr2); + memcpy(sa, &sa2, sizeof(sockaddr_in)); // Prepare fi_getinfo query unique_ptr ofi_hints(fi_allocinfo(), fi_freeinfo); @@ -124,7 +115,7 @@ auto Context::InitOfi(ConnectionType type, std::string address) -> void if (!fOfiInfo) throw ContextError{"Could not find any ofi compatible fabric."}; // for(auto cursor{ofi_info}; cursor->next != nullptr; cursor = cursor->next) { - // LOG(debug) << fi_tostr(fOfiInfo, FI_TYPE_INFO); + LOG(debug) << fi_tostr(fOfiInfo, FI_TYPE_INFO); // } // } else { @@ -224,6 +215,45 @@ auto Context::CreateOfiCompletionQueue(Direction dir) -> fid_cq* return cq; } +auto Context::InsertAddressVector(sockaddr_in address) -> fi_addr_t +{ + fi_addr_t mappedAddress; + auto ret = fi_av_insert(fOfiAddressVector, &address, 1, &mappedAddress, 0, nullptr); + if (ret != FI_SUCCESS) + throw ContextError{tools::ToString("Failed to insert address into ofi address vector, reason: ", fi_strerror(ret))}; + + return ret; +} + +auto Context::ConvertAddress(std::string address) -> Address +{ + string protocol, ip; + unsigned int port = 0; + regex address_regex("^([a-z]+)://([0-9]+\\.[0-9]+\\.[0-9]+\\.[0-9]+):([0-9]+).*"); + smatch address_result; + if (regex_match(address, address_result, address_regex)) { + protocol = address_result[1]; + ip = address_result[2]; + port = stoul(address_result[3]); + // LOG(debug) << "Parsed '" << protocol << "', '" << ip << "', '" << port << "' fields from '" << address << "'"; + } else { + throw ContextError(tools::ToString("Wrong format: Address must be in format prot://ip:port")); + } + + return { protocol, ip, port }; +} + +auto Context::ConvertAddress(Address address) -> sockaddr_in +{ + sockaddr_in sa; + if (inet_pton(AF_INET, address.Ip.c_str(), &(sa.sin_addr)) != 1) + throw ContextError(tools::ToString("Failed to convert given IP '", address.Ip, "' to struct in_addr, reason: ", strerror(errno))); + sa.sin_port = htons(address.Port); + sa.sin_family = AF_INET; + + return sa; +} + } /* namespace ofi */ } /* namespace mq */ } /* namespace fair */ diff --git a/fairmq/ofi/Context.h b/fairmq/ofi/Context.h index 7dbd4d19..04f26f2e 100644 --- a/fairmq/ofi/Context.h +++ b/fairmq/ofi/Context.h @@ -10,6 +10,7 @@ #define FAIR_MQ_OFI_CONTEXT_H #include +#include #include #include #include @@ -43,6 +44,14 @@ class Context auto GetZmqVersion() const -> std::string; auto GetOfiApiVersion() const -> std::string; auto GetZmqContext() const -> void* { return fZmqContext; } + auto InsertAddressVector(sockaddr_in address) -> fi_addr_t; + struct Address { + std::string Protocol; + std::string Ip; + unsigned int Port; + }; + static auto ConvertAddress(std::string address) -> Address; + static auto ConvertAddress(Address address) -> sockaddr_in; private: void* fZmqContext; diff --git a/fairmq/ofi/Socket.cxx b/fairmq/ofi/Socket.cxx index 781e0c73..4b6f6eac 100644 --- a/fairmq/ofi/Socket.cxx +++ b/fairmq/ofi/Socket.cxx @@ -11,10 +11,14 @@ #include #include +#include +#include +#include #include #include -#include #include +#include +#include #include namespace fair @@ -31,43 +35,56 @@ Socket::Socket(Context& context, const string& type, const string& name, const s , fDataCompletionQueueTx(nullptr) , fDataCompletionQueueRx(nullptr) , fId(id + "." + name + "." + type) + , fMetaSocket(nullptr) + , fMonitorSocket(nullptr) , fSndTimeout(100) , fRcvTimeout(100) , fContext(context) + , fWaitingForRemoteConnect(false) { if (type != "pair") { throw SocketError{tools::ToString("Socket type '", type, "' not implemented for ofi transport.")}; } else { - fMetaSocket = zmq_socket(fContext.GetZmqContext(), GetConstant(type)); + fMetaSocket = zmq_socket(fContext.GetZmqContext(), ZMQ_PAIR); - if (fMetaSocket == nullptr) { - throw SocketError{tools::ToString("Failed creating socket ", fId, ", reason: ", zmq_strerror(errno))}; - } + if (fMetaSocket == nullptr) + throw SocketError{tools::ToString("Failed creating zmq meta socket ", fId, ", reason: ", zmq_strerror(errno))}; - if (zmq_setsockopt(fMetaSocket, ZMQ_IDENTITY, fId.c_str(), fId.length()) != 0) { + if (zmq_setsockopt(fMetaSocket, ZMQ_IDENTITY, fId.c_str(), fId.length()) != 0) throw SocketError{tools::ToString("Failed setting ZMQ_IDENTITY socket option, reason: ", zmq_strerror(errno))}; - } // Tell socket to try and send/receive outstanding messages for milliseconds before terminating. // Default value for ZeroMQ is -1, which is to wait forever. int linger = 1000; - if (zmq_setsockopt(fMetaSocket, ZMQ_LINGER, &linger, sizeof(linger)) != 0) { + if (zmq_setsockopt(fMetaSocket, ZMQ_LINGER, &linger, sizeof(linger)) != 0) throw SocketError{tools::ToString("Failed setting ZMQ_LINGER socket option, reason: ", zmq_strerror(errno))}; - } - if (zmq_setsockopt(fMetaSocket, ZMQ_SNDTIMEO, &fSndTimeout, sizeof(fSndTimeout)) != 0) { + if (zmq_setsockopt(fMetaSocket, ZMQ_SNDTIMEO, &fSndTimeout, sizeof(fSndTimeout)) != 0) throw SocketError{tools::ToString("Failed setting ZMQ_SNDTIMEO socket option, reason: ", zmq_strerror(errno))}; - } - if (zmq_setsockopt(fMetaSocket, ZMQ_RCVTIMEO, &fRcvTimeout, sizeof(fRcvTimeout)) != 0) { + if (zmq_setsockopt(fMetaSocket, ZMQ_RCVTIMEO, &fRcvTimeout, sizeof(fRcvTimeout)) != 0) throw SocketError{tools::ToString("Failed setting ZMQ_RCVTIMEO socket option, reason: ", zmq_strerror(errno))}; - } + + fMonitorSocket = zmq_socket(fContext.GetZmqContext(), ZMQ_PAIR); + + if (fMonitorSocket == nullptr) + throw SocketError{tools::ToString("Failed creating zmq monitor socket ", fId, ", reason: ", zmq_strerror(errno))}; + + auto mon_addr = tools::ToString("inproc://", fId); + if (zmq_socket_monitor(fMetaSocket, mon_addr.c_str(), ZMQ_EVENT_ACCEPTED) < 0) + throw SocketError{tools::ToString("Failed setting up monitor on meta socket, reason: ", zmq_strerror(errno))}; + + if (zmq_connect(fMonitorSocket, mon_addr.c_str()) != 0) + throw SocketError{tools::ToString("Failed connecting monitor socket to meta socket, reason: ", zmq_strerror(errno))}; } } auto Socket::Bind(const string& address) -> bool { - // TODO handle verbs:// + auto addr2 = fContext.ConvertAddress(address); + if (addr2.Protocol != "tcp") + throw SocketError("Wrong protocol: Supported protocols are: tcp"); + if (zmq_bind(fMetaSocket, address.c_str()) != 0) { if (errno == EADDRINUSE) { // do not print error in this case, this is handled by FairMQDevice @@ -87,19 +104,26 @@ auto Socket::Bind(const string& address) -> bool return false; } + fWaitingForRemoteConnect = true; + return true; } auto Socket::Connect(const string& address) -> void { - // TODO handle verbs:// + auto addr2 = fContext.ConvertAddress(address); + if (addr2.Protocol != "tcp") + throw SocketError("Wrong protocol: Supported protocols are: tcp"); + if (zmq_connect(fMetaSocket, address.c_str()) != 0) { - throw SocketError{tools::ToString("Failed connecting socket ", fId, ", reason: ", zmq_strerror(errno))}; + throw SocketError(tools::ToString("Failed connecting socket ", fId, ", reason: ", zmq_strerror(errno))); } fContext.InitOfi(ConnectionType::Connect, address); InitDataEndpoint(); + + fRemoteAddr = fContext.InsertAddressVector(fContext.ConvertAddress(addr2)); } auto Socket::InitDataEndpoint() -> void @@ -129,6 +153,42 @@ auto Socket::InitDataEndpoint() -> void } } +auto Socket::WaitForRemoteConnect() -> void +{ + assert(fWaitingForRemoteConnect); + + // First frame in message contains event number and value + zmq_msg_t msg; + zmq_msg_init(&msg); + if (zmq_msg_recv(&msg, fMonitorSocket, 0) == -1) + throw SocketError(tools::ToString("Failed to get monitor event, reason: ", zmq_strerror(errno))); + + uint8_t* data = (uint8_t*) zmq_msg_data(&msg); + uint16_t event = *(uint16_t*)(data); + int value = *(uint32_t *)(data + 2); + + // Second frame in message contains event address + zmq_msg_init(&msg); + if (zmq_msg_recv(&msg, fMonitorSocket, 0) == -1) + throw SocketError(tools::ToString("Failed to get monitor event, reason: ", zmq_strerror(errno))); + + string localAddress = string(static_cast(zmq_msg_data(&msg)), zmq_msg_size(&msg)); + + assert(event == ZMQ_EVENT_ACCEPTED); // we only subscribed for this event + + sockaddr_in remoteAddr; + socklen_t addrSize = sizeof(sockaddr_in); + int ret = getpeername(value, (sockaddr*)&remoteAddr, &addrSize); + if (ret != 0) + throw SocketError(tools::ToString("Failed retrieving peer address, reason: ", strerror(errno))); + string remoteIp(inet_ntoa(remoteAddr.sin_addr)); + int remotePort = ntohs(remoteAddr.sin_port); + LOG(debug) << "peer connected from " << remoteIp << ":" << remotePort << " at " << localAddress; + + fRemoteAddr = fContext.InsertAddressVector(remoteAddr); + fWaitingForRemoteConnect = false; +} + auto Socket::Send(MessagePtr& msg, const int timeout) -> int { return SendImpl(msg, 0, timeout); } auto Socket::Receive(MessagePtr& msg, const int timeout) -> int { return ReceiveImpl(msg, 0, timeout); } auto Socket::Send(std::vector& msgVec, const int timeout) -> int64_t { return SendImpl(msgVec, 0, timeout); } @@ -141,9 +201,18 @@ auto Socket::TryReceive(std::vector& msgVec) -> int64_t { return Rec auto Socket::SendImpl(FairMQMessagePtr& msg, const int flags, const int timeout) -> int { - void* metadata = malloc(sizeof(size_t)); + if (fWaitingForRemoteConnect) { + try { + WaitForRemoteConnect(); + } catch (const std::exception& e) { + LOG(error) << e.what(); + return -1; + } + } + + // void* metadata = malloc(sizeof(size_t)); - auto ret = zmq_send(fMetaSocket, &metadata, sizeof(int), flags); + auto ret = zmq_send(fMetaSocket, nullptr, 0, flags); if (ret == EAGAIN) { return -2; } else if (ret < 0) { @@ -341,26 +410,28 @@ auto Socket::ReceiveImpl(vector& msgVec, const int flags, cons auto Socket::Close() -> void { - if (zmq_close(fMetaSocket) != 0) { - throw SocketError{tools::ToString("Failed closing zmq socket, reason: ", zmq_strerror(errno))}; - } + if (zmq_close(fMetaSocket) != 0) + throw SocketError(tools::ToString("Failed closing zmq meta socket, reason: ", zmq_strerror(errno))); + + if (zmq_close(fMonitorSocket) != 0) + throw SocketError(tools::ToString("Failed closing zmq monitor socket, reason: ", zmq_strerror(errno))); if (fDataEndpoint) { auto ret = fi_close(&fDataEndpoint->fid); if (ret != FI_SUCCESS) - LOG(error) << "Failed closing ofi endpoint, reason: " << fi_strerror(ret); + throw SocketError(tools::ToString("Failed closing ofi endpoint, reason: ", fi_strerror(ret))); } if (fDataCompletionQueueTx) { auto ret = fi_close(&fDataCompletionQueueTx->fid); if (ret != FI_SUCCESS) - LOG(error) << "Failed closing ofi transmit completion queue, reason: " << fi_strerror(ret); + throw SocketError(tools::ToString("Failed closing ofi transmit completion queue, reason: ", fi_strerror(ret))); } if (fDataCompletionQueueRx) { auto ret = fi_close(&fDataCompletionQueueRx->fid); if (ret != FI_SUCCESS) - LOG(error) << "Failed closing ofi receive completion queue, reason: " << fi_strerror(ret); + throw SocketError(tools::ToString("Failed closing ofi receive completion queue, reason: ", fi_strerror(ret))); } } @@ -540,7 +611,11 @@ auto Socket::GetConstant(const string& constant) -> int Socket::~Socket() { - Close(); // NOLINT(clang-analyzer-optin.cplusplus.VirtualCall) + try { + Close(); // NOLINT(clang-analyzer-optin.cplusplus.VirtualCall) + } catch (SocketError& e) { + LOG(error) << e.what(); + } } } /* namespace ofi */ diff --git a/fairmq/ofi/Socket.h b/fairmq/ofi/Socket.h index 77820ca3..44d12ad1 100644 --- a/fairmq/ofi/Socket.h +++ b/fairmq/ofi/Socket.h @@ -75,6 +75,7 @@ class Socket : public fair::mq::Socket private: void* fMetaSocket; + void* fMonitorSocket; fid_ep* fDataEndpoint; fid_cq* fDataCompletionQueueTx; fid_cq* fDataCompletionQueueRx; @@ -84,6 +85,8 @@ class Socket : public fair::mq::Socket std::atomic fMessagesTx; std::atomic fMessagesRx; Context& fContext; + fi_addr_t fRemoteAddr; + bool fWaitingForRemoteConnect; int fSndTimeout; int fRcvTimeout; @@ -94,6 +97,7 @@ class Socket : public fair::mq::Socket auto ReceiveImpl(std::vector& msgVec, const int flags, const int timeout) -> int64_t; auto InitDataEndpoint() -> void; + auto WaitForRemoteConnect() -> void; }; /* class Socket */ } /* namespace ofi */