Set pointer to factory also when receiving multi-part

This commit is contained in:
mkrzewic 2018-11-27 13:23:42 +01:00 committed by Alexey Rybalchenko
parent 25fcf13985
commit cc0c525e0d
18 changed files with 45 additions and 29 deletions

View File

@ -14,11 +14,13 @@
#include <memory> #include <memory>
#include "FairMQMessage.h" #include "FairMQMessage.h"
class FairMQTransportFactory;
class FairMQSocket class FairMQSocket
{ {
public: public:
FairMQSocket() {} FairMQSocket() {}
FairMQSocket(FairMQTransportFactory* fac): fTransport(fac) {}
virtual std::string GetId() = 0; virtual std::string GetId() = 0;
@ -51,7 +53,13 @@ class FairMQSocket
virtual unsigned long GetMessagesTx() const = 0; virtual unsigned long GetMessagesTx() const = 0;
virtual unsigned long GetMessagesRx() const = 0; virtual unsigned long GetMessagesRx() const = 0;
FairMQTransportFactory* GetTransport() { return fTransport; }
void SetTransport(FairMQTransportFactory* transport) { fTransport=transport; }
virtual ~FairMQSocket() {}; virtual ~FairMQSocket() {};
private:
FairMQTransportFactory* fTransport{nullptr};
}; };
using FairMQSocketPtr = std::unique_ptr<FairMQSocket>; using FairMQSocketPtr = std::unique_ptr<FairMQSocket>;

View File

@ -62,7 +62,7 @@ class FairMQTransportFactory
virtual FairMQMessagePtr CreateMessage(FairMQUnmanagedRegionPtr& unmanagedRegion, void* data, const size_t size, void* hint = 0) = 0; virtual FairMQMessagePtr CreateMessage(FairMQUnmanagedRegionPtr& unmanagedRegion, void* data, const size_t size, void* hint = 0) = 0;
/// Create a socket /// Create a socket
virtual FairMQSocketPtr CreateSocket(const std::string& type, const std::string& name) const = 0; virtual FairMQSocketPtr CreateSocket(const std::string& type, const std::string& name) = 0;
/// Create a poller for a single channel (all subchannels) /// Create a poller for a single channel (all subchannels)
virtual FairMQPollerPtr CreatePoller(const std::vector<FairMQChannel>& channels) const = 0; virtual FairMQPollerPtr CreatePoller(const std::vector<FairMQChannel>& channels) const = 0;

View File

