From 684e711b8b7a7698341a49c2f30e2ee00d43345f Mon Sep 17 00:00:00 2001 From: Alexey Rybalchenko Date: Thu, 19 Dec 2019 15:06:14 +0100 Subject: [PATCH] Shmem: track number of message objects, throw if non-zero at reset --- fairmq/shmem/Message.cxx | 8 +++++++ fairmq/shmem/Socket.cxx | 1 + fairmq/shmem/TransportFactory.cxx | 10 ++++++++ fairmq/shmem/TransportFactory.h | 6 ++++- test/device/_error_state.cxx | 39 +++++++++++++++++++++++++++++++ 5 files changed, 63 insertions(+), 1 deletion(-) diff --git a/fairmq/shmem/Message.cxx b/fairmq/shmem/Message.cxx index 56c985f0..a2622bde 100644 --- a/fairmq/shmem/Message.cxx +++ b/fairmq/shmem/Message.cxx @@ -9,6 +9,7 @@ #include "Region.h" #include "Message.h" #include "UnmanagedRegion.h" +#include "TransportFactory.h" #include @@ -39,6 +40,7 @@ Message::Message(Manager& manager, FairMQTransportFactory* factory) , fRegionPtr(nullptr) , fLocalPtr(nullptr) { + static_cast(GetTransport())->IncrementMsgCounter(); } Message::Message(Manager& manager, const size_t size, FairMQTransportFactory* factory) @@ -50,6 +52,7 @@ Message::Message(Manager& manager, const size_t size, FairMQTransportFactory* fa , fLocalPtr(nullptr) { InitializeChunk(size); + static_cast(GetTransport())->IncrementMsgCounter(); } Message::Message(Manager& manager, MetaHeader& hdr, FairMQTransportFactory* factory) @@ -60,6 +63,7 @@ Message::Message(Manager& manager, MetaHeader& hdr, FairMQTransportFactory* fact , fRegionPtr(nullptr) , fLocalPtr(nullptr) { + static_cast(GetTransport())->IncrementMsgCounter(); } Message::Message(Manager& manager, void* data, const size_t size, fairmq_free_fn* ffn, void* hint, FairMQTransportFactory* factory) @@ -78,6 +82,7 @@ Message::Message(Manager& manager, void* data, const size_t size, fairmq_free_fn free(data); } } + static_cast(GetTransport())->IncrementMsgCounter(); } Message::Message(Manager& manager, UnmanagedRegionPtr& region, void* data, const size_t size, void* hint, FairMQTransportFactory* factory) @@ -95,6 +100,7 @@ Message::Message(Manager& manager, UnmanagedRegionPtr& region, void* data, const LOG(error) << "trying to create region message with data from outside the region"; throw runtime_error("trying to create region message with data from outside the region"); } + static_cast(GetTransport())->IncrementMsgCounter(); } bool Message::InitializeChunk(const size_t size) @@ -225,6 +231,8 @@ void Message::CloseMessage() } } } + + static_cast(GetTransport())->DecrementMsgCounter(); } } diff --git a/fairmq/shmem/Socket.cxx b/fairmq/shmem/Socket.cxx index 18eb8760..835f1d25 100644 --- a/fairmq/shmem/Socket.cxx +++ b/fairmq/shmem/Socket.cxx @@ -10,6 +10,7 @@ #include "Socket.h" #include "Message.h" #include "UnmanagedRegion.h" +#include "TransportFactory.h" #include #include diff --git a/fairmq/shmem/TransportFactory.cxx b/fairmq/shmem/TransportFactory.cxx index 13129e98..8502dcb7 100644 --- a/fairmq/shmem/TransportFactory.cxx +++ b/fairmq/shmem/TransportFactory.cxx @@ -47,6 +47,7 @@ TransportFactory::TransportFactory(const string& id, const ProgOptions* config) , fManager(nullptr) , fHeartbeatThread() , fSendHeartbeats(true) + , fMsgCounter(0) { int major, minor, patch; zmq_version(&major, &minor, &patch); @@ -168,6 +169,15 @@ Transport TransportFactory::GetType() const return fTransportType; } +void TransportFactory::Reset() +{ + if (fMsgCounter.load() != 0) { + LOG(error) << "Message counter during Reset expected to be 0, found: " << fMsgCounter.load(); + throw MessageError(tools::ToString("Message counter during Reset expected to be 0, found: ", fMsgCounter.load())); + } +} + + TransportFactory::~TransportFactory() { LOG(debug) << "Destroying Shared Memory transport..."; diff --git a/fairmq/shmem/TransportFactory.h b/fairmq/shmem/TransportFactory.h index 6e20e905..2864004e 100644 --- a/fairmq/shmem/TransportFactory.h +++ b/fairmq/shmem/TransportFactory.h @@ -55,7 +55,10 @@ class TransportFactory final : public fair::mq::TransportFactory void Interrupt() override { Socket::Interrupt(); } void Resume() override { Socket::Resume(); } - void Reset() override {} + void Reset() override; + + void IncrementMsgCounter() { ++fMsgCounter; } + void DecrementMsgCounter() { --fMsgCounter; } ~TransportFactory() override; @@ -69,6 +72,7 @@ class TransportFactory final : public fair::mq::TransportFactory std::unique_ptr fManager; std::thread fHeartbeatThread; std::atomic fSendHeartbeats; + std::atomic fMsgCounter; }; } // namespace shmem diff --git a/test/device/_error_state.cxx b/test/device/_error_state.cxx index 0051ad1b..61a06f02 100644 --- a/test/device/_error_state.cxx +++ b/test/device/_error_state.cxx @@ -11,6 +11,7 @@ #include #include #include +#include #include #include @@ -23,6 +24,39 @@ using namespace std; using namespace fair::mq::test; using namespace fair::mq::tools; +class BadDevice : public FairMQDevice +{ + public: + BadDevice() + { + fDeviceThread = thread([&](){ + EXPECT_THROW(RunStateMachine(), fair::mq::MessageError); + }); + + SetTransport("shmem"); + + ChangeState(fair::mq::Transition::InitDevice); + WaitForState(fair::mq::State::InitializingDevice); + ChangeState(fair::mq::Transition::CompleteInit); + WaitForState(fair::mq::State::Initialized); + + parts.AddPart(NewMessage()); + } + + ~BadDevice() + { + ChangeState(fair::mq::Transition::ResetDevice); + + if (fDeviceThread.joinable()) { + fDeviceThread.join(); + } + } + + private: + thread fDeviceThread; + FairMQParts parts; +}; + void RunErrorStateIn(const string& state, const string& control, const string& input = "") { size_t session{fair::mq::tools::UuidHash()}; @@ -118,4 +152,9 @@ TEST(ErrorState, interactive_InReset) EXPECT_EXIT(RunErrorStateIn("Reset", "interactive", "q"), ::testing::ExitedWithCode(1), ""); } +TEST(ErrorState, OrphanMessages) +{ + BadDevice badDevice; +} + } // namespace