diff options
Diffstat (limited to 'security/nss/gtests/ssl_gtest/tls_agent.cc')
-rw-r--r-- | security/nss/gtests/ssl_gtest/tls_agent.cc | 75 |
1 files changed, 48 insertions, 27 deletions
diff --git a/security/nss/gtests/ssl_gtest/tls_agent.cc b/security/nss/gtests/ssl_gtest/tls_agent.cc index 9bed1ce1b7..fb66196b5b 100644 --- a/security/nss/gtests/ssl_gtest/tls_agent.cc +++ b/security/nss/gtests/ssl_gtest/tls_agent.cc @@ -15,6 +15,9 @@ #include "tls_filter.h" #include "tls_parser.h" +// This is an internal header, used to get DTLS_1_3_DRAFT_VERSION. +#include "ssl3prot.h" + extern "C" { // This is not something that should make you happy. #include "libssl_internals.h" @@ -23,7 +26,7 @@ extern "C" { #define GTEST_HAS_RTTI 0 #include "gtest/gtest.h" #include "gtest_utils.h" -#include "scoped_ptrs.h" +#include "nss_scoped_ptrs.h" extern std::string g_working_dir_path; @@ -53,7 +56,7 @@ static const uint8_t kCannedTls13ServerHello[] = { 0x00, 0x1d, 0x00, 0x20, 0xc2, 0xcf, 0x23, 0x17, 0x64, 0x23, 0x03, 0xf0, 0xfb, 0x45, 0x98, 0x26, 0xd1, 0x65, 0x24, 0xa1, 0x6c, 0xa9, 0x80, 0x8f, 0x2c, 0xac, 0x0a, 0xea, 0x53, 0x3a, 0xcb, 0xe3, 0x08, - 0x84, 0xae, 0x19, 0x00, 0x2b, 0x00, 0x02, 0x7f, kD13}; + 0x84, 0xae, 0x19, 0x00, 0x2b, 0x00, 0x02, 0x03, 0x04}; TlsAgent::TlsAgent(const std::string& nm, Role rl, SSLProtocolVariant var) : name_(nm), @@ -226,6 +229,7 @@ bool TlsAgent::EnsureTlsSetup(PRFileDesc* modelSocket) { bool TlsAgent::MaybeSetResumptionToken() { if (!resumption_token_.empty()) { + LOG("setting external resumption token"); SECStatus rv = SSL_SetResumptionToken(ssl_fd(), resumption_token_.data(), resumption_token_.size()); @@ -583,6 +587,7 @@ void TlsAgent::CheckAuthType(SSLAuthType auth, // switch statement because default label is different. switch (auth) { case ssl_auth_rsa_sign: + case ssl_auth_rsa_pss: EXPECT_EQ(ssl_auth_rsa_decrypt, csinfo_.authAlgorithm) << "authAlgorithm for RSA is always decrypt"; break; @@ -934,8 +939,8 @@ void TlsAgent::SendRecordDirect(const TlsRecord& record) { SendDirect(buf); } -static bool ErrorIsNonFatal(PRErrorCode code) { - return code == PR_WOULD_BLOCK_ERROR || code == SSL_ERROR_RX_SHORT_DTLS_READ; +static bool ErrorIsFatal(PRErrorCode code) { + return code != PR_WOULD_BLOCK_ERROR && code != SSL_ERROR_RX_SHORT_DTLS_READ; } void TlsAgent::SendData(size_t bytes, size_t blocksize) { @@ -975,7 +980,7 @@ bool TlsAgent::SendEncryptedRecord(const std::shared_ptr<TlsCipherSpec>& spec, LOGV("Encrypting " << buf.len() << " bytes"); // Ensure that we are doing TLS 1.3. EXPECT_GE(expected_version_, SSL_LIBRARY_VERSION_TLS_1_3); - TlsRecordHeader header(variant_, expected_version_, kTlsApplicationDataType, + TlsRecordHeader header(variant_, expected_version_, ssl_ct_application_data, seq); DataBuffer padded = buf; padded.Write(padded.len(), ct, 1); @@ -994,28 +999,39 @@ bool TlsAgent::SendEncryptedRecord(const std::shared_ptr<TlsCipherSpec>& spec, void TlsAgent::ReadBytes(size_t amount) { uint8_t block[16384]; - int32_t rv = PR_Read(ssl_fd(), block, (std::min)(amount, sizeof(block))); - LOGV("ReadBytes " << rv); - int32_t err; + size_t remaining = amount; + while (remaining > 0) { + int32_t rv = PR_Read(ssl_fd(), block, (std::min)(amount, sizeof(block))); + LOGV("ReadBytes " << rv); - if (rv >= 0) { - size_t count = static_cast<size_t>(rv); - for (size_t i = 0; i < count; ++i) { - ASSERT_EQ(recv_ctr_ & 0xff, block[i]); - recv_ctr_++; - } - } else { - err = PR_GetError(); - LOG("Read error " << PORT_ErrorToName(err) << ": " - << PORT_ErrorToString(err)); - if (err != PR_WOULD_BLOCK_ERROR && expect_readwrite_error_) { - error_code_ = err; - expect_readwrite_error_ = false; + if (rv > 0) { + size_t count = static_cast<size_t>(rv); + for (size_t i = 0; i < count; ++i) { + ASSERT_EQ(recv_ctr_ & 0xff, block[i]); + recv_ctr_++; + } + remaining -= rv; + } else { + PRErrorCode err = 0; + if (rv < 0) { + err = PR_GetError(); + LOG("Read error " << PORT_ErrorToName(err) << ": " + << PORT_ErrorToString(err)); + if (err != PR_WOULD_BLOCK_ERROR && expect_readwrite_error_) { + error_code_ = err; + expect_readwrite_error_ = false; + } + } + if (err != 0 && ErrorIsFatal(err)) { + // If we hit a fatal error, we're done. + remaining = 0; + } + break; } } // If closed, then don't bother waiting around. - if (rv > 0 || (rv < 0 && ErrorIsNonFatal(err))) { + if (remaining) { LOGV("Re-arming"); Poller::Instance()->Wait(READABLE_EVENT, adapter_, this, &TlsAgent::ReadableCallback); @@ -1104,7 +1120,7 @@ void TlsAgentTestBase::MakeRecord(SSLProtocolVariant variant, uint8_t type, if (variant == ssl_variant_stream) { index = out->Write(index, version, 2); } else if (version >= SSL_LIBRARY_VERSION_TLS_1_3 && - type == kTlsApplicationDataType) { + type == ssl_ct_application_data) { uint32_t epoch = (sequence_number >> 48) & 0x3; uint32_t seqno = sequence_number & ((1ULL << 30) - 1); index = out->Write(index, (epoch << 30) | seqno, 4); @@ -1157,10 +1173,10 @@ void TlsAgentTestBase::MakeTrivialHandshakeRecord(uint8_t hs_type, size_t hs_len, DataBuffer* out) { size_t index = 0; - index = out->Write(index, kTlsHandshakeType, 1); // Content Type - index = out->Write(index, 3, 1); // Version high - index = out->Write(index, 1, 1); // Version low - index = out->Write(index, 4 + hs_len, 2); // Length + index = out->Write(index, ssl_ct_handshake, 1); // Content Type + index = out->Write(index, 3, 1); // Version high + index = out->Write(index, 1, 1); // Version low + index = out->Write(index, 4 + hs_len, 2); // Length index = out->Write(index, hs_type, 1); // Handshake record type. index = out->Write(index, hs_len, 3); // Handshake length @@ -1173,6 +1189,11 @@ DataBuffer TlsAgentTestBase::MakeCannedTls13ServerHello() { DataBuffer sh(kCannedTls13ServerHello, sizeof(kCannedTls13ServerHello)); if (variant_ == ssl_variant_datagram) { sh.Write(0, SSL_LIBRARY_VERSION_DTLS_1_2_WIRE, 2); + // The version should be at the end. + uint32_t v; + EXPECT_TRUE(sh.Read(sh.len() - 2, 2, &v)); + EXPECT_EQ(static_cast<uint32_t>(SSL_LIBRARY_VERSION_TLS_1_3), v); + sh.Write(sh.len() - 2, 0x7f00 | DTLS_1_3_DRAFT_VERSION, 2); } return sh; } |