@ -32,8 +32,9 @@ using namespace fair::mq;
atomic<bool> FairMQSocketNN::fInterrupted(false); atomic<bool> FairMQSocketNN::fInterrupted(false);
FairMQSocketNN::FairMQSocketNN(const string& type, const string& name, const string& id /*= ""*/) FairMQSocketNN::FairMQSocketNN(const string& type, const string& name, const string& id /*= ""*/, FairMQTransportFactory* fac /*=nullptr*/)
: fSocket(-1) : FairMQSocket{fac}
, fSocket(-1)
, fId(id + "." + name + "." + type) , fId(id + "." + name + "." + type)
, fBytesTx(0) , fBytesTx(0)
, fBytesRx(0) , fBytesRx(0)
@ -368,7 +369,7 @@ int64_t FairMQSocketNN::Receive(vector<FairMQMessagePtr>& msgVec, const int time
object.convert(buf); object.convert(buf);
// get the single message size // get the single message size
size_t size = buf.size() * sizeof(char); size_t size = buf.size() * sizeof(char);
FairMQMessagePtr part(new FairMQMessageNN(size)); FairMQMessagePtr part(new FairMQMessageNN(size, GetTransport()));
static_cast<FairMQMessageNN*>(part.get())->fReceiving = true; static_cast<FairMQMessageNN*>(part.get())->fReceiving = true;
memcpy(part->GetData(), buf.data(), size); memcpy(part->GetData(), buf.data(), size);
msgVec.push_back(move(part)); msgVec.push_back(move(part));

View File

@ -14,11 +14,12 @@
#include "FairMQSocket.h" #include "FairMQSocket.h"
#include "FairMQMessage.h" #include "FairMQMessage.h"
class FairMQTransportFactory;
class FairMQSocketNN final : public FairMQSocket class FairMQSocketNN final : public FairMQSocket
{ {
public: public:
FairMQSocketNN(const std::string& type, const std::string& name, const std::string& id = ""); FairMQSocketNN(const std::string& type, const std::string& name, const std::string& id = "", FairMQTransportFactory* fac = nullptr);
FairMQSocketNN(const FairMQSocketNN&) = delete; FairMQSocketNN(const FairMQSocketNN&) = delete;
FairMQSocketNN operator=(const FairMQSocketNN&) = delete; FairMQSocketNN operator=(const FairMQSocketNN&) = delete;

View File

@ -43,9 +43,9 @@ FairMQMessagePtr FairMQTransportFactoryNN::CreateMessage(FairMQUnmanagedRegionPt
return unique_ptr<FairMQMessage>(new FairMQMessageNN(region, data, size, hint, this)); return unique_ptr<FairMQMessage>(new FairMQMessageNN(region, data, size, hint, this));
} }
FairMQSocketPtr FairMQTransportFactoryNN::CreateSocket(const string& type, const string& name) const FairMQSocketPtr FairMQTransportFactoryNN::CreateSocket(const string& type, const string& name)
{ {
unique_ptr<FairMQSocket> socket(new FairMQSocketNN(type, name, GetId())); unique_ptr<FairMQSocket> socket(new FairMQSocketNN(type, name, GetId(), this));
fSockets.push_back(socket.get()); fSockets.push_back(socket.get());
return socket; return socket;
} }

View File

@ -30,7 +30,7 @@ class FairMQTransportFactoryNN final : public FairMQTransportFactory
FairMQMessagePtr CreateMessage(void* data, const size_t size, fairmq_free_fn* ffn, void* hint = nullptr) override; FairMQMessagePtr CreateMessage(void* data, const size_t size, fairmq_free_fn* ffn, void* hint = nullptr) override;
FairMQMessagePtr CreateMessage(FairMQUnmanagedRegionPtr& region, void* data, const size_t size, void* hint = 0) override; FairMQMessagePtr CreateMessage(FairMQUnmanagedRegionPtr& region, void* data, const size_t size, void* hint = 0) override;
FairMQSocketPtr CreateSocket(const std::string& type, const std::string& name) const override; FairMQSocketPtr CreateSocket(const std::string& type, const std::string& name) override;
FairMQPollerPtr CreatePoller(const std::vector<FairMQChannel>& channels) const override; FairMQPollerPtr CreatePoller(const std::vector<FairMQChannel>& channels) const override;
FairMQPollerPtr CreatePoller(const std::vector<FairMQChannel*>& channels) const override; FairMQPollerPtr CreatePoller(const std::vector<FairMQChannel*>& channels) const override;

View File

@ -31,8 +31,9 @@ namespace ofi
using namespace std; using namespace std;
Socket::Socket(Context& context, const string& type, const string& name, const string& id /*= ""*/) Socket::Socket(Context& context, const string& type, const string& name, const string& id /*= ""*/, FairMQTransportFactory* fac)
: fDataEndpoint(nullptr) : FairMQSocket{fac}
, fDataEndpoint(nullptr)
, fDataCompletionQueueTx(nullptr) , fDataCompletionQueueTx(nullptr)
, fDataCompletionQueueRx(nullptr) , fDataCompletionQueueRx(nullptr)
, fId(id + "." + name + "." + type) , fId(id + "." + name + "." + type)
@ -515,7 +516,7 @@ auto Socket::ReceiveImpl(vector<FairMQMessagePtr>& msgVec, const int flags, cons
// //
// do // do
// { // {
// FairMQMessagePtr part(new FairMQMessageSHM(fManager)); // FairMQMessagePtr part(new FairMQMessageSHM(fManager, GetTransport()));
// zmq_msg_t* msgPtr = static_cast<FairMQMessageSHM*>(part.get())->GetMessage(); // zmq_msg_t* msgPtr = static_cast<FairMQMessageSHM*>(part.get())->GetMessage();
// //
// int nbytes = zmq_msg_recv(msgPtr, fSocket, flags); // int nbytes = zmq_msg_recv(msgPtr, fSocket, flags);

View File

@ -18,6 +18,7 @@
#include <memory> // unique_ptr #include <memory> // unique_ptr
#include <netinet/in.h> #include <netinet/in.h>
#include <rdma/fabric.h> #include <rdma/fabric.h>
class FairMQTransportFactory;
namespace fair namespace fair
{ {
@ -35,7 +36,7 @@ namespace ofi
class Socket final : public fair::mq::Socket class Socket final : public fair::mq::Socket
{ {
public: public:
Socket(Context& factory, const std::string& type, const std::string& name, const std::string& id = ""); Socket(Context& factory, const std::string& type, const std::string& name, const std::string& id = "", FairMQTransportFactory* fac);
Socket(const Socket&) = delete; Socket(const Socket&) = delete;
Socket operator=(const Socket&) = delete; Socket operator=(const Socket&) = delete;

View File

@ -56,9 +56,9 @@ auto TransportFactory::CreateMessage(UnmanagedRegionPtr& region, void* data, con
return MessagePtr{new Message(region, data, size, hint)}; return MessagePtr{new Message(region, data, size, hint)};
} }
auto TransportFactory::CreateSocket(const string& type, const string& name) const -> SocketPtr auto TransportFactory::CreateSocket(const string& type, const string& name) -> SocketPtr
{ {
return SocketPtr{new Socket(fContext, type, name, GetId())}; return SocketPtr{new Socket(fContext, type, name, GetId(), this)};
} }
auto TransportFactory::CreatePoller(const vector<FairMQChannel>& channels) const -> PollerPtr auto TransportFactory::CreatePoller(const vector<FairMQChannel>& channels) const -> PollerPtr

View File

@ -38,7 +38,7 @@ class TransportFactory final : public FairMQTransportFactory
auto CreateMessage(void* data, const std::size_t size, fairmq_free_fn* ffn, void* hint = nullptr) const -> MessagePtr override; auto CreateMessage(void* data, const std::size_t size, fairmq_free_fn* ffn, void* hint = nullptr) const -> MessagePtr override;
auto CreateMessage(UnmanagedRegionPtr& region, void* data, const std::size_t size, void* hint = nullptr) const -> MessagePtr override; auto CreateMessage(UnmanagedRegionPtr& region, void* data, const std::size_t size, void* hint = nullptr) const -> MessagePtr override;
auto CreateSocket(const std::string& type, const std::string& name) const -> SocketPtr override; auto CreateSocket(const std::string& type, const std::string& name) -> SocketPtr override;
auto CreatePoller(const std::vector<FairMQChannel>& channels) const -> PollerPtr override; auto CreatePoller(const std::vector<FairMQChannel>& channels) const -> PollerPtr override;
auto CreatePoller(const std::vector<const FairMQChannel*>& channels) const -> PollerPtr override; auto CreatePoller(const std::vector<const FairMQChannel*>& channels) const -> PollerPtr override;

View File

@ -23,8 +23,9 @@ using namespace fair::mq;
atomic<bool> FairMQSocketSHM::fInterrupted(false); atomic<bool> FairMQSocketSHM::fInterrupted(false);
FairMQSocketSHM::FairMQSocketSHM(Manager& manager, const string& type, const string& name, const string& id /*= ""*/, void* context) FairMQSocketSHM::FairMQSocketSHM(Manager& manager, const string& type, const string& name, const string& id /*= ""*/, void* context, FairMQTransportFactory* fac /*=nullptr*/)
: fSocket(nullptr) : FairMQSocket{fac}
, fSocket(nullptr)
, fManager(manager) , fManager(manager)
, fId(id + "." + name + "." + type) , fId(id + "." + name + "." + type)
, fBytesTx(0) , fBytesTx(0)
@ -377,7 +378,7 @@ int64_t FairMQSocketSHM::Receive(vector<FairMQMessagePtr>& msgVec, const int tim
MetaHeader metaHeader; MetaHeader metaHeader;
memcpy(&metaHeader, &hdrVec[m], sizeof(MetaHeader)); memcpy(&metaHeader, &hdrVec[m], sizeof(MetaHeader));
msgVec.emplace_back(fair::mq::tools::make_unique<FairMQMessageSHM>(fManager)); msgVec.emplace_back(fair::mq::tools::make_unique<FairMQMessageSHM>(fManager, GetTransport()));
FairMQMessageSHM* msg = static_cast<FairMQMessageSHM*>(msgVec.back().get()); FairMQMessageSHM* msg = static_cast<FairMQMessageSHM*>(msgVec.back().get());
MetaHeader* msgHdr = static_cast<MetaHeader*>(zmq_msg_data(msg->GetMessage())); MetaHeader* msgHdr = static_cast<MetaHeader*>(zmq_msg_data(msg->GetMessage()));

View File

@ -15,11 +15,12 @@
#include <atomic> #include <atomic>
#include <memory> // unique_ptr #include <memory> // unique_ptr
class FairMQTransportFactory;
class FairMQSocketSHM final : public FairMQSocket class FairMQSocketSHM final : public FairMQSocket
{ {
public: public:
FairMQSocketSHM(fair::mq::shmem::Manager& manager, const std::string& type, const std::string& name, const std::string& id = "", void* context = nullptr); FairMQSocketSHM(fair::mq::shmem::Manager& manager, const std::string& type, const std::string& name, const std::string& id = "", void* context = nullptr, FairMQTransportFactory* fac = nullptr);
FairMQSocketSHM(const FairMQSocketSHM&) = delete; FairMQSocketSHM(const FairMQSocketSHM&) = delete;
FairMQSocketSHM operator=(const FairMQSocketSHM&) = delete; FairMQSocketSHM operator=(const FairMQSocketSHM&) = delete;

View File

@ -233,10 +233,10 @@ FairMQMessagePtr FairMQTransportFactorySHM::CreateMessage(FairMQUnmanagedRegionP
return unique_ptr<FairMQMessage>(new FairMQMessageSHM(*fManager, region, data, size, hint, this)); return unique_ptr<FairMQMessage>(new FairMQMessageSHM(*fManager, region, data, size, hint, this));
} }
FairMQSocketPtr FairMQTransportFactorySHM::CreateSocket(const string& type, const string& name) const FairMQSocketPtr FairMQTransportFactorySHM::CreateSocket(const string& type, const string& name)
{ {
assert(fContext); assert(fContext);
return unique_ptr<FairMQSocket>(new FairMQSocketSHM(*fManager, type, name, GetId(), fContext)); return unique_ptr<FairMQSocket>(new FairMQSocketSHM(*fManager, type, name, GetId(), fContext, this));
} }
FairMQPollerPtr FairMQTransportFactorySHM::CreatePoller(const vector<FairMQChannel>& channels) const FairMQPollerPtr FairMQTransportFactorySHM::CreatePoller(const vector<FairMQChannel>& channels) const

View File

@ -38,7 +38,7 @@ class FairMQTransportFactorySHM final : public FairMQTransportFactory
FairMQMessagePtr CreateMessage(void* data, const size_t size, fairmq_free_fn* ffn, void* hint = nullptr) override; FairMQMessagePtr CreateMessage(void* data, const size_t size, fairmq_free_fn* ffn, void* hint = nullptr) override;
FairMQMessagePtr CreateMessage(FairMQUnmanagedRegionPtr& region, void* data, const size_t size, void* hint = 0) override; FairMQMessagePtr CreateMessage(FairMQUnmanagedRegionPtr& region, void* data, const size_t size, void* hint = 0) override;
FairMQSocketPtr CreateSocket(const std::string& type, const std::string& name) const override; FairMQSocketPtr CreateSocket(const std::string& type, const std::string& name) override;
FairMQPollerPtr CreatePoller(const std::vector<FairMQChannel>& channels) const override; FairMQPollerPtr CreatePoller(const std::vector<FairMQChannel>& channels) const override;
FairMQPollerPtr CreatePoller(const std::vector<FairMQChannel*>& channels) const override; FairMQPollerPtr CreatePoller(const std::vector<FairMQChannel*>& channels) const override;

View File

@ -20,8 +20,9 @@ using namespace fair::mq;
atomic<bool> FairMQSocketZMQ::fInterrupted(false); atomic<bool> FairMQSocketZMQ::fInterrupted(false);
FairMQSocketZMQ::FairMQSocketZMQ(const string& type, const string& name, const string& id /*= ""*/, void* context) FairMQSocketZMQ::FairMQSocketZMQ(const string& type, const string& name, const string& id /*= ""*/, void* context, FairMQTransportFactory* fac)
: fSocket(nullptr) : FairMQSocket{fac}
, fSocket(nullptr)
, fId(id + "." + name + "." + type) , fId(id + "." + name + "." + type)
, fBytesTx(0) , fBytesTx(0)
, fBytesRx(0) , fBytesRx(0)
@ -314,7 +315,7 @@ int64_t FairMQSocketZMQ::Receive(vector<FairMQMessagePtr>& msgVec, const int tim
do do
{ {
unique_ptr<FairMQMessage> part(new FairMQMessageZMQ()); unique_ptr<FairMQMessage> part(new FairMQMessageZMQ(GetTransport()));
int nbytes = zmq_msg_recv(static_cast<FairMQMessageZMQ*>(part.get())->GetMessage(), fSocket, flags); int nbytes = zmq_msg_recv(static_cast<FairMQMessageZMQ*>(part.get())->GetMessage(), fSocket, flags);
if (nbytes >= 0) if (nbytes >= 0)

View File

@ -15,11 +15,12 @@
#include "FairMQSocket.h" #include "FairMQSocket.h"
#include "FairMQMessage.h" #include "FairMQMessage.h"
class FairMQTransportFactory;
class FairMQSocketZMQ final : public FairMQSocket class FairMQSocketZMQ final : public FairMQSocket
{ {
public: public:
FairMQSocketZMQ(const std::string& type, const std::string& name, const std::string& id = "", void* context = nullptr); FairMQSocketZMQ(const std::string& type, const std::string& name, const std::string& id = "", void* context = nullptr, FairMQTransportFactory* factory = nullptr);
FairMQSocketZMQ(const FairMQSocketZMQ&) = delete; FairMQSocketZMQ(const FairMQSocketZMQ&) = delete;
FairMQSocketZMQ operator=(const FairMQSocketZMQ&) = delete; FairMQSocketZMQ operator=(const FairMQSocketZMQ&) = delete;

View File

@ -69,10 +69,10 @@ FairMQMessagePtr FairMQTransportFactoryZMQ::CreateMessage(FairMQUnmanagedRegionP
return unique_ptr<FairMQMessage>(new FairMQMessageZMQ(region, data, size, hint, this)); return unique_ptr<FairMQMessage>(new FairMQMessageZMQ(region, data, size, hint, this));
} }
FairMQSocketPtr FairMQTransportFactoryZMQ::CreateSocket(const string& type, const string& name) const FairMQSocketPtr FairMQTransportFactoryZMQ::CreateSocket(const string& type, const string& name)
{ {
assert(fContext); assert(fContext);
return unique_ptr<FairMQSocket>(new FairMQSocketZMQ(type, name, GetId(), fContext)); return unique_ptr<FairMQSocket>(new FairMQSocketZMQ(type, name, GetId(), fContext, this));
} }
FairMQPollerPtr FairMQTransportFactoryZMQ::CreatePoller(const vector<FairMQChannel>& channels) const FairMQPollerPtr FairMQTransportFactoryZMQ::CreatePoller(const vector<FairMQChannel>& channels) const

View File

@ -39,7 +39,7 @@ class FairMQTransportFactoryZMQ final : public FairMQTransportFactory
FairMQMessagePtr CreateMessage(void* data, const size_t size, fairmq_free_fn* ffn, void* hint = nullptr) override; FairMQMessagePtr CreateMessage(void* data, const size_t size, fairmq_free_fn* ffn, void* hint = nullptr) override;
FairMQMessagePtr CreateMessage(FairMQUnmanagedRegionPtr& region, void* data, const size_t size, void* hint = 0) override; FairMQMessagePtr CreateMessage(FairMQUnmanagedRegionPtr& region, void* data, const size_t size, void* hint = 0) override;
FairMQSocketPtr CreateSocket(const std::string& type, const std::string& name) const override; FairMQSocketPtr CreateSocket(const std::string& type, const std::string& name) override;
FairMQPollerPtr CreatePoller(const std::vector<FairMQChannel>& channels) const override; FairMQPollerPtr CreatePoller(const std::vector<FairMQChannel>& channels) const override;
FairMQPollerPtr CreatePoller(const std::vector<FairMQChannel*>& channels) const override; FairMQPollerPtr CreatePoller(const std::vector<FairMQChannel*>& channels) const override;