From aeab9e5407694642bd38bde2d03f63ee971d5e9c Mon Sep 17 00:00:00 2001 From: Alexey Rybalchenko Date: Mon, 22 Jun 2020 14:28:28 +0200 Subject: [PATCH] Socket.h: refactor to reduce duplicate code --- fairmq/shmem/Socket.h | 110 +++++++++++++------------------ fairmq/zeromq/Socket.h | 142 +++++++++++++++++------------------------ 2 files changed, 101 insertions(+), 151 deletions(-) diff --git a/fairmq/shmem/Socket.h b/fairmq/shmem/Socket.h index 1919db1b..df48a990 100644 --- a/fairmq/shmem/Socket.h +++ b/fairmq/shmem/Socket.h @@ -55,8 +55,7 @@ class Socket final : public fair::mq::Socket , fBytesRx(0) , fMessagesTx(0) , fMessagesRx(0) - , fSndTimeout(100) - , fRcvTimeout(100) + , fTimeout(100) { assert(context); fSocket = zmq_socket(context, GetConstant(type)); @@ -77,11 +76,11 @@ class Socket final : public fair::mq::Socket LOG(error) << "Failed setting ZMQ_LINGER socket option, reason: " << zmq_strerror(errno); } - if (zmq_setsockopt(fSocket, ZMQ_SNDTIMEO, &fSndTimeout, sizeof(fSndTimeout)) != 0) { + if (zmq_setsockopt(fSocket, ZMQ_SNDTIMEO, &fTimeout, sizeof(fTimeout)) != 0) { LOG(error) << "Failed setting ZMQ_SNDTIMEO socket option, reason: " << zmq_strerror(errno); } - if (zmq_setsockopt(fSocket, ZMQ_RCVTIMEO, &fRcvTimeout, sizeof(fRcvTimeout)) != 0) { + if (zmq_setsockopt(fSocket, ZMQ_RCVTIMEO, &fTimeout, sizeof(fTimeout)) != 0) { LOG(error) << "Failed setting ZMQ_RCVTIMEO socket option, reason: " << zmq_strerror(errno); } @@ -129,6 +128,35 @@ class Socket final : public fair::mq::Socket return true; } + bool ShouldRetry(int flags, int timeout, int& elapsed) const + { + if (!fManager.Interrupted() && ((flags & ZMQ_DONTWAIT) == 0)) { + if (timeout > 0) { + elapsed += fTimeout; + if (elapsed >= timeout) { + return false; + } + } + return true; + } else { + return false; + } + } + + int HandleErrors() const + { + if (zmq_errno() == ETERM) { + LOG(debug) << "Terminating socket " << fId; + return -1; + } else if (zmq_errno() == EINTR) { + LOG(debug) << "Transfer interrupted by system call"; + return -1; + } else { + LOG(error) << "Failed transfer on socket " << fId << ", reason: " << zmq_strerror(errno); + return -1; + } + } + int Send(MessagePtr& msg, const int timeout = -1) override { int flags = 0; @@ -150,26 +178,13 @@ class Socket final : public fair::mq::Socket fBytesTx += size; return size; } else if (zmq_errno() == EAGAIN) { - if (!fManager.Interrupted() && ((flags & ZMQ_DONTWAIT) == 0)) { - if (timeout > 0) { - elapsed += fSndTimeout; - if (elapsed >= timeout) { - return -2; - } - } + if (ShouldRetry(flags, timeout, elapsed)) { continue; } else { return -2; } - } else if (zmq_errno() == ETERM) { - LOG(info) << "terminating socket " << fId; - return -1; - } else if (zmq_errno() == EINTR) { - LOG(debug) << "Send interrupted by system call"; - return nbytes; - }else { - LOG(error) << "Failed sending on socket " << fId << ", reason: " << zmq_strerror(errno) << ", nbytes = " << nbytes; - return nbytes; + } else { + return HandleErrors(); } } @@ -206,26 +221,13 @@ class Socket final : public fair::mq::Socket ++fMessagesRx; return size; } else if (zmq_errno() == EAGAIN) { - if (!fManager.Interrupted() && ((flags & ZMQ_DONTWAIT) == 0)) { - if (timeout > 0) { - elapsed += fRcvTimeout; - if (elapsed >= timeout) { - return -2; - } - } + if (ShouldRetry(flags, timeout, elapsed)) { continue; } else { return -2; } - } else if (zmq_errno() == ETERM) { - LOG(info) << "terminating socket " << fId; - return -1; - } else if (zmq_errno() == EINTR) { - LOG(debug) << "Receive interrupted by system call"; - return nbytes; - }else { - LOG(error) << "Failed receiving on socket " << fId << ", errno: " << errno << ", reason: " << zmq_strerror(errno) << ", nbytes = " << nbytes; - return nbytes; + } else { + return HandleErrors(); } } } @@ -268,26 +270,13 @@ class Socket final : public fair::mq::Socket return totalSize; } else if (zmq_errno() == EAGAIN) { - if (!fManager.Interrupted() && ((flags & ZMQ_DONTWAIT) == 0)) { - if (timeout > 0) { - elapsed += fSndTimeout; - if (elapsed >= timeout) { - return -2; - } - } + if (ShouldRetry(flags, timeout, elapsed)) { continue; } else { return -2; } - } else if (zmq_errno() == ETERM) { - LOG(info) << "terminating socket " << fId; - return -1; - } else if (zmq_errno() == EINTR) { - LOG(debug) << "Send interrupted by system call"; - return nbytes; - }else { - LOG(error) << "Failed sending on socket " << fId << ", reason: " << zmq_strerror(errno) << ", nbytes = " << nbytes; - return nbytes; + } else { + return HandleErrors(); } } @@ -335,23 +324,13 @@ class Socket final : public fair::mq::Socket return totalSize; } else if (zmq_errno() == EAGAIN) { - if (!fManager.Interrupted() && ((flags & ZMQ_DONTWAIT) == 0)) { - if (timeout > 0) { - elapsed += fRcvTimeout; - if (elapsed >= timeout) { - return -2; - } - } + if (ShouldRetry(flags, timeout, elapsed)) { continue; } else { return -2; } - } else if (zmq_errno() == EINTR) { - LOG(debug) << "Receive interrupted by system call"; - return nbytes; } else { - LOG(error) << "Failed receiving on socket " << fId << ", errno: " << errno << ", reason: " << zmq_strerror(errno) << ", nbytes = " << nbytes; - return nbytes; + return HandleErrors(); } } @@ -521,8 +500,7 @@ class Socket final : public fair::mq::Socket std::atomic fMessagesTx; std::atomic fMessagesRx; - int fSndTimeout; - int fRcvTimeout; + int fTimeout; }; } diff --git a/fairmq/zeromq/Socket.h b/fairmq/zeromq/Socket.h index ca2454c2..453a1479 100644 --- a/fairmq/zeromq/Socket.h +++ b/fairmq/zeromq/Socket.h @@ -12,13 +12,15 @@ #include #include #include -#include #include #include #include -#include // unique_ptr + #include +#include +#include // unique_ptr + namespace fair { namespace mq { namespace zmq { @@ -35,8 +37,7 @@ class Socket final : public fair::mq::Socket , fBytesRx(0) , fMessagesTx(0) , fMessagesRx(0) - , fSndTimeout(100) - , fRcvTimeout(100) + , fTimeout(100) { if (fSocket == nullptr) { LOG(error) << "Failed creating socket " << fId << ", reason: " << zmq_strerror(errno); @@ -54,11 +55,11 @@ class Socket final : public fair::mq::Socket LOG(error) << "Failed setting ZMQ_LINGER socket option, reason: " << zmq_strerror(errno); } - if (zmq_setsockopt(fSocket, ZMQ_SNDTIMEO, &fSndTimeout, sizeof(fSndTimeout)) != 0) { + if (zmq_setsockopt(fSocket, ZMQ_SNDTIMEO, &fTimeout, sizeof(fTimeout)) != 0) { LOG(error) << "Failed setting ZMQ_SNDTIMEO socket option, reason: " << zmq_strerror(errno); } - if (zmq_setsockopt(fSocket, ZMQ_RCVTIMEO, &fRcvTimeout, sizeof(fRcvTimeout)) != 0) { + if (zmq_setsockopt(fSocket, ZMQ_RCVTIMEO, &fTimeout, sizeof(fTimeout)) != 0) { LOG(error) << "Failed setting ZMQ_RCVTIMEO socket option, reason: " << zmq_strerror(errno); } @@ -78,13 +79,12 @@ class Socket final : public fair::mq::Socket bool Bind(const std::string& address) override { - // LOG(info) << "bind socket " << fId << " on " << address; + // LOG(debug) << "Binding socket " << fId << " on " << address; if (zmq_bind(fSocket, address.c_str()) != 0) { if (errno == EADDRINUSE) { // do not print error in this case, this is handled by FairMQDevice in case no - // connection could be established after trying a number of random ports from a - // range. + // connection could be established after trying a number of random ports from a range. return false; } LOG(error) << "Failed binding socket " << fId << ", address: " << address << ", reason: " << zmq_strerror(errno); @@ -96,7 +96,7 @@ class Socket final : public fair::mq::Socket bool Connect(const std::string& address) override { - // LOG(info) << "connect socket " << fId << " on " << address; + // LOG(debug) << "Connecting socket " << fId << " on " << address; if (zmq_connect(fSocket, address.c_str()) != 0) { LOG(error) << "Failed connecting socket " << fId << ", address: " << address << ", reason: " << zmq_strerror(errno); @@ -106,6 +106,35 @@ class Socket final : public fair::mq::Socket return true; } + bool ShouldRetry(int flags, int timeout, int& elapsed) const + { + if (!fCtx.Interrupted() && ((flags & ZMQ_DONTWAIT) == 0)) { + if (timeout > 0) { + elapsed += fTimeout; + if (elapsed >= timeout) { + return false; + } + } + return true; + } else { + return false; + } + } + + int HandleErrors() const + { + if (zmq_errno() == ETERM) { + LOG(debug) << "Terminating socket " << fId; + return -1; + } else if (zmq_errno() == EINTR) { + LOG(debug) << "Transfer interrupted by system call"; + return -1; + } else { + LOG(error) << "Failed transfer on socket " << fId << ", errno: " << errno << ", reason: " << zmq_strerror(errno); + return -1; + } + } + int Send(MessagePtr& msg, const int timeout = -1) override { int flags = 0; @@ -121,29 +150,15 @@ class Socket final : public fair::mq::Socket if (nbytes >= 0) { fBytesTx += nbytes; ++fMessagesTx; - return nbytes; } else if (zmq_errno() == EAGAIN) { - if (!fCtx.Interrupted() && ((flags & ZMQ_DONTWAIT) == 0)) { - if (timeout > 0) { - elapsed += fSndTimeout; - if (elapsed >= timeout) { - return -2; - } - } + if (ShouldRetry(flags, timeout, elapsed)) { continue; } else { return -2; } - } else if (zmq_errno() == ETERM) { - LOG(info) << "terminating socket " << fId; - return -1; - } else if (zmq_errno() == EINTR) { - LOG(debug) << "Send interrupted by system call"; - return nbytes; } else { - LOG(error) << "Failed sending on socket " << fId << ", reason: " << zmq_strerror(errno); - return nbytes; + return HandleErrors(); } } } @@ -163,26 +178,13 @@ class Socket final : public fair::mq::Socket ++fMessagesRx; return nbytes; } else if (zmq_errno() == EAGAIN) { - if (!fCtx.Interrupted() && ((flags & ZMQ_DONTWAIT) == 0)) { - if (timeout > 0) { - elapsed += fRcvTimeout; - if (elapsed >= timeout) { - return -2; - } - } + if (ShouldRetry(flags, timeout, elapsed)) { continue; } else { return -2; } - } else if (zmq_errno() == ETERM) { - LOG(info) << "terminating socket " << fId; - return -1; - } else if (zmq_errno() == EINTR) { - LOG(debug) << "Receive interrupted by system call"; - return nbytes; } else { - LOG(error) << "Failed receiving on socket " << fId << ", reason: " << zmq_strerror(errno); - return nbytes; + return HandleErrors(); } } } @@ -210,32 +212,15 @@ class Socket final : public fair::mq::Socket int nbytes = zmq_msg_send(static_cast(msgVec[i].get())->GetMessage(), fSocket, (i < vecSize - 1) ? ZMQ_SNDMORE | flags : flags); if (nbytes >= 0) { totalSize += nbytes; - } else { - // according to ZMQ docs, this can only occur for the first part - if (zmq_errno() == EAGAIN) { - if (!fCtx.Interrupted() && ((flags & ZMQ_DONTWAIT) == 0)) { - if (timeout > 0) { - elapsed += fSndTimeout; - if (elapsed >= timeout) { - return -2; - } - } - repeat = true; - break; - } else { - return -2; - } - } - if (zmq_errno() == ETERM) { - LOG(info) << "terminating socket " << fId; - return -1; - } else if (zmq_errno() == EINTR) { - LOG(debug) << "Receive interrupted by system call"; - return nbytes; + } else if (zmq_errno() == EAGAIN) { + if (ShouldRetry(flags, timeout, elapsed)) { + repeat = true; + break; } else { - LOG(error) << "Failed sending on socket " << fId << ", reason: " << zmq_strerror(errno); - return nbytes; + return -2; } + } else { + return HandleErrors(); } } @@ -243,16 +228,14 @@ class Socket final : public fair::mq::Socket continue; } - // store statistics on how many messages have been sent (handle all parts as a - // single message) + // store statistics on how many messages have been sent (handle all parts as a single message) ++fMessagesTx; fBytesTx += totalSize; return totalSize; } - } // If there's only one part, send it as a regular message - else if (vecSize == 1) { + } else if (vecSize == 1) { // If there's only one part, send it as a regular message return Send(msgVec.back(), timeout); - } else { // if the vector is empty, something might be wrong + } else { // if the vector is empty, something might be wrong LOG(warn) << "Will not send empty vector"; return -1; } @@ -279,23 +262,14 @@ class Socket final : public fair::mq::Socket msgVec.push_back(move(part)); totalSize += nbytes; } else if (zmq_errno() == EAGAIN) { - if (!fCtx.Interrupted() && ((flags & ZMQ_DONTWAIT) == 0)) { - if (timeout > 0) { - elapsed += fRcvTimeout; - if (elapsed >= timeout) { - return -2; - } - } + if (ShouldRetry(flags, timeout, elapsed)) { repeat = true; break; } else { return -2; } - } else if (zmq_errno() == EINTR) { - LOG(debug) << "Receive interrupted by system call"; - return nbytes; } else { - return nbytes; + return HandleErrors(); } size_t moreSize = sizeof(more); @@ -306,8 +280,7 @@ class Socket final : public fair::mq::Socket continue; } - // store statistics on how many messages have been received (handle all parts as a - // single message) + // store statistics on how many messages have been received (handle all parts as a single message) ++fMessagesRx; fBytesRx += totalSize; return totalSize; @@ -475,8 +448,7 @@ class Socket final : public fair::mq::Socket std::atomic fMessagesTx; std::atomic fMessagesRx; - int fSndTimeout; - int fRcvTimeout; + int fTimeout; }; } // namespace zmq