FairMQ: Implement blocking ofi::Socket Send/Receive with FI_MSG

Completion events are not yet working.
This commit is contained in:
Dennis Klein 2018-03-07 00:14:13 +01:00 committed by Mohammad Al-Turany
parent 8f5b888314
commit 144aa912d7
11 changed files with 189 additions and 32 deletions

View File

@ -209,7 +209,7 @@ add_custom_command(
${CMAKE_CURRENT_BINARY_DIR}/ofi/Control.pb.cc ${CMAKE_CURRENT_BINARY_DIR}/ofi/Control.pb.cc
COMMAND ${PROTOBUF_PROTOC_EXECUTABLE} -I=${CMAKE_CURRENT_SOURCE_DIR}/ofi --cpp_out=${CMAKE_CURRENT_BINARY_DIR}/ofi Control.proto COMMAND ${PROTOBUF_PROTOC_EXECUTABLE} -I=${CMAKE_CURRENT_SOURCE_DIR}/ofi --cpp_out=${CMAKE_CURRENT_BINARY_DIR}/ofi Control.proto
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
DEPENDS mkofibuilddir DEPENDS mkofibuilddir ${CMAKE_CURRENT_SOURCE_DIR}/ofi/Control.proto
) )
set_source_files_properties(${CMAKE_CURRENT_BINARY_DIR}/ofi/Control.pb.h PROPERTIES GENERATED TRUE) set_source_files_properties(${CMAKE_CURRENT_BINARY_DIR}/ofi/Control.pb.h PROPERTIES GENERATED TRUE)
set_source_files_properties(${CMAKE_CURRENT_BINARY_DIR}/ofi/Control.pb.cc PROPERTIES GENERATED TRUE) set_source_files_properties(${CMAKE_CURRENT_BINARY_DIR}/ofi/Control.pb.cc PROPERTIES GENERATED TRUE)

View File

