/* * Copyright (c) 2018-2020 Atmosphère-NX * * This program is free software; you can redistribute it and/or modify it * under the terms and conditions of the GNU General Public License, * version 2, as published by the Free Software Foundation. * * This program is distributed in the hope it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for * more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see . */ #include #include "impl/mem_impl_platform.hpp" #include "impl/heap/mem_impl_heap_tls_heap_static.hpp" #include "impl/heap/mem_impl_heap_tls_heap_cache.hpp" #include "impl/heap/mem_impl_heap_tls_heap_central.hpp" namespace ams::mem { constexpr inline size_t DefaultAlignment = alignof(std::max_align_t); constexpr inline size_t MinimumAllocatorSize = 16_KB; namespace { void ThreadDestroy(uintptr_t arg) { if (arg) { reinterpret_cast(arg)->Finalize(); } } ALWAYS_INLINE impl::heap::CentralHeap *GetCentral(const impl::InternalCentralHeapStorage *storage) { return reinterpret_cast(const_cast(storage)); } ALWAYS_INLINE impl::heap::CentralHeap *GetCentral(const impl::InternalCentralHeapStorage &storage) { return GetCentral(std::addressof(storage)); } ALWAYS_INLINE void GetCache(impl::heap::CentralHeap *central, os::TlsSlot slot) { impl::heap::CachedHeap tmp_cache; if (central->MakeCache(std::addressof(tmp_cache))) { impl::heap::TlsHeapCache *cache = tmp_cache.Release(); os::SetTlsValue(slot, reinterpret_cast(cache)); } } struct InternalHash { size_t allocated_count; size_t allocated_size; crypto::Sha1Generator sha1; }; int InternalHashCallback(void *ptr, size_t size, void *user_data) { InternalHash *hash = reinterpret_cast(user_data); hash->sha1.Update(reinterpret_cast(std::addressof(ptr)), sizeof(ptr)); hash->sha1.Update(reinterpret_cast(std::addressof(size)), sizeof(size)); hash->allocated_count++; hash->allocated_size += size; return 1; } } StandardAllocator::StandardAllocator() : initialized(false), enable_thread_cache(false), unused(0) { static_assert(sizeof(impl::heap::CentralHeap) <= sizeof(this->central_heap_storage)); std::construct_at(GetCentral(this->central_heap_storage)); } StandardAllocator::StandardAllocator(void *mem, size_t size) : StandardAllocator() { this->Initialize(mem, size); } StandardAllocator::StandardAllocator(void *mem, size_t size, bool enable_cache) : StandardAllocator() { this->Initialize(mem, size, enable_cache); } void StandardAllocator::Initialize(void *mem, size_t size) { this->Initialize(mem, size, false); } void StandardAllocator::Initialize(void *mem, size_t size, bool enable_cache) { AMS_ABORT_UNLESS(!this->initialized); const uintptr_t aligned_start = util::AlignUp(reinterpret_cast(mem), impl::heap::TlsHeapStatic::PageSize); const uintptr_t aligned_end = util::AlignDown(reinterpret_cast(mem) + size, impl::heap::TlsHeapStatic::PageSize); const size_t aligned_size = aligned_end - aligned_start; if (mem == nullptr) { AMS_ABORT_UNLESS(os::IsVirtualAddressMemoryEnabled()); AMS_ABORT_UNLESS(GetCentral(this->central_heap_storage)->Initialize(nullptr, size, 0) == 0); } else { AMS_ABORT_UNLESS(aligned_start < aligned_end); AMS_ABORT_UNLESS(aligned_size >= MinimumAllocatorSize); AMS_ABORT_UNLESS(GetCentral(this->central_heap_storage)->Initialize(reinterpret_cast(aligned_start), aligned_size, 0) == 0); } this->enable_thread_cache = enable_cache; if (this->enable_thread_cache) { R_ABORT_UNLESS(os::AllocateTlsSlot(std::addressof(this->tls_slot), ThreadDestroy)); } this->initialized = true; } void StandardAllocator::Finalize() { AMS_ABORT_UNLESS(this->initialized); if (this->enable_thread_cache) { os::FreeTlsSlot(this->tls_slot); } GetCentral(this->central_heap_storage)->Finalize(); this->initialized = false; } void *StandardAllocator::Allocate(size_t size) { AMS_ASSERT(this->initialized); return this->Allocate(size, DefaultAlignment); } void *StandardAllocator::Allocate(size_t size, size_t alignment) { AMS_ASSERT(this->initialized); impl::heap::TlsHeapCache *heap_cache = nullptr; if (this->enable_thread_cache) { heap_cache = reinterpret_cast(os::GetTlsValue(this->tls_slot)); if (!heap_cache) { GetCache(GetCentral(this->central_heap_storage), this->tls_slot); heap_cache = reinterpret_cast(os::GetTlsValue(this->tls_slot)); } } void *ptr = nullptr; if (heap_cache) { ptr = heap_cache->Allocate(size, alignment); if (ptr) { return ptr; } impl::heap::CachedHeap cache; cache.Reset(heap_cache); cache.Query(impl::AllocQuery_FinalizeCache); os::SetTlsValue(this->tls_slot, 0); } return GetCentral(this->central_heap_storage)->Allocate(size, alignment); } void StandardAllocator::Free(void *ptr) { AMS_ASSERT(this->initialized); if (ptr == nullptr) { return; } if (this->enable_thread_cache) { impl::heap::TlsHeapCache *heap_cache = reinterpret_cast(os::GetTlsValue(this->tls_slot)); if (heap_cache) { heap_cache->Free(ptr); return; } } const auto err = GetCentral(this->central_heap_storage)->Free(ptr); AMS_ASSERT(err == 0); AMS_UNUSED(err); } void *StandardAllocator::Reallocate(void *ptr, size_t new_size) { AMS_ASSERT(this->initialized); if (new_size > impl::MaxSize) { return nullptr; } if (ptr == nullptr) { return this->Allocate(new_size); } if (new_size == 0) { this->Free(ptr); return nullptr; } size_t aligned_new_size = util::AlignUp(new_size, DefaultAlignment); impl::heap::TlsHeapCache *heap_cache = nullptr; if (this->enable_thread_cache) { heap_cache = reinterpret_cast(os::GetTlsValue(this->tls_slot)); if (!heap_cache) { GetCache(GetCentral(this->central_heap_storage), this->tls_slot); heap_cache = reinterpret_cast(os::GetTlsValue(this->tls_slot)); } } void *p = nullptr; impl::errno_t err; if (heap_cache) { err = heap_cache->Reallocate(ptr, aligned_new_size, std::addressof(p)); } else { err = GetCentral(this->central_heap_storage)->Reallocate(ptr, aligned_new_size, std::addressof(p)); } if (err == 0) { return p; } else { return nullptr; } } size_t StandardAllocator::Shrink(void *ptr, size_t new_size) { AMS_ASSERT(this->initialized); if (this->enable_thread_cache) { impl::heap::TlsHeapCache *heap_cache = reinterpret_cast(os::GetTlsValue(this->tls_slot)); if (heap_cache) { if (heap_cache->Shrink(ptr, new_size) == 0) { return heap_cache->GetAllocationSize(ptr); } else { return 0; } } } if (GetCentral(this->central_heap_storage)->Shrink(ptr, new_size) == 0) { return GetCentral(this->central_heap_storage)->GetAllocationSize(ptr); } else { return 0; } } void StandardAllocator::ClearThreadCache() const { if (this->enable_thread_cache) { impl::heap::TlsHeapCache *heap_cache = reinterpret_cast(os::GetTlsValue(this->tls_slot)); impl::heap::CachedHeap cache; cache.Reset(heap_cache); cache.Query(impl::AllocQuery_ClearCache); cache.Release(); } } void StandardAllocator::CleanUpManagementArea() const { AMS_ASSERT(this->initialized); const auto err = GetCentral(this->central_heap_storage)->Query(impl::AllocQuery_UnifyFreeList); AMS_ASSERT(err == 0); AMS_UNUSED(err); } size_t StandardAllocator::GetSizeOf(const void *ptr) const { AMS_ASSERT(this->initialized); if (!util::IsAligned(reinterpret_cast(ptr), DefaultAlignment)) { return 0; } impl::heap::TlsHeapCache *heap_cache = nullptr; if (this->enable_thread_cache) { heap_cache = reinterpret_cast(os::GetTlsValue(this->tls_slot)); if (!heap_cache) { GetCache(GetCentral(this->central_heap_storage), this->tls_slot); heap_cache = reinterpret_cast(os::GetTlsValue(this->tls_slot)); } } if (heap_cache) { return heap_cache->GetAllocationSize(ptr); } else { return GetCentral(this->central_heap_storage)->GetAllocationSize(ptr); } } size_t StandardAllocator::GetTotalFreeSize() const { size_t size = 0; auto err = GetCentral(this->central_heap_storage)->Query(impl::AllocQuery_FreeSizeMapped, std::addressof(size)); if (err != 0) { err = GetCentral(this->central_heap_storage)->Query(impl::AllocQuery_FreeSize, std::addressof(size)); } AMS_ASSERT(err == 0); return size; } size_t StandardAllocator::GetAllocatableSize() const { size_t size = 0; auto err = GetCentral(this->central_heap_storage)->Query(impl::AllocQuery_MaxAllocatableSizeMapped, std::addressof(size)); if (err != 0) { err = GetCentral(this->central_heap_storage)->Query(impl::AllocQuery_MaxAllocatableSize, std::addressof(size)); } AMS_ASSERT(err == 0); return size; } void StandardAllocator::WalkAllocatedBlocks(WalkCallback callback, void *user_data) const { AMS_ASSERT(this->initialized); this->ClearThreadCache(); GetCentral(this->central_heap_storage)->WalkAllocatedPointers(callback, user_data); } void StandardAllocator::Dump() const { AMS_ASSERT(this->initialized); size_t tmp; auto err = GetCentral(this->central_heap_storage)->Query(impl::AllocQuery_MaxAllocatableSizeMapped, std::addressof(tmp)); if (err == 0) { GetCentral(this->central_heap_storage)->Query(impl::AllocQuery_Dump, impl::DumpMode_Spans | impl::DumpMode_Pointers, 1); } else { GetCentral(this->central_heap_storage)->Query(impl::AllocQuery_Dump, impl::DumpMode_All, 1); } } StandardAllocator::AllocatorHash StandardAllocator::Hash() const { AMS_ASSERT(this->initialized); AllocatorHash alloc_hash; { char temp_hash[crypto::Sha1Generator::HashSize]; InternalHash internal_hash; internal_hash.allocated_count = 0; internal_hash.allocated_size = 0; internal_hash.sha1.Initialize(); this->WalkAllocatedBlocks(InternalHashCallback, reinterpret_cast(std::addressof(internal_hash))); alloc_hash.allocated_count = internal_hash.allocated_count; alloc_hash.allocated_size = internal_hash.allocated_size; internal_hash.sha1.GetHash(temp_hash, sizeof(temp_hash)); std::memcpy(std::addressof(alloc_hash.hash), temp_hash, sizeof(alloc_hash.hash)); } return alloc_hash; } }