From 02a3b5ae2bc528f2b153a5036869a1eea7f8d550 Mon Sep 17 00:00:00 2001
From: B3n30 <benediktthomas@gmail.com>
Date: Wed, 8 Aug 2018 23:30:48 +0200
Subject: [PATCH] Service::SM: Wait till client is registered

---
 src/core/hle/service/sm/srv.cpp | 56 ++++++++++++++++++++++++++++++---
 1 file changed, 51 insertions(+), 5 deletions(-)

diff --git a/src/core/hle/service/sm/srv.cpp b/src/core/hle/service/sm/srv.cpp
index f459d3784..e311e79f2 100644
--- a/src/core/hle/service/sm/srv.cpp
+++ b/src/core/hle/service/sm/srv.cpp
@@ -3,7 +3,7 @@
 // Refer to the license.txt file included.
 
 #include <tuple>
-
+#include <unordered_map>
 #include "common/common_types.h"
 #include "common/logging/log.h"
 #include "core/hle/ipc.h"
@@ -11,10 +11,12 @@
 #include "core/hle/kernel/client_port.h"
 #include "core/hle/kernel/client_session.h"
 #include "core/hle/kernel/errors.h"
+#include "core/hle/kernel/event.h"
 #include "core/hle/kernel/hle_ipc.h"
 #include "core/hle/kernel/semaphore.h"
 #include "core/hle/kernel/server_port.h"
 #include "core/hle/kernel/server_session.h"
+#include "core/hle/lock.h"
 #include "core/hle/service/sm/sm.h"
 #include "core/hle/service/sm/srv.h"
 
@@ -23,6 +25,9 @@ namespace SM {
 
 constexpr int MAX_PENDING_NOTIFICATIONS = 16;
 
+static std::unordered_map<std::string, Kernel::SharedPtr<Kernel::Event>>
+    get_service_handle_delayed_map;
+
 /**
  * SRV::RegisterClient service function
  *  Inputs:
@@ -99,12 +104,47 @@ void SRV::GetServiceHandle(Kernel::HLERequestContext& ctx) {
 
     // TODO(yuriks): Permission checks go here
 
+    auto get_handle = [name, this, wait_until_available](Kernel::SharedPtr<Kernel::Thread> thread,
+                                                         Kernel::HLERequestContext& ctx,
+                                                         ThreadWakeupReason reason) {
+        LOG_ERROR(Service_SRV, "called service={} wakeup", name);
+        auto client_port = service_manager->GetServicePort(name);
+
+        auto session = client_port.Unwrap()->Connect();
+        if (session.Succeeded()) {
+            LOG_DEBUG(Service_SRV, "called service={} -> session={}", name,
+                      (*session)->GetObjectId());
+            IPC::RequestBuilder rb(ctx, 0x5, 1, 2);
+            rb.Push(session.Code());
+            rb.PushMoveObjects(std::move(session).Unwrap());
+        } else if (session.Code() == Kernel::ERR_MAX_CONNECTIONS_REACHED && wait_until_available) {
+            LOG_WARNING(Service_SRV, "called service={} -> ERR_MAX_CONNECTIONS_REACHED", name);
+            // TODO(Subv): Put the caller guest thread to sleep until this port becomes available
+            // again.
+            UNIMPLEMENTED_MSG("Unimplemented wait until port {} is available.", name);
+        } else {
+            LOG_ERROR(Service_SRV, "called service={} -> error 0x{:08X}", name, session.Code().raw);
+            IPC::RequestBuilder rb(ctx, 0x5, 1, 0);
+            rb.Push(session.Code());
+        }
+    };
+
     auto client_port = service_manager->GetServicePort(name);
     if (client_port.Failed()) {
-        IPC::RequestBuilder rb = rp.MakeBuilder(1, 0);
-        rb.Push(client_port.Code());
-        LOG_ERROR(Service_SRV, "called service={} -> error 0x{:08X}", name, client_port.Code().raw);
-        return;
+        if (wait_until_available) {
+            LOG_ERROR(Service_SRV, "called service={} delayed", name);
+            Kernel::SharedPtr<Kernel::Event> get_service_handle_event =
+                ctx.SleepClientThread(Kernel::GetCurrentThread(), "GetServiceHandle",
+                                      std::chrono::nanoseconds(-1), get_handle);
+            get_service_handle_delayed_map[name] = get_service_handle_event;
+            return;
+        } else {
+            IPC::RequestBuilder rb = rp.MakeBuilder(1, 0);
+            rb.Push(client_port.Code());
+            LOG_ERROR(Service_SRV, "called service={} -> error 0x{:08X}", name,
+                      client_port.Code().raw);
+            return;
+        }
     }
 
     auto session = client_port.Unwrap()->Connect();
@@ -192,6 +232,12 @@ void SRV::RegisterService(Kernel::HLERequestContext& ctx) {
 
     auto port = service_manager->RegisterService(name, max_sessions);
 
+    if (get_service_handle_delayed_map.find(name) != get_service_handle_delayed_map.end()) {
+        std::lock_guard<std::recursive_mutex> lock(HLE::g_hle_lock);
+        get_service_handle_delayed_map.at(name)->Signal();
+        get_service_handle_delayed_map.erase(name);
+    }
+
     if (port.Failed()) {
         IPC::RequestBuilder rb = rp.MakeBuilder(1, 0);
         rb.Push(port.Code());