mirror of
https://github.com/FairRootGroup/FairMQ.git
synced 2025-10-13 16:46:47 +00:00
Implement parallel ofi::Socket::Receive
This commit is contained in:
parent
46e2420547
commit
8e7cfacd78
|
@ -31,11 +31,12 @@ namespace ofi
|
|||
|
||||
using namespace std;
|
||||
|
||||
Context::Context(int numberIoThreads)
|
||||
Context::Context(FairMQTransportFactory& receiveFactory, int numberIoThreads)
|
||||
: fOfiInfo(nullptr)
|
||||
, fOfiFabric(nullptr)
|
||||
, fOfiDomain(nullptr)
|
||||
, fIoWork(fIoContext)
|
||||
, fReceiveFactory(receiveFactory)
|
||||
{
|
||||
InitThreadPool(numberIoThreads);
|
||||
}
|
||||
|
@ -176,6 +177,11 @@ auto Context::VerifyAddress(const std::string& address) -> Address
|
|||
return addr;
|
||||
}
|
||||
|
||||
auto Context::MakeReceiveMessage(size_t size) -> MessagePtr
|
||||
{
|
||||
return fReceiveFactory.CreateMessage(size);
|
||||
}
|
||||
|
||||
} /* namespace ofi */
|
||||
} /* namespace mq */
|
||||
} /* namespace fair */
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
#define FAIR_MQ_OFI_CONTEXT_H
|
||||
|
||||
#include <FairMQLogger.h>
|
||||
#include <FairMQTransportFactory.h>
|
||||
|
||||
#include <asiofi/connected_endpoint.hpp>
|
||||
#include <asiofi/domain.hpp>
|
||||
|
@ -44,7 +45,7 @@ enum class Direction : bool { Receive, Transmit };
|
|||
class Context
|
||||
{
|
||||
public:
|
||||
Context(int numberIoThreads = 1);
|
||||
Context(FairMQTransportFactory& receiveFactory, int numberIoThreads = 1);
|
||||
~Context();
|
||||
|
||||
// auto CreateOfiEndpoint() -> fid_ep*;
|
||||
|
@ -66,6 +67,7 @@ class Context
|
|||
auto GetDomain() const -> const asiofi::domain& { return *fOfiDomain; }
|
||||
auto Interrupt() -> void { LOG(debug) << "OFI transport: Interrupted (NOOP - not implemented)."; }
|
||||
auto Resume() -> void { LOG(debug) << "OFI transport: Resumed (NOOP - not implemented)."; }
|
||||
auto MakeReceiveMessage(size_t size) -> MessagePtr;
|
||||
|
||||
private:
|
||||
std::unique_ptr<asiofi::info> fOfiInfo;
|
||||
|
@ -74,6 +76,7 @@ class Context
|
|||
boost::asio::io_context fIoContext;
|
||||
boost::asio::io_context::work fIoWork;
|
||||
std::vector<std::thread> fThreadPool;
|
||||
FairMQTransportFactory& fReceiveFactory;
|
||||
|
||||
auto InitThreadPool(int numberIoThreads) -> void;
|
||||
auto InitOfi(ConnectionType type, Address address) -> void;
|
||||
|
|
|
@ -49,8 +49,11 @@ Socket::Socket(Context& context, const string& type, const string& name, const s
|
|||
, fControlEndpoint(fIoStrand.context(), ZMQ_PAIR)
|
||||
, fSndTimeout(100)
|
||||
, fRcvTimeout(100)
|
||||
, fQueue1(fIoStrand.context())
|
||||
, fQueue2(fIoStrand.context())
|
||||
, fSendQueueWrite(fIoStrand.context(), ZMQ_PUSH)
|
||||
, fSendQueueRead(fIoStrand.context(), ZMQ_PULL)
|
||||
, fRecvQueueWrite(fIoStrand.context(), ZMQ_PUSH)
|
||||
, fRecvQueueRead(fIoStrand.context(), ZMQ_PULL)
|
||||
, fSentCount(0)
|
||||
{
|
||||
if (type != "pair") {
|
||||
throw SocketError{tools::ToString("Socket type '", type, "' not implemented for ofi transport.")};
|
||||
|
@ -63,17 +66,28 @@ Socket::Socket(Context& context, const string& type, const string& name, const s
|
|||
|
||||
// Setup internal queue
|
||||
auto hashed_id = std::hash<std::string>()(fId);
|
||||
auto queue_id = tools::ToString("inproc://QUEUE", hashed_id);
|
||||
LOG(debug) << "OFI transport (" << fId << "): " << "Binding Q1: " << queue_id;
|
||||
fQueue1.bind(queue_id);
|
||||
LOG(debug) << "OFI transport (" << fId << "): " << "Connecting Q2: " << queue_id;
|
||||
fQueue2.connect(queue_id);
|
||||
azmq::socket::snd_hwm send_max(100);
|
||||
azmq::socket::rcv_hwm recv_max(100);
|
||||
fQueue1.set_option(send_max);
|
||||
fQueue1.set_option(recv_max);
|
||||
fQueue2.set_option(send_max);
|
||||
fQueue2.set_option(recv_max);
|
||||
auto queue_id = tools::ToString("inproc://TXQUEUE", hashed_id);
|
||||
LOG(debug) << "OFI transport (" << fId << "): " << "Binding SQR: " << queue_id;
|
||||
fSendQueueRead.bind(queue_id);
|
||||
LOG(debug) << "OFI transport (" << fId << "): " << "Connecting SQW: " << queue_id;
|
||||
fSendQueueWrite.connect(queue_id);
|
||||
queue_id = tools::ToString("inproc://RXQUEUE", hashed_id);
|
||||
LOG(debug) << "OFI transport (" << fId << "): " << "Binding RQR: " << queue_id;
|
||||
fRecvQueueRead.bind(queue_id);
|
||||
LOG(debug) << "OFI transport (" << fId << "): " << "Connecting RQW: " << queue_id;
|
||||
fRecvQueueWrite.connect(queue_id);
|
||||
|
||||
// TODO wire this up with config
|
||||
azmq::socket::snd_hwm send_max(10);
|
||||
azmq::socket::rcv_hwm recv_max(10);
|
||||
fSendQueueRead.set_option(send_max);
|
||||
fSendQueueRead.set_option(recv_max);
|
||||
fSendQueueWrite.set_option(send_max);
|
||||
fSendQueueWrite.set_option(recv_max);
|
||||
fRecvQueueRead.set_option(send_max);
|
||||
fRecvQueueRead.set_option(recv_max);
|
||||
fSendQueueWrite.set_option(send_max);
|
||||
fSendQueueWrite.set_option(recv_max);
|
||||
fControlEndpoint.set_option(send_max);
|
||||
fControlEndpoint.set_option(recv_max);
|
||||
}
|
||||
|
@ -90,7 +104,8 @@ try {
|
|||
fLocalDataAddr = addr;
|
||||
BindDataEndpoint();
|
||||
|
||||
boost::asio::post(fIoStrand, std::bind(&Socket::SendQueueReader, this));
|
||||
// boost::asio::post(fIoStrand, std::bind(&Socket::SendQueueReader, this));
|
||||
boost::asio::post(fIoStrand, std::bind(&Socket::RecvControlQueueReader, this));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
@ -116,6 +131,9 @@ auto Socket::Connect(const string& address) -> bool
|
|||
ReceiveDataAddressAnnouncement();
|
||||
|
||||
ConnectDataEndpoint();
|
||||
|
||||
// boost::asio::post(fIoStrand, std::bind(&Socket::SendQueueReader, this));
|
||||
boost::asio::post(fIoStrand, std::bind(&Socket::RecvControlQueueReader, this));
|
||||
}
|
||||
|
||||
auto Socket::BindControlEndpoint(Context::Address address) -> void
|
||||
|
@ -225,11 +243,13 @@ auto Socket::AnnounceDataAddress() -> void
|
|||
|
||||
auto Socket::Send(MessagePtr& msg, const int /*timeout*/) -> int
|
||||
{
|
||||
LOG(debug) << "OFI transport (" << fId << "): ENTER Send: size=" << msg->GetSize();
|
||||
LOG(debug) << "OFI transport (" << fId << "): ENTER Send: data=" << msg->GetData() << ",size=" << msg->GetSize();
|
||||
|
||||
MessagePtr* msgptr(new std::unique_ptr<Message>(std::move(msg)));
|
||||
try {
|
||||
auto res = fQueue1.send(boost::asio::const_buffer(msgptr, sizeof(MessagePtr)), 0);
|
||||
++fSentCount;
|
||||
LOG(info) << fSentCount;
|
||||
auto res = fSendQueueWrite.send(boost::asio::const_buffer(msgptr, sizeof(MessagePtr)), 0);
|
||||
|
||||
LOG(debug) << "OFI transport (" << fId << "): LEAVE Send";
|
||||
return res;
|
||||
|
@ -244,16 +264,44 @@ auto Socket::Send(MessagePtr& msg, const int /*timeout*/) -> int
|
|||
}
|
||||
}
|
||||
|
||||
auto Socket::Receive(MessagePtr& msg, const int timeout) -> int { return 0; /*ReceiveImpl(msg, 0, timeout);*/ }
|
||||
auto Socket::Receive(MessagePtr& msg, const int /*timeout*/) -> int
|
||||
{
|
||||
LOG(debug) << "OFI transport (" << fId << "): ENTER Receive";
|
||||
|
||||
try {
|
||||
azmq::message zmsg;
|
||||
auto recv = fRecvQueueRead.receive(zmsg);
|
||||
|
||||
size_t size(0);
|
||||
if (recv > 0) {
|
||||
msg = std::move(*(static_cast<MessagePtr*>(zmsg.buffer().data())));
|
||||
size = msg->GetSize();
|
||||
}
|
||||
|
||||
fBytesRx += size;
|
||||
fMessagesRx++;
|
||||
|
||||
LOG(debug) << "OFI transport (" << fId << "): LEAVE Receive";
|
||||
return size;
|
||||
} catch (const std::exception& e) {
|
||||
LOG(error) << e.what();
|
||||
return -1;
|
||||
} catch (const boost::system::error_code& e) {
|
||||
LOG(error) << e;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
auto Socket::Send(std::vector<MessagePtr>& msgVec, const int timeout) -> int64_t { return SendImpl(msgVec, 0, timeout); }
|
||||
auto Socket::Receive(std::vector<MessagePtr>& msgVec, const int timeout) -> int64_t { return ReceiveImpl(msgVec, 0, timeout); }
|
||||
|
||||
auto Socket::SendQueueReader() -> void
|
||||
{
|
||||
fQueue2.async_receive(boost::asio::bind_executor(
|
||||
fSendQueueRead.async_receive(boost::asio::bind_executor(
|
||||
fIoStrand,
|
||||
[&](const boost::system::error_code& ec, azmq::message& zmsg, size_t bytes_transferred) {
|
||||
if (!ec) {
|
||||
--fSentCount;
|
||||
OnSend(zmsg, bytes_transferred);
|
||||
}
|
||||
}));
|
||||
|
@ -266,7 +314,7 @@ auto Socket::OnSend(azmq::message& zmsg, size_t bytes_transferred) -> void
|
|||
MessagePtr msg(std::move(*(static_cast<MessagePtr*>(zmsg.buffer().data()))));
|
||||
auto size = msg->GetSize();
|
||||
|
||||
LOG(debug) << "OFI transport (" << fId << "): >>>>> OnSend: size=" << size;
|
||||
LOG(debug) << "OFI transport (" << fId << "): OnSend: data=" << msg->GetData() << ",size=" << msg->GetSize();
|
||||
|
||||
// Create and send control message
|
||||
auto pb = MakeControlMessage<PostBuffer>();
|
||||
|
@ -284,7 +332,9 @@ auto Socket::OnSend(azmq::message& zmsg, size_t bytes_transferred) -> void
|
|||
|
||||
auto Socket::OnControlMessageSent(size_t bytes_transferred, MessagePtr msg) -> void
|
||||
{
|
||||
LOG(debug) << "OFI transport (" << fId << "): ENTER OnControlMessageSent: bytes_transferred=" << bytes_transferred;
|
||||
LOG(debug) << "OFI transport (" << fId
|
||||
<< "): ENTER OnControlMessageSent: bytes_transferred=" << bytes_transferred
|
||||
<< ",data=" << msg->GetData() << ",size=" << msg->GetSize();
|
||||
assert(bytes_transferred == sizeof(PostBuffer));
|
||||
|
||||
auto size = msg->GetSize();
|
||||
|
@ -302,77 +352,92 @@ auto Socket::OnControlMessageSent(size_t bytes_transferred, MessagePtr msg) -> v
|
|||
// received, size_ack=" << size_ack;
|
||||
|
||||
boost::asio::mutable_buffer buffer(msg->GetData(), size);
|
||||
asiofi::memory_region mr(fContext.GetDomain(), buffer, asiofi::mr::access::send);
|
||||
// asiofi::memory_region mr(fContext.GetDomain(), buffer, asiofi::mr::access::send);
|
||||
// auto desc = mr.desc();
|
||||
|
||||
fDataEndpoint->send(buffer, mr.desc(), [&, mr2 = std::move(mr)](boost::asio::mutable_buffer) {
|
||||
fDataEndpoint->send(
|
||||
buffer,
|
||||
// desc,
|
||||
[&, size, msg2 = std::move(msg)/*, mr2 = std::move(mr)*/](boost::asio::mutable_buffer) mutable {
|
||||
LOG(debug) << "OFI transport (" << fId << "): >>>>> Data buffer sent";
|
||||
fBytesTx += size;
|
||||
fMessagesTx++;
|
||||
});
|
||||
}
|
||||
|
||||
boost::asio::post(fIoStrand, std::bind(&Socket::SendQueueReader, this));
|
||||
LOG(debug) << "OFI transport (" << fId << "): LEAVE OnControlMessageSent";
|
||||
}
|
||||
|
||||
auto Socket::ReceiveImpl(FairMQMessagePtr& msg, const int /*flags*/, const int /*timeout*/) -> int
|
||||
try {
|
||||
LOG(debug) << "OFI transport (" << fId << "): ENTER ReceiveImpl";
|
||||
// Receive and process control message
|
||||
azmq::message ctrl;
|
||||
auto recv = fControlEndpoint.receive(ctrl);
|
||||
assert(recv == sizeof(PostBuffer)); (void)recv;
|
||||
auto pb(static_cast<const PostBuffer*>(ctrl.data()));
|
||||
auto Socket::RecvControlQueueReader() -> void
|
||||
{
|
||||
fControlEndpoint.async_receive(boost::asio::bind_executor(
|
||||
fIoStrand,
|
||||
[&](const boost::system::error_code& ec, azmq::message& zmsg, size_t bytes_transferred) {
|
||||
if (!ec) {
|
||||
OnRecvControl(zmsg, bytes_transferred);
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
auto Socket::OnRecvControl(azmq::message& zmsg, size_t bytes_transferred) -> void
|
||||
{
|
||||
LOG(debug) << "OFI transport (" << fId
|
||||
<< "): ENTER OnRecvControl: bytes_transferred=" << bytes_transferred;
|
||||
|
||||
assert(bytes_transferred == sizeof(PostBuffer));
|
||||
auto pb(static_cast<const PostBuffer*>(zmsg.data()));
|
||||
assert(pb->type == ControlMessageType::PostBuffer);
|
||||
auto size = pb->size;
|
||||
LOG(debug) << "OFI transport (" << fId << "): <<<<< ReceiveImpl: Control message received, size=" << size;
|
||||
LOG(debug) << "OFI transport (" << fId << "): OnRecvControl: PostBuffer.size=" << size;
|
||||
|
||||
// Receive data
|
||||
if (size) {
|
||||
msg->Rebuild(size);
|
||||
auto msg = fContext.MakeReceiveMessage(size);
|
||||
boost::asio::mutable_buffer buffer(msg->GetData(), size);
|
||||
asiofi::memory_region mr(fContext.GetDomain(), buffer, asiofi::mr::access::recv);
|
||||
// asiofi::memory_region mr(fContext.GetDomain(), buffer, asiofi::mr::access::recv);
|
||||
// auto msg33 = fContext.MakeReceiveMessage(size);
|
||||
// boost::asio::mutable_buffer buffer33(msg33->GetData(), size);
|
||||
// asiofi::memory_region mr33(fContext.GetDomain(), buffer33, asiofi::mr::access::recv);
|
||||
// auto desc = mr.desc();
|
||||
|
||||
std::mutex m;
|
||||
std::condition_variable cv;
|
||||
bool completed(false);
|
||||
|
||||
fDataEndpoint->recv(buffer, mr.desc(), [&](boost::asio::mutable_buffer) {
|
||||
{
|
||||
std::unique_lock<std::mutex> lk(m);
|
||||
completed = true;
|
||||
fDataEndpoint->recv(
|
||||
buffer,
|
||||
// desc,
|
||||
[&, msg2 = std::move(msg)/*, mr2 = std::move(mr)*/](boost::asio::mutable_buffer) mutable {
|
||||
MessagePtr* msgptr(new std::unique_ptr<Message>(std::move(msg2)));
|
||||
fRecvQueueWrite.async_send(
|
||||
azmq::message(boost::asio::const_buffer(msgptr, sizeof(MessagePtr))),
|
||||
[&](const boost::system::error_code& ec, size_t bytes_transferred2) {
|
||||
if (!ec) {
|
||||
LOG(debug) << "OFI transport (" << fId
|
||||
<< "): <<<<< Data buffer received, bytes_transferred2="
|
||||
<< bytes_transferred2;
|
||||
}
|
||||
cv.notify_one();
|
||||
}
|
||||
);
|
||||
});
|
||||
});
|
||||
// LOG(debug) << "OFI transport (" << fId << "): <<<<< ReceiveImpl: Data buffer posted";
|
||||
|
||||
auto ack = MakeControlMessage<PostBuffer>();
|
||||
ack.size = size;
|
||||
auto sent = fControlEndpoint.send(boost::asio::buffer(ack));
|
||||
assert(sent == sizeof(PostBuffer)); (void)sent;
|
||||
// auto ack = MakeControlMessage<PostBuffer>();
|
||||
// ack.size = size;
|
||||
// auto sent = fControlEndpoint.send(boost::asio::buffer(ack));
|
||||
// assert(sent == sizeof(PostBuffer)); (void)sent;
|
||||
// LOG(debug) << "OFI transport (" << fId << "): <<<<< ReceiveImpl: Control Ack sent";
|
||||
|
||||
{
|
||||
std::unique_lock<std::mutex> lk(m);
|
||||
cv.wait(lk, [&](){ return completed; });
|
||||
} else {
|
||||
fRecvQueueWrite.async_send(
|
||||
azmq::message(boost::asio::const_buffer(nullptr, 0)),
|
||||
[&](const boost::system::error_code& ec, size_t bytes_transferred2) {
|
||||
if (!ec) {
|
||||
LOG(debug) << "OFI transport (" << fId
|
||||
<< "): <<<<< Data buffer received, bytes_transferred2="
|
||||
<< bytes_transferred2;
|
||||
}
|
||||
// LOG(debug) << "OFI transport (" << fId << "): <<<<< ReceiveImpl: Data received";
|
||||
});
|
||||
}
|
||||
|
||||
fBytesRx += size;
|
||||
fMessagesRx++;
|
||||
boost::asio::post(fIoStrand, std::bind(&Socket::RecvControlQueueReader, this));
|
||||
|
||||
// LOG(debug) << "OFI transport (" << fId << "): EXIT ReceiveImpl";
|
||||
return size;
|
||||
}
|
||||
catch (const SilentSocketError& e)
|
||||
{
|
||||
return -2;
|
||||
}
|
||||
catch (const std::exception& e)
|
||||
{
|
||||
LOG(error) << e.what();
|
||||
return -1;
|
||||
LOG(debug) << "OFI transport (" << fId << "): LEAVE OnRecvControl";
|
||||
}
|
||||
|
||||
auto Socket::SendImpl(vector<FairMQMessagePtr>& /*msgVec*/, const int /*flags*/, const int /*timeout*/) -> int64_t
|
||||
|
@ -548,14 +613,14 @@ auto Socket::ReceiveImpl(vector<FairMQMessagePtr>& /*msgVec*/, const int /*flags
|
|||
|
||||
auto Socket::Close() -> void {}
|
||||
|
||||
auto Socket::SetOption(const string& option, const void* value, size_t valueSize) -> void
|
||||
auto Socket::SetOption(const string& /*option*/, const void* /*value*/, size_t /*valueSize*/) -> void
|
||||
{
|
||||
// if (zmq_setsockopt(fControlSocket, GetConstant(option), value, valueSize) < 0) {
|
||||
// throw SocketError{tools::ToString("Failed setting socket option, reason: ", zmq_strerror(errno))};
|
||||
// }
|
||||
}
|
||||
|
||||
auto Socket::GetOption(const string& option, void* value, size_t* valueSize) -> void
|
||||
auto Socket::GetOption(const string& /*option*/, void* /*value*/, size_t* /*valueSize*/) -> void
|
||||
{
|
||||
// if (zmq_getsockopt(fControlSocket, GetConstant(option), value, valueSize) < 0) {
|
||||
// throw SocketError{tools::ToString("Failed getting socket option, reason: ", zmq_strerror(errno))};
|
||||
|
|
|
@ -19,7 +19,6 @@
|
|||
#include <boost/asio.hpp>
|
||||
#include <memory> // unique_ptr
|
||||
#include <netinet/in.h>
|
||||
class FairMQTransportFactory;
|
||||
|
||||
namespace fair
|
||||
{
|
||||
|
@ -37,7 +36,7 @@ namespace ofi
|
|||
class Socket final : public fair::mq::Socket
|
||||
{
|
||||
public:
|
||||
Socket(Context& factory, const std::string& type, const std::string& name, const std::string& id = "", FairMQTransportFactory* fac);
|
||||
Socket(Context& context, const std::string& type, const std::string& name, const std::string& id = "");
|
||||
Socket(const Socket&) = delete;
|
||||
Socket operator=(const Socket&) = delete;
|
||||
|
||||
|
@ -93,11 +92,16 @@ class Socket final : public fair::mq::Socket
|
|||
mutable azmq::socket fControlEndpoint;
|
||||
int fSndTimeout;
|
||||
int fRcvTimeout;
|
||||
azmq::pair_socket fQueue1, fQueue2;
|
||||
azmq::socket fSendQueueWrite, fSendQueueRead;
|
||||
azmq::socket fRecvQueueWrite, fRecvQueueRead;
|
||||
std::atomic<unsigned long> fSentCount;
|
||||
|
||||
auto SendQueueReader() -> void;
|
||||
auto OnSend(azmq::message& msg, size_t bytes_transferred) -> void;
|
||||
auto OnControlMessageSent(size_t bytes_transferred, MessagePtr msg) -> void;
|
||||
auto RecvControlQueueReader() -> void;
|
||||
auto OnRecvControl(azmq::message& msg, size_t bytes_transferred) -> void;
|
||||
auto OnReceive() -> void;
|
||||
auto ReceiveImpl(MessagePtr& msg, const int flags, const int timeout) -> int;
|
||||
auto SendImpl(std::vector<MessagePtr>& msgVec, const int flags, const int timeout) -> int64_t;
|
||||
auto ReceiveImpl(std::vector<MessagePtr>& msgVec, const int flags, const int timeout) -> int64_t;
|
||||
|
|
|
@ -24,13 +24,12 @@ namespace ofi
|
|||
using namespace std;
|
||||
|
||||
TransportFactory::TransportFactory(const string& id, const FairMQProgOptions* /*config*/)
|
||||
try : FairMQTransportFactory{id}
|
||||
try : FairMQTransportFactory(id)
|
||||
, fContext(*this, 1)
|
||||
{
|
||||
LOG(debug) << "OFI transport: Using AZMQ & "
|
||||
<< "asiofi (" << fContext.GetAsiofiVersion() << ")";
|
||||
}
|
||||
catch (ContextError& e)
|
||||
{
|
||||
} catch (ContextError& e) {
|
||||
throw TransportFactoryError{e.what()};
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user