diff options
Diffstat (limited to 'methods')
-rw-r--r-- | methods/connect.cc | 240 |
1 files changed, 184 insertions, 56 deletions
diff --git a/methods/connect.cc b/methods/connect.cc index 1354fe97b..334a1d3f3 100644 --- a/methods/connect.cc +++ b/methods/connect.cc @@ -23,6 +23,7 @@ #include <gnutls/gnutls.h> #include <gnutls/x509.h> +#include <list> #include <set> #include <sstream> #include <string> @@ -35,6 +36,7 @@ #include <netdb.h> #include <arpa/inet.h> #include <netinet/in.h> +#include <sys/select.h> #include <sys/socket.h> #include "aptmethod.h" @@ -112,16 +114,54 @@ std::unique_ptr<MethodFd> MethodFd::FromFd(int iFd) // DoConnect - Attempt a connect operation /*{{{*/ // --------------------------------------------------------------------- /* This helper function attempts a connection to a single address. */ -static ResultState DoConnect(struct addrinfo *Addr, std::string const &Host, - unsigned long TimeOut, std::unique_ptr<MethodFd> &Fd, aptMethod *Owner) +struct Connection { - // Show a status indicator + struct addrinfo *Addr; + std::string Host; + aptMethod *Owner; + std::unique_ptr<FdFd> Fd; char Name[NI_MAXHOST]; char Service[NI_MAXSERV]; - Fd.reset(new FdFd()); - Name[0] = 0; - Service[0] = 0; + Connection(struct addrinfo *Addr, std::string const &Host, aptMethod *Owner) : Addr(Addr), Host(Host), Owner(Owner), Fd(new FdFd()), Name{0}, Service{0} + { + } + + // Allow moving values, but not connections. + Connection(Connection &&Conn) = default; + Connection(const Connection &Conn) = delete; + Connection &operator=(const Connection &) = delete; + Connection &operator=(Connection &&Conn) = default; + + ~Connection() + { + if (Fd != nullptr) + { + Fd->Close(); + } + } + + std::unique_ptr<MethodFd> Take() + { + /* Store the IP we are using.. If something goes + wrong this will get tacked onto the end of the error message */ + std::stringstream ss; + ioprintf(ss, _("[IP: %s %s]"), Name, Service); + Owner->SetIP(ss.str()); + Owner->Status(_("Connected to %s (%s)"), Host.c_str(), Name); + _error->Discard(); + Owner->SetFailReason(""); + LastUsed = Addr; + return std::move(Fd); + } + + ResultState DoConnect(); + + ResultState CheckError(); +}; + +ResultState Connection::DoConnect() +{ getnameinfo(Addr->ai_addr,Addr->ai_addrlen, Name,sizeof(Name),Service,sizeof(Service), NI_NUMERICHOST|NI_NUMERICSERV); @@ -130,15 +170,6 @@ static ResultState DoConnect(struct addrinfo *Addr, std::string const &Host, // if that addr did timeout before, we do not try it again if(bad_addr.find(std::string(Name)) != bad_addr.end()) return ResultState::TRANSIENT_ERROR; - - /* If this is an IP rotation store the IP we are using.. If something goes - wrong this will get tacked onto the end of the error message */ - if (LastHostAddr->ai_next != 0) - { - std::stringstream ss; - ioprintf(ss, _("[IP: %s %s]"),Name,Service); - Owner->SetIP(ss.str()); - } // Get a socket if ((static_cast<FdFd *>(Fd.get())->fd = socket(Addr->ai_family, Addr->ai_socktype, @@ -159,18 +190,11 @@ static ResultState DoConnect(struct addrinfo *Addr, std::string const &Host, return ResultState::TRANSIENT_ERROR; } - /* This implements a timeout for connect by opening the connection - nonblocking */ - if (WaitFd(Fd->Fd(), true, TimeOut) == false) - { - bad_addr.insert(bad_addr.begin(), std::string(Name)); - Owner->SetFailReason("Timeout"); - _error->Error(_("Could not connect to %s:%s (%s), " - "connection timed out"), - Host.c_str(), Service, Name); - return ResultState::TRANSIENT_ERROR; - } + return ResultState::SUCCESSFUL; +} +ResultState Connection::CheckError() +{ // Check the socket for an error condition unsigned int Err; unsigned int Len = sizeof(Err); @@ -198,6 +222,126 @@ static ResultState DoConnect(struct addrinfo *Addr, std::string const &Host, return ResultState::SUCCESSFUL; } /*}}}*/ +// Order the given host names returned by getaddrinfo() /*{{{*/ +static std::vector<struct addrinfo *> OrderAddresses(struct addrinfo *CurHost) +{ + std::vector<struct addrinfo *> preferredAddrs; + std::vector<struct addrinfo *> otherAddrs; + std::vector<struct addrinfo *> allAddrs; + + // Partition addresses into preferred and other address families + while (CurHost != 0) + { + if (preferredAddrs.empty() || CurHost->ai_family == preferredAddrs[0]->ai_family) + preferredAddrs.push_back(CurHost); + else + otherAddrs.push_back(CurHost); + + // Ignore UNIX domain sockets + do + { + CurHost = CurHost->ai_next; + } while (CurHost != 0 && CurHost->ai_family == AF_UNIX); + + /* If we reached the end of the search list then wrap around to the + start */ + if (CurHost == 0 && LastUsed != 0) + CurHost = LastHostAddr; + + // Reached the end of the search cycle + if (CurHost == LastUsed) + break; + } + + // Build a new address vector alternating between preferred and other + for (auto prefIter = preferredAddrs.cbegin(), otherIter = otherAddrs.cbegin(); + prefIter != preferredAddrs.end() || otherIter != otherAddrs.end();) + { + if (prefIter != preferredAddrs.end()) + allAddrs.push_back(*prefIter++); + if (otherIter != otherAddrs.end()) + allAddrs.push_back(*otherIter++); + } + + return std::move(allAddrs); +} + /*}}}*/ +// Check for errors and report them /*{{{*/ +static ResultState WaitAndCheckErrors(std::list<Connection> &Conns, std::unique_ptr<MethodFd> &Fd, long TimeoutMsec, bool ReportTimeout) +{ + // The last error detected + ResultState Result = ResultState::TRANSIENT_ERROR; + + struct timeval tv = { + // Split our millisecond timeout into seconds and microseconds + .tv_sec = TimeoutMsec / 1000, + .tv_usec = (TimeoutMsec % 1000) * 1000, + }; + + // We will return once we have no more connections, a time out, or + // a success. + while (!Conns.empty()) + { + fd_set Set; + int nfds = -1; + + FD_ZERO(&Set); + + for (auto &Conn : Conns) + { + int fd = Conn.Fd->Fd(); + FD_SET(fd, &Set); + nfds = std::max(nfds, fd); + } + + { + int Res; + do + { + Res = select(nfds + 1, 0, &Set, 0, (TimeoutMsec != 0 ? &tv : 0)); + } while (Res < 0 && errno == EINTR); + + if (Res == 0) + { + if (ReportTimeout) + { + for (auto &Conn : Conns) + { + Conn.Owner->SetFailReason("Timeout"); + _error->Error(_("Could not connect to %s:%s (%s), " + "connection timed out"), + Conn.Host.c_str(), Conn.Service, Conn.Name); + } + } + return ResultState::TRANSIENT_ERROR; + } + } + + // iterate over connections, remove failed ones, and return if + // there was a successful one. + for (auto ConnI = Conns.begin(); ConnI != Conns.end();) + { + if (!FD_ISSET(ConnI->Fd->Fd(), &Set)) + { + ConnI++; + continue; + } + + Result = ConnI->CheckError(); + if (Result == ResultState::SUCCESSFUL) + { + Fd = ConnI->Take(); + return Result; + } + + // Connection failed. Erase it and continue to next position + ConnI = Conns.erase(ConnI); + } + } + + return Result; +} + /*}}}*/ // Connect to a given Hostname /*{{{*/ static ResultState ConnectToHostname(std::string const &Host, int const Port, const char *const Service, int DefPort, std::unique_ptr<MethodFd> &Fd, @@ -301,39 +445,23 @@ static ResultState ConnectToHostname(std::string const &Host, int const Port, } // When we have an IP rotation stay with the last IP. - struct addrinfo *CurHost = LastHostAddr; - if (LastUsed != 0) - CurHost = LastUsed; - - while (CurHost != 0) + auto Addresses = OrderAddresses(LastUsed != nullptr ? LastUsed : LastHostAddr); + std::list<Connection> Conns; + + for (auto Addr : Addresses) { - auto const result = DoConnect(CurHost, Host, TimeOut, Fd, Owner); - if (result == ResultState::SUCCESSFUL) - { - LastUsed = CurHost; - return result; - } - Fd->Close(); + Connection Conn(Addr, Host, Owner); + if (Conn.DoConnect() != ResultState::SUCCESSFUL) + continue; - // Ignore UNIX domain sockets - do - { - CurHost = CurHost->ai_next; - } - while (CurHost != 0 && CurHost->ai_family == AF_UNIX); + Conns.push_back(std::move(Conn)); - /* If we reached the end of the search list then wrap around to the - start */ - if (CurHost == 0 && LastUsed != 0) - CurHost = LastHostAddr; - - // Reached the end of the search cycle - if (CurHost == LastUsed) - break; - - if (CurHost != 0) - _error->Discard(); - } + if (WaitAndCheckErrors(Conns, Fd, Owner->ConfigFindI("ConnectionAttemptDelayMsec", 250), false) == ResultState::SUCCESSFUL) + return ResultState::SUCCESSFUL; + } + + if (WaitAndCheckErrors(Conns, Fd, TimeOut * 1000, true) == ResultState::SUCCESSFUL) + return ResultState::SUCCESSFUL; if (_error->PendingError() == true) return ResultState::FATAL_ERROR; |