@ -11,6 +11,7 @@
#include <FairMQLogger.h> #include <FairMQLogger.h>
#include <arpa/inet.h> #include <arpa/inet.h>
#include <boost/version.hpp>
#include <cstring> #include <cstring>
#include <google/protobuf/stubs/common.h> #include <google/protobuf/stubs/common.h>
#include <memory> #include <memory>
@ -41,18 +42,37 @@ Context::Context(int numberIoThreads)
, fOfiAddressVector(nullptr) , fOfiAddressVector(nullptr)
, fOfiEventQueue(nullptr) , fOfiEventQueue(nullptr)
, fZmqContext(zmq_ctx_new()) , fZmqContext(zmq_ctx_new())
, fIoWork(fIoContext)
{ {
if (!fZmqContext) if (!fZmqContext)
throw ContextError{tools::ToString("Failed creating zmq context, reason: ", zmq_strerror(errno))}; throw ContextError{tools::ToString("Failed creating zmq context, reason: ", zmq_strerror(errno))};
GOOGLE_PROTOBUF_VERIFY_VERSION; GOOGLE_PROTOBUF_VERIFY_VERSION;
InitThreadPool(numberIoThreads);
}
auto Context::InitThreadPool(int numberIoThreads) -> void
{
assert(numberIoThreads > 0);
for (int i = 1; i <= numberIoThreads; ++i) {
fThreadPool.emplace_back([&, i, numberIoThreads]{
LOG(debug) << "I/O thread #" << i << "/" << numberIoThreads << " started";
fIoContext.run();
LOG(debug) << "I/O thread #" << i << "/" << numberIoThreads << " stopped";
});
}
} }
Context::~Context() Context::~Context()
{ {
if (zmq_ctx_term(fZmqContext) != 0) { fIoContext.stop();
for (auto& thread : fThreadPool)
thread.join();
if (zmq_ctx_term(fZmqContext) != 0)
LOG(error) << "Failed closing zmq context, reason: " << zmq_strerror(errno); LOG(error) << "Failed closing zmq context, reason: " << zmq_strerror(errno);
}
if (fOfiEventQueue) { if (fOfiEventQueue) {
auto ret = fi_close(&fOfiEventQueue->fid); auto ret = fi_close(&fOfiEventQueue->fid);
@ -97,6 +117,11 @@ auto Context::GetPbVersion() const -> string
return google::protobuf::internal::VersionString(GOOGLE_PROTOBUF_VERSION); return google::protobuf::internal::VersionString(GOOGLE_PROTOBUF_VERSION);
} }
auto Context::GetBoostVersion() const -> std::string
{
return tools::ToString(BOOST_VERSION / 100000, ".", BOOST_VERSION / 100 % 1000, ".", BOOST_VERSION % 100);
}
auto Context::InitOfi(ConnectionType type, std::string addr) -> void auto Context::InitOfi(ConnectionType type, std::string addr) -> void
{ {
auto addr2 = ConvertAddress(addr); auto addr2 = ConvertAddress(addr);

View File

@ -9,12 +9,15 @@
#ifndef FAIR_MQ_OFI_CONTEXT_H #ifndef FAIR_MQ_OFI_CONTEXT_H
#define FAIR_MQ_OFI_CONTEXT_H #define FAIR_MQ_OFI_CONTEXT_H
#include <boost/asio.hpp>
#include <memory> #include <memory>
#include <netinet/in.h> #include <netinet/in.h>
#include <ostream> #include <ostream>
#include <rdma/fabric.h> #include <rdma/fabric.h>
#include <string>
#include <stdexcept> #include <stdexcept>
#include <string>
#include <thread>
#include <vector>
namespace fair namespace fair
{ {
@ -35,7 +38,7 @@ enum class Direction : bool { Receive, Transmit };
class Context class Context
{ {
public: public:
Context(int numberIoThreads = 1); Context(int numberIoThreads = 2);
~Context(); ~Context();
auto InitOfi(ConnectionType type, std::string address) -> void; auto InitOfi(ConnectionType type, std::string address) -> void;
@ -44,7 +47,9 @@ class Context
auto GetZmqVersion() const -> std::string; auto GetZmqVersion() const -> std::string;
auto GetOfiApiVersion() const -> std::string; auto GetOfiApiVersion() const -> std::string;
auto GetPbVersion() const -> std::string; auto GetPbVersion() const -> std::string;
auto GetBoostVersion() const -> std::string;
auto GetZmqContext() const -> void* { return fZmqContext; } auto GetZmqContext() const -> void* { return fZmqContext; }
auto GetIoContext() -> boost::asio::io_service& { return fIoContext; }
auto InsertAddressVector(sockaddr_in address) -> fi_addr_t; auto InsertAddressVector(sockaddr_in address) -> fi_addr_t;
struct Address { struct Address {
std::string Protocol; std::string Protocol;
@ -64,11 +69,15 @@ class Context
fid_domain* fOfiDomain; fid_domain* fOfiDomain;
fid_av* fOfiAddressVector; fid_av* fOfiAddressVector;
fid_eq* fOfiEventQueue; fid_eq* fOfiEventQueue;
boost::asio::io_service fIoContext;
boost::asio::io_service::work fIoWork;
std::vector<std::thread> fThreadPool;
auto OpenOfiFabric() -> void; auto OpenOfiFabric() -> void;
auto OpenOfiEventQueue() -> void; auto OpenOfiEventQueue() -> void;
auto OpenOfiDomain() -> void; auto OpenOfiDomain() -> void;
auto OpenOfiAddressVector() -> void; auto OpenOfiAddressVector() -> void;
auto InitThreadPool(int numberIoThreads) -> void;
}; /* class Context */ }; /* class Context */
struct ContextError : std::runtime_error { using std::runtime_error::runtime_error; }; struct ContextError : std::runtime_error { using std::runtime_error::runtime_error; };

View File

@ -8,8 +8,18 @@ message DataAddressAnnouncement {
uint32 port = 2; // in_port_t from <netinet/in.h> uint32 port = 2; // in_port_t from <netinet/in.h>
} }
message PostBuffer {
uint64 size = 1; // buffer size (size_t)
}
message PostBufferAcknowledgement {
uint64 size = 1; // size_t
}
message ControlMessage { message ControlMessage {
oneof type { oneof type {
DataAddressAnnouncement data_address_announcement = 1; DataAddressAnnouncement data_address_announcement = 1;
PostBuffer post_buffer = 2;
PostBufferAcknowledgement post_buffer_acknowledgement = 3;
} }
} }

View File

@ -23,6 +23,11 @@ namespace ofi
using namespace std; using namespace std;
Message::Message() Message::Message()
: fInitialSize(0)
, fSize(0)
, fData(nullptr)
, fFreeFunction(nullptr)
, fHint(nullptr)
{ {
} }
@ -30,12 +35,18 @@ Message::Message(const size_t size)
: fInitialSize(size) : fInitialSize(size)
, fSize(size) , fSize(size)
, fData(nullptr) , fData(nullptr)
, fFreeFunction(nullptr)
, fHint(nullptr)
{ {
} }
Message::Message(void* data, const size_t size, fairmq_free_fn* ffn, void* hint) Message::Message(void* data, const size_t size, fairmq_free_fn* ffn, void* hint)
: fInitialSize(size)
, fSize(size)
, fData(data)
, fFreeFunction(ffn)
, fHint(hint)
{ {
throw MessageError{"Not yet implemented."};
} }
Message::Message(FairMQUnmanagedRegionPtr& region, void* data, const size_t size, void* hint) Message::Message(FairMQUnmanagedRegionPtr& region, void* data, const size_t size, void* hint)
@ -45,17 +56,48 @@ Message::Message(FairMQUnmanagedRegionPtr& region, void* data, const size_t size
auto Message::Rebuild() -> void auto Message::Rebuild() -> void
{ {
throw MessageError{"Not implemented."}; if (fFreeFunction) {
fFreeFunction(fData, fHint);
} else {
free(fData);
}
fData = nullptr;
fInitialSize = 0;
fSize = 0;
fFreeFunction = nullptr;
fHint = nullptr;
} }
auto Message::Rebuild(const size_t size) -> void auto Message::Rebuild(const size_t size) -> void
{ {
throw MessageError{"Not implemented."}; if (fFreeFunction) {
fFreeFunction(fData, fHint);
fData = nullptr;
fData = malloc(size);
} else {
fData = realloc(fData, size);
}
assert(fData);
fInitialSize = size;
fSize = size;
fFreeFunction = nullptr;
fHint = nullptr;
} }
auto Message::Rebuild(void* data, const size_t size, fairmq_free_fn* ffn, void* hint) -> void auto Message::Rebuild(void* data, const size_t size, fairmq_free_fn* ffn, void* hint) -> void
{ {
throw MessageError{"Not implemented."}; if (fFreeFunction) {
fFreeFunction(fData, fHint);
fData = nullptr;
fData = malloc(size);
} else {
fData = realloc(fData, size);
}
assert(fData);
fInitialSize = size;
fSize = size;
fFreeFunction = ffn;
fHint = hint;
} }
auto Message::GetData() const -> void* auto Message::GetData() const -> void*
@ -91,6 +133,11 @@ auto Message::Copy(const fair::mq::MessagePtr& msg) -> void
Message::~Message() Message::~Message()
{ {
if (fFreeFunction) {
fFreeFunction(fData, fHint);
} else {
free(fData);
}
} }
} /* namespace ofi */ } /* namespace ofi */

View File

@ -61,6 +61,8 @@ class Message : public fair::mq::Message
size_t fInitialSize; size_t fInitialSize;
size_t fSize; size_t fSize;
void* fData; void* fData;
fairmq_free_fn* fFreeFunction;
void* fHint;
}; /* class Message */ }; /* class Message */
} /* namespace ofi */ } /* namespace ofi */

View File

@ -42,6 +42,7 @@ Socket::Socket(Context& context, const string& type, const string& name, const s
, fRcvTimeout(100) , fRcvTimeout(100)
, fContext(context) , fContext(context)
, fWaitingForControlPeer(false) , fWaitingForControlPeer(false)
, fIoStrand(fContext.GetIoContext())
{ {
if (type != "pair") { if (type != "pair") {
throw SocketError{tools::ToString("Socket type '", type, "' not implemented for ofi transport.")}; throw SocketError{tools::ToString("Socket type '", type, "' not implemented for ofi transport.")};
@ -206,6 +207,7 @@ try {
auto Socket::SendControlMessage(unique_ptr<ControlMessage> ctrl) -> void auto Socket::SendControlMessage(unique_ptr<ControlMessage> ctrl) -> void
{ {
assert(fControlSocket); assert(fControlSocket);
LOG(debug) << "About to send control message: " << ctrl->DebugString();
// Serialize // Serialize
string* str = new string(); string* str = new string();
@ -234,6 +236,7 @@ auto Socket::ReceiveControlMessage() -> unique_ptr<ControlMessage>
auto ctrl = tools::make_unique<ControlMessage>(); auto ctrl = tools::make_unique<ControlMessage>();
ctrl->ParseFromArray(zmq_msg_data(&msg), zmq_msg_size(&msg)); ctrl->ParseFromArray(zmq_msg_data(&msg), zmq_msg_size(&msg));
LOG(debug) << "Received control message: " << ctrl->DebugString();
return ctrl; return ctrl;
} }
@ -293,11 +296,36 @@ try {
ProcessDataAddressAnnouncement(ReceiveControlMessage()); ProcessDataAddressAnnouncement(ReceiveControlMessage());
} }
auto ret = zmq_send(fControlSocket, nullptr, 0, flags); auto size = msg->GetSize();
if (ret == EAGAIN) throw SilentSocketError("EAGAIN");
if (ret == -1) throw SocketError(tools::ToString("Failed sending control message on socket ", fId, ", reason: ", zmq_strerror(errno)));
return ret; // Create and send control message
auto ctrl = tools::make_unique<ControlMessage>();
auto buf = tools::make_unique<PostBuffer>();
buf->set_size(size);
ctrl->set_allocated_post_buffer(buf.release());
assert(ctrl->IsInitialized());
SendControlMessage(move(ctrl));
if (size) {
// Receive and process control message
auto ctrl2 = ReceiveControlMessage();
assert(ctrl2->has_post_buffer_acknowledgement());
assert(ctrl2->post_buffer_acknowledgement().size() == size);
// Send data
auto ret = fi_send(fDataEndpoint, msg->GetData(), size, nullptr, fRemoteDataAddr, nullptr);
if (ret != FI_SUCCESS)
throw SocketError(tools::ToString("Failed posting ofi send buffer, reason: ", fi_strerror(ret)));
fi_cq_err_entry cqEntry;
ret = fi_cq_sread(fDataCompletionQueueTx, &cqEntry, 1, nullptr, 1000);
if (ret != 1)
throw SocketError(tools::ToString("Failed reading ofi tx completion queue event, reason: ", fi_strerror(ret)));
}
// TODO free msg on tx completion?
return size;
} }
catch (const SilentSocketError& e) catch (const SilentSocketError& e)
{ {
@ -317,11 +345,35 @@ try {
ProcessDataAddressAnnouncement(ReceiveControlMessage()); ProcessDataAddressAnnouncement(ReceiveControlMessage());
} }
auto ret = zmq_recv(fControlSocket, nullptr, 0, flags); // Receive and process control message
if (ret == EAGAIN) throw SilentSocketError("EAGAIN"); auto ctrl = ReceiveControlMessage();
if (ret == -1) throw SocketError(tools::ToString("Failed sending control message on socket ", fId, ", reason: ", zmq_strerror(errno))); assert(ctrl->has_post_buffer());
auto postBuffer = ctrl->post_buffer();
auto size = postBuffer.size();
LOG(debug) << "Received post buffer control message with size: " << size;
return ret; // Receive data
if (size) {
msg->Rebuild(size);
auto ret = fi_recv(fDataEndpoint, msg->GetData(), msg->GetSize(), nullptr, fRemoteDataAddr, nullptr);
if (ret != FI_SUCCESS)
throw SocketError(tools::ToString("Failed posting ofi receive buffer, reason: ", fi_strerror(ret)));
// Create and send control message
auto ctrl2 = tools::make_unique<ControlMessage>();
auto ack = tools::make_unique<PostBufferAcknowledgement>();
ack->set_size(msg->GetSize());
ctrl2->set_allocated_post_buffer_acknowledgement(ack.release());
assert(ctrl2->IsInitialized());
SendControlMessage(move(ctrl2));
fi_cq_err_entry cqEntry;
ret = fi_cq_sread(fDataCompletionQueueRx, &cqEntry, 1, nullptr, 1000);
if (ret != 1)
throw SocketError(tools::ToString("Failed reading ofi rx completion queue event, reason: ", fi_strerror(ret)));
}
return size;
} }
catch (const SilentSocketError& e) catch (const SilentSocketError& e)
{ {

View File

@ -14,6 +14,7 @@
#include <fairmq/ofi/Context.h> #include <fairmq/ofi/Context.h>
#include <fairmq/ofi/Control.pb.h> #include <fairmq/ofi/Control.pb.h>
#include <boost/asio.hpp>
#include <memory> // unique_ptr #include <memory> // unique_ptr
#include <netinet/in.h> #include <netinet/in.h>
#include <rdma/fabric.h> #include <rdma/fabric.h>
@ -90,6 +91,7 @@ class Socket : public fair::mq::Socket
fi_addr_t fRemoteDataAddr; fi_addr_t fRemoteDataAddr;
sockaddr_in fLocalDataAddr; sockaddr_in fLocalDataAddr;
bool fWaitingForControlPeer; bool fWaitingForControlPeer;
boost::asio::io_service::strand fIoStrand;
int fSndTimeout; int fSndTimeout;
int fRcvTimeout; int fRcvTimeout;

View File

@ -28,9 +28,13 @@ try : FairMQTransportFactory{id}
{ {
LOG(debug) << "Transport: Using ZeroMQ (" << fContext.GetZmqVersion() << ") & " LOG(debug) << "Transport: Using ZeroMQ (" << fContext.GetZmqVersion() << ") & "
<< "OFI libfabric (API " << fContext.GetOfiApiVersion() << ") & " << "OFI libfabric (API " << fContext.GetOfiApiVersion() << ") & "
<< "Google Protobuf (" << fContext.GetPbVersion() << ")"; << "Google Protobuf (" << fContext.GetPbVersion() << ") & "
<< "Boost.Asio (" << fContext.GetBoostVersion() << ")";
}
catch (ContextError& e)
{
throw TransportFactoryError{e.what()};
} }
catch (ContextError& e) { throw TransportFactoryError{e.what()}; }
auto TransportFactory::CreateMessage() const -> MessagePtr auto TransportFactory::CreateMessage() const -> MessagePtr
{ {

View File

@ -34,19 +34,23 @@ class PairLeft : public FairMQDevice
// Simple empty message ping pong // Simple empty message ping pong
auto msg1{NewMessageFor("data", 0)}; auto msg1{NewMessageFor("data", 0)};
auto msg2{NewMessageFor("data", 0)};
auto msg3{NewMessageFor("data", 0)};
if (Send(msg1, "data") >= 0) counter++; if (Send(msg1, "data") >= 0) counter++;
auto msg2{NewMessageFor("data", 0)};
if (Receive(msg2, "data") >= 0) counter++; if (Receive(msg2, "data") >= 0) counter++;
if (Send(msg2, "data") >= 0) counter++; auto msg3{NewMessageFor("data", 0)};
if (Receive(msg3, "data") >= 0) counter++; if (Send(msg3, "data") >= 0) counter++;
auto msg4{NewMessageFor("data", 0)};
if (Receive(msg4, "data") >= 0) counter++;
if (counter == 4) LOG(info) << "Simple empty message ping pong successfull"; if (counter == 4) LOG(info) << "Simple empty message ping pong successfull";
// Simple message with short text data // Simple message with short text data
auto msg4{NewSimpleMessageFor("data", 0, "testdata1234")}; auto msg5{NewSimpleMessageFor("data", 0, "testdata1234")};
if (Send(msg4, "data") >= 0) counter++; LOG(info) << "Will send msg5";
if (Send(msg5, "data") >= 0) counter++;
LOG(info) << "Sent msg5";
if (counter == 5) LOG(info) << "Simple message with short text data successfull"; if (counter == 5) LOG(info) << "Simple message with short text data successfull";
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
assert(counter == 5); assert(counter == 5);
}; };
}; };

View File

@ -36,19 +36,21 @@ class PairRight : public FairMQDevice
// Simple empty message ping pong // Simple empty message ping pong
auto msg1{NewMessageFor("data", 0)}; auto msg1{NewMessageFor("data", 0)};
if (Receive(msg1, "data") >= 0) counter++; if (Receive(msg1, "data") >= 0) counter++;
if (Send(msg1, "data") >= 0) counter++;
auto msg2{NewMessageFor("data", 0)}; auto msg2{NewMessageFor("data", 0)};
if (Receive(msg2, "data") >= 0) counter++;
if (Send(msg2, "data") >= 0) counter++; if (Send(msg2, "data") >= 0) counter++;
auto msg3{NewMessageFor("data", 0)};
if (Receive(msg3, "data") >= 0) counter++;
auto msg4{NewMessageFor("data", 0)};
if (Send(msg4, "data") >= 0) counter++;
if (counter == 4) LOG(info) << "Simple empty message ping pong successfull"; if (counter == 4) LOG(info) << "Simple empty message ping pong successfull";
// Simple message with short text data // Simple message with short text data
auto msg3{NewMessageFor("data", 0)}; auto msg5{NewMessageFor("data", 0)};
auto ret = Receive(msg3, "data"); auto ret = Receive(msg5, "data");
if (ret > 0) { if (ret > 0) {
auto content = std::string{static_cast<char*>(msg3->GetData()), msg3->GetSize()}; auto content = std::string{static_cast<char*>(msg5->GetData()), msg5->GetSize()};
LOG(info) << ret << ", " << msg3->GetSize() << ", '" << content << "'"; LOG(info) << ret << ", " << msg5->GetSize() << ", '" << content << "'";
if (msg3->GetSize() == ret && content == "testdata1234") counter++; if (msg5->GetSize() == ret && content == "testdata1234") counter++;
} }
if (counter == 5) LOG(info) << "Simple message with short text data successfull"; if (counter == 5) LOG(info) << "Simple message with short text data successfull";