diff options
Diffstat (limited to 'security/nss/gtests/ssl_gtest/tls_agent.cc')
-rw-r--r-- | security/nss/gtests/ssl_gtest/tls_agent.cc | 131 |
1 files changed, 82 insertions, 49 deletions
diff --git a/security/nss/gtests/ssl_gtest/tls_agent.cc b/security/nss/gtests/ssl_gtest/tls_agent.cc index 2f71caedb0..9bed1ce1b7 100644 --- a/security/nss/gtests/ssl_gtest/tls_agent.cc +++ b/security/nss/gtests/ssl_gtest/tls_agent.cc @@ -33,6 +33,7 @@ const char* TlsAgent::states[] = {"INIT", "CONNECTING", "CONNECTED", "ERROR"}; const std::string TlsAgent::kClient = "client"; // both sign and encrypt const std::string TlsAgent::kRsa2048 = "rsa2048"; // bigger +const std::string TlsAgent::kRsa8192 = "rsa8192"; // biggest allowed const std::string TlsAgent::kServerRsa = "rsa"; // both sign and encrypt const std::string TlsAgent::kServerRsaSign = "rsa_sign"; const std::string TlsAgent::kServerRsaPss = "rsa_pss"; @@ -44,13 +45,22 @@ const std::string TlsAgent::kServerEcdhRsa = "ecdh_rsa"; const std::string TlsAgent::kServerEcdhEcdsa = "ecdh_ecdsa"; const std::string TlsAgent::kServerDsa = "dsa"; -TlsAgent::TlsAgent(const std::string& name, Role role, - SSLProtocolVariant variant) - : name_(name), - variant_(variant), - role_(role), +static const uint8_t kCannedTls13ServerHello[] = { + 0x03, 0x03, 0x9c, 0xbc, 0x14, 0x9b, 0x0e, 0x2e, 0xfa, 0x0d, 0xf3, + 0xf0, 0x5c, 0x70, 0x7a, 0xe0, 0xd1, 0x9b, 0x3e, 0x5a, 0x44, 0x6b, + 0xdf, 0xe5, 0xc2, 0x28, 0x64, 0xf7, 0x00, 0xc1, 0x9c, 0x08, 0x76, + 0x08, 0x00, 0x13, 0x01, 0x00, 0x00, 0x2e, 0x00, 0x33, 0x00, 0x24, + 0x00, 0x1d, 0x00, 0x20, 0xc2, 0xcf, 0x23, 0x17, 0x64, 0x23, 0x03, + 0xf0, 0xfb, 0x45, 0x98, 0x26, 0xd1, 0x65, 0x24, 0xa1, 0x6c, 0xa9, + 0x80, 0x8f, 0x2c, 0xac, 0x0a, 0xea, 0x53, 0x3a, 0xcb, 0xe3, 0x08, + 0x84, 0xae, 0x19, 0x00, 0x2b, 0x00, 0x02, 0x7f, kD13}; + +TlsAgent::TlsAgent(const std::string& nm, Role rl, SSLProtocolVariant var) + : name_(nm), + variant_(var), + role_(rl), server_key_bits_(0), - adapter_(new DummyPrSocket(role_str(), variant)), + adapter_(new DummyPrSocket(role_str(), var)), ssl_fd_(nullptr), state_(STATE_INIT), timer_handle_(nullptr), @@ -103,11 +113,11 @@ TlsAgent::~TlsAgent() { } } -void TlsAgent::SetState(State state) { - if (state_ == state) return; +void TlsAgent::SetState(State s) { + if (state_ == s) return; - LOG("Changing state from " << state_ << " to " << state); - state_ = state; + LOG("Changing state from " << state_ << " to " << s); + state_ = s; } /*static*/ bool TlsAgent::LoadCertificate(const std::string& name, @@ -124,11 +134,11 @@ void TlsAgent::SetState(State state) { return true; } -bool TlsAgent::ConfigServerCert(const std::string& name, bool updateKeyBits, +bool TlsAgent::ConfigServerCert(const std::string& id, bool updateKeyBits, const SSLExtraServerCertData* serverCertData) { ScopedCERTCertificate cert; ScopedSECKEYPrivateKey priv; - if (!TlsAgent::LoadCertificate(name, &cert, &priv)) { + if (!TlsAgent::LoadCertificate(id, &cert, &priv)) { return false; } @@ -175,6 +185,10 @@ bool TlsAgent::EnsureTlsSetup(PRFileDesc* modelSocket) { if (rv != SECSuccess) return false; } + ScopedCERTCertList anchors(CERT_NewCertList()); + rv = SSL_SetTrustAnchors(ssl_fd(), anchors.get()); + if (rv != SECSuccess) return false; + if (role_ == SERVER) { EXPECT_TRUE(ConfigServerCert(name_, true)); @@ -182,10 +196,6 @@ bool TlsAgent::EnsureTlsSetup(PRFileDesc* modelSocket) { EXPECT_EQ(SECSuccess, rv); if (rv != SECSuccess) return false; - ScopedCERTCertList anchors(CERT_NewCertList()); - rv = SSL_SetTrustAnchors(ssl_fd(), anchors.get()); - if (rv != SECSuccess) return false; - rv = SSL_SetMaxEarlyDataSize(ssl_fd(), 1024); EXPECT_EQ(SECSuccess, rv); if (rv != SECSuccess) return false; @@ -246,6 +256,17 @@ void TlsAgent::SetupClientAuth() { reinterpret_cast<void*>(this))); } +void CheckCertReqAgainstDefaultCAs(const CERTDistNames* caNames) { + ScopedCERTDistNames expected(CERT_GetSSLCACerts(nullptr)); + + ASSERT_EQ(expected->nnames, caNames->nnames); + + for (size_t i = 0; i < static_cast<size_t>(expected->nnames); ++i) { + EXPECT_EQ(SECEqual, + SECITEM_CompareItem(&(expected->names[i]), &(caNames->names[i]))); + } +} + SECStatus TlsAgent::GetClientAuthDataHook(void* self, PRFileDesc* fd, CERTDistNames* caNames, CERTCertificate** clientCert, @@ -254,6 +275,9 @@ SECStatus TlsAgent::GetClientAuthDataHook(void* self, PRFileDesc* fd, ScopedCERTCertificate peerCert(SSL_PeerCertificate(agent->ssl_fd())); EXPECT_TRUE(peerCert) << "Client should be able to see the server cert"; + // See bug 1457716 + // CheckCertReqAgainstDefaultCAs(caNames); + ScopedCERTCertificate cert; ScopedSECKEYPrivateKey priv; if (!TlsAgent::LoadCertificate(agent->name(), &cert, &priv)) { @@ -282,8 +306,8 @@ bool TlsAgent::GetPeerChainLength(size_t* count) { return true; } -void TlsAgent::CheckCipherSuite(uint16_t cipher_suite) { - EXPECT_EQ(csinfo_.cipherSuite, cipher_suite); +void TlsAgent::CheckCipherSuite(uint16_t suite) { + EXPECT_EQ(csinfo_.cipherSuite, suite); } void TlsAgent::RequestClientAuth(bool requireAuth) { @@ -442,9 +466,7 @@ void TlsAgent::GetVersionRange(uint16_t* minver, uint16_t* maxver) { *maxver = vrange_.max; } -void TlsAgent::SetExpectedVersion(uint16_t version) { - expected_version_ = version; -} +void TlsAgent::SetExpectedVersion(uint16_t ver) { expected_version_ = ver; } void TlsAgent::SetServerKeyBits(uint16_t bits) { server_key_bits_ = bits; } @@ -491,10 +513,10 @@ void TlsAgent::SetSignatureSchemes(const SSLSignatureScheme* schemes, EXPECT_EQ(i, configuredCount) << "schemes in use were all set"; } -void TlsAgent::CheckKEA(SSLKEAType kea_type, SSLNamedGroup kea_group, +void TlsAgent::CheckKEA(SSLKEAType kea, SSLNamedGroup kea_group, size_t kea_size) const { EXPECT_EQ(STATE_CONNECTED, state_); - EXPECT_EQ(kea_type, info_.keaType); + EXPECT_EQ(kea, info_.keaType); if (kea_size == 0) { switch (kea_group) { case ssl_grp_ec_curve25519: @@ -515,7 +537,7 @@ void TlsAgent::CheckKEA(SSLKEAType kea_type, SSLNamedGroup kea_group, case ssl_grp_ffdhe_custom: break; default: - if (kea_type == ssl_kea_rsa) { + if (kea == ssl_kea_rsa) { kea_size = server_key_bits_; } else { EXPECT_TRUE(false) << "need to update group sizes"; @@ -534,13 +556,13 @@ void TlsAgent::CheckOriginalKEA(SSLNamedGroup kea_group) const { } } -void TlsAgent::CheckAuthType(SSLAuthType auth_type, +void TlsAgent::CheckAuthType(SSLAuthType auth, SSLSignatureScheme sig_scheme) const { EXPECT_EQ(STATE_CONNECTED, state_); - EXPECT_EQ(auth_type, info_.authType); + EXPECT_EQ(auth, info_.authType); EXPECT_EQ(server_key_bits_, info_.authKeyBits); if (expected_version_ < SSL_LIBRARY_VERSION_TLS_1_2) { - switch (auth_type) { + switch (auth) { case ssl_auth_rsa_sign: sig_scheme = ssl_sig_rsa_pkcs1_sha1md5; break; @@ -558,9 +580,8 @@ void TlsAgent::CheckAuthType(SSLAuthType auth_type, } // Check authAlgorithm, which is the old value for authType. This is a second - // switch - // statement because default label is different. - switch (auth_type) { + // switch statement because default label is different. + switch (auth) { case ssl_auth_rsa_sign: EXPECT_EQ(ssl_auth_rsa_decrypt, csinfo_.authAlgorithm) << "authAlgorithm for RSA is always decrypt"; @@ -574,7 +595,7 @@ void TlsAgent::CheckAuthType(SSLAuthType auth_type, << "authAlgorithm for ECDH_ECDSA is ECDSA (i.e., wrong)"; break; default: - EXPECT_EQ(auth_type, csinfo_.authAlgorithm) + EXPECT_EQ(auth, csinfo_.authAlgorithm) << "authAlgorithm is (usually) the same as authType"; break; } @@ -593,22 +614,20 @@ void TlsAgent::ExpectResumption() { expect_resumption_ = true; } void TlsAgent::EnableAlpn(const uint8_t* val, size_t len) { EXPECT_TRUE(EnsureTlsSetup()); - - SetOption(SSL_ENABLE_ALPN, PR_TRUE); EXPECT_EQ(SECSuccess, SSL_SetNextProtoNego(ssl_fd(), val, len)); } void TlsAgent::CheckAlpn(SSLNextProtoState expected_state, const std::string& expected) const { - SSLNextProtoState state; + SSLNextProtoState alpn_state; char chosen[10]; unsigned int chosen_len; - SECStatus rv = SSL_GetNextProto(ssl_fd(), &state, + SECStatus rv = SSL_GetNextProto(ssl_fd(), &alpn_state, reinterpret_cast<unsigned char*>(chosen), &chosen_len, sizeof(chosen)); EXPECT_EQ(SECSuccess, rv); - EXPECT_EQ(expected_state, state); - if (state == SSL_NEXT_PROTO_NO_SUPPORT) { + EXPECT_EQ(expected_state, alpn_state); + if (alpn_state == SSL_NEXT_PROTO_NO_SUPPORT) { EXPECT_EQ("", expected); } else { EXPECT_NE("", expected); @@ -840,10 +859,10 @@ void TlsAgent::CheckSecretsDestroyed() { ASSERT_EQ(PR_TRUE, SSLInt_CheckSecretsDestroyed(ssl_fd())); } -void TlsAgent::SetDowngradeCheckVersion(uint16_t version) { +void TlsAgent::SetDowngradeCheckVersion(uint16_t ver) { ASSERT_TRUE(EnsureTlsSetup()); - SECStatus rv = SSL_SetDowngradeCheckVersion(ssl_fd(), version); + SECStatus rv = SSL_SetDowngradeCheckVersion(ssl_fd(), ver); ASSERT_EQ(SECSuccess, rv); } @@ -920,9 +939,9 @@ static bool ErrorIsNonFatal(PRErrorCode code) { } void TlsAgent::SendData(size_t bytes, size_t blocksize) { - uint8_t block[4096]; + uint8_t block[16385]; // One larger than the maximum record size. - ASSERT_LT(blocksize, sizeof(block)); + ASSERT_LE(blocksize, sizeof(block)); while (bytes) { size_t tosend = std::min(blocksize, bytes); @@ -951,12 +970,13 @@ void TlsAgent::SendBuffer(const DataBuffer& buf) { } bool TlsAgent::SendEncryptedRecord(const std::shared_ptr<TlsCipherSpec>& spec, - uint16_t wireVersion, uint64_t seq, - uint8_t ct, const DataBuffer& buf) { - LOGV("Writing " << buf.len() << " bytes"); - // Ensure we are a TLS 1.3 cipher agent. + uint64_t seq, uint8_t ct, + const DataBuffer& buf) { + LOGV("Encrypting " << buf.len() << " bytes"); + // Ensure that we are doing TLS 1.3. EXPECT_GE(expected_version_, SSL_LIBRARY_VERSION_TLS_1_3); - TlsRecordHeader header(wireVersion, kTlsApplicationDataType, seq); + TlsRecordHeader header(variant_, expected_version_, kTlsApplicationDataType, + seq); DataBuffer padded = buf; padded.Write(padded.len(), ct, 1); DataBuffer ciphertext; @@ -1078,15 +1098,20 @@ void TlsAgentTestBase::ProcessMessage(const DataBuffer& buffer, void TlsAgentTestBase::MakeRecord(SSLProtocolVariant variant, uint8_t type, uint16_t version, const uint8_t* buf, size_t len, DataBuffer* out, - uint64_t seq_num) { + uint64_t sequence_number) { size_t index = 0; index = out->Write(index, type, 1); if (variant == ssl_variant_stream) { index = out->Write(index, version, 2); + } else if (version >= SSL_LIBRARY_VERSION_TLS_1_3 && + type == kTlsApplicationDataType) { + uint32_t epoch = (sequence_number >> 48) & 0x3; + uint32_t seqno = sequence_number & ((1ULL << 30) - 1); + index = out->Write(index, (epoch << 30) | seqno, 4); } else { index = out->Write(index, TlsVersionToDtlsVersion(version), 2); - index = out->Write(index, seq_num >> 32, 4); - index = out->Write(index, seq_num & PR_UINT32_MAX, 4); + index = out->Write(index, sequence_number >> 32, 4); + index = out->Write(index, sequence_number & PR_UINT32_MAX, 4); } index = out->Write(index, len, 2); out->Write(index, buf, len); @@ -1144,4 +1169,12 @@ void TlsAgentTestBase::MakeTrivialHandshakeRecord(uint8_t hs_type, } } +DataBuffer TlsAgentTestBase::MakeCannedTls13ServerHello() { + DataBuffer sh(kCannedTls13ServerHello, sizeof(kCannedTls13ServerHello)); + if (variant_ == ssl_variant_datagram) { + sh.Write(0, SSL_LIBRARY_VERSION_DTLS_1_2_WIRE, 2); + } + return sh; +} + } // namespace nss_test |