/*
* Copyright (c) Atmosphère-NX
*
* This program is free software; you can redistribute it and/or modify it
* under the terms and conditions of the GNU General Public License,
* version 2, as published by the Free Software Foundation.
*
* This program is distributed in the hope it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
* more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
#pragma once
#include
#include
namespace ams::fssystem {
/* ACCURATE_TO_VERSION: Unknown */
class BucketTree {
NON_COPYABLE(BucketTree);
NON_MOVEABLE(BucketTree);
public:
static constexpr u32 Magic = util::FourCC<'B','K','T','R'>::Code;
static constexpr u32 Version = 1;
static constexpr size_t NodeSizeMin = 1_KB;
static constexpr size_t NodeSizeMax = 512_KB;
public:
class Visitor;
struct Header {
u32 magic;
u32 version;
s32 entry_count;
s32 reserved;
void Format(s32 entry_count);
Result Verify() const;
};
static_assert(util::is_pod::value);
static_assert(sizeof(Header) == 0x10);
struct NodeHeader {
s32 index;
s32 count;
s64 offset;
Result Verify(s32 node_index, size_t node_size, size_t entry_size) const;
};
static_assert(util::is_pod::value);
static_assert(sizeof(NodeHeader) == 0x10);
struct Offsets {
s64 start_offset;
s64 end_offset;
constexpr bool IsInclude(s64 offset) const {
return this->start_offset <= offset && offset < this->end_offset;
}
constexpr bool IsInclude(s64 offset, s64 size) const {
return size > 0 && this->start_offset <= offset && size <= (this->end_offset - offset);
}
};
static_assert(util::is_pod::value);
static_assert(sizeof(Offsets) == 0x10);
struct OffsetCache {
Offsets offsets;
os::SdkMutex mutex;
bool is_initialized;
constexpr OffsetCache() : offsets{ -1, -1 }, mutex(), is_initialized(false) { /* ... */ }
};
class ContinuousReadingInfo {
private:
size_t m_read_size;
s32 m_skip_count;
bool m_done;
public:
constexpr ContinuousReadingInfo() : m_read_size(), m_skip_count(), m_done() { /* ... */ }
constexpr void Reset() { m_read_size = 0; m_skip_count = 0; m_done = false; }
constexpr void SetSkipCount(s32 count) { AMS_ASSERT(count >= 0); m_skip_count = count; }
constexpr s32 GetSkipCount() const { return m_skip_count; }
constexpr bool CheckNeedScan() { return (--m_skip_count) <= 0; }
constexpr void Done() { m_read_size = 0; m_done = true; }
constexpr bool IsDone() const { return m_done; }
constexpr void SetReadSize(size_t size) { m_read_size = size; }
constexpr size_t GetReadSize() const { return m_read_size; }
constexpr bool CanDo() const { return m_read_size > 0; }
};
using IAllocator = MemoryResource;
private:
class NodeBuffer {
NON_COPYABLE(NodeBuffer);
private:
IAllocator *m_allocator;
void *m_header;
public:
NodeBuffer() : m_allocator(), m_header() { /* ... */ }
~NodeBuffer() {
AMS_ASSERT(m_header == nullptr);
}
NodeBuffer(NodeBuffer &&rhs) : m_allocator(rhs.m_allocator), m_header(rhs.m_allocator) {
rhs.m_allocator = nullptr;
rhs.m_header = nullptr;
}
NodeBuffer &operator=(NodeBuffer &&rhs) {
if (this != std::addressof(rhs)) {
AMS_ASSERT(m_header == nullptr);
m_allocator = rhs.m_allocator;
m_header = rhs.m_header;
rhs.m_allocator = nullptr;
rhs.m_header = nullptr;
}
return *this;
}
bool Allocate(IAllocator *allocator, size_t node_size) {
AMS_ASSERT(m_header == nullptr);
m_allocator = allocator;
m_header = allocator->Allocate(node_size, sizeof(s64));
AMS_ASSERT(util::IsAligned(m_header, sizeof(s64)));
return m_header != nullptr;
}
void Free(size_t node_size) {
if (m_header) {
m_allocator->Deallocate(m_header, node_size);
m_header = nullptr;
}
m_allocator = nullptr;
}
void FillZero(size_t node_size) const {
if (m_header) {
std::memset(m_header, 0, node_size);
}
}
NodeHeader *Get() const {
return reinterpret_cast(m_header);
}
NodeHeader *operator->() const { return this->Get(); }
template
T *Get() const {
static_assert(util::is_pod::value);
static_assert(sizeof(T) == sizeof(NodeHeader));
return reinterpret_cast(m_header);
}
IAllocator *GetAllocator() const {
return m_allocator;
}
};
private:
static constexpr s32 GetEntryCount(size_t node_size, size_t entry_size) {
return static_cast((node_size - sizeof(NodeHeader)) / entry_size);
}
static constexpr s32 GetOffsetCount(size_t node_size) {
return static_cast((node_size - sizeof(NodeHeader)) / sizeof(s64));
}
static constexpr s32 GetEntrySetCount(size_t node_size, size_t entry_size, s32 entry_count) {
const s32 entry_count_per_node = GetEntryCount(node_size, entry_size);
return util::DivideUp(entry_count, entry_count_per_node);
}
static constexpr s32 GetNodeL2Count(size_t node_size, size_t entry_size, s32 entry_count) {
const s32 offset_count_per_node = GetOffsetCount(node_size);
const s32 entry_set_count = GetEntrySetCount(node_size, entry_size, entry_count);
if (entry_set_count <= offset_count_per_node) {
return 0;
}
const s32 node_l2_count = util::DivideUp(entry_set_count, offset_count_per_node);
AMS_ABORT_UNLESS(node_l2_count <= offset_count_per_node);
return util::DivideUp(entry_set_count - (offset_count_per_node - (node_l2_count - 1)), offset_count_per_node);
}
public:
static constexpr s64 QueryHeaderStorageSize() { return sizeof(Header); }
static constexpr s64 QueryNodeStorageSize(size_t node_size, size_t entry_size, s32 entry_count) {
AMS_ASSERT(entry_size >= sizeof(s64));
AMS_ASSERT(node_size >= entry_size + sizeof(NodeHeader));
AMS_ASSERT(NodeSizeMin <= node_size && node_size <= NodeSizeMax);
AMS_ASSERT(util::IsPowerOfTwo(node_size));
AMS_ASSERT(entry_count >= 0);
if (entry_count <= 0) {
return 0;
}
return (1 + GetNodeL2Count(node_size, entry_size, entry_count)) * static_cast(node_size);
}
static constexpr s64 QueryEntryStorageSize(size_t node_size, size_t entry_size, s32 entry_count) {
AMS_ASSERT(entry_size >= sizeof(s64));
AMS_ASSERT(node_size >= entry_size + sizeof(NodeHeader));
AMS_ASSERT(NodeSizeMin <= node_size && node_size <= NodeSizeMax);
AMS_ASSERT(util::IsPowerOfTwo(node_size));
AMS_ASSERT(entry_count >= 0);
if (entry_count <= 0) {
return 0;
}
return GetEntrySetCount(node_size, entry_size, entry_count) * static_cast(node_size);
}
private:
mutable fs::SubStorage m_node_storage;
mutable fs::SubStorage m_entry_storage;
NodeBuffer m_node_l1;
size_t m_node_size;
size_t m_entry_size;
s32 m_entry_count;
s32 m_offset_count;
s32 m_entry_set_count;
OffsetCache m_offset_cache;
public:
BucketTree() : m_node_storage(), m_entry_storage(), m_node_l1(), m_node_size(), m_entry_size(), m_entry_count(), m_offset_count(), m_entry_set_count(), m_offset_cache() { /* ... */ }
~BucketTree() { this->Finalize(); }
Result Initialize(IAllocator *allocator, fs::SubStorage node_storage, fs::SubStorage entry_storage, size_t node_size, size_t entry_size, s32 entry_count);
void Initialize(size_t node_size, s64 end_offset);
void Finalize();
bool IsInitialized() const { return m_node_size > 0; }
bool IsEmpty() const { return m_entry_size == 0; }
Result Find(Visitor *visitor, s64 virtual_address);
Result InvalidateCache();
s32 GetEntryCount() const { return m_entry_count; }
IAllocator *GetAllocator() const { return m_node_l1.GetAllocator(); }
Result GetOffsets(Offsets *out) {
/* Ensure we have an offset cache. */
R_TRY(this->EnsureOffsetCache());
/* Set the output. */
*out = m_offset_cache.offsets;
R_SUCCEED();
}
private:
template
struct ContinuousReadingParam {
s64 offset;
size_t size;
NodeHeader entry_set;
s32 entry_index;
Offsets offsets;
EntryType entry;
};
private:
template
Result ScanContinuousReading(ContinuousReadingInfo *out_info, const ContinuousReadingParam ¶m) const;
bool IsExistL2() const { return m_offset_count < m_entry_set_count; }
bool IsExistOffsetL2OnL1() const { return this->IsExistL2() && m_node_l1->count < m_offset_count; }
s64 GetEntrySetIndex(s32 node_index, s32 offset_index) const {
return (m_offset_count - m_node_l1->count) + (m_offset_count * node_index) + offset_index;
}
Result EnsureOffsetCache();
};
/* ACCURATE_TO_VERSION: Unknown */
class BucketTree::Visitor {
NON_COPYABLE(Visitor);
NON_MOVEABLE(Visitor);
private:
friend class BucketTree;
union EntrySetHeader {
NodeHeader header;
struct Info {
s32 index;
s32 count;
s64 end;
s64 start;
} info;
static_assert(util::is_pod::value);
};
static_assert(util::is_pod::value);
private:
const BucketTree *m_tree;
BucketTree::Offsets m_offsets;
void *m_entry;
s32 m_entry_index;
s32 m_entry_set_count;
EntrySetHeader m_entry_set;
public:
constexpr Visitor() : m_tree(), m_entry(), m_entry_index(-1), m_entry_set_count(), m_entry_set{} { /* ... */ }
~Visitor() {
if (m_entry != nullptr) {
m_tree->GetAllocator()->Deallocate(m_entry, m_tree->m_entry_size);
m_tree = nullptr;
m_entry = nullptr;
}
}
bool IsValid() const { return m_entry_index >= 0; }
bool CanMoveNext() const { return this->IsValid() && (m_entry_index + 1 < m_entry_set.info.count || m_entry_set.info.index + 1 < m_entry_set_count); }
bool CanMovePrevious() const { return this->IsValid() && (m_entry_index > 0 || m_entry_set.info.index > 0); }
Result MoveNext();
Result MovePrevious();
template
Result ScanContinuousReading(ContinuousReadingInfo *out_info, s64 offset, size_t size) const;
const void *Get() const { AMS_ASSERT(this->IsValid()); return m_entry; }
template
const T *Get() const { AMS_ASSERT(this->IsValid()); return reinterpret_cast(m_entry); }
const BucketTree *GetTree() const { return m_tree; }
private:
Result Initialize(const BucketTree *tree, const BucketTree::Offsets &offsets);
Result Find(s64 virtual_address);
Result FindEntrySet(s32 *out_index, s64 virtual_address, s32 node_index);
Result FindEntrySetWithBuffer(s32 *out_index, s64 virtual_address, s32 node_index, char *buffer);
Result FindEntrySetWithoutBuffer(s32 *out_index, s64 virtual_address, s32 node_index);
Result FindEntry(s64 virtual_address, s32 entry_set_index);
Result FindEntryWithBuffer(s64 virtual_address, s32 entry_set_index, char *buffer);
Result FindEntryWithoutBuffer(s64 virtual_address, s32 entry_set_index);
};
}