diff --git a/fairmq/shmem/Manager.h b/fairmq/shmem/Manager.h index ba90ca14..ee87e73e 100644 --- a/fairmq/shmem/Manager.h +++ b/fairmq/shmem/Manager.h @@ -396,7 +396,23 @@ class Manager const uint16_t id = cfg.id.value(); std::lock_guard lock(fLocalRegionsMtx); - auto& region = fRegions[id] = std::make_unique(fShmId, size, true, cfg); + + UnmanagedRegion* region = nullptr; + + auto it = fRegions.find(id); + if (it != fRegions.end()) { + region = it->second.get(); + if (region->fControlling) { + LOG(error) << "Unmanaged Region with id " << id << " already exists. Only unique IDs per session are allowed."; + throw TransportError(tools::ToString("Unmanaged Region with id ", id, " already exists. Only unique IDs per session are allowed.")); + } + + LOG(debug) << "Unmanaged region (view) already present, promoting to controller"; + region->BecomeController(cfg); + } else { + auto res = fRegions.emplace(id, std::make_unique(fShmId, size, true, cfg)); + region = res.first->second.get(); + } // LOG(debug) << "Created region with id '" << id << "', path: '" << cfg.path << "', flags: '" << cfg.creationFlags << "'"; // start ack receiver only if a callback has been provided. @@ -406,7 +422,7 @@ class Manager region->StartAckSender(); region->StartAckReceiver(); } - result.first = region.get(); + result.first = region; result.second = id; } fRegionsGen += 1; // signal TL cache invalidation diff --git a/fairmq/shmem/UnmanagedRegion.h b/fairmq/shmem/UnmanagedRegion.h index 2981cf39..18edcfa7 100644 --- a/fairmq/shmem/UnmanagedRegion.h +++ b/fairmq/shmem/UnmanagedRegion.h @@ -146,7 +146,7 @@ struct UnmanagedRegion LOG(debug) << "Successfully zeroed free memory of region " << id << "."; } - if (fControlling) { + if (fControlling && created) { Register(shmId, cfg); } @@ -160,6 +160,13 @@ struct UnmanagedRegion UnmanagedRegion& operator=(const UnmanagedRegion&) = delete; UnmanagedRegion& operator=(UnmanagedRegion&&) = delete; + void BecomeController(RegionConfig& cfg) + { + fControlling = true; + fLinger = cfg.linger; + fRemoveOnDestruction = cfg.removeOnDestruction; + } + void Zero() { memset(fRegion.get_address(), 0x00, fRegion.get_size()); @@ -263,10 +270,14 @@ struct UnmanagedRegion EventCounter* eventCounter = mngSegment.find_or_construct(unique_instance)(0); - bool newShmRegionCreated = shmRegions->emplace(cfg.id.value(), RegionInfo(cfg.path.c_str(), cfg.creationFlags, cfg.userFlags, cfg.size, alloc)).second; - if (newShmRegionCreated) { - (eventCounter->fCount)++; + auto it = shmRegions->find(cfg.id.value()); + if (it != shmRegions->end()) { + LOG(error) << "Unmanaged Region with id " << cfg.id.value() << " has already been registered. Only unique IDs per session are allowed."; + throw TransportError(tools::ToString("Unmanaged Region with id ", cfg.id.value(), " has already been registered. Only unique IDs per session are allowed.")); } + + shmRegions->emplace(cfg.id.value(), RegionInfo(cfg.path.c_str(), cfg.creationFlags, cfg.userFlags, cfg.size, alloc)).second; + (eventCounter->fCount)++; } void SetCallbacks(RegionCallback callback, RegionBulkCallback bulkCallback)