diff --git a/fairmq/sdk/Topology.h b/fairmq/sdk/Topology.h index 46a8d530..93e993ed 100644 --- a/fairmq/sdk/Topology.h +++ b/fairmq/sdk/Topology.h @@ -319,6 +319,7 @@ class BasicTopology : public AsioBase AsioBase::GetAllocator(), std::move(handler)); fChangeStateTarget = expectedState.at(transition); + ResetTransitionedCount(fChangeStateTarget); fDDSSession.SendCommand(GetTransitionName(transition)); if (timeout > std::chrono::milliseconds(0)) { fChangeStateOpTimer.expires_after(timeout); @@ -394,6 +395,10 @@ class BasicTopology : public AsioBase auto StateEqualsTo(DeviceState state) const -> bool { return sdk::StateEqualsTo(GetCurrentState(), state); } private: + using TransitionedCount = unsigned int; + // using TransitionCounts = std::map; + + DDSSession fDDSSession; DDSTopology fDDSTopo; TopologyStateByTask fState; @@ -403,6 +408,7 @@ class BasicTopology : public AsioBase ChangeStateOp fChangeStateOp; asio::steady_timer fChangeStateOpTimer; DeviceState fChangeStateTarget; + TransitionedCount fTransitionedCount; static auto makeTopologyState(const DDSTopo& topo) -> TopologyStateByTask { @@ -422,6 +428,9 @@ class BasicTopology : public AsioBase DeviceStatus& task = fState.at(taskId); task.initialized = true; task.state = fair::mq::GetState(endState); + if (task.state == fChangeStateTarget) { + ++fTransitionedCount; + } LOG(debug) << "Updated state entry: taskId=" << taskId << ",state=" << state; TryChangeStateCompletion(); } catch (const std::exception& e) { @@ -432,17 +441,23 @@ class BasicTopology : public AsioBase /// call only under locked fMtx! auto TryChangeStateCompletion() -> void { - bool targetStateReached( - std::all_of(fState.cbegin(), fState.cend(), [&](TopologyStateByTask::value_type i) { - return (i.second.state == fChangeStateTarget) && i.second.initialized; - })); - - if (!fChangeStateOp.IsCompleted() && targetStateReached) { + if (!fChangeStateOp.IsCompleted() && fTransitionedCount == fState.size()) { fChangeStateOpTimer.cancel(); fChangeStateOp.Complete(MakeTopologyStateFromMap()); } } + /// call only under locked fMtx! + auto ResetTransitionedCount(DeviceState targetState) -> void + { + fTransitionedCount = 0; + for (const auto& s : fState) { + if (s.second.state == targetState) { + ++fTransitionedCount; + } + } + } + /// call only under locked fMtx! auto GetCurrentStateUnsafe() const -> TopologyState {