diff --git a/examples/readout/Readout.h b/examples/readout/Readout.h index 3ac88667..e2424409 100644 --- a/examples/readout/Readout.h +++ b/examples/readout/Readout.h @@ -37,10 +37,10 @@ class Readout : public FairMQDevice fRegion = FairMQUnmanagedRegionPtr(NewUnmanagedRegionFor("rb", 0, 10000000, - [this](void* /*data*/, size_t /*size*/, void* /*hint*/) { // callback to be called when message buffers no longer needed by transport - --fNumUnackedMsgs; + [this](const std::vector& blocks) { // callback to be called when message buffers no longer needed by transport + fNumUnackedMsgs -= blocks.size(); if (fMaxIterations > 0) { - LOG(debug) << "Received ack"; + LOG(debug) << "Received " << blocks.size() << " acks"; } } )); diff --git a/examples/region/CMakeLists.txt b/examples/region/CMakeLists.txt index b233fc3b..7903aea3 100644 --- a/examples/region/CMakeLists.txt +++ b/examples/region/CMakeLists.txt @@ -32,10 +32,10 @@ configure_file(${CMAKE_CURRENT_SOURCE_DIR}/fairmq-start-ex-region.sh.in ${CMAKE_ configure_file(${CMAKE_CURRENT_SOURCE_DIR}/test-ex-region.sh.in ${CMAKE_CURRENT_BINARY_DIR}/test-ex-region.sh) add_test(NAME Example.Region.zeromq COMMAND ${CMAKE_CURRENT_BINARY_DIR}/test-ex-region.sh zeromq) -set_tests_properties(Example.Region.zeromq PROPERTIES TIMEOUT "30" RUN_SERIAL true PASS_REGULAR_EXPRESSION "Received ack") +set_tests_properties(Example.Region.zeromq PROPERTIES TIMEOUT "30" RUN_SERIAL true PASS_REGULAR_EXPRESSION "Received [0-9*] acks") add_test(NAME Example.Region.shmem COMMAND ${CMAKE_CURRENT_BINARY_DIR}/test-ex-region.sh shmem) -set_tests_properties(Example.Region.shmem PROPERTIES TIMEOUT "30" RUN_SERIAL true PASS_REGULAR_EXPRESSION "Received ack") +set_tests_properties(Example.Region.shmem PROPERTIES TIMEOUT "30" RUN_SERIAL true PASS_REGULAR_EXPRESSION "Received [0-9*] acks") # install diff --git a/examples/region/Sampler.cxx b/examples/region/Sampler.cxx index c5e82366..416e2b55 100644 --- a/examples/region/Sampler.cxx +++ b/examples/region/Sampler.cxx @@ -46,10 +46,10 @@ void Sampler::InitTask() fRegion = FairMQUnmanagedRegionPtr(NewUnmanagedRegionFor("data", 0, 10000000, - [this](void* /*data*/, size_t /*size*/, void* /*hint*/) { // callback to be called when message buffers no longer needed by transport - --fNumUnackedMsgs; + [this](const std::vector& blocks) { // callback to be called when message buffers no longer needed by transport + fNumUnackedMsgs -= blocks.size(); if (fMaxIterations > 0) { - LOG(debug) << "Received ack"; + LOG(debug) << "Received " << blocks.size() << " acks"; } } )); diff --git a/fairmq/FairMQChannel.h b/fairmq/FairMQChannel.h index 6753699d..d7793cca 100644 --- a/fairmq/FairMQChannel.h +++ b/fairmq/FairMQChannel.h @@ -335,14 +335,10 @@ class FairMQChannel return Transport()->NewStaticMessage(data); } - FairMQUnmanagedRegionPtr NewUnmanagedRegion(const size_t size, FairMQRegionCallback callback = nullptr, const std::string& path = "", int flags = 0) + template + FairMQUnmanagedRegionPtr NewUnmanagedRegion(Args&&... args) { - return Transport()->CreateUnmanagedRegion(size, callback, path, flags); - } - - FairMQUnmanagedRegionPtr NewUnmanagedRegion(const size_t size, const int64_t userFlags, FairMQRegionCallback callback = nullptr, const std::string& path = "", int flags = 0) - { - return Transport()->CreateUnmanagedRegion(size, userFlags, callback, path, flags); + return Transport()->CreateUnmanagedRegion(std::forward(args)...); } static constexpr fair::mq::Transport DefaultTransportType = fair::mq::Transport::DEFAULT; diff --git a/fairmq/FairMQDevice.h b/fairmq/FairMQDevice.h index a7d8ffc1..29ee45ac 100644 --- a/fairmq/FairMQDevice.h +++ b/fairmq/FairMQDevice.h @@ -217,45 +217,17 @@ class FairMQDevice } // creates unamanaged region with the default device transport - FairMQUnmanagedRegionPtr NewUnmanagedRegion(const size_t size, - FairMQRegionCallback callback = nullptr, - const std::string& path = "", - int flags = 0) + template + FairMQUnmanagedRegionPtr NewUnmanagedRegion(Args&&... args) { - return Transport()->CreateUnmanagedRegion(size, callback, path, flags); - } - - // creates unamanaged region with the default device transport - FairMQUnmanagedRegionPtr NewUnmanagedRegion(const size_t size, - const int64_t userFlags, - FairMQRegionCallback callback = nullptr, - const std::string& path = "", - int flags = 0) - { - return Transport()->CreateUnmanagedRegion(size, userFlags, callback, path, flags); + return Transport()->CreateUnmanagedRegion(std::forward(args)...); } // creates unmanaged region with the transport of the specified channel - FairMQUnmanagedRegionPtr NewUnmanagedRegionFor(const std::string& channel, - int index, - const size_t size, - FairMQRegionCallback callback = nullptr, - const std::string& path = "", - int flags = 0) + template + FairMQUnmanagedRegionPtr NewUnmanagedRegionFor(const std::string& channel, int index, Args&&... args) { - return GetChannel(channel, index).NewUnmanagedRegion(size, callback, path, flags); - } - - // creates unmanaged region with the transport of the specified channel - FairMQUnmanagedRegionPtr NewUnmanagedRegionFor(const std::string& channel, - int index, - const size_t size, - const int64_t userFlags, - FairMQRegionCallback callback = nullptr, - const std::string& path = "", - int flags = 0) - { - return GetChannel(channel, index).NewUnmanagedRegion(size, userFlags, callback, path, flags); + return GetChannel(channel, index).NewUnmanagedRegion(std::forward(args)...); } template diff --git a/fairmq/FairMQTransportFactory.h b/fairmq/FairMQTransportFactory.h index 9073d62b..9d0719e0 100644 --- a/fairmq/FairMQTransportFactory.h +++ b/fairmq/FairMQTransportFactory.h @@ -85,6 +85,7 @@ class FairMQTransportFactory /// @param flags optional parameter to pass to the underlying transport /// @return pointer to UnmanagedRegion virtual FairMQUnmanagedRegionPtr CreateUnmanagedRegion(const size_t size, FairMQRegionCallback callback = nullptr, const std::string& path = "", int flags = 0) = 0; + virtual FairMQUnmanagedRegionPtr CreateUnmanagedRegion(const size_t size, FairMQRegionBulkCallback callback = nullptr, const std::string& path = "", int flags = 0) = 0; /// @brief Create new UnmanagedRegion /// @param size size of the region /// @param userFlags flags to be stored with the region, have no effect on the transport, but can be retrieved from the region by the user @@ -93,6 +94,7 @@ class FairMQTransportFactory /// @param flags optional parameter to pass to the underlying transport /// @return pointer to UnmanagedRegion virtual FairMQUnmanagedRegionPtr CreateUnmanagedRegion(const size_t size, const int64_t userFlags, FairMQRegionCallback callback = nullptr, const std::string& path = "", int flags = 0) = 0; + virtual FairMQUnmanagedRegionPtr CreateUnmanagedRegion(const size_t size, const int64_t userFlags, FairMQRegionBulkCallback callback = nullptr, const std::string& path = "", int flags = 0) = 0; /// @brief Subscribe to region events (creation, destruction, ...) /// @param callback the callback that is called when a region event occurs diff --git a/fairmq/FairMQUnmanagedRegion.h b/fairmq/FairMQUnmanagedRegion.h index 21b27b19..080dab60 100644 --- a/fairmq/FairMQUnmanagedRegion.h +++ b/fairmq/FairMQUnmanagedRegion.h @@ -13,6 +13,7 @@ #include // std::unique_ptr #include // std::function #include // std::ostream +#include class FairMQTransportFactory; @@ -48,14 +49,25 @@ struct FairMQRegionInfo FairMQRegionEvent event; }; +struct FairMQRegionBlock { + void* ptr; + size_t size; + void* hint; + + FairMQRegionBlock(void* p, size_t s, void* h) + : ptr(p), size(s), hint(h) + {} +}; + using FairMQRegionCallback = std::function; +using FairMQRegionBulkCallback = std::function&)>; using FairMQRegionEventCallback = std::function; class FairMQUnmanagedRegion { public: FairMQUnmanagedRegion() {} - FairMQUnmanagedRegion(FairMQTransportFactory* factory): fTransport(factory) {} + FairMQUnmanagedRegion(FairMQTransportFactory* factory) : fTransport(factory) {} virtual void* GetData() const = 0; virtual size_t GetSize() const = 0; @@ -92,9 +104,11 @@ namespace mq { using RegionCallback = FairMQRegionCallback; +using RegionBulkCallback = FairMQRegionBulkCallback; using RegionEventCallback = FairMQRegionEventCallback; using RegionEvent = FairMQRegionEvent; using RegionInfo = FairMQRegionInfo; +using RegionBlock = FairMQRegionBlock; using UnmanagedRegion = FairMQUnmanagedRegion; using UnmanagedRegionPtr = FairMQUnmanagedRegionPtr; diff --git a/fairmq/ofi/TransportFactory.cxx b/fairmq/ofi/TransportFactory.cxx index 918f76df..2e71c55d 100644 --- a/fairmq/ofi/TransportFactory.cxx +++ b/fairmq/ofi/TransportFactory.cxx @@ -90,11 +90,21 @@ auto TransportFactory::CreateUnmanagedRegion(const size_t /*size*/, FairMQRegion throw runtime_error{"Not yet implemented UMR."}; } +auto TransportFactory::CreateUnmanagedRegion(const size_t /*size*/, FairMQRegionBulkCallback /*callback*/, const std::string& /* path = "" */, int /* flags = 0 */) -> UnmanagedRegionPtr +{ + throw runtime_error{"Not yet implemented UMR."}; +} + auto TransportFactory::CreateUnmanagedRegion(const size_t /*size*/, const int64_t /*userFlags*/, FairMQRegionCallback /*callback*/, const std::string& /* path = "" */, int /* flags = 0 */) -> UnmanagedRegionPtr { throw runtime_error{"Not yet implemented UMR."}; } +auto TransportFactory::CreateUnmanagedRegion(const size_t /*size*/, const int64_t /*userFlags*/, FairMQRegionBulkCallback /*callback*/, const std::string& /* path = "" */, int /* flags = 0 */) -> UnmanagedRegionPtr +{ + throw runtime_error{"Not yet implemented UMR."}; +} + auto TransportFactory::GetType() const -> Transport { return Transport::OFI; diff --git a/fairmq/ofi/TransportFactory.h b/fairmq/ofi/TransportFactory.h index bee2a814..1fdc9bc2 100644 --- a/fairmq/ofi/TransportFactory.h +++ b/fairmq/ofi/TransportFactory.h @@ -47,7 +47,9 @@ class TransportFactory final : public FairMQTransportFactory auto CreatePoller(const std::unordered_map>& channelsMap, const std::vector& channelList) const -> PollerPtr override; auto CreateUnmanagedRegion(const size_t size, RegionCallback callback = nullptr, const std::string& path = "", int flags = 0) -> UnmanagedRegionPtr override; + auto CreateUnmanagedRegion(const size_t size, RegionBulkCallback callback = nullptr, const std::string& path = "", int flags = 0) -> UnmanagedRegionPtr override; auto CreateUnmanagedRegion(const size_t size, int64_t userFlags, RegionCallback callback = nullptr, const std::string& path = "", int flags = 0) -> UnmanagedRegionPtr override; + auto CreateUnmanagedRegion(const size_t size, int64_t userFlags, RegionBulkCallback callback = nullptr, const std::string& path = "", int flags = 0) -> UnmanagedRegionPtr override; void SubscribeToRegionEvents(RegionEventCallback /* callback */) override { LOG(error) << "SubscribeToRegionEvents not yet implemented for OFI"; } bool SubscribedToRegionEvents() override { LOG(error) << "Region event subscriptions not yet implemented for OFI"; return false; } diff --git a/fairmq/shmem/Manager.cxx b/fairmq/shmem/Manager.cxx index c239f4a6..38e7de4e 100644 --- a/fairmq/shmem/Manager.cxx +++ b/fairmq/shmem/Manager.cxx @@ -101,7 +101,12 @@ void Manager::StartMonitor(const string& id) } } -pair Manager::CreateRegion(const size_t size, const int64_t userFlags, RegionCallback callback, const string& path /* = "" */, int flags /* = 0 */) +pair Manager::CreateRegion(const size_t size, + const int64_t userFlags, + RegionCallback callback, + RegionBulkCallback bulkCallback, + const string& path /* = "" */, + int flags /* = 0 */) { try { @@ -134,7 +139,7 @@ pair Manager::CreateRegion(const size_t size, co // create region info fRegionInfos->emplace(id, RegionInfo(path.c_str(), flags, userFlags, fShmVoidAlloc)); - auto r = fRegions.emplace(id, tools::make_unique(*this, id, size, false, callback, path, flags)); + auto r = fRegions.emplace(id, tools::make_unique(*this, id, size, false, callback, bulkCallback, path, flags)); // LOG(debug) << "Created region with id '" << id << "', path: '" << path << "', flags: '" << flags << "'"; r.first->second->StartReceivingAcks(); @@ -182,7 +187,7 @@ Region* Manager::GetRegionUnsafe(const uint64_t id) int flags = regionInfo.fFlags; // LOG(debug) << "Located remote region with id '" << id << "', path: '" << path << "', flags: '" << flags << "'"; - auto r = fRegions.emplace(id, tools::make_unique(*this, id, 0, true, nullptr, path, flags)); + auto r = fRegions.emplace(id, tools::make_unique(*this, id, 0, true, nullptr, nullptr, path, flags)); return r.first->second.get(); } catch (bie& e) { LOG(warn) << "Could not get remote region for id: " << id; diff --git a/fairmq/shmem/Manager.h b/fairmq/shmem/Manager.h index 95eef9df..03fc77c9 100644 --- a/fairmq/shmem/Manager.h +++ b/fairmq/shmem/Manager.h @@ -70,7 +70,12 @@ class Manager int IncrementDeviceCounter(); int DecrementDeviceCounter(); - std::pair CreateRegion(const size_t size, const int64_t userFlags, RegionCallback callback, const std::string& path = "", int flags = 0); + std::pair CreateRegion(const size_t size, + const int64_t userFlags, + RegionCallback callback, + RegionBulkCallback bulkCallback, + const std::string& path = "", + int flags = 0); Region* GetRegion(const uint64_t id); Region* GetRegionUnsafe(const uint64_t id); void RemoveRegion(const uint64_t id); diff --git a/fairmq/shmem/Region.cxx b/fairmq/shmem/Region.cxx index 7735cdad..a40ba57d 100644 --- a/fairmq/shmem/Region.cxx +++ b/fairmq/shmem/Region.cxx @@ -33,7 +33,7 @@ namespace mq namespace shmem { -Region::Region(Manager& manager, uint64_t id, uint64_t size, bool remote, RegionCallback callback, const string& path /* = "" */, int flags /* = 0 */) +Region::Region(Manager& manager, uint64_t id, uint64_t size, bool remote, RegionCallback callback, RegionBulkCallback bulkCallback, const string& path, int flags) : fManager(manager) , fRemote(remote) , fStop(false) @@ -46,6 +46,7 @@ Region::Region(Manager& manager, uint64_t id, uint64_t size, bool remote, Region , fReceiveAcksWorker() , fSendAcksWorker() , fCallback(callback) + , fBulkCallback(bulkCallback) { if (path != "") { fName = string(path + fName); @@ -110,14 +111,22 @@ void Region::ReceiveAcks() unsigned int priority; bipc::message_queue::size_type recvdSize; unique_ptr blocks = tools::make_unique(fAckBunchSize); + std::vector result; + result.reserve(fAckBunchSize); while (!fStop) { // end thread condition (should exist until region is destroyed) auto rcvTill = bpt::microsec_clock::universal_time() + bpt::milliseconds(500); while (fQueue->timed_receive(blocks.get(), fAckBunchSize * sizeof(RegionBlock), recvdSize, priority, rcvTill)) { // LOG(debug) << "received: " << block.fHandle << " " << block.fSize << " " << block.fMessageId; - if (fCallback) { - const auto numBlocks = recvdSize / sizeof(RegionBlock); + const auto numBlocks = recvdSize / sizeof(RegionBlock); + if (fBulkCallback) { + result.clear(); + for (size_t i = 0; i < numBlocks; i++) { + result.emplace_back(reinterpret_cast(fRegion.get_address()) + blocks[i].fHandle, blocks[i].fSize, reinterpret_cast(blocks[i].fHint)); + } + fBulkCallback(result); + } else if (fCallback) { for (size_t i = 0; i < numBlocks; i++) { fCallback(reinterpret_cast(fRegion.get_address()) + blocks[i].fHandle, blocks[i].fSize, reinterpret_cast(blocks[i].fHint)); } @@ -125,7 +134,7 @@ void Region::ReceiveAcks() } } // while !fStop - LOG(debug) << "receive ack worker for " << fName << " leaving."; + LOG(debug) << "ReceiveAcks() worker for " << fName << " leaving."; } void Region::ReleaseBlock(const RegionBlock &block) diff --git a/fairmq/shmem/Region.h b/fairmq/shmem/Region.h index 99d047b4..b15dfbcb 100644 --- a/fairmq/shmem/Region.h +++ b/fairmq/shmem/Region.h @@ -40,7 +40,7 @@ class Manager; struct Region { - Region(Manager& manager, uint64_t id, uint64_t size, bool remote, RegionCallback callback = nullptr, const std::string& path = "", int flags = 0); + Region(Manager& manager, uint64_t id, uint64_t size, bool remote, RegionCallback callback, RegionBulkCallback bulkCallback, const std::string& path, int flags); Region() = delete; @@ -76,6 +76,7 @@ struct Region std::thread fReceiveAcksWorker; std::thread fSendAcksWorker; RegionCallback fCallback; + RegionBulkCallback fBulkCallback; }; } // namespace shmem diff --git a/fairmq/shmem/TransportFactory.cxx b/fairmq/shmem/TransportFactory.cxx index faeb5711..f885f4d2 100644 --- a/fairmq/shmem/TransportFactory.cxx +++ b/fairmq/shmem/TransportFactory.cxx @@ -161,12 +161,22 @@ PollerPtr TransportFactory::CreatePoller(const unordered_map(*fManager, size, callback, path, flags, this); + return tools::make_unique(*fManager, size, 0, callback, nullptr, path, flags, this); +} + +UnmanagedRegionPtr TransportFactory::CreateUnmanagedRegion(const size_t size, RegionBulkCallback bulkCallback, const std::string& path /* = "" */, int flags /* = 0 */) +{ + return tools::make_unique(*fManager, size, 0, nullptr, bulkCallback, path, flags, this); } UnmanagedRegionPtr TransportFactory::CreateUnmanagedRegion(const size_t size, const int64_t userFlags, RegionCallback callback, const std::string& path /* = "" */, int flags /* = 0 */) { - return tools::make_unique(*fManager, size, userFlags, callback, path, flags, this); + return tools::make_unique(*fManager, size, userFlags, callback, nullptr, path, flags, this); +} + +UnmanagedRegionPtr TransportFactory::CreateUnmanagedRegion(const size_t size, const int64_t userFlags, RegionBulkCallback bulkCallback, const std::string& path /* = "" */, int flags /* = 0 */) +{ + return tools::make_unique(*fManager, size, userFlags, nullptr, bulkCallback, path, flags, this); } void TransportFactory::SubscribeToRegionEvents(RegionEventCallback callback) diff --git a/fairmq/shmem/TransportFactory.h b/fairmq/shmem/TransportFactory.h index 991fe7da..9b331b3f 100644 --- a/fairmq/shmem/TransportFactory.h +++ b/fairmq/shmem/TransportFactory.h @@ -50,7 +50,9 @@ class TransportFactory final : public fair::mq::TransportFactory PollerPtr CreatePoller(const std::unordered_map>& channelsMap, const std::vector& channelList) const override; UnmanagedRegionPtr CreateUnmanagedRegion(const size_t size, RegionCallback callback = nullptr, const std::string& path = "", int flags = 0) override; + UnmanagedRegionPtr CreateUnmanagedRegion(const size_t size, RegionBulkCallback callback = nullptr, const std::string& path = "", int flags = 0) override; UnmanagedRegionPtr CreateUnmanagedRegion(const size_t size, int64_t userFlags, RegionCallback callback = nullptr, const std::string& path = "", int flags = 0) override; + UnmanagedRegionPtr CreateUnmanagedRegion(const size_t size, int64_t userFlags, RegionBulkCallback callback = nullptr, const std::string& path = "", int flags = 0) override; void SubscribeToRegionEvents(RegionEventCallback callback) override; bool SubscribedToRegionEvents() override; diff --git a/fairmq/shmem/UnmanagedRegion.h b/fairmq/shmem/UnmanagedRegion.h index cc16bbda..cc6a589b 100644 --- a/fairmq/shmem/UnmanagedRegion.h +++ b/fairmq/shmem/UnmanagedRegion.h @@ -36,17 +36,20 @@ class UnmanagedRegion final : public fair::mq::UnmanagedRegion friend class Socket; public: - UnmanagedRegion(Manager& manager, const size_t size, RegionCallback callback, const std::string& path = "", int flags = 0, FairMQTransportFactory* factory = nullptr) - : UnmanagedRegion(manager, size, 0, callback, path, flags, factory) - {} - - UnmanagedRegion(Manager& manager, const size_t size, const int64_t userFlags, RegionCallback callback, const std::string& path = "", int flags = 0, FairMQTransportFactory* factory = nullptr) + UnmanagedRegion(Manager& manager, + const size_t size, + const int64_t userFlags, + RegionCallback callback, + RegionBulkCallback bulkCallback, + const std::string& path = "", + int flags = 0, + FairMQTransportFactory* factory = nullptr) : FairMQUnmanagedRegion(factory) , fManager(manager) , fRegion(nullptr) , fRegionId(0) { - auto result = fManager.CreateRegion(size, userFlags, callback, path, flags); + auto result = fManager.CreateRegion(size, userFlags, callback, bulkCallback, path, flags); fRegion = result.first; fRegionId = result.second; } diff --git a/fairmq/zeromq/FairMQMessageZMQ.cxx b/fairmq/zeromq/FairMQMessageZMQ.cxx index f08088b7..07bf2a9f 100644 --- a/fairmq/zeromq/FairMQMessageZMQ.cxx +++ b/fairmq/zeromq/FairMQMessageZMQ.cxx @@ -81,7 +81,12 @@ FairMQMessageZMQ::FairMQMessageZMQ(FairMQUnmanagedRegionPtr& region, void* data, memcpy(zmq_msg_data(fMsg.get()), data, size); // call region callback - static_cast(region.get())->fCallback(data, size, hint); + auto ptr = static_cast(region.get()); + if (ptr->fBulkCallback) { + ptr->fBulkCallback({{data, size, hint}}); + } else if (ptr->fCallback) { + ptr->fCallback(data, size, hint); + } // if (zmq_msg_init_data(fMsg.get(), data, size, [](void*, void*){}, nullptr) != 0) // { diff --git a/fairmq/zeromq/FairMQSocketZMQ.cxx b/fairmq/zeromq/FairMQSocketZMQ.cxx index 66514f49..c14ec2d6 100644 --- a/fairmq/zeromq/FairMQSocketZMQ.cxx +++ b/fairmq/zeromq/FairMQSocketZMQ.cxx @@ -84,7 +84,7 @@ bool FairMQSocketZMQ::Bind(const string& address) // do not print error in this case, this is handled by FairMQDevice in case no connection could be established after trying a number of random ports from a range. return false; } - LOG(error) << "Failed binding socket " << fId << ", reason: " << zmq_strerror(errno); + LOG(error) << "Failed binding socket " << fId << ", address: " << address << ", reason: " << zmq_strerror(errno); return false; } @@ -97,7 +97,7 @@ bool FairMQSocketZMQ::Connect(const string& address) if (zmq_connect(fSocket, address.c_str()) != 0) { - LOG(error) << "Failed connecting socket " << fId << ", reason: " << zmq_strerror(errno); + LOG(error) << "Failed connecting socket " << fId << ", address: " << address << ", reason: " << zmq_strerror(errno); return false; } diff --git a/fairmq/zeromq/FairMQTransportFactoryZMQ.cxx b/fairmq/zeromq/FairMQTransportFactoryZMQ.cxx index 371297c4..4d487a93 100644 --- a/fairmq/zeromq/FairMQTransportFactoryZMQ.cxx +++ b/fairmq/zeromq/FairMQTransportFactoryZMQ.cxx @@ -94,19 +94,55 @@ FairMQPollerPtr FairMQTransportFactoryZMQ::CreatePoller(const unordered_map(new FairMQPollerZMQ(channelsMap, channelList)); } -FairMQUnmanagedRegionPtr FairMQTransportFactoryZMQ::CreateUnmanagedRegion(const size_t size, FairMQRegionCallback callback, const string& path /* = "" */, int flags /* = 0 */) +FairMQUnmanagedRegionPtr FairMQTransportFactoryZMQ::CreateUnmanagedRegion( + const size_t size, + FairMQRegionCallback callback, + const string& path /* = "" */, + int flags /* = 0 */) { - return CreateUnmanagedRegion(size, 0, callback, path, flags); + return CreateUnmanagedRegion(size, 0, callback, nullptr, path, flags); +} +FairMQUnmanagedRegionPtr FairMQTransportFactoryZMQ::CreateUnmanagedRegion( + const size_t size, + FairMQRegionBulkCallback bulkCallback, + const string& path /* = "" */, + int flags /* = 0 */) +{ + return CreateUnmanagedRegion(size, 0, nullptr, bulkCallback, path, flags); +} +FairMQUnmanagedRegionPtr FairMQTransportFactoryZMQ::CreateUnmanagedRegion( + const size_t size, + const int64_t userFlags, + FairMQRegionCallback callback, + const string& path /* = "" */, + int flags /* = 0 */) +{ + return CreateUnmanagedRegion(size, userFlags, callback, nullptr, path, flags); +} +FairMQUnmanagedRegionPtr FairMQTransportFactoryZMQ::CreateUnmanagedRegion( + const size_t size, + const int64_t userFlags, + FairMQRegionBulkCallback bulkCallback, + const string& path /* = "" */, + int flags /* = 0 */) +{ + return CreateUnmanagedRegion(size, userFlags, nullptr, bulkCallback, path, flags); } -FairMQUnmanagedRegionPtr FairMQTransportFactoryZMQ::CreateUnmanagedRegion(const size_t size, const int64_t userFlags, FairMQRegionCallback callback, const string& path /* = "" */, int flags /* = 0 */) +FairMQUnmanagedRegionPtr FairMQTransportFactoryZMQ::CreateUnmanagedRegion( + const size_t size, + const int64_t userFlags, + FairMQRegionCallback callback, + FairMQRegionBulkCallback bulkCallback, + const string& path /* = "" */, + int flags /* = 0 */) { unique_ptr ptr = nullptr; { lock_guard lock(fMtx); ++fRegionCounter; - ptr = unique_ptr(new FairMQUnmanagedRegionZMQ(fRegionCounter, size, userFlags, callback, path, flags, this)); + ptr = unique_ptr(new FairMQUnmanagedRegionZMQ(fRegionCounter, size, userFlags, callback, bulkCallback, path, flags, this)); auto zPtr = static_cast(ptr.get()); fRegionInfos.emplace_back(zPtr->GetId(), zPtr->GetData(), zPtr->GetSize(), zPtr->GetUserFlags(), fair::mq::RegionEvent::created); fRegionEvents.emplace(zPtr->GetId(), zPtr->GetData(), zPtr->GetSize(), zPtr->GetUserFlags(), fair::mq::RegionEvent::created); diff --git a/fairmq/zeromq/FairMQTransportFactoryZMQ.h b/fairmq/zeromq/FairMQTransportFactoryZMQ.h index 79c5b77e..40f7d8e4 100644 --- a/fairmq/zeromq/FairMQTransportFactoryZMQ.h +++ b/fairmq/zeromq/FairMQTransportFactoryZMQ.h @@ -49,7 +49,10 @@ class FairMQTransportFactoryZMQ final : public FairMQTransportFactory FairMQPollerPtr CreatePoller(const std::unordered_map>& channelsMap, const std::vector& channelList) const override; FairMQUnmanagedRegionPtr CreateUnmanagedRegion(const size_t size, FairMQRegionCallback callback, const std::string& path = "", int flags = 0) override; + FairMQUnmanagedRegionPtr CreateUnmanagedRegion(const size_t size, FairMQRegionBulkCallback bulkCallback, const std::string& path = "", int flags = 0) override; FairMQUnmanagedRegionPtr CreateUnmanagedRegion(const size_t size, const int64_t userFlags, FairMQRegionCallback callback, const std::string& path = "", int flags = 0) override; + FairMQUnmanagedRegionPtr CreateUnmanagedRegion(const size_t size, const int64_t userFlags, FairMQRegionBulkCallback bulkCallback, const std::string& path = "", int flags = 0) override; + FairMQUnmanagedRegionPtr CreateUnmanagedRegion(const size_t size, const int64_t userFlags, FairMQRegionCallback callback, FairMQRegionBulkCallback bulkCallback, const std::string& path = "", int flags = 0); void SubscribeToRegionEvents(FairMQRegionEventCallback callback) override; bool SubscribedToRegionEvents() override; diff --git a/fairmq/zeromq/FairMQUnmanagedRegionZMQ.cxx b/fairmq/zeromq/FairMQUnmanagedRegionZMQ.cxx index ec02fb4c..66a4ed4d 100644 --- a/fairmq/zeromq/FairMQUnmanagedRegionZMQ.cxx +++ b/fairmq/zeromq/FairMQUnmanagedRegionZMQ.cxx @@ -10,13 +10,21 @@ #include "FairMQTransportFactoryZMQ.h" #include "FairMQLogger.h" -FairMQUnmanagedRegionZMQ::FairMQUnmanagedRegionZMQ(uint64_t id, const size_t size, int64_t userFlags, FairMQRegionCallback callback, const std::string& /* path = "" */, int /* flags = 0 */, FairMQTransportFactory* factory /* = nullptr */) +FairMQUnmanagedRegionZMQ::FairMQUnmanagedRegionZMQ(uint64_t id, + const size_t size, + int64_t userFlags, + FairMQRegionCallback callback, + FairMQRegionBulkCallback bulkCallback, + const std::string& /* path = "" */, + int /* flags = 0 */, + FairMQTransportFactory* factory /* = nullptr */) : FairMQUnmanagedRegion(factory) , fId(id) , fBuffer(malloc(size)) , fSize(size) , fUserFlags(userFlags) , fCallback(callback) + , fBulkCallback(bulkCallback) {} void* FairMQUnmanagedRegionZMQ::GetData() const diff --git a/fairmq/zeromq/FairMQUnmanagedRegionZMQ.h b/fairmq/zeromq/FairMQUnmanagedRegionZMQ.h index c971d737..cda32214 100644 --- a/fairmq/zeromq/FairMQUnmanagedRegionZMQ.h +++ b/fairmq/zeromq/FairMQUnmanagedRegionZMQ.h @@ -21,7 +21,7 @@ class FairMQUnmanagedRegionZMQ final : public FairMQUnmanagedRegion friend class FairMQMessageZMQ; public: - FairMQUnmanagedRegionZMQ(uint64_t id, const size_t size, int64_t userFlags, FairMQRegionCallback callback, const std::string& /* path = "" */, int /* flags = 0 */, FairMQTransportFactory* factory = nullptr); + FairMQUnmanagedRegionZMQ(uint64_t id, const size_t size, int64_t userFlags, FairMQRegionCallback callback, FairMQRegionBulkCallback bulkCallback, const std::string& /* path = "" */, int /* flags = 0 */, FairMQTransportFactory* factory = nullptr); FairMQUnmanagedRegionZMQ(const FairMQUnmanagedRegionZMQ&) = delete; FairMQUnmanagedRegionZMQ operator=(const FairMQUnmanagedRegionZMQ&) = delete; @@ -39,6 +39,7 @@ class FairMQUnmanagedRegionZMQ final : public FairMQUnmanagedRegion size_t fSize; int64_t fUserFlags; FairMQRegionCallback fCallback; + FairMQRegionBulkCallback fBulkCallback; }; #endif /* FAIRMQUNMANAGEDREGIONZMQ_H_ */ diff --git a/test/region/_region.cxx b/test/region/_region.cxx index f7d8bd1f..59dbaf4f 100644 --- a/test/region/_region.cxx +++ b/test/region/_region.cxx @@ -32,7 +32,7 @@ void RegionEventSubscriptions(const string& transport) constexpr int size1 = 1000000; constexpr int size2 = 5000000; constexpr int64_t userFlags = 12345; - fair::mq::tools::SharedSemaphore blocker; + fair::mq::tools::Semaphore blocker; { auto region1 = factory->CreateUnmanagedRegion(size1, [](void*, size_t, void*) {}); @@ -90,6 +90,72 @@ void RegionEventSubscriptions(const string& transport) ASSERT_EQ(factory->SubscribedToRegionEvents(), false); } +void RegionCallbacks(const string& transport, const string& _address) +{ + size_t session(fair::mq::tools::UuidHash()); + std::string address(fair::mq::tools::ToString(_address, "_", transport)); + + fair::mq::ProgOptions config; + config.SetProperty("session", to_string(session)); + + auto factory = FairMQTransportFactory::CreateTransportFactory(transport, fair::mq::tools::Uuid(), &config); + + unique_ptr intPtr1 = fair::mq::tools::make_unique(42); + unique_ptr intPtr2 = fair::mq::tools::make_unique(43); + fair::mq::tools::Semaphore blocker; + + FairMQChannel push("Push", "push", factory); + push.Bind(address); + + FairMQChannel pull("Pull", "pull", factory); + pull.Connect(address); + + void* ptr1 = nullptr; + size_t size1 = 100; + void* ptr2 = nullptr; + size_t size2 = 200; + + auto region1 = factory->CreateUnmanagedRegion(2000000, [&](void* ptr, size_t size, void* hint) { + ASSERT_EQ(ptr, ptr1); + ASSERT_EQ(size, size1); + ASSERT_EQ(hint, intPtr1.get()); + ASSERT_EQ(*static_cast(hint), 42); + blocker.Signal(); + }); + ptr1 = region1->GetData(); + + auto region2 = factory->CreateUnmanagedRegion(3000000, [&](const std::vector& blocks) { + ASSERT_EQ(blocks.size(), 1); + ASSERT_EQ(blocks.at(0).ptr, ptr2); + ASSERT_EQ(blocks.at(0).size, size2); + ASSERT_EQ(blocks.at(0).hint, intPtr2.get()); + ASSERT_EQ(*static_cast(blocks.at(0).hint), 43); + blocker.Signal(); + }); + ptr2 = region2->GetData(); + + + { + FairMQMessagePtr msg1out(push.NewMessage(region1, ptr1, size1, intPtr1.get())); + FairMQMessagePtr msg2out(push.NewMessage(region2, ptr2, size2, intPtr2.get())); + ASSERT_EQ(push.Send(msg1out), size1); + ASSERT_EQ(push.Send(msg2out), size2); + } + + { + FairMQMessagePtr msg1in(pull.NewMessage()); + FairMQMessagePtr msg2in(pull.NewMessage()); + ASSERT_EQ(pull.Receive(msg1in), size1); + ASSERT_EQ(pull.Receive(msg2in), size2); + } + + LOG(info) << "waiting for blockers..."; + blocker.Wait(); + LOG(info) << "1 done."; + blocker.Wait(); + LOG(info) << "2 done."; +} + TEST(EventSubscriptions, zeromq) { RegionEventSubscriptions("zeromq"); @@ -100,4 +166,14 @@ TEST(EventSubscriptions, shmem) RegionEventSubscriptions("shmem"); } +TEST(Callbacks, zeromq) +{ + RegionCallbacks("zeromq", "ipc://test_region_callbacks"); +} + +TEST(Callbacks, shmem) +{ + RegionCallbacks("shmem", "ipc://test_region_callbacks"); +} + } // namespace diff --git a/test/transport/_options.cxx b/test/transport/_options.cxx index 3a40e5ed..454eef0c 100644 --- a/test/transport/_options.cxx +++ b/test/transport/_options.cxx @@ -25,7 +25,7 @@ namespace using namespace std; -void CheckOldOptionInterface(FairMQChannel& channel, const string& option, const string& transport) +void CheckOldOptionInterface(FairMQChannel& channel, const string& option) { int value = 500; channel.GetSocket().SetOption(option, &value, sizeof(value)); @@ -44,11 +44,11 @@ void RunOptionsTest(const string& transport) auto factory = FairMQTransportFactory::CreateTransportFactory(transport, fair::mq::tools::Uuid(), &config); FairMQChannel channel("Push", "push", factory); - CheckOldOptionInterface(channel, "linger", transport); - CheckOldOptionInterface(channel, "snd-hwm", transport); - CheckOldOptionInterface(channel, "rcv-hwm", transport); - CheckOldOptionInterface(channel, "snd-size", transport); - CheckOldOptionInterface(channel, "rcv-size", transport); + CheckOldOptionInterface(channel, "linger"); + CheckOldOptionInterface(channel, "snd-hwm"); + CheckOldOptionInterface(channel, "rcv-hwm"); + CheckOldOptionInterface(channel, "snd-size"); + CheckOldOptionInterface(channel, "rcv-size"); channel.GetSocket().SetLinger(300); ASSERT_EQ(channel.GetSocket().GetLinger(), 300);