From a4771d739ca5be6bf9152ba21863b95758a1051a Mon Sep 17 00:00:00 2001 From: Alexey Rybalchenko Date: Thu, 18 Aug 2022 21:40:56 +0200 Subject: [PATCH] fix(shm): race/deadlock in region locks --- examples/region/sampler.cxx | 2 + fairmq/shmem/Manager.h | 120 ++++++++++++++++++++++-------------- fairmq/shmem/Message.h | 4 +- 3 files changed, 77 insertions(+), 49 deletions(-) diff --git a/examples/region/sampler.cxx b/examples/region/sampler.cxx index ca745648..c1711df3 100644 --- a/examples/region/sampler.cxx +++ b/examples/region/sampler.cxx @@ -76,6 +76,8 @@ struct Sampler : fair::mq::Device void ResetTask() override { + // give some time for acks to be received + std::this_thread::sleep_for(std::chrono::milliseconds(250)); fRegion.reset(); { std::lock_guard lock(fMtx); diff --git a/fairmq/shmem/Manager.h b/fairmq/shmem/Manager.h index 66b963cd..21404369 100644 --- a/fairmq/shmem/Manager.h +++ b/fairmq/shmem/Manager.h @@ -397,12 +397,10 @@ class Manager UnmanagedRegion* region = nullptr; bool newRegionCreated = false; - { - std::lock_guard lock(fLocalRegionsMtx); - auto res = fRegions.emplace(id, std::make_unique(fShmId, size, false, cfg)); - newRegionCreated = res.second; - region = res.first->second.get(); - } + std::lock_guard lock(fLocalRegionsMtx); + auto res = fRegions.emplace(id, std::make_unique(fShmId, size, false, cfg)); + newRegionCreated = res.second; + region = res.first->second.get(); // LOG(debug) << "Created region with id '" << id << "', path: '" << cfg.path << "', flags: '" << cfg.creationFlags << "'"; if (!newRegionCreated) { @@ -429,7 +427,7 @@ class Manager } } - UnmanagedRegion* GetRegion(uint16_t id) + UnmanagedRegion* GetRegionFromCache(uint16_t id) { // NOTE: gcc optimizations. Prevent loading tls addresses many times in the fast path const auto &lTlCache = fTlRegionCache; @@ -443,41 +441,40 @@ class Manager } } - boost::interprocess::scoped_lock shmLock(*fShmMtx); // slow path: check invalidation if (lTlCacheGen != fRegionsGen) { fTlRegionCache.fRegionsTLCache.clear(); } - std::lock_guard lock(fLocalRegionsMtx); - auto* lRegion = GetRegionUnsafe(id, shmLock); + auto* lRegion = GetRegion(id); fTlRegionCache.fRegionsTLCache.emplace_back(std::make_tuple(lRegion, id, fShmId64)); fTlRegionCache.fRegionsTLCacheGen = fRegionsGen; return lRegion; } - UnmanagedRegion* GetRegionUnsafe(uint16_t id, boost::interprocess::scoped_lock& lockedShmLock) + UnmanagedRegion* GetRegion(uint16_t id) { + std::lock_guard lock(fLocalRegionsMtx); // remote region could actually be a local one if a message originates from this device (has been sent out and returned) auto it = fRegions.find(id); if (it != fRegions.end()) { return it->second.get(); } else { try { - // get region info - RegionInfo regionInfo = fShmRegions->at(id); - // safe to unlock now - no shm container accessed after this - lockedShmLock.unlock(); RegionConfig cfg; - cfg.id = id; - cfg.creationFlags = regionInfo.fCreationFlags; - cfg.path = regionInfo.fPath.c_str(); + // get region info + { + boost::interprocess::scoped_lock shmLock(*fShmMtx); + RegionInfo regionInfo = fShmRegions->at(id); + cfg.id = id; + cfg.creationFlags = regionInfo.fCreationFlags; + cfg.path = regionInfo.fPath.c_str(); + } // LOG(debug) << "Located remote region with id '" << id << "', path: '" << cfg.path << "', flags: '" << cfg.creationFlags << "'"; auto r = fRegions.emplace(id, std::make_unique(fShmId, 0, true, std::move(cfg))); r.first->second->InitializeQueues(); r.first->second->StartAckSender(); - lockedShmLock.lock(); return r.first->second.get(); } catch (std::out_of_range& oor) { LOG(error) << "Could not get remote region with id '" << id << "'. Does the region creator run with the same session id?"; @@ -493,10 +490,10 @@ class Manager void RemoveRegion(uint16_t id) { try { + boost::interprocess::scoped_lock shmLock(*fShmMtx); std::lock_guard lock(fLocalRegionsMtx); fRegions.at(id)->StopAcks(); { - boost::interprocess::scoped_lock shmLock(*fShmMtx); if (fRegions.at(id)->RemoveOnDestruction()) { fShmRegions->at(id).fDestroyed = true; (fEventCounter->fCount)++; @@ -512,44 +509,73 @@ class Manager std::vector GetRegionInfo() { std::vector result; - boost::interprocess::scoped_lock shmLock(*fShmMtx); + std::map regionCfgs; - for (const auto& e : *fShmSegments) { - // make sure any segments in the session are found - GetSegment(e.first); - try { + { + boost::interprocess::scoped_lock shmLock(*fShmMtx); + + for (const auto& [segmentId, segmentInfo] : *fShmSegments) { + // make sure any segments in the session are found + GetSegment(segmentId); + try { + fair::mq::RegionInfo info; + info.managed = true; + info.id = segmentId; + info.event = RegionEvent::created; + info.ptr = boost::apply_visitor(SegmentAddress(), fSegments.at(segmentId)); + info.size = boost::apply_visitor(SegmentSize(), fSegments.at(segmentId)); + result.push_back(info); + } catch (const std::out_of_range& oor) { + LOG(error) << "could not find segment with id " << segmentId; + LOG(error) << oor.what(); + } + } + + for (const auto& [regionId, regionInfo] : *fShmRegions) { fair::mq::RegionInfo info; - info.managed = true; - info.id = e.first; - info.event = RegionEvent::created; - info.ptr = boost::apply_visitor(SegmentAddress(), fSegments.at(e.first)); - info.size = boost::apply_visitor(SegmentSize(), fSegments.at(e.first)); + info.managed = false; + info.id = regionId; + info.flags = regionInfo.fUserFlags; + info.event = regionInfo.fDestroyed ? RegionEvent::destroyed : RegionEvent::created; + if (info.event == RegionEvent::created) { + RegionConfig cfg; + cfg.id = info.id; + cfg.creationFlags = info.id; + cfg.path = regionInfo.fPath.c_str(); + regionCfgs.emplace(info.id, cfg); + // fill the ptr+size info after shmLock is released, to avoid constructing local region under it + } else { + info.ptr = nullptr; + info.size = 0; + } result.push_back(info); - } catch (const std::out_of_range& oor) { - LOG(error) << "could not find segment with id " << e.first; - LOG(error) << oor.what(); } } - for (const auto& e : *fShmRegions) { - fair::mq::RegionInfo info; - info.managed = false; - info.id = e.first; - info.flags = e.second.fUserFlags; - info.event = e.second.fDestroyed ? RegionEvent::destroyed : RegionEvent::created; - if (info.event == RegionEvent::created) { - auto region = GetRegionUnsafe(info.id, shmLock); - if (region) { + // do another iteration outside of shm lock, to fill ptr+size of unmanaged regions + for (auto& info : result) { + if (!info.managed && info.event == RegionEvent::created) { + auto cfgIt = regionCfgs.find(info.id); + if (cfgIt != regionCfgs.end()) { + UnmanagedRegion* region = nullptr; + std::lock_guard lock(fLocalRegionsMtx); + auto it = fRegions.find(info.id); + if (it != fRegions.end()) { + region = it->second.get(); + } else { + auto r = fRegions.emplace(cfgIt->first, std::make_unique(fShmId, 0, true, cfgIt->second)); + region = r.first->second.get(); + region->InitializeQueues(); + region->StartAckSender(); + } + info.ptr = region->GetData(); info.size = region->GetSize(); } else { - throw std::runtime_error(tools::ToString("GetRegionInfo() could not get region with id '", info.id, "'")); + info.ptr = nullptr; + info.size = 0; } - } else { - info.ptr = nullptr; - info.size = 0; } - result.push_back(info); } return result; diff --git a/fairmq/shmem/Message.h b/fairmq/shmem/Message.h index 37c8779f..2aa9e93d 100644 --- a/fairmq/shmem/Message.h +++ b/fairmq/shmem/Message.h @@ -195,7 +195,7 @@ class Message final : public fair::mq::Message fLocalPtr = nullptr; } } else { - fRegionPtr = fManager.GetRegion(fMeta.fRegionId); + fRegionPtr = fManager.GetRegionFromCache(fMeta.fRegionId); if (fRegionPtr) { fLocalPtr = reinterpret_cast(fRegionPtr->GetData()) + fMeta.fHandle; } else { @@ -365,7 +365,7 @@ class Message final : public fair::mq::Message void ReleaseUnmanagedRegionBlock() { if (!fRegionPtr) { - fRegionPtr = fManager.GetRegion(fMeta.fRegionId); + fRegionPtr = fManager.GetRegionFromCache(fMeta.fRegionId); } if (fRegionPtr) {