From 92320604a93c1678b2275aca66cd3251e9c134d6 Mon Sep 17 00:00:00 2001 From: Dennis Klein Date: Tue, 27 Feb 2018 20:38:26 +0100 Subject: [PATCH] FairMQ: Add ofi address vector Translates between native and ofi addresses. Needed for unconnected endpoints. --- fairmq/FairMQMessage.h | 2 +- fairmq/ofi/Context.cxx | 78 ++++++++++++++++++++++++++++++++++-------- fairmq/ofi/Context.h | 5 ++- fairmq/ofi/Message.cxx | 10 +++--- fairmq/ofi/Message.h | 4 ++- fairmq/ofi/Socket.cxx | 53 +++++++++++++--------------- fairmq/ofi/Socket.h | 2 ++ 7 files changed, 104 insertions(+), 50 deletions(-) diff --git a/fairmq/FairMQMessage.h b/fairmq/FairMQMessage.h index 78431d96..69998217 100644 --- a/fairmq/FairMQMessage.h +++ b/fairmq/FairMQMessage.h @@ -33,7 +33,7 @@ class FairMQMessage virtual void Copy(const std::unique_ptr& msg) __attribute__((deprecated("Use 'Copy(const FairMQMessage& msg)'"))) = 0; virtual void Copy(const FairMQMessage& msg) = 0; - virtual ~FairMQMessage() noexcept(false) {}; + virtual ~FairMQMessage() {}; }; using FairMQMessagePtr = std::unique_ptr; diff --git a/fairmq/ofi/Context.cxx b/fairmq/ofi/Context.cxx index 9924cbaa..b9bb927f 100644 --- a/fairmq/ofi/Context.cxx +++ b/fairmq/ofi/Context.cxx @@ -12,6 +12,7 @@ #include #include +#include #include #include #include @@ -19,6 +20,7 @@ #include #include #include +#include #include namespace fair @@ -45,10 +47,10 @@ Context::~Context() LOG(error) << "Failed closing zmq context, reason: " << zmq_strerror(errno); } - if (fOfiFabric) { - auto ret = fi_close(&fOfiFabric->fid); + if (fOfiAddressVector) { + auto ret = fi_close(&fOfiAddressVector->fid); if (ret != FI_SUCCESS) - LOG(error) << "Failed closing ofi fabric, reason: " << fi_strerror(ret); + LOG(error) << "Failed closing ofi address vector, reason: " << fi_strerror(ret); } if (fOfiDomain) { @@ -56,6 +58,12 @@ Context::~Context() if (ret != FI_SUCCESS) LOG(error) << "Failed closing ofi domain, reason: " << fi_strerror(ret); } + + if (fOfiFabric) { + auto ret = fi_close(&fOfiFabric->fid); + if (ret != FI_SUCCESS) + LOG(error) << "Failed closing ofi fabric, reason: " << fi_strerror(ret); + } } auto Context::GetZmqVersion() const -> std::string @@ -89,23 +97,26 @@ auto Context::InitOfi(ConnectionType type, std::string address) -> void if (!fOfiInfo) { sockaddr_in* sa = static_cast(malloc(sizeof(sockaddr_in))); inet_pton(AF_INET, ip.c_str(), &(sa->sin_addr)); - sa->sin_port = port; + sa->sin_port = htons(port); sa->sin_family = AF_INET; // Prepare fi_getinfo query unique_ptr ofi_hints(fi_allocinfo(), fi_freeinfo); - ofi_hints->caps = FI_MSG; + ofi_hints->caps = FI_MSG | FI_SOURCE; ofi_hints->mode = FI_ASYNC_IOV; ofi_hints->addr_format = FI_SOCKADDR_IN; ofi_hints->fabric_attr->prov_name = strdup("sockets"); ofi_hints->ep_attr->type = FI_EP_RDM; - if (type == ConnectionType::Bind) { - ofi_hints->src_addr = sa; - ofi_hints->src_addrlen = sizeof(sockaddr_in); - } else { + ofi_hints->domain_attr->threading = FI_THREAD_SAFE; + ofi_hints->domain_attr->control_progress = FI_PROGRESS_AUTO; + ofi_hints->domain_attr->data_progress = FI_PROGRESS_AUTO; + // if (type == ConnectionType::Bind) { + // ofi_hints->src_addr = sa; + // ofi_hints->src_addrlen = sizeof(sockaddr_in); + // } else { ofi_hints->dest_addr = sa; ofi_hints->dest_addrlen = sizeof(sockaddr_in); - } + // } // Query fi_getinfo for fabric to use auto res = fi_getinfo(FI_VERSION(1, 5), nullptr, nullptr, 0, ofi_hints.get(), &fOfiInfo); @@ -113,7 +124,7 @@ auto Context::InitOfi(ConnectionType type, std::string address) -> void if (!fOfiInfo) throw ContextError{"Could not find any ofi compatible fabric."}; // for(auto cursor{ofi_info}; cursor->next != nullptr; cursor = cursor->next) { - // LOG(debug) << fi_tostr(cursor, FI_TYPE_INFO); + // LOG(debug) << fi_tostr(fOfiInfo, FI_TYPE_INFO); // } // } else { @@ -122,12 +133,14 @@ auto Context::InitOfi(ConnectionType type, std::string address) -> void OpenOfiFabric(); OpenOfiDomain(); + OpenOfiAddressVector(); } auto Context::OpenOfiFabric() -> void { if (!fOfiFabric) { - auto ret = fi_fabric(fOfiInfo->fabric_attr, &fOfiFabric, NULL); + assert(fOfiInfo); + auto ret = fi_fabric(fOfiInfo->fabric_attr, &fOfiFabric, nullptr); if (ret != FI_SUCCESS) throw ContextError{tools::ToString("Failed opening ofi fabric, reason: ", fi_strerror(ret))}; } else { @@ -140,7 +153,9 @@ auto Context::OpenOfiFabric() -> void auto Context::OpenOfiDomain() -> void { if (!fOfiDomain) { - auto ret = fi_domain(fOfiFabric, fOfiInfo, &fOfiDomain, NULL); + assert(fOfiInfo); + assert(fOfiFabric); + auto ret = fi_domain(fOfiFabric, fOfiInfo, &fOfiDomain, nullptr); if (ret != FI_SUCCESS) throw ContextError{tools::ToString("Failed opening ofi domain, reason: ", fi_strerror(ret))}; } else { @@ -148,19 +163,54 @@ auto Context::OpenOfiDomain() -> void } } +auto Context::OpenOfiAddressVector() -> void +{ + if (!fOfiAddressVector) { + assert(fOfiDomain); + fi_av_attr attr = {fOfiInfo->domain_attr->av_type, 0, 1000, 0, nullptr, nullptr, 0}; +// struct fi_av_attr { +// enum fi_av_type type; [> type of AV <] +// int rx_ctx_bits; [> address bits to identify rx ctx <] +// size_t count; [> # entries for AV <] +// size_t ep_per_node; [> # endpoints per fabric address <] +// const char *name; [> system name of AV <] +// void *map_addr; [> base mmap address <] +// uint64_t flags; [> operation flags <] +// }; + auto ret = fi_av_open(fOfiDomain, &attr, &fOfiAddressVector, nullptr); + if (ret != FI_SUCCESS) + throw ContextError{tools::ToString("Failed opening ofi address vector, reason: ", fi_strerror(ret))}; + } else { + LOG(debug) << "Ofi address vector already opened. Skipping."; + } +} + auto Context::CreateOfiEndpoint() -> fid_ep* { + assert(fOfiDomain); + assert(fOfiInfo); fid_ep* ep = nullptr; auto ret = fi_endpoint(fOfiDomain, fOfiInfo, &ep, nullptr); if (ret != FI_SUCCESS) throw ContextError{tools::ToString("Failed creating ofi endpoint, reason: ", fi_strerror(ret))}; + + assert(fOfiAddressVector); + ret = fi_ep_bind(ep, &fOfiAddressVector->fid, 0); + if (ret != FI_SUCCESS) + throw ContextError{tools::ToString("Failed binding ofi address vector to ofi endpoint, reason: ", fi_strerror(ret))}; + return ep; } -auto Context::CreateOfiCompletionQueue() -> fid_cq* +auto Context::CreateOfiCompletionQueue(Direction dir) -> fid_cq* { fid_cq* cq = nullptr; fi_cq_attr attr = {0, 0, FI_CQ_FORMAT_DATA, FI_WAIT_UNSPEC, 0, FI_CQ_COND_NONE, nullptr}; + if (dir == Direction::Receive) { + attr.size = fOfiInfo->rx_attr->size; + } else { + attr.size = fOfiInfo->tx_attr->size; + } // size_t size; [> # entries for CQ <] // uint64_t flags; [> operation flags <] // enum fi_cq_format format; [> completion format <] diff --git a/fairmq/ofi/Context.h b/fairmq/ofi/Context.h index 75c0c184..7dbd4d19 100644 --- a/fairmq/ofi/Context.h +++ b/fairmq/ofi/Context.h @@ -22,6 +22,7 @@ namespace ofi { enum class ConnectionType : bool { Bind, Connect }; +enum class Direction : bool { Receive, Transmit }; /** * @class Context Context.h @@ -38,7 +39,7 @@ class Context /// Deferred Ofi initialization auto InitOfi(ConnectionType type, std::string address) -> void; auto CreateOfiEndpoint() -> fid_ep*; - auto CreateOfiCompletionQueue() -> fid_cq*; + auto CreateOfiCompletionQueue(Direction dir) -> fid_cq*; auto GetZmqVersion() const -> std::string; auto GetOfiApiVersion() const -> std::string; auto GetZmqContext() const -> void* { return fZmqContext; } @@ -48,9 +49,11 @@ class Context fi_info* fOfiInfo; fid_fabric* fOfiFabric; fid_domain* fOfiDomain; + fid_av* fOfiAddressVector; auto OpenOfiFabric() -> void; auto OpenOfiDomain() -> void; + auto OpenOfiAddressVector() -> void; }; /* class Context */ struct ContextError : std::runtime_error { using std::runtime_error::runtime_error; }; diff --git a/fairmq/ofi/Message.cxx b/fairmq/ofi/Message.cxx index dbc93b6a..571be3b2 100644 --- a/fairmq/ofi/Message.cxx +++ b/fairmq/ofi/Message.cxx @@ -10,6 +10,7 @@ #include #include +#include #include namespace fair @@ -26,9 +27,10 @@ Message::Message() } Message::Message(const size_t size) - : fSize{size} + : fInitialSize(size) + , fSize(size) + , fData(nullptr) { - throw MessageError{"Not yet implemented."}; } Message::Message(void* data, const size_t size, fairmq_free_fn* ffn, void* hint) @@ -58,7 +60,7 @@ auto Message::Rebuild(void* data, const size_t size, fairmq_free_fn* ffn, void* auto Message::GetData() const -> void* { - throw MessageError{"Not implemented."}; + return fData; } auto Message::GetSize() const -> size_t @@ -87,7 +89,7 @@ auto Message::Copy(const fair::mq::MessagePtr& msg) -> void throw MessageError{"Not yet implemented."}; } -Message::~Message() noexcept(false) +Message::~Message() { } diff --git a/fairmq/ofi/Message.h b/fairmq/ofi/Message.h index efd63d2b..a50f8914 100644 --- a/fairmq/ofi/Message.h +++ b/fairmq/ofi/Message.h @@ -55,10 +55,12 @@ class Message : public fair::mq::Message auto Copy(const fair::mq::Message& msg) -> void override; auto Copy(const fair::mq::MessagePtr& msg) -> void override; - ~Message() noexcept(false) override; + ~Message() override; private: + size_t fInitialSize; size_t fSize; + void* fData; }; /* class Message */ } /* namespace ofi */ diff --git a/fairmq/ofi/Socket.cxx b/fairmq/ofi/Socket.cxx index a0df61eb..781e0c73 100644 --- a/fairmq/ofi/Socket.cxx +++ b/fairmq/ofi/Socket.cxx @@ -13,6 +13,7 @@ #include #include +#include #include #include @@ -27,6 +28,8 @@ using namespace std; Socket::Socket(Context& context, const string& type, const string& name, const string& id /*= ""*/) : fDataEndpoint(nullptr) + , fDataCompletionQueueTx(nullptr) + , fDataCompletionQueueRx(nullptr) , fId(id + "." + name + "." + type) , fSndTimeout(100) , fRcvTimeout(100) @@ -64,6 +67,7 @@ Socket::Socket(Context& context, const string& type, const string& name, const s auto Socket::Bind(const string& address) -> bool { + // TODO handle verbs:// if (zmq_bind(fMetaSocket, address.c_str()) != 0) { if (errno == EADDRINUSE) { // do not print error in this case, this is handled by FairMQDevice @@ -76,28 +80,11 @@ auto Socket::Bind(const string& address) -> bool fContext.InitOfi(ConnectionType::Bind, address); - if (!fDataEndpoint) { - try { - fDataEndpoint = fContext.CreateOfiEndpoint(); - } catch (ContextError& e) { - LOG(error) << "Failed creating ofi endpoint for " << address << ", reason: " << e.what(); - } - - if (!fDataCompletionQueueTx) - fDataCompletionQueueTx = fContext.CreateOfiCompletionQueue(); - auto ret = fi_ep_bind(fDataEndpoint, &fDataCompletionQueueTx->fid, FI_TRANSMIT); - if (ret != FI_SUCCESS) - LOG(error) << "Failed binding ofi transmit completion queue to endpoint, reason: " << fi_strerror(ret); - - if (!fDataCompletionQueueRx) - fDataCompletionQueueRx = fContext.CreateOfiCompletionQueue(); - ret = fi_ep_bind(fDataEndpoint, &fDataCompletionQueueRx->fid, FI_RECV); - if (ret != FI_SUCCESS) - LOG(error) << "Failed binding ofi receive completion queue to endpoint, reason: " << fi_strerror(ret); - - ret = fi_enable(fDataEndpoint); - if (ret != FI_SUCCESS) - LOG(error) << "Failed opening ofi fabric, reason: " << fi_strerror(ret); + try { + InitDataEndpoint(); + } catch (SocketError& e) { + LOG(error) << e.what(); + return false; } return true; @@ -105,34 +92,40 @@ auto Socket::Bind(const string& address) -> bool auto Socket::Connect(const string& address) -> void { + // TODO handle verbs:// if (zmq_connect(fMetaSocket, address.c_str()) != 0) { throw SocketError{tools::ToString("Failed connecting socket ", fId, ", reason: ", zmq_strerror(errno))}; } fContext.InitOfi(ConnectionType::Connect, address); + InitDataEndpoint(); +} + +auto Socket::InitDataEndpoint() -> void +{ if (!fDataEndpoint) { try { fDataEndpoint = fContext.CreateOfiEndpoint(); } catch (ContextError& e) { - throw SocketError(tools::ToString("Failed creating ofi endpoint for ", address, ", reason: ", e.what())); + throw SocketError(tools::ToString("Failed creating ofi endpoint, reason: ", e.what())); } if (!fDataCompletionQueueTx) - fDataCompletionQueueTx = fContext.CreateOfiCompletionQueue(); + fDataCompletionQueueTx = fContext.CreateOfiCompletionQueue(Direction::Transmit); auto ret = fi_ep_bind(fDataEndpoint, &fDataCompletionQueueTx->fid, FI_TRANSMIT); if (ret != FI_SUCCESS) - LOG(error) << "Failed binding ofi transmit completion queue to endpoint, reason: " << fi_strerror(ret); + throw SocketError(tools::ToString("Failed binding ofi transmit completion queue to endpoint, reason: ", fi_strerror(ret))); if (!fDataCompletionQueueRx) - fDataCompletionQueueRx = fContext.CreateOfiCompletionQueue(); + fDataCompletionQueueRx = fContext.CreateOfiCompletionQueue(Direction::Receive); ret = fi_ep_bind(fDataEndpoint, &fDataCompletionQueueRx->fid, FI_RECV); if (ret != FI_SUCCESS) - LOG(error) << "Failed binding ofi receive completion queue to endpoint, reason: " << fi_strerror(ret); + throw SocketError(tools::ToString("Failed binding ofi receive completion queue to endpoint, reason: ", fi_strerror(ret))); ret = fi_enable(fDataEndpoint); if (ret != FI_SUCCESS) - throw SocketError{tools::ToString("Failed opening ofi fabric, reason: ", fi_strerror(ret))}; + throw SocketError(tools::ToString("Failed opening ofi fabric, reason: ", fi_strerror(ret))); } } @@ -148,7 +141,8 @@ auto Socket::TryReceive(std::vector& msgVec) -> int64_t { return Rec auto Socket::SendImpl(FairMQMessagePtr& msg, const int flags, const int timeout) -> int { - auto metadata = new int; + void* metadata = malloc(sizeof(size_t)); + auto ret = zmq_send(fMetaSocket, &metadata, sizeof(int), flags); if (ret == EAGAIN) { return -2; @@ -156,6 +150,7 @@ auto Socket::SendImpl(FairMQMessagePtr& msg, const int flags, const int timeout) LOG(error) << "Failed sending meta message on socket " << fId << ", reason: " << zmq_strerror(errno); return -1; } else { + // auto ret2 = fi_send(fDataEndpoint, msg->GetData(), msg->GetSize(), nullptr, fi_addr_t dest_addr, nullptr); return ret; } } diff --git a/fairmq/ofi/Socket.h b/fairmq/ofi/Socket.h index dcf8b264..77820ca3 100644 --- a/fairmq/ofi/Socket.h +++ b/fairmq/ofi/Socket.h @@ -92,6 +92,8 @@ class Socket : public fair::mq::Socket auto ReceiveImpl(MessagePtr& msg, const int flags, const int timeout) -> int; auto SendImpl(std::vector& msgVec, const int flags, const int timeout) -> int64_t; auto ReceiveImpl(std::vector& msgVec, const int flags, const int timeout) -> int64_t; + + auto InitDataEndpoint() -> void; }; /* class Socket */ } /* namespace ofi */