diff options
author | Andrea Marchesini <amarchesini@mozilla.com> | 2020-07-29 10:52:30 +0000 |
---|---|---|
committer | Moonchild <moonchild@palemoon.org> | 2020-07-29 14:28:45 +0000 |
commit | 3323a8eda3d4cbb29909f6761cee96cc5e164cf2 (patch) | |
tree | e2b69428455e26be4e48d75c99e407825b2a4f99 | |
parent | 71293f0fe4db15d4e0309488337ab4669e9c55b3 (diff) | |
download | uxp-3323a8eda3d4cbb29909f6761cee96cc5e164cf2.tar.gz |
[xpcom] Make Base64 compatible with ReadSegments() with small buffers.
-rw-r--r-- | netwerk/test/gtest/TestBase64Stream.cpp | 95 | ||||
-rw-r--r-- | netwerk/test/gtest/moz.build | 1 | ||||
-rw-r--r-- | xpcom/io/Base64.cpp | 32 |
3 files changed, 121 insertions, 7 deletions
diff --git a/netwerk/test/gtest/TestBase64Stream.cpp b/netwerk/test/gtest/TestBase64Stream.cpp new file mode 100644 index 0000000000..37f5cb824e --- /dev/null +++ b/netwerk/test/gtest/TestBase64Stream.cpp @@ -0,0 +1,95 @@ +/* -*- Mode: C++; tab-width: 2; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ +/* This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include "gtest/gtest.h" +#include "mozilla/Base64.h" +#include "nsIInputStream.h" + +namespace mozilla { +namespace net { + +// An input stream whose ReadSegments method calls aWriter with writes of size +// aStep from the provided aInput in order to test edge-cases related to small +// buffers. +class TestStream final : public nsIInputStream { + public: + NS_DECL_ISUPPORTS; + + TestStream(const nsACString& aInput, uint32_t aStep) + : mInput(aInput), mStep(aStep) {} + + NS_IMETHOD Close() override { MOZ_CRASH("This should not be called"); } + + NS_IMETHOD Available(uint64_t* aLength) override { + *aLength = mInput.Length() - mPos; + return NS_OK; + } + + NS_IMETHOD Read(char* aBuffer, uint32_t aCount, + uint32_t* aReadCount) override { + MOZ_CRASH("This should not be called"); + } + + NS_IMETHOD ReadSegments(nsWriteSegmentFun aWriter, void* aClosure, + uint32_t aCount, uint32_t* aResult) override { + *aResult = 0; + + if (mPos == mInput.Length()) { + return NS_OK; + } + + while (aCount > 0) { + uint32_t amt = std::min(mStep, (uint32_t)(mInput.Length() - mPos)); + + uint32_t read = 0; + nsresult rv = + aWriter(this, aClosure, mInput.get() + mPos, *aResult, amt, &read); + if (NS_WARN_IF(NS_FAILED(rv))) { + return rv; + } + + *aResult += read; + aCount -= read; + mPos += read; + } + + return NS_OK; + } + + NS_IMETHOD IsNonBlocking(bool* aNonBlocking) override { + *aNonBlocking = true; + return NS_OK; + } + + private: + ~TestStream() = default; + + nsCString mInput; + const uint32_t mStep; + uint32_t mPos = 0; +}; + +NS_IMPL_ISUPPORTS(TestStream, nsIInputStream) + +// Test the base64 encoder with writer buffer sizes between 1 byte and the +// entire length of "Hello World!" in order to exercise various edge cases. +TEST(TestBase64Stream, Run) +{ + nsCString input; + input.AssignLiteral("Hello World!"); + + for (uint32_t step = 1; step <= input.Length(); ++step) { + RefPtr<TestStream> ts = new TestStream(input, step); + + nsAutoString encodedData; + nsresult rv = Base64EncodeInputStream(ts, encodedData, input.Length()); + ASSERT_TRUE(NS_SUCCEEDED(rv)); + + EXPECT_TRUE(encodedData.EqualsLiteral("SGVsbG8gV29ybGQh")); + } +} + +} // namespace net +} // namespace mozilla diff --git a/netwerk/test/gtest/moz.build b/netwerk/test/gtest/moz.build index 6e6c801521..e463feb651 100644 --- a/netwerk/test/gtest/moz.build +++ b/netwerk/test/gtest/moz.build @@ -5,6 +5,7 @@ # file, You can obtain one at http://mozilla.org/MPL/2.0/. UNIFIED_SOURCES += [ + 'TestBase64Stream.cpp', 'TestProtocolProxyService.cpp', 'TestStandardURL.cpp', ] diff --git a/xpcom/io/Base64.cpp b/xpcom/io/Base64.cpp index 911c0595ac..b9fa7baf83 100644 --- a/xpcom/io/Base64.cpp +++ b/xpcom/io/Base64.cpp @@ -108,30 +108,51 @@ EncodeInputStream_Encoder(nsIInputStream* aStream, EncodeInputStream_State<T>* state = static_cast<EncodeInputStream_State<T>*>(aClosure); + // We always consume all data. + *aWriteCount = aCount; + // If we have any data left from last time, encode it now. uint32_t countRemaining = aCount; const unsigned char* src = (const unsigned char*)aFromSegment; if (state->charsOnStack) { + MOZ_ASSERT(state->charsOnStack == 1 || state->charsOnStack == 2); + + // Not enough data to compose a triple. + if (state->charsOnStack == 1 && countRemaining == 1) { + state->charsOnStack = 2; + state->c[1] = src[0]; + return NS_OK; + } + + uint32_t consumed = 0; unsigned char firstSet[4]; if (state->charsOnStack == 1) { firstSet[0] = state->c[0]; firstSet[1] = src[0]; - firstSet[2] = (countRemaining > 1) ? src[1] : '\0'; + firstSet[2] = src[1]; firstSet[3] = '\0'; + consumed = 2; } else /* state->charsOnStack == 2 */ { firstSet[0] = state->c[0]; firstSet[1] = state->c[1]; firstSet[2] = src[0]; firstSet[3] = '\0'; + consumed = 1; } + Encode(firstSet, 3, state->buffer); state->buffer += 4; - countRemaining -= (3 - state->charsOnStack); - src += (3 - state->charsOnStack); + countRemaining -= consumed; + src += consumed; state->charsOnStack = 0; + + // Bail if there is nothing left. + if (!countRemaining) { + return NS_OK; + } } - // Encode the bulk of the + // Encode as many full triplets as possible. uint32_t encodeLength = countRemaining - countRemaining % 3; MOZ_ASSERT(encodeLength % 3 == 0, "Should have an exact number of triplets!"); @@ -140,9 +161,6 @@ EncodeInputStream_Encoder(nsIInputStream* aStream, src += encodeLength; countRemaining -= encodeLength; - // We must consume all data, so if there's some data left stash it - *aWriteCount = aCount; - if (countRemaining) { // We should never have a full triplet left at this point. MOZ_ASSERT(countRemaining < 3, "We should have encoded more!"); |