summaryrefslogtreecommitdiff
path: root/security/nss/gtests/ssl_gtest/tls_agent.cc
diff options
context:
space:
mode:
Diffstat (limited to 'security/nss/gtests/ssl_gtest/tls_agent.cc')
-rw-r--r--security/nss/gtests/ssl_gtest/tls_agent.cc75
1 files changed, 48 insertions, 27 deletions
diff --git a/security/nss/gtests/ssl_gtest/tls_agent.cc b/security/nss/gtests/ssl_gtest/tls_agent.cc
index 9bed1ce1b7..fb66196b5b 100644
--- a/security/nss/gtests/ssl_gtest/tls_agent.cc
+++ b/security/nss/gtests/ssl_gtest/tls_agent.cc
@@ -15,6 +15,9 @@
#include "tls_filter.h"
#include "tls_parser.h"
+// This is an internal header, used to get DTLS_1_3_DRAFT_VERSION.
+#include "ssl3prot.h"
+
extern "C" {
// This is not something that should make you happy.
#include "libssl_internals.h"
@@ -23,7 +26,7 @@ extern "C" {
#define GTEST_HAS_RTTI 0
#include "gtest/gtest.h"
#include "gtest_utils.h"
-#include "scoped_ptrs.h"
+#include "nss_scoped_ptrs.h"
extern std::string g_working_dir_path;
@@ -53,7 +56,7 @@ static const uint8_t kCannedTls13ServerHello[] = {
0x00, 0x1d, 0x00, 0x20, 0xc2, 0xcf, 0x23, 0x17, 0x64, 0x23, 0x03,
0xf0, 0xfb, 0x45, 0x98, 0x26, 0xd1, 0x65, 0x24, 0xa1, 0x6c, 0xa9,
0x80, 0x8f, 0x2c, 0xac, 0x0a, 0xea, 0x53, 0x3a, 0xcb, 0xe3, 0x08,
- 0x84, 0xae, 0x19, 0x00, 0x2b, 0x00, 0x02, 0x7f, kD13};
+ 0x84, 0xae, 0x19, 0x00, 0x2b, 0x00, 0x02, 0x03, 0x04};
TlsAgent::TlsAgent(const std::string& nm, Role rl, SSLProtocolVariant var)
: name_(nm),
@@ -226,6 +229,7 @@ bool TlsAgent::EnsureTlsSetup(PRFileDesc* modelSocket) {
bool TlsAgent::MaybeSetResumptionToken() {
if (!resumption_token_.empty()) {
+ LOG("setting external resumption token");
SECStatus rv = SSL_SetResumptionToken(ssl_fd(), resumption_token_.data(),
resumption_token_.size());
@@ -583,6 +587,7 @@ void TlsAgent::CheckAuthType(SSLAuthType auth,
// switch statement because default label is different.
switch (auth) {
case ssl_auth_rsa_sign:
+ case ssl_auth_rsa_pss:
EXPECT_EQ(ssl_auth_rsa_decrypt, csinfo_.authAlgorithm)
<< "authAlgorithm for RSA is always decrypt";
break;
@@ -934,8 +939,8 @@ void TlsAgent::SendRecordDirect(const TlsRecord& record) {
SendDirect(buf);
}
-static bool ErrorIsNonFatal(PRErrorCode code) {
- return code == PR_WOULD_BLOCK_ERROR || code == SSL_ERROR_RX_SHORT_DTLS_READ;
+static bool ErrorIsFatal(PRErrorCode code) {
+ return code != PR_WOULD_BLOCK_ERROR && code != SSL_ERROR_RX_SHORT_DTLS_READ;
}
void TlsAgent::SendData(size_t bytes, size_t blocksize) {
@@ -975,7 +980,7 @@ bool TlsAgent::SendEncryptedRecord(const std::shared_ptr<TlsCipherSpec>& spec,
LOGV("Encrypting " << buf.len() << " bytes");
// Ensure that we are doing TLS 1.3.
EXPECT_GE(expected_version_, SSL_LIBRARY_VERSION_TLS_1_3);
- TlsRecordHeader header(variant_, expected_version_, kTlsApplicationDataType,
+ TlsRecordHeader header(variant_, expected_version_, ssl_ct_application_data,
seq);
DataBuffer padded = buf;
padded.Write(padded.len(), ct, 1);
@@ -994,28 +999,39 @@ bool TlsAgent::SendEncryptedRecord(const std::shared_ptr<TlsCipherSpec>& spec,
void TlsAgent::ReadBytes(size_t amount) {
uint8_t block[16384];
- int32_t rv = PR_Read(ssl_fd(), block, (std::min)(amount, sizeof(block)));
- LOGV("ReadBytes " << rv);
- int32_t err;
+ size_t remaining = amount;
+ while (remaining > 0) {
+ int32_t rv = PR_Read(ssl_fd(), block, (std::min)(amount, sizeof(block)));
+ LOGV("ReadBytes " << rv);
- if (rv >= 0) {
- size_t count = static_cast<size_t>(rv);
- for (size_t i = 0; i < count; ++i) {
- ASSERT_EQ(recv_ctr_ & 0xff, block[i]);
- recv_ctr_++;
- }
- } else {
- err = PR_GetError();
- LOG("Read error " << PORT_ErrorToName(err) << ": "
- << PORT_ErrorToString(err));
- if (err != PR_WOULD_BLOCK_ERROR && expect_readwrite_error_) {
- error_code_ = err;
- expect_readwrite_error_ = false;
+ if (rv > 0) {
+ size_t count = static_cast<size_t>(rv);
+ for (size_t i = 0; i < count; ++i) {
+ ASSERT_EQ(recv_ctr_ & 0xff, block[i]);
+ recv_ctr_++;
+ }
+ remaining -= rv;
+ } else {
+ PRErrorCode err = 0;
+ if (rv < 0) {
+ err = PR_GetError();
+ LOG("Read error " << PORT_ErrorToName(err) << ": "
+ << PORT_ErrorToString(err));
+ if (err != PR_WOULD_BLOCK_ERROR && expect_readwrite_error_) {
+ error_code_ = err;
+ expect_readwrite_error_ = false;
+ }
+ }
+ if (err != 0 && ErrorIsFatal(err)) {
+ // If we hit a fatal error, we're done.
+ remaining = 0;
+ }
+ break;
}
}
// If closed, then don't bother waiting around.
- if (rv > 0 || (rv < 0 && ErrorIsNonFatal(err))) {
+ if (remaining) {
LOGV("Re-arming");
Poller::Instance()->Wait(READABLE_EVENT, adapter_, this,
&TlsAgent::ReadableCallback);
@@ -1104,7 +1120,7 @@ void TlsAgentTestBase::MakeRecord(SSLProtocolVariant variant, uint8_t type,
if (variant == ssl_variant_stream) {
index = out->Write(index, version, 2);
} else if (version >= SSL_LIBRARY_VERSION_TLS_1_3 &&
- type == kTlsApplicationDataType) {
+ type == ssl_ct_application_data) {
uint32_t epoch = (sequence_number >> 48) & 0x3;
uint32_t seqno = sequence_number & ((1ULL << 30) - 1);
index = out->Write(index, (epoch << 30) | seqno, 4);
@@ -1157,10 +1173,10 @@ void TlsAgentTestBase::MakeTrivialHandshakeRecord(uint8_t hs_type,
size_t hs_len,
DataBuffer* out) {
size_t index = 0;
- index = out->Write(index, kTlsHandshakeType, 1); // Content Type
- index = out->Write(index, 3, 1); // Version high
- index = out->Write(index, 1, 1); // Version low
- index = out->Write(index, 4 + hs_len, 2); // Length
+ index = out->Write(index, ssl_ct_handshake, 1); // Content Type
+ index = out->Write(index, 3, 1); // Version high
+ index = out->Write(index, 1, 1); // Version low
+ index = out->Write(index, 4 + hs_len, 2); // Length
index = out->Write(index, hs_type, 1); // Handshake record type.
index = out->Write(index, hs_len, 3); // Handshake length
@@ -1173,6 +1189,11 @@ DataBuffer TlsAgentTestBase::MakeCannedTls13ServerHello() {
DataBuffer sh(kCannedTls13ServerHello, sizeof(kCannedTls13ServerHello));
if (variant_ == ssl_variant_datagram) {
sh.Write(0, SSL_LIBRARY_VERSION_DTLS_1_2_WIRE, 2);
+ // The version should be at the end.
+ uint32_t v;
+ EXPECT_TRUE(sh.Read(sh.len() - 2, 2, &v));
+ EXPECT_EQ(static_cast<uint32_t>(SSL_LIBRARY_VERSION_TLS_1_3), v);
+ sh.Write(sh.len() - 2, 0x7f00 | DTLS_1_3_DRAFT_VERSION, 2);
}
return sh;
}