From a0ad3ef9498e64d5d0e1a72b9c6e739e05af3ebc Mon Sep 17 00:00:00 2001
From: Michael Scire <SciresM@gmail.com>
Date: Wed, 9 Oct 2024 11:36:17 -0700
Subject: [PATCH] kern/svc: update WaitForAddress to support 64-bit WaitIfEqual

---
 .../arm64/kern_userspace_memory_access.hpp    |  1 +
 .../mesosphere/kern_k_address_arbiter.hpp     | 11 +--
 .../include/mesosphere/kern_k_process.hpp     |  2 +-
 .../arm64/kern_userspace_memory_access_asm.s  | 14 ++++
 .../arm64/svc/kern_svc_address_arbiter_asm.s  | 67 +++++++++++++++++++
 .../source/arch/arm64/svc/kern_svc_tables.cpp |  8 +++
 .../source/kern_k_address_arbiter.cpp         | 51 ++++++++++++++
 .../source/svc/kern_svc_address_arbiter.cpp   | 17 +++--
 .../svc/svc_stratosphere_shims.hpp            |  2 +-
 .../vapours/svc/svc_definition_macro.hpp      |  2 +-
 .../include/vapours/svc/svc_types_common.hpp  |  1 +
 11 files changed, 163 insertions(+), 13 deletions(-)
 create mode 100644 libraries/libmesosphere/source/arch/arm64/svc/kern_svc_address_arbiter_asm.s

diff --git a/libraries/libmesosphere/include/mesosphere/arch/arm64/kern_userspace_memory_access.hpp b/libraries/libmesosphere/include/mesosphere/arch/arm64/kern_userspace_memory_access.hpp
index 202be5f51..9e0c1844e 100644
--- a/libraries/libmesosphere/include/mesosphere/arch/arm64/kern_userspace_memory_access.hpp
+++ b/libraries/libmesosphere/include/mesosphere/arch/arm64/kern_userspace_memory_access.hpp
@@ -25,6 +25,7 @@ namespace ams::kern::arch::arm64 {
             static bool CopyMemoryFromUser(void *dst, const void *src, size_t size);
             static bool CopyMemoryFromUserAligned32Bit(void *dst, const void *src, size_t size);
             static bool CopyMemoryFromUserAligned64Bit(void *dst, const void *src, size_t size);
+            static bool CopyMemoryFromUserSize64Bit(void *dst, const void *src);
             static bool CopyMemoryFromUserSize32Bit(void *dst, const void *src);
             static s32  CopyStringFromUser(void *dst, const void *src, size_t size);
 
diff --git a/libraries/libmesosphere/include/mesosphere/kern_k_address_arbiter.hpp b/libraries/libmesosphere/include/mesosphere/kern_k_address_arbiter.hpp
index 4066de80d..f4ef74e6b 100644
--- a/libraries/libmesosphere/include/mesosphere/kern_k_address_arbiter.hpp
+++ b/libraries/libmesosphere/include/mesosphere/kern_k_address_arbiter.hpp
@@ -39,14 +39,16 @@ namespace ams::kern {
                 }
             }
 
-            Result WaitForAddress(uintptr_t addr, ams::svc::ArbitrationType type, s32 value, s64 timeout) {
+            Result WaitForAddress(uintptr_t addr, ams::svc::ArbitrationType type, s64 value, s64 timeout) {
                 switch (type) {
                     case ams::svc::ArbitrationType_WaitIfLessThan:
-                        R_RETURN(this->WaitIfLessThan(addr, value, false, timeout));
+                        R_RETURN(this->WaitIfLessThan(addr, static_cast<s32>(value), false, timeout));
                     case ams::svc::ArbitrationType_DecrementAndWaitIfLessThan:
-                        R_RETURN(this->WaitIfLessThan(addr, value, true, timeout));
+                        R_RETURN(this->WaitIfLessThan(addr, static_cast<s32>(value), true, timeout));
                     case ams::svc::ArbitrationType_WaitIfEqual:
-                        R_RETURN(this->WaitIfEqual(addr, value, timeout));
+                        R_RETURN(this->WaitIfEqual(addr, static_cast<s32>(value), timeout));
+                    case ams::svc::ArbitrationType_WaitIfEqual64:
+                        R_RETURN(this->WaitIfEqual64(addr, value, timeout));
                     MESOSPHERE_UNREACHABLE_DEFAULT_CASE();
                 }
             }
@@ -56,6 +58,7 @@ namespace ams::kern {
             Result SignalAndModifyByWaitingCountIfEqual(uintptr_t addr, s32 value, s32 count);
             Result WaitIfLessThan(uintptr_t addr, s32 value, bool decrement, s64 timeout);
             Result WaitIfEqual(uintptr_t addr, s32 value, s64 timeout);
+            Result WaitIfEqual64(uintptr_t addr, s64 value, s64 timeout);
     };
 
 }
