diff options
Diffstat (limited to 'methods')
-rw-r--r-- | methods/connect.cc | 79 | ||||
-rw-r--r-- | methods/http.cc | 6 |
2 files changed, 54 insertions, 31 deletions
diff --git a/methods/connect.cc b/methods/connect.cc index fdcf965f8..61968efe0 100644 --- a/methods/connect.cc +++ b/methods/connect.cc @@ -808,6 +808,7 @@ struct TlsFd : public MethodFd gnutls_session_t session; gnutls_certificate_credentials_t credentials; std::string hostname; + unsigned long Timeout; int Fd() APT_OVERRIDE { return UnderlyingFd->Fd(); } @@ -820,9 +821,56 @@ struct TlsFd : public MethodFd return HandleError(gnutls_record_send(session, buf, count)); } + ssize_t DoTLSHandshake() + { + int err; + // Do the handshake. Our socket is non-blocking, so we need to call WaitFd() + // accordingly. + do + { + err = gnutls_handshake(session); + if ((err == GNUTLS_E_INTERRUPTED || err == GNUTLS_E_AGAIN) && + WaitFd(this->Fd(), gnutls_record_get_direction(session) == 1, Timeout) == false) + { + _error->Errno("select", "Could not wait for server fd"); + return err; + } + } while (err < 0 && gnutls_error_is_fatal(err) == 0); + + if (err < 0) + { + // Print reason why validation failed. + if (err == GNUTLS_E_CERTIFICATE_VERIFICATION_ERROR) + { + gnutls_datum_t txt; + auto type = gnutls_certificate_type_get(session); + auto status = gnutls_session_get_verify_cert_status(session); + if (gnutls_certificate_verification_status_print(status, type, &txt, 0) == 0) + { + _error->Error("Certificate verification failed: %s", txt.data); + } + gnutls_free(txt.data); + } + _error->Error("Could not handshake: %s", gnutls_strerror(err)); + } + return err; + } + template <typename T> T HandleError(T err) { + // Server may request re-handshake if client certificates need to be provided + // based on resource requested + if (err == GNUTLS_E_REHANDSHAKE) + { + int rc = DoTLSHandshake(); + // Only reset err if DoTLSHandshake() fails. + // Otherwise, we want to follow the original error path and set errno to EAGAIN + // so that the request is retried. + if (rc < 0) + err = rc; + } + if (err < 0 && gnutls_error_is_fatal(err)) errno = EIO; else if (err < 0) @@ -859,6 +907,7 @@ ResultState UnwrapTLS(std::string Host, std::unique_ptr<MethodFd> &Fd, tlsFd->hostname = Host; tlsFd->UnderlyingFd = MethodFd::FromFd(-1); // For now + tlsFd->Timeout = Timeout; if ((err = gnutls_init(&tlsFd->session, GNUTLS_CLIENT | GNUTLS_NONBLOCK)) < 0) { @@ -992,37 +1041,11 @@ ResultState UnwrapTLS(std::string Host, std::unique_ptr<MethodFd> &Fd, tlsFd->UnderlyingFd = std::move(Fd); Fd.reset(tlsFd); - // Do the handshake. Our socket is non-blocking, so we need to call WaitFd() - // accordingly. - do - { - err = gnutls_handshake(tlsFd->session); - if ((err == GNUTLS_E_INTERRUPTED || err == GNUTLS_E_AGAIN) && - WaitFd(Fd->Fd(), gnutls_record_get_direction(tlsFd->session) == 1, Timeout) == false) - { - _error->Errno("select", "Could not wait for server fd"); - return ResultState::TRANSIENT_ERROR; - } - } while (err < 0 && gnutls_error_is_fatal(err) == 0); + // Do the handshake. + err = tlsFd->DoTLSHandshake(); if (err < 0) - { - // Print reason why validation failed. - if (err == GNUTLS_E_CERTIFICATE_VERIFICATION_ERROR) - { - gnutls_datum_t txt; - auto type = gnutls_certificate_type_get(tlsFd->session); - auto status = gnutls_session_get_verify_cert_status(tlsFd->session); - if (gnutls_certificate_verification_status_print(status, - type, &txt, 0) == 0) - { - _error->Error("Certificate verification failed: %s", txt.data); - } - gnutls_free(txt.data); - } - _error->Error("Could not handshake: %s", gnutls_strerror(err)); return ResultState::FATAL_ERROR; - } return ResultState::SUCCESSFUL; } diff --git a/methods/http.cc b/methods/http.cc index d3e16bba3..a4d187189 100644 --- a/methods/http.cc +++ b/methods/http.cc @@ -320,14 +320,14 @@ static ResultState UnwrapHTTPConnect(std::string Host, int Port, URI Proxy, std: std::string ProperHost; if (Host.find(':') != std::string::npos) - ProperHost = '[' + Proxy.Host + ']'; + ProperHost = '[' + Host + ']'; else - ProperHost = Proxy.Host; + ProperHost = Host; // Build the connect Req << "CONNECT " << Host << ":" << std::to_string(Port) << " HTTP/1.1\r\n"; if (Proxy.Port != 0) - Req << "Host: " << ProperHost << ":" << std::to_string(Proxy.Port) << "\r\n"; + Req << "Host: " << ProperHost << ":" << std::to_string(Port) << "\r\n"; else Req << "Host: " << ProperHost << "\r\n"; |