Atmosphere/libraries/libstratosphere/source/fssystem/fssystem_integrity_verification_storage.cpp
2022-04-29 16:14:01 -07:00

501 lines
22 KiB
C++

/*
* Copyright (c) Atmosphère-NX
*
* This program is free software; you can redistribute it and/or modify it
* under the terms and conditions of the GNU General Public License,
* version 2, as published by the Free Software Foundation.
*
* This program is distributed in the hope it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
* more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#include <stratosphere.hpp>
namespace ams::fssystem {
void IntegrityVerificationStorage::Initialize(fs::SubStorage hs, fs::SubStorage ds, s64 verif_block_size, s64 upper_layer_verif_block_size, fs::IBufferManager *bm, fssystem::IHash256GeneratorFactory *hgf, const util::optional<fs::HashSalt> &salt, bool is_real_data, bool is_writable, bool allow_cleared_blocks) {
/* Validate preconditions. */
AMS_ASSERT(verif_block_size >= HashSize);
AMS_ASSERT(bm != nullptr);
AMS_ASSERT(hgf != nullptr);
/* Set storages. */
m_hash_storage = hs;
m_data_storage = ds;
/* Set hash generator factory. */
m_hash_generator_factory = hgf;
/* Set verification block sizes. */
m_verification_block_size = verif_block_size;
m_verification_block_order = ILog2(static_cast<u32>(verif_block_size));
AMS_ASSERT(m_verification_block_size == (1l << m_verification_block_order));
/* Set buffer manager. */
m_buffer_manager = bm;
/* Set upper layer block sizes. */
upper_layer_verif_block_size = std::max(upper_layer_verif_block_size, HashSize);
m_upper_layer_verification_block_size = upper_layer_verif_block_size;
m_upper_layer_verification_block_order = ILog2(static_cast<u32>(upper_layer_verif_block_size));
AMS_ASSERT(m_upper_layer_verification_block_size == (1l << m_upper_layer_verification_block_order));
/* Validate sizes. */
{
s64 hash_size = 0;
s64 data_size = 0;
AMS_ASSERT(R_SUCCEEDED(m_hash_storage.GetSize(std::addressof(hash_size))));
AMS_ASSERT(R_SUCCEEDED(m_data_storage.GetSize(std::addressof(data_size))));
AMS_ASSERT(((hash_size / HashSize) * m_verification_block_size) >= data_size);
AMS_UNUSED(hash_size, data_size);
}
/* Set salt. */
m_salt = salt;
/* Set data, writable, and allow cleared. */
m_is_real_data = is_real_data;
m_is_writable = is_writable;
m_allow_cleared_blocks = allow_cleared_blocks;
}
void IntegrityVerificationStorage::Finalize() {
if (m_buffer_manager != nullptr) {
m_hash_storage = fs::SubStorage();
m_data_storage = fs::SubStorage();
m_buffer_manager = nullptr;
}
}
Result IntegrityVerificationStorage::Read(s64 offset, void *buffer, size_t size) {
/* Although we support zero-size reads, we expect non-zero sizes. */
AMS_ASSERT(size != 0);
/* Validate other preconditions. */
AMS_ASSERT(util::IsAligned(offset, static_cast<size_t>(m_verification_block_size)));
AMS_ASSERT(util::IsAligned(size, static_cast<size_t>(m_verification_block_size)));
/* Succeed if zero size. */
R_SUCCEED_IF(size == 0);
/* Validate arguments. */
R_UNLESS(buffer != nullptr, fs::ResultNullptrArgument());
/* Validate the offset. */
s64 data_size;
R_TRY(m_data_storage.GetSize(std::addressof(data_size)));
R_UNLESS(offset <= data_size, fs::ResultInvalidOffset());
/* Validate the access range. */
R_TRY(IStorage::CheckAccessRange(offset, size, util::AlignUp(data_size, static_cast<size_t>(m_verification_block_size))));
/* Determine the read extents. */
size_t read_size = size;
if (static_cast<s64>(offset + read_size) > data_size) {
/* Determine the padding sizes. */
s64 padding_offset = data_size - offset;
size_t padding_size = static_cast<size_t>(m_verification_block_size - (padding_offset & (m_verification_block_size - 1)));
AMS_ASSERT(static_cast<s64>(padding_size) < m_verification_block_size);
/* Clear the padding. */
std::memset(static_cast<u8 *>(buffer) + padding_offset, 0, padding_size);
/* Set the new in-bounds size. */
read_size = static_cast<size_t>(data_size - offset);
}
/* Perform the read. */
{
auto clear_guard = SCOPE_GUARD { std::memset(buffer, 0, size); };
R_TRY(m_data_storage.Read(offset, buffer, read_size));
clear_guard.Cancel();
}
/* Verify the signatures. */
Result verify_hash_result = ResultSuccess();
/* Create hash generator. */
std::unique_ptr<IHash256Generator> generator = nullptr;
R_TRY(m_hash_generator_factory->Create(std::addressof(generator)));
/* Prepare to validate the signatures. */
const auto signature_count = size >> m_verification_block_order;
PooledBuffer signature_buffer(signature_count * sizeof(BlockHash), sizeof(BlockHash));
const auto buffer_count = std::min(signature_count, signature_buffer.GetSize() / sizeof(BlockHash));
size_t verified_count = 0;
while (verified_count < signature_count) {
/* Read the current signatures. */
const auto cur_count = std::min(buffer_count, signature_count - verified_count);
auto cur_result = this->ReadBlockSignature(signature_buffer.GetBuffer(), signature_buffer.GetSize(), offset + (verified_count << m_verification_block_order), cur_count << m_verification_block_order);
/* Temporarily increase our priority. */
ScopedThreadPriorityChanger cp(+1, ScopedThreadPriorityChanger::Mode::Relative);
/* Loop over each signature we read. */
for (size_t i = 0; i < cur_count && R_SUCCEEDED(cur_result); ++i) {
const auto verified_size = (verified_count + i) << m_verification_block_order;
u8 *cur_buf = static_cast<u8 *>(buffer) + verified_size;
cur_result = this->VerifyHash(cur_buf, reinterpret_cast<BlockHash *>(signature_buffer.GetBuffer()) + i, generator);
/* If the data is corrupted, clear the corrupted parts. */
if (fs::ResultIntegrityVerificationStorageCorrupted::Includes(cur_result)) {
std::memset(cur_buf, 0, m_verification_block_size);
/* Set the result if we should. */
if (!fs::ResultClearedRealDataVerificationFailed::Includes(cur_result) && !m_allow_cleared_blocks) {
verify_hash_result = cur_result;
}
cur_result = ResultSuccess();
}
}
/* If we failed, clear and return. */
if (R_FAILED(cur_result)) {
std::memset(buffer, 0, size);
R_THROW(cur_result);
}
/* Advance. */
verified_count += cur_count;
}
R_RETURN(verify_hash_result);
}
Result IntegrityVerificationStorage::Write(s64 offset, const void *buffer, size_t size) {
/* Succeed if zero size. */
R_SUCCEED_IF(size == 0);
/* Validate arguments. */
R_UNLESS(buffer != nullptr, fs::ResultNullptrArgument());
/* Check the offset/size. */
R_TRY(IStorage::CheckOffsetAndSize(offset, size));
/* Validate the offset. */
s64 data_size;
R_TRY(m_data_storage.GetSize(std::addressof(data_size)));
R_UNLESS(offset < data_size, fs::ResultInvalidOffset());
/* Validate the access range. */
R_TRY(IStorage::CheckAccessRange(offset, size, util::AlignUp(data_size, static_cast<size_t>(m_verification_block_size))));
/* Validate preconditions. */
AMS_ASSERT(util::IsAligned(offset, m_verification_block_size));
AMS_ASSERT(util::IsAligned(size, m_verification_block_size));
AMS_ASSERT(offset <= data_size);
AMS_ASSERT(static_cast<s64>(offset + size) < data_size + m_verification_block_size);
/* Validate that if writing past the end, all extra data is zero padding. */
if (static_cast<s64>(offset + size) > data_size) {
const u8 *padding_cur = static_cast<const u8 *>(buffer) + data_size - offset;
const u8 *padding_end = padding_cur + (offset + size - data_size);
while (padding_cur < padding_end) {
AMS_ASSERT((*padding_cur) == 0);
++padding_cur;
}
}
/* Determine the unpadded size to write. */
auto write_size = size;
if (static_cast<s64>(offset + write_size) > data_size) {
write_size = static_cast<size_t>(data_size - offset);
R_SUCCEED_IF(write_size == 0);
}
/* Determine the size we're writing in blocks. */
const auto aligned_write_size = util::AlignUp(write_size, m_verification_block_size);
/* Write the updated block signatures. */
Result update_result = ResultSuccess();
size_t updated_count = 0;
{
const auto signature_count = aligned_write_size >> m_verification_block_order;
PooledBuffer signature_buffer(signature_count * sizeof(BlockHash), sizeof(BlockHash));
const auto buffer_count = std::min(signature_count, signature_buffer.GetSize() / sizeof(BlockHash));
/* Create hash generator. */
std::unique_ptr<IHash256Generator> generator = nullptr;
R_TRY(m_hash_generator_factory->Create(std::addressof(generator)));
while (updated_count < signature_count) {
const auto cur_count = std::min(buffer_count, signature_count - updated_count);
/* Calculate the hash with temporarily increased priority. */
{
ScopedThreadPriorityChanger cp(+1, ScopedThreadPriorityChanger::Mode::Relative);
for (size_t i = 0; i < cur_count; ++i) {
const auto updated_size = (updated_count + i) << m_verification_block_order;
this->CalcBlockHash(reinterpret_cast<BlockHash *>(signature_buffer.GetBuffer()) + i, reinterpret_cast<const u8 *>(buffer) + updated_size, generator);
}
}
/* Write the new block signatures. */
if (R_FAILED((update_result = this->WriteBlockSignature(signature_buffer.GetBuffer(), signature_buffer.GetSize(), offset + (updated_count << m_verification_block_order), cur_count << m_verification_block_order)))) {
break;
}
/* Advance. */
updated_count += cur_count;
}
}
/* Write the data. */
R_TRY(m_data_storage.Write(offset, buffer, std::min(write_size, updated_count << m_verification_block_order)));
R_RETURN(update_result);
}
Result IntegrityVerificationStorage::GetSize(s64 *out) {
R_RETURN(m_data_storage.GetSize(out));
}
Result IntegrityVerificationStorage::Flush() {
/* Flush both storages. */
R_TRY(m_hash_storage.Flush());
R_TRY(m_data_storage.Flush());
R_SUCCEED();
}
Result IntegrityVerificationStorage::OperateRange(void *dst, size_t dst_size, fs::OperationId op_id, s64 offset, s64 size, const void *src, size_t src_size) {
/* Validate preconditions. */
if (op_id != fs::OperationId::Invalidate) {
AMS_ASSERT(util::IsAligned(offset, static_cast<size_t>(m_verification_block_size)));
AMS_ASSERT(util::IsAligned(size, static_cast<size_t>(m_verification_block_size)));
}
switch (op_id) {
case fs::OperationId::FillZero:
{
/* FillZero should only be called for writable storages. */
AMS_ASSERT(m_is_writable);
/* Validate the range. */
s64 data_size = 0;
R_TRY(m_data_storage.GetSize(std::addressof(data_size)));
R_UNLESS(0 <= offset && offset <= data_size, fs::ResultInvalidOffset());
/* Determine the extents to clear. */
const auto sign_offset = (offset >> m_verification_block_order) * HashSize;
const auto sign_size = (std::min(size, data_size - offset) >> m_verification_block_order) * HashSize;
/* Allocate a work buffer. */
const auto buf_size = static_cast<size_t>(std::min(sign_size, static_cast<s64>(1) << (m_upper_layer_verification_block_order + 2)));
std::unique_ptr<char[], fs::impl::Deleter> buf = fs::impl::MakeUnique<char[]>(buf_size);
R_UNLESS(buf != nullptr, fs::ResultAllocationMemoryFailedInIntegrityVerificationStorageA());
/* Clear the work buffer. */
std::memset(buf.get(), 0, buf_size);
/* Clear in chunks. */
auto remaining_size = sign_size;
while (remaining_size > 0) {
const auto cur_size = static_cast<size_t>(std::min(remaining_size, static_cast<s64>(buf_size)));
R_TRY(m_hash_storage.Write(sign_offset + sign_size - remaining_size, buf.get(), cur_size));
remaining_size -= cur_size;
}
R_SUCCEED();
}
case fs::OperationId::DestroySignature:
{
/* DestroySignature should only be called for save data. */
AMS_ASSERT(m_is_writable);
/* Validate the range. */
s64 data_size = 0;
R_TRY(m_data_storage.GetSize(std::addressof(data_size)));
R_UNLESS(0 <= offset && offset <= data_size, fs::ResultInvalidOffset());
/* Determine the extents to clear the signature for. */
const auto sign_offset = (offset >> m_verification_block_order) * HashSize;
const auto sign_size = (std::min(size, data_size - offset) >> m_verification_block_order) * HashSize;
/* Allocate a work buffer. */
std::unique_ptr<char[], fs::impl::Deleter> buf = fs::impl::MakeUnique<char[]>(sign_size);
R_UNLESS(buf != nullptr, fs::ResultAllocationMemoryFailedInIntegrityVerificationStorageB());
/* Read the existing signature. */
R_TRY(m_hash_storage.Read(sign_offset, buf.get(), sign_size));
/* Clear the signature. */
/* This flips all bits other than the verification bit. */
for (auto i = 0; i < sign_size; ++i) {
buf[i] ^= ((i + 1) % HashSize == 0 ? 0x7F : 0xFF);
}
/* Write the cleared signature. */
R_RETURN(m_hash_storage.Write(sign_offset, buf.get(), sign_size));
}
case fs::OperationId::Invalidate:
{
/* Only allow cache invalidation read-only storages. */
R_UNLESS(!m_is_writable, fs::ResultUnsupportedOperateRangeForWritableIntegrityVerificationStorage());
/* Operate on our storages. */
R_TRY(m_hash_storage.OperateRange(op_id, 0, std::numeric_limits<s64>::max()));
R_TRY(m_data_storage.OperateRange(op_id, offset, size));
R_SUCCEED();
}
case fs::OperationId::QueryRange:
{
/* Validate the range. */
s64 data_size = 0;
R_TRY(m_data_storage.GetSize(std::addressof(data_size)));
R_UNLESS(0 <= offset && offset <= data_size, fs::ResultInvalidOffset());
/* Determine the real size to query. */
const auto actual_size = std::min(size, data_size - offset);
/* Query the data storage. */
R_RETURN(m_data_storage.OperateRange(dst, dst_size, op_id, offset, actual_size, src, src_size));
}
default:
R_THROW(fs::ResultUnsupportedOperateRangeForIntegrityVerificationStorage());
}
}
void IntegrityVerificationStorage::CalcBlockHash(BlockHash *out, const void *buffer, size_t block_size, std::unique_ptr<fssystem::IHash256Generator> &generator) const {
/* Hash procedure depends on whether or not we're writable. */
if (m_is_writable) {
/* Compute the hash with or without the hash salt, if we have one. */
if (m_salt.has_value()) {
/* Initialize the generator. */
generator->Initialize();
/* Hash the salt. */
generator->Update(m_salt->value, sizeof(m_salt->value));
/* Update with the buffer and get the hash. */
generator->Update(buffer, block_size);
generator->GetHash(out, sizeof(*out));
} else {
/* If we have no hash salt, just calculate the hash. */
m_hash_generator_factory->GenerateHash(out, sizeof(*out), buffer, block_size);
}
/* Set the validation bit. */
SetValidationBit(out);
} else {
/* If we're not writable, just calculate the hash. */
m_hash_generator_factory->GenerateHash(out, sizeof(*out), buffer, block_size);
}
}
Result IntegrityVerificationStorage::ReadBlockSignature(void *dst, size_t dst_size, s64 offset, size_t size) {
/* Validate preconditions. */
AMS_ASSERT(dst != nullptr);
AMS_ASSERT(util::IsAligned(offset, static_cast<size_t>(m_verification_block_size)));
AMS_ASSERT(util::IsAligned(size, static_cast<size_t>(m_verification_block_size)));
/* Determine where to read the signature. */
const s64 sign_offset = (offset >> m_verification_block_order) * HashSize;
const auto sign_size = static_cast<size_t>((size >> m_verification_block_order) * HashSize);
AMS_ASSERT(dst_size >= sign_size);
AMS_UNUSED(dst_size);
/* Create a guard in the event of failure. */
auto clear_guard = SCOPE_GUARD { std::memset(dst, 0, sign_size); };
/* Validate that we can read the signature. */
s64 hash_size;
R_TRY(m_hash_storage.GetSize(std::addressof(hash_size)));
const bool range_valid = static_cast<s64>(sign_offset + sign_size) <= hash_size;
AMS_ASSERT(range_valid);
R_UNLESS(range_valid, fs::ResultOutOfRange());
/* Read the signature. */
R_TRY(m_hash_storage.Read(sign_offset, dst, sign_size));
/* We succeeded. */
clear_guard.Cancel();
R_SUCCEED();
}
Result IntegrityVerificationStorage::WriteBlockSignature(const void *src, size_t src_size, s64 offset, size_t size) {
/* Validate preconditions. */
AMS_ASSERT(src != nullptr);
AMS_ASSERT(util::IsAligned(offset, static_cast<size_t>(m_verification_block_size)));
/* Determine where to write the signature. */
const s64 sign_offset = (offset >> m_verification_block_order) * HashSize;
const auto sign_size = static_cast<size_t>((size >> m_verification_block_order) * HashSize);
AMS_ASSERT(src_size >= sign_size);
AMS_UNUSED(src_size);
/* Write the signature. */
R_TRY(m_hash_storage.Write(sign_offset, src, sign_size));
/* We succeeded. */
R_SUCCEED();
}
Result IntegrityVerificationStorage::VerifyHash(const void *buf, BlockHash *hash, std::unique_ptr<fssystem::IHash256Generator> &generator) {
/* Validate preconditions. */
AMS_ASSERT(buf != nullptr);
AMS_ASSERT(hash != nullptr);
/* Get the comparison hash. */
auto &cmp_hash = *hash;
/* If writable, check if the data is uninitialized. */
if (m_is_writable) {
bool is_cleared = false;
R_TRY(this->IsCleared(std::addressof(is_cleared), cmp_hash));
R_UNLESS(!is_cleared, fs::ResultClearedRealDataVerificationFailed());
}
/* Get the calculated hash. */
BlockHash calc_hash;
this->CalcBlockHash(std::addressof(calc_hash), buf, generator);
/* Check that the signatures are equal. */
if (!crypto::IsSameBytes(std::addressof(cmp_hash), std::addressof(calc_hash), sizeof(BlockHash))) {
/* Clear the comparison hash. */
std::memset(std::addressof(cmp_hash), 0, sizeof(cmp_hash));
/* Return the appropriate result. */
if (m_is_real_data) {
R_THROW(fs::ResultUnclearedRealDataVerificationFailed());
} else {
R_THROW(fs::ResultNonRealDataVerificationFailed());
}
}
R_SUCCEED();
}
Result IntegrityVerificationStorage::IsCleared(bool *is_cleared, const BlockHash &hash) {
/* Validate preconditions. */
AMS_ASSERT(is_cleared != nullptr);
AMS_ASSERT(m_is_writable);
/* Default to uncleared. */
*is_cleared = false;
/* Succeed if the validation bit is set. */
R_SUCCEED_IF(IsValidationBit(std::addressof(hash)));
/* Otherwise, we expect the hash to be all zero. */
for (size_t i = 0; i < sizeof(hash.hash); ++i) {
R_UNLESS(hash.hash[i] == 0, fs::ResultInvalidZeroHash());
}
/* Set cleared. */
*is_cleared = true;
R_SUCCEED();
}
}