Implement parallel ofi::Socket::Receive

This commit is contained in:
Dennis Klein 2018-11-20 12:45:46 +01:00 committed by Dennis Klein
parent 46e2420547
commit 8e7cfacd78
5 changed files with 160 additions and 83 deletions

View File

@ -31,11 +31,12 @@ namespace ofi
using namespace std;
Context::Context(int numberIoThreads)
Context::Context(FairMQTransportFactory& receiveFactory, int numberIoThreads)
: fOfiInfo(nullptr)
, fOfiFabric(nullptr)
, fOfiDomain(nullptr)
, fIoWork(fIoContext)
, fReceiveFactory(receiveFactory)
{
InitThreadPool(numberIoThreads);
}
@ -176,6 +177,11 @@ auto Context::VerifyAddress(const std::string& address) -> Address
return addr;
}
auto Context::MakeReceiveMessage(size_t size) -> MessagePtr
{
return fReceiveFactory.CreateMessage(size);
}
} /* namespace ofi */
} /* namespace mq */
} /* namespace fair */

View File

@ -10,6 +10,7 @@
#define FAIR_MQ_OFI_CONTEXT_H
#include <FairMQLogger.h>
#include <FairMQTransportFactory.h>
#include <asiofi/connected_endpoint.hpp>
#include <asiofi/domain.hpp>
@ -44,7 +45,7 @@ enum class Direction : bool { Receive, Transmit };
class Context
{
public:
Context(int numberIoThreads = 1);
Context(FairMQTransportFactory& receiveFactory, int numberIoThreads = 1);
~Context();
// auto CreateOfiEndpoint() -> fid_ep*;
@ -66,6 +67,7 @@ class Context
auto GetDomain() const -> const asiofi::domain& { return *fOfiDomain; }
auto Interrupt() -> void { LOG(debug) << "OFI transport: Interrupted (NOOP - not implemented)."; }
auto Resume() -> void { LOG(debug) << "OFI transport: Resumed (NOOP - not implemented)."; }
auto MakeReceiveMessage(size_t size) -> MessagePtr;
private:
std::unique_ptr<asiofi::info> fOfiInfo;
@ -74,6 +76,7 @@ class Context
boost::asio::io_context fIoContext;
boost::asio::io_context::work fIoWork;
std::vector<std::thread> fThreadPool;
FairMQTransportFactory& fReceiveFactory;
auto InitThreadPool(int numberIoThreads) -> void;
auto InitOfi(ConnectionType type, Address address) -> void;

View File

@ -49,8 +49,11 @@ Socket::Socket(Context& context, const string& type, const string& name, const s
, fControlEndpoint(fIoStrand.context(), ZMQ_PAIR)
, fSndTimeout(100)
, fRcvTimeout(100)
, fQueue1(fIoStrand.context())
, fQueue2(fIoStrand.context())
, fSendQueueWrite(fIoStrand.context(), ZMQ_PUSH)
, fSendQueueRead(fIoStrand.context(), ZMQ_PULL)
, fRecvQueueWrite(fIoStrand.context(), ZMQ_PUSH)
, fRecvQueueRead(fIoStrand.context(), ZMQ_PULL)
, fSentCount(0)
{
if (type != "pair") {
throw SocketError{tools::ToString("Socket type '", type, "' not implemented for ofi transport.")};
@ -63,17 +66,28 @@ Socket::Socket(Context& context, const string& type, const string& name, const s
// Setup internal queue
auto hashed_id = std::hash<std::string>()(fId);
auto queue_id = tools::ToString("inproc://QUEUE", hashed_id);
LOG(debug) << "OFI transport (" << fId << "): " << "Binding Q1: " << queue_id;
fQueue1.bind(queue_id);
LOG(debug) << "OFI transport (" << fId << "): " << "Connecting Q2: " << queue_id;
fQueue2.connect(queue_id);
azmq::socket::snd_hwm send_max(100);
azmq::socket::rcv_hwm recv_max(100);
fQueue1.set_option(send_max);
fQueue1.set_option(recv_max);
fQueue2.set_option(send_max);
fQueue2.set_option(recv_max);
auto queue_id = tools::ToString("inproc://TXQUEUE", hashed_id);
LOG(debug) << "OFI transport (" << fId << "): " << "Binding SQR: " << queue_id;
fSendQueueRead.bind(queue_id);
LOG(debug) << "OFI transport (" << fId << "): " << "Connecting SQW: " << queue_id;
fSendQueueWrite.connect(queue_id);
queue_id = tools::ToString("inproc://RXQUEUE", hashed_id);
LOG(debug) << "OFI transport (" << fId << "): " << "Binding RQR: " << queue_id;
fRecvQueueRead.bind(queue_id);
LOG(debug) << "OFI transport (" << fId << "): " << "Connecting RQW: " << queue_id;
fRecvQueueWrite.connect(queue_id);
// TODO wire this up with config
azmq::socket::snd_hwm send_max(10);
azmq::socket::rcv_hwm recv_max(10);
fSendQueueRead.set_option(send_max);
fSendQueueRead.set_option(recv_max);
fSendQueueWrite.set_option(send_max);
fSendQueueWrite.set_option(recv_max);
fRecvQueueRead.set_option(send_max);
fRecvQueueRead.set_option(recv_max);
fSendQueueWrite.set_option(send_max);
fSendQueueWrite.set_option(recv_max);
fControlEndpoint.set_option(send_max);
fControlEndpoint.set_option(recv_max);
}
@ -90,7 +104,8 @@ try {
fLocalDataAddr = addr;
BindDataEndpoint();
boost::asio::post(fIoStrand, std::bind(&Socket::SendQueueReader, this));
// boost::asio::post(fIoStrand, std::bind(&Socket::SendQueueReader, this));
boost::asio::post(fIoStrand, std::bind(&Socket::RecvControlQueueReader, this));
return true;
}
@ -116,6 +131,9 @@ auto Socket::Connect(const string& address) -> bool
ReceiveDataAddressAnnouncement();
ConnectDataEndpoint();
// boost::asio::post(fIoStrand, std::bind(&Socket::SendQueueReader, this));
boost::asio::post(fIoStrand, std::bind(&Socket::RecvControlQueueReader, this));
}
auto Socket::BindControlEndpoint(Context::Address address) -> void
@ -225,11 +243,13 @@ auto Socket::AnnounceDataAddress() -> void
auto Socket::Send(MessagePtr& msg, const int /*timeout*/) -> int
{
LOG(debug) << "OFI transport (" << fId << "): ENTER Send: size=" << msg->GetSize();
LOG(debug) << "OFI transport (" << fId << "): ENTER Send: data=" << msg->GetData() << ",size=" << msg->GetSize();
MessagePtr* msgptr(new std::unique_ptr<Message>(std::move(msg)));
try {
auto res = fQueue1.send(boost::asio::const_buffer(msgptr, sizeof(MessagePtr)), 0);
++fSentCount;
LOG(info) << fSentCount;
auto res = fSendQueueWrite.send(boost::asio::const_buffer(msgptr, sizeof(MessagePtr)), 0);
LOG(debug) << "OFI transport (" << fId << "): LEAVE Send";
return res;
@ -244,16 +264,44 @@ auto Socket::Send(MessagePtr& msg, const int /*timeout*/) -> int
}
}
auto Socket::Receive(MessagePtr& msg, const int timeout) -> int { return 0; /*ReceiveImpl(msg, 0, timeout);*/ }
auto Socket::Receive(MessagePtr& msg, const int /*timeout*/) -> int
{
LOG(debug) << "OFI transport (" << fId << "): ENTER Receive";
try {
azmq::message zmsg;
auto recv = fRecvQueueRead.receive(zmsg);
size_t size(0);
if (recv > 0) {
msg = std::move(*(static_cast<MessagePtr*>(zmsg.buffer().data())));
size = msg->GetSize();
}
fBytesRx += size;
fMessagesRx++;
LOG(debug) << "OFI transport (" << fId << "): LEAVE Receive";
return size;
} catch (const std::exception& e) {
LOG(error) << e.what();
return -1;
} catch (const boost::system::error_code& e) {
LOG(error) << e;
return -1;
}
}
auto Socket::Send(std::vector<MessagePtr>& msgVec, const int timeout) -> int64_t { return SendImpl(msgVec, 0, timeout); }
auto Socket::Receive(std::vector<MessagePtr>& msgVec, const int timeout) -> int64_t { return ReceiveImpl(msgVec, 0, timeout); }
auto Socket::SendQueueReader() -> void
{
fQueue2.async_receive(boost::asio::bind_executor(
fSendQueueRead.async_receive(boost::asio::bind_executor(
fIoStrand,
[&](const boost::system::error_code& ec, azmq::message& zmsg, size_t bytes_transferred) {
if (!ec) {
--fSentCount;
OnSend(zmsg, bytes_transferred);
}
}));
@ -266,7 +314,7 @@ auto Socket::OnSend(azmq::message& zmsg, size_t bytes_transferred) -> void
MessagePtr msg(std::move(*(static_cast<MessagePtr*>(zmsg.buffer().data()))));
auto size = msg->GetSize();
LOG(debug) << "OFI transport (" << fId << "): >>>>> OnSend: size=" << size;
LOG(debug) << "OFI transport (" << fId << "): OnSend: data=" << msg->GetData() << ",size=" << msg->GetSize();
// Create and send control message
auto pb = MakeControlMessage<PostBuffer>();
@ -284,7 +332,9 @@ auto Socket::OnSend(azmq::message& zmsg, size_t bytes_transferred) -> void
auto Socket::OnControlMessageSent(size_t bytes_transferred, MessagePtr msg) -> void
{
LOG(debug) << "OFI transport (" << fId << "): ENTER OnControlMessageSent: bytes_transferred=" << bytes_transferred;
LOG(debug) << "OFI transport (" << fId
<< "): ENTER OnControlMessageSent: bytes_transferred=" << bytes_transferred
<< ",data=" << msg->GetData() << ",size=" << msg->GetSize();
assert(bytes_transferred == sizeof(PostBuffer));
auto size = msg->GetSize();
@ -302,77 +352,92 @@ auto Socket::OnControlMessageSent(size_t bytes_transferred, MessagePtr msg) -> v
// received, size_ack=" << size_ack;
boost::asio::mutable_buffer buffer(msg->GetData(), size);
asiofi::memory_region mr(fContext.GetDomain(), buffer, asiofi::mr::access::send);
// asiofi::memory_region mr(fContext.GetDomain(), buffer, asiofi::mr::access::send);
// auto desc = mr.desc();
fDataEndpoint->send(buffer, mr.desc(), [&, mr2 = std::move(mr)](boost::asio::mutable_buffer) {
fDataEndpoint->send(
buffer,
// desc,
[&, size, msg2 = std::move(msg)/*, mr2 = std::move(mr)*/](boost::asio::mutable_buffer) mutable {
LOG(debug) << "OFI transport (" << fId << "): >>>>> Data buffer sent";
fBytesTx += size;
fMessagesTx++;
});
}
boost::asio::post(fIoStrand, std::bind(&Socket::SendQueueReader, this));
LOG(debug) << "OFI transport (" << fId << "): LEAVE OnControlMessageSent";
}
auto Socket::ReceiveImpl(FairMQMessagePtr& msg, const int /*flags*/, const int /*timeout*/) -> int
try {
LOG(debug) << "OFI transport (" << fId << "): ENTER ReceiveImpl";
// Receive and process control message
azmq::message ctrl;
auto recv = fControlEndpoint.receive(ctrl);
assert(recv == sizeof(PostBuffer)); (void)recv;
auto pb(static_cast<const PostBuffer*>(ctrl.data()));
auto Socket::RecvControlQueueReader() -> void
{
fControlEndpoint.async_receive(boost::asio::bind_executor(
fIoStrand,
[&](const boost::system::error_code& ec, azmq::message& zmsg, size_t bytes_transferred) {
if (!ec) {
OnRecvControl(zmsg, bytes_transferred);
}
}));
}
auto Socket::OnRecvControl(azmq::message& zmsg, size_t bytes_transferred) -> void
{
LOG(debug) << "OFI transport (" << fId
<< "): ENTER OnRecvControl: bytes_transferred=" << bytes_transferred;
assert(bytes_transferred == sizeof(PostBuffer));
auto pb(static_cast<const PostBuffer*>(zmsg.data()));
assert(pb->type == ControlMessageType::PostBuffer);
auto size = pb->size;
LOG(debug) << "OFI transport (" << fId << "): <<<<< ReceiveImpl: Control message received, size=" << size;
LOG(debug) << "OFI transport (" << fId << "): OnRecvControl: PostBuffer.size=" << size;
// Receive data
if (size) {
msg->Rebuild(size);
auto msg = fContext.MakeReceiveMessage(size);
boost::asio::mutable_buffer buffer(msg->GetData(), size);
asiofi::memory_region mr(fContext.GetDomain(), buffer, asiofi::mr::access::recv);
// asiofi::memory_region mr(fContext.GetDomain(), buffer, asiofi::mr::access::recv);
// auto msg33 = fContext.MakeReceiveMessage(size);
// boost::asio::mutable_buffer buffer33(msg33->GetData(), size);
// asiofi::memory_region mr33(fContext.GetDomain(), buffer33, asiofi::mr::access::recv);
// auto desc = mr.desc();
std::mutex m;
std::condition_variable cv;
bool completed(false);
fDataEndpoint->recv(buffer, mr.desc(), [&](boost::asio::mutable_buffer) {
{
std::unique_lock<std::mutex> lk(m);
completed = true;
fDataEndpoint->recv(
buffer,
// desc,
[&, msg2 = std::move(msg)/*, mr2 = std::move(mr)*/](boost::asio::mutable_buffer) mutable {
MessagePtr* msgptr(new std::unique_ptr<Message>(std::move(msg2)));
fRecvQueueWrite.async_send(
azmq::message(boost::asio::const_buffer(msgptr, sizeof(MessagePtr))),
[&](const boost::system::error_code& ec, size_t bytes_transferred2) {
if (!ec) {
LOG(debug) << "OFI transport (" << fId
<< "): <<<<< Data buffer received, bytes_transferred2="
<< bytes_transferred2;
}
cv.notify_one();
}
);
});
});
// LOG(debug) << "OFI transport (" << fId << "): <<<<< ReceiveImpl: Data buffer posted";
auto ack = MakeControlMessage<PostBuffer>();
ack.size = size;
auto sent = fControlEndpoint.send(boost::asio::buffer(ack));
assert(sent == sizeof(PostBuffer)); (void)sent;
// auto ack = MakeControlMessage<PostBuffer>();
// ack.size = size;
// auto sent = fControlEndpoint.send(boost::asio::buffer(ack));
// assert(sent == sizeof(PostBuffer)); (void)sent;
// LOG(debug) << "OFI transport (" << fId << "): <<<<< ReceiveImpl: Control Ack sent";
{
std::unique_lock<std::mutex> lk(m);
cv.wait(lk, [&](){ return completed; });
} else {
fRecvQueueWrite.async_send(
azmq::message(boost::asio::const_buffer(nullptr, 0)),
[&](const boost::system::error_code& ec, size_t bytes_transferred2) {
if (!ec) {
LOG(debug) << "OFI transport (" << fId
<< "): <<<<< Data buffer received, bytes_transferred2="
<< bytes_transferred2;
}
// LOG(debug) << "OFI transport (" << fId << "): <<<<< ReceiveImpl: Data received";
});
}
fBytesRx += size;
fMessagesRx++;
boost::asio::post(fIoStrand, std::bind(&Socket::RecvControlQueueReader, this));
// LOG(debug) << "OFI transport (" << fId << "): EXIT ReceiveImpl";
return size;
}
catch (const SilentSocketError& e)
{
return -2;
}
catch (const std::exception& e)
{
LOG(error) << e.what();
return -1;
LOG(debug) << "OFI transport (" << fId << "): LEAVE OnRecvControl";
}
auto Socket::SendImpl(vector<FairMQMessagePtr>& /*msgVec*/, const int /*flags*/, const int /*timeout*/) -> int64_t
@ -548,14 +613,14 @@ auto Socket::ReceiveImpl(vector<FairMQMessagePtr>& /*msgVec*/, const int /*flags
auto Socket::Close() -> void {}
auto Socket::SetOption(const string& option, const void* value, size_t valueSize) -> void
auto Socket::SetOption(const string& /*option*/, const void* /*value*/, size_t /*valueSize*/) -> void
{
// if (zmq_setsockopt(fControlSocket, GetConstant(option), value, valueSize) < 0) {
// throw SocketError{tools::ToString("Failed setting socket option, reason: ", zmq_strerror(errno))};
// }
}
auto Socket::GetOption(const string& option, void* value, size_t* valueSize) -> void
auto Socket::GetOption(const string& /*option*/, void* /*value*/, size_t* /*valueSize*/) -> void
{
// if (zmq_getsockopt(fControlSocket, GetConstant(option), value, valueSize) < 0) {
// throw SocketError{tools::ToString("Failed getting socket option, reason: ", zmq_strerror(errno))};

View File

@ -19,7 +19,6 @@
#include <boost/asio.hpp>
#include <memory> // unique_ptr
#include <netinet/in.h>
class FairMQTransportFactory;
namespace fair
{
@ -37,7 +36,7 @@ namespace ofi
class Socket final : public fair::mq::Socket
{
public:
Socket(Context& factory, const std::string& type, const std::string& name, const std::string& id = "", FairMQTransportFactory* fac);
Socket(Context& context, const std::string& type, const std::string& name, const std::string& id = "");
Socket(const Socket&) = delete;
Socket operator=(const Socket&) = delete;
@ -93,11 +92,16 @@ class Socket final : public fair::mq::Socket
mutable azmq::socket fControlEndpoint;
int fSndTimeout;
int fRcvTimeout;
azmq::pair_socket fQueue1, fQueue2;
azmq::socket fSendQueueWrite, fSendQueueRead;
azmq::socket fRecvQueueWrite, fRecvQueueRead;
std::atomic<unsigned long> fSentCount;
auto SendQueueReader() -> void;
auto OnSend(azmq::message& msg, size_t bytes_transferred) -> void;
auto OnControlMessageSent(size_t bytes_transferred, MessagePtr msg) -> void;
auto RecvControlQueueReader() -> void;
auto OnRecvControl(azmq::message& msg, size_t bytes_transferred) -> void;
auto OnReceive() -> void;
auto ReceiveImpl(MessagePtr& msg, const int flags, const int timeout) -> int;
auto SendImpl(std::vector<MessagePtr>& msgVec, const int flags, const int timeout) -> int64_t;
auto ReceiveImpl(std::vector<MessagePtr>& msgVec, const int flags, const int timeout) -> int64_t;

View File

@ -24,13 +24,12 @@ namespace ofi
using namespace std;
TransportFactory::TransportFactory(const string& id, const FairMQProgOptions* /*config*/)
try : FairMQTransportFactory{id}
try : FairMQTransportFactory(id)
, fContext(*this, 1)
{
LOG(debug) << "OFI transport: Using AZMQ & "
<< "asiofi (" << fContext.GetAsiofiVersion() << ")";
}
catch (ContextError& e)
{
} catch (ContextError& e) {
throw TransportFactoryError{e.what()};
}