diff --git a/libraries/libmesosphere/include/mesosphere/kern_k_process.hpp b/libraries/libmesosphere/include/mesosphere/kern_k_process.hpp
index 3258d41e2..5a37e3a6b 100644
--- a/libraries/libmesosphere/include/mesosphere/kern_k_process.hpp
+++ b/libraries/libmesosphere/include/mesosphere/kern_k_process.hpp
@@ -360,7 +360,7 @@ namespace ams::kern {
                 R_RETURN(m_address_arbiter.SignalToAddress(address, signal_type, value, count));
             }
 
-            Result WaitAddressArbiter(uintptr_t address, ams::svc::ArbitrationType arb_type, s32 value, s64 timeout) {
+            Result WaitAddressArbiter(uintptr_t address, ams::svc::ArbitrationType arb_type, s64 value, s64 timeout) {
                 R_RETURN(m_address_arbiter.WaitForAddress(address, arb_type, value, timeout));
             }
 
diff --git a/libraries/libmesosphere/source/arch/arm64/kern_userspace_memory_access_asm.s b/libraries/libmesosphere/source/arch/arm64/kern_userspace_memory_access_asm.s
index 5b1fd4a51..e697fde3c 100644
--- a/libraries/libmesosphere/source/arch/arm64/kern_userspace_memory_access_asm.s
+++ b/libraries/libmesosphere/source/arch/arm64/kern_userspace_memory_access_asm.s
@@ -126,6 +126,20 @@ _ZN3ams4kern4arch5arm6415UserspaceAccess30CopyMemoryFromUserAligned64BitEPvPKvm:
     mov     x0, #1
     ret
 
+/* ams::kern::arch::arm64::UserspaceAccess::CopyMemoryFromUserSize64Bit(void *dst, const void *src) */
+.section    .text._ZN3ams4kern4arch5arm6415UserspaceAccess27CopyMemoryFromUserSize64BitEPvPKv, "ax", %progbits
+.global     _ZN3ams4kern4arch5arm6415UserspaceAccess27CopyMemoryFromUserSize64BitEPvPKv
+.type       _ZN3ams4kern4arch5arm6415UserspaceAccess27CopyMemoryFromUserSize64BitEPvPKv, %function
+.balign 0x10
+_ZN3ams4kern4arch5arm6415UserspaceAccess27CopyMemoryFromUserSize64BitEPvPKv:
+    /* Just load and store a u64. */
+    ldtr    x2, [x1]
+    str     x2, [x0]
+
+    /* We're done. */
+    mov     x0, #1
+    ret
+
 /* ams::kern::arch::arm64::UserspaceAccess::CopyMemoryFromUserSize32Bit(void *dst, const void *src) */
 .section    .text._ZN3ams4kern4arch5arm6415UserspaceAccess27CopyMemoryFromUserSize32BitEPvPKv, "ax", %progbits
 .global     _ZN3ams4kern4arch5arm6415UserspaceAccess27CopyMemoryFromUserSize32BitEPvPKv
diff --git a/libraries/libmesosphere/source/arch/arm64/svc/kern_svc_address_arbiter_asm.s b/libraries/libmesosphere/source/arch/arm64/svc/kern_svc_address_arbiter_asm.s
new file mode 100644
index 000000000..f82db4454
--- /dev/null
+++ b/libraries/libmesosphere/source/arch/arm64/svc/kern_svc_address_arbiter_asm.s
@@ -0,0 +1,67 @@
+/*
+ * Copyright (c) Atmosphère-NX
+ *
+ * This program is free software; you can redistribute it and/or modify it
+ * under the terms and conditions of the GNU General Public License,
+ * version 2, as published by the Free Software Foundation.
+ *
+ * This program is distributed in the hope it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
+ * more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program.  If not, see <http://www.gnu.org/licenses/>.
+ */
+
+/* ams::kern::svc::CallWaitForAddress64From32() */
+.section    .text._ZN3ams4kern3svc26CallWaitForAddress64From32Ev, "ax", %progbits
+.global     _ZN3ams4kern3svc26CallWaitForAddress64From32Ev
+.type       _ZN3ams4kern3svc26CallWaitForAddress64From32Ev, %function
+_ZN3ams4kern3svc26CallWaitForAddress64From32Ev:
+    /* Save LR + callee-save registers. */
+    str     x30, [sp, #-16]!
+    stp     x6, x7, [sp, #-16]!
+
+    /* Gather the arguments into correct registers. */
+    /* NOTE: This has to be manually implemented via asm, */
+    /* in order to avoid breaking ABI with pre-19.0.0. */
+    orr     x2, x2, x5, lsl#32
+    orr     x3, x3, x4, lsl#32
+
+    /* Invoke the svc handler. */
+    bl      _ZN3ams4kern3svc22WaitForAddress64From32ENS_3svc7AddressENS2_15ArbitrationTypeEll
+
+    /* Clean up registers. */
+    mov     x1, xzr
+    mov     x2, xzr
+    mov     x3, xzr
+    mov     x4, xzr
+    mov     x5, xzr
+
+    ldp     x6, x7, [sp], #0x10
+    ldr     x30, [sp], #0x10
+    ret
+
+/* ams::kern::svc::CallWaitForAddress64() */
+.section    .text._ZN3ams4kern3svc20CallWaitForAddress64Ev, "ax", %progbits
+.global     _ZN3ams4kern3svc20CallWaitForAddress64Ev
+.type       _ZN3ams4kern3svc20CallWaitForAddress64Ev, %function
+_ZN3ams4kern3svc20CallWaitForAddress64Ev:
+    /* Save LR + FP. */
+    stp     x29, x30, [sp, #-16]!
+
+    /* Invoke the svc handler. */
+    bl      _ZN3ams4kern3svc22WaitForAddress64From32ENS_3svc7AddressENS2_15ArbitrationTypeEll
+
+    /* Clean up registers. */
+    mov     x1, xzr
+    mov     x2, xzr
+    mov     x3, xzr
+    mov     x4, xzr
+    mov     x5, xzr
+    mov     x6, xzr
+    mov     x7, xzr
+
+    ldp     x29, x30, [sp], #0x10
+    ret
diff --git a/libraries/libmesosphere/source/arch/arm64/svc/kern_svc_tables.cpp b/libraries/libmesosphere/source/arch/arm64/svc/kern_svc_tables.cpp
index cb9151823..cfc0cb7c5 100644
--- a/libraries/libmesosphere/source/arch/arm64/svc/kern_svc_tables.cpp
+++ b/libraries/libmesosphere/source/arch/arm64/svc/kern_svc_tables.cpp
@@ -36,6 +36,10 @@ namespace ams::kern::svc {
     /* Declare special prototype for (unsupported) CallCallSecureMonitor64From32. */
     void CallCallSecureMonitor64From32();
 
+    /* Declare special prototypes for WaitForAddress. */
+    void CallWaitForAddress64();
+    void CallWaitForAddress64From32();
+
     namespace {
 
         #ifndef MESOSPHERE_USE_STUBBED_SVC_TABLES
@@ -81,6 +85,8 @@ namespace ams::kern::svc {
 
             table[svc::SvcId_CallSecureMonitor]    = CallCallSecureMonitor64From32;
 
+            table[svc::SvcId_WaitForAddress]       = CallWaitForAddress64From32;
+
             return table;
         }();
 
@@ -97,6 +103,8 @@ namespace ams::kern::svc {
 
             table[svc::SvcId_ReturnFromException]  = CallReturnFromException64;
 
+            table[svc::SvcId_WaitForAddress]       = CallWaitForAddress64;
+
             return table;
         }();
 
diff --git a/libraries/libmesosphere/source/kern_k_address_arbiter.cpp b/libraries/libmesosphere/source/kern_k_address_arbiter.cpp
index c084c9053..66c0c1645 100644
--- a/libraries/libmesosphere/source/kern_k_address_arbiter.cpp
+++ b/libraries/libmesosphere/source/kern_k_address_arbiter.cpp
@@ -23,6 +23,10 @@ namespace ams::kern {
             return UserspaceAccess::CopyMemoryFromUserSize32Bit(out, GetVoidPointer(address));
         }
 
+        ALWAYS_INLINE bool ReadFromUser(s64 *out, KProcessAddress address) {
+            return UserspaceAccess::CopyMemoryFromUserSize64Bit(out, GetVoidPointer(address));
+        }
+
         ALWAYS_INLINE bool DecrementIfLessThan(s32 *out, KProcessAddress address, s32 value) {
             /* NOTE: If scheduler lock is not held here, interrupt disable is required. */
             /* KScopedInterruptDisable di; */
@@ -279,4 +283,51 @@ namespace ams::kern {
         R_RETURN(cur_thread->GetWaitResult());
     }
 
+    Result KAddressArbiter::WaitIfEqual64(uintptr_t addr, s64 value, s64 timeout) {
+        /* Prepare to wait. */
+        KThread *cur_thread = GetCurrentThreadPointer();
+        KHardwareTimer *timer;
+        ThreadQueueImplForKAddressArbiter wait_queue(std::addressof(m_tree));
+
+        {
+            KScopedSchedulerLockAndSleep slp(std::addressof(timer), cur_thread, timeout);
+
+            /* Check that the thread isn't terminating. */
+            if (cur_thread->IsTerminationRequested()) {
+                slp.CancelSleep();
+                R_THROW(svc::ResultTerminationRequested());
+            }
+
+            /* Read the value from userspace. */
+            s64 user_value;
+            if (!ReadFromUser(std::addressof(user_value), addr)) {
+                slp.CancelSleep();
+                R_THROW(svc::ResultInvalidCurrentMemory());
+            }
+
+            /* Check that the value is equal. */
+            if (value != user_value) {
+                slp.CancelSleep();
+                R_THROW(svc::ResultInvalidState());
+            }
+
+            /* Check that the timeout is non-zero. */
+            if (timeout == 0) {
+                slp.CancelSleep();
+                R_THROW(svc::ResultTimedOut());
+            }
+
+            /* Set the arbiter. */
+            cur_thread->SetAddressArbiter(std::addressof(m_tree), addr);
+            m_tree.insert(*cur_thread);
+
+            /* Wait for the thread to finish. */
+            wait_queue.SetHardwareTimer(timer);
+            cur_thread->BeginWait(std::addressof(wait_queue));
+        }
+
+        /* Get the wait result. */
+        R_RETURN(cur_thread->GetWaitResult());
+    }
+
 }
diff --git a/libraries/libmesosphere/source/svc/kern_svc_address_arbiter.cpp b/libraries/libmesosphere/source/svc/kern_svc_address_arbiter.cpp
index 68eb61e84..2b495ad45 100644
--- a/libraries/libmesosphere/source/svc/kern_svc_address_arbiter.cpp
+++ b/libraries/libmesosphere/source/svc/kern_svc_address_arbiter.cpp
@@ -41,17 +41,22 @@ namespace ams::kern::svc {
                 case ams::svc::ArbitrationType_WaitIfLessThan:
                 case ams::svc::ArbitrationType_DecrementAndWaitIfLessThan:
                 case ams::svc::ArbitrationType_WaitIfEqual:
+                case ams::svc::ArbitrationType_WaitIfEqual64:
                     return true;
                 default:
                     return false;
             }
         }
 
-        Result WaitForAddress(uintptr_t address, ams::svc::ArbitrationType arb_type, int32_t value, int64_t timeout_ns) {
+        Result WaitForAddress(uintptr_t address, ams::svc::ArbitrationType arb_type, int64_t value, int64_t timeout_ns) {
             /* Validate input. */
-            R_UNLESS(AMS_LIKELY(!IsKernelAddress(address)),     svc::ResultInvalidCurrentMemory());
-            R_UNLESS(util::IsAligned(address, sizeof(int32_t)), svc::ResultInvalidAddress());
-            R_UNLESS(IsValidArbitrationType(arb_type),          svc::ResultInvalidEnumValue());
+            R_UNLESS(AMS_LIKELY(!IsKernelAddress(address)),         svc::ResultInvalidCurrentMemory());
+            if (arb_type == ams::svc::ArbitrationType_WaitIfEqual64) {
+                R_UNLESS(util::IsAligned(address, sizeof(int64_t)), svc::ResultInvalidAddress());
+            } else {
+                R_UNLESS(util::IsAligned(address, sizeof(int32_t)), svc::ResultInvalidAddress());
+            }
+            R_UNLESS(IsValidArbitrationType(arb_type),              svc::ResultInvalidEnumValue());
 
             /* Convert timeout from nanoseconds to ticks. */
             s64 timeout;
@@ -85,7 +90,7 @@ namespace ams::kern::svc {
 
     /* =============================    64 ABI    ============================= */
 
-    Result WaitForAddress64(ams::svc::Address address, ams::svc::ArbitrationType arb_type, int32_t value, int64_t timeout_ns) {
+    Result WaitForAddress64(ams::svc::Address address, ams::svc::ArbitrationType arb_type, int64_t value, int64_t timeout_ns) {
         R_RETURN(WaitForAddress(address, arb_type, value, timeout_ns));
     }
 
@@ -95,7 +100,7 @@ namespace ams::kern::svc {
 
     /* ============================= 64From32 ABI ============================= */
 
-    Result WaitForAddress64From32(ams::svc::Address address, ams::svc::ArbitrationType arb_type, int32_t value, int64_t timeout_ns) {
+    Result WaitForAddress64From32(ams::svc::Address address, ams::svc::ArbitrationType arb_type, int64_t value, int64_t timeout_ns) {
         R_RETURN(WaitForAddress(address, arb_type, value, timeout_ns));
     }
 
diff --git a/libraries/libstratosphere/include/stratosphere/svc/svc_stratosphere_shims.hpp b/libraries/libstratosphere/include/stratosphere/svc/svc_stratosphere_shims.hpp
index b724ce5f1..d68a56f33 100644
--- a/libraries/libstratosphere/include/stratosphere/svc/svc_stratosphere_shims.hpp
+++ b/libraries/libstratosphere/include/stratosphere/svc/svc_stratosphere_shims.hpp
@@ -231,7 +231,7 @@
                     R_RETURN(::svcGetThreadContext3(reinterpret_cast<::ThreadContext *>(out_context.GetPointerUnsafe()), thread_handle));
                 }
 
-                ALWAYS_INLINE Result WaitForAddress(::ams::svc::Address address, ::ams::svc::ArbitrationType arb_type, int32_t value, int64_t timeout_ns) {
+                ALWAYS_INLINE Result WaitForAddress(::ams::svc::Address address, ::ams::svc::ArbitrationType arb_type, int64_t value, int64_t timeout_ns) {
                     R_RETURN(::svcWaitForAddress(reinterpret_cast<void *>(static_cast<uintptr_t>(address)), arb_type, value, timeout_ns));
                 }
 
diff --git a/libraries/libvapours/include/vapours/svc/svc_definition_macro.hpp b/libraries/libvapours/include/vapours/svc/svc_definition_macro.hpp
index 0f1c57783..7e4121aa3 100644
--- a/libraries/libvapours/include/vapours/svc/svc_definition_macro.hpp
+++ b/libraries/libvapours/include/vapours/svc/svc_definition_macro.hpp
@@ -67,7 +67,7 @@
     HANDLER(0x31, Result,  GetResourceLimitCurrentValue,   OUTPUT(int64_t, out_current_value), INPUT(::ams::svc::Handle, resource_limit_handle), INPUT(::ams::svc::LimitableResource, which))                                                                                                                          \
     HANDLER(0x32, Result,  SetThreadActivity,              INPUT(::ams::svc::Handle, thread_handle), INPUT(::ams::svc::ThreadActivity, thread_activity))                                                                                                                                                               \
     HANDLER(0x33, Result,  GetThreadContext3,              OUTPTR(::ams::svc::ThreadContext, out_context), INPUT(::ams::svc::Handle, thread_handle))                                                                                                                                                                   \
-    HANDLER(0x34, Result,  WaitForAddress,                 INPUT(::ams::svc::Address, address), INPUT(::ams::svc::ArbitrationType, arb_type), INPUT(int32_t, value), INPUT(int64_t, timeout_ns))                                                                                                                       \
+    HANDLER(0x34, Result,  WaitForAddress,                 INPUT(::ams::svc::Address, address), INPUT(::ams::svc::ArbitrationType, arb_type), INPUT(int64_t, value), INPUT(int64_t, timeout_ns))                                                                                                                       \
     HANDLER(0x35, Result,  SignalToAddress,                INPUT(::ams::svc::Address, address), INPUT(::ams::svc::SignalType, signal_type), INPUT(int32_t, value), INPUT(int32_t, count))                                                                                                                              \
     HANDLER(0x36, void,    SynchronizePreemptionState)                                                                                                                                                                                                                                                                 \
     HANDLER(0x37, Result,  GetResourceLimitPeakValue,      OUTPUT(int64_t, out_peak_value), INPUT(::ams::svc::Handle, resource_limit_handle), INPUT(::ams::svc::LimitableResource, which))                                                                                                                             \
diff --git a/libraries/libvapours/include/vapours/svc/svc_types_common.hpp b/libraries/libvapours/include/vapours/svc/svc_types_common.hpp
index 0247f59b9..21ab685d0 100644
--- a/libraries/libvapours/include/vapours/svc/svc_types_common.hpp
+++ b/libraries/libvapours/include/vapours/svc/svc_types_common.hpp
@@ -263,6 +263,7 @@ namespace ams::svc {
         ArbitrationType_WaitIfLessThan             = 0,
         ArbitrationType_DecrementAndWaitIfLessThan = 1,
         ArbitrationType_WaitIfEqual                = 2,
+        ArbitrationType_WaitIfEqual64              = 3,
     };
 
     enum YieldType : s64 {