diff --git a/src/LibHac/Os/ILockable.cs b/src/LibHac/Os/ILockable.cs index 9f86b444..6000607b 100644 --- a/src/LibHac/Os/ILockable.cs +++ b/src/LibHac/Os/ILockable.cs @@ -10,4 +10,11 @@ { bool TryLock(); } + + public interface ISharedMutex : ILockable + { + void LockShared(); + bool TryLockShared(); + void UnlockShared(); + } } diff --git a/src/LibHac/Os/Impl/ReaderWriterLockImpl-os.net.cs b/src/LibHac/Os/Impl/ReaderWriterLockImpl-os.net.cs new file mode 100644 index 00000000..ae06a2df --- /dev/null +++ b/src/LibHac/Os/Impl/ReaderWriterLockImpl-os.net.cs @@ -0,0 +1,212 @@ +using System; +using LibHac.Diag; + +namespace LibHac.Os.Impl +{ + internal static partial class ReaderWriterLockImpl + { + public static void AcquireReadLockImpl(this OsStateImpl os, ref ReaderWriterLockType rwLock) + { + ref InternalCriticalSection cs = ref GetLockCount(ref rwLock).Cs; + using ScopedLock lk = ScopedLock.Lock(ref cs); + + // If we already own the lock, no additional action is needed + if (rwLock.OwnerThread == Environment.CurrentManagedThreadId) + { + Assert.True(GetWriteLocked(in GetLockCount(ref rwLock)) == 1); + } + // Otherwise we might need to block until we can acquire the read lock + else + { + // Wait until there aren't any writers or waiting writers + while (GetWriteLocked(in GetLockCount(ref rwLock)) == 1 || + GetWriteLockWaiterCount(in GetLockCount(ref rwLock)) != 0) + { + IncReadLockWaiterCount(ref GetLockCount(ref rwLock)); + rwLock.CvReadLockWaiter.Wait(ref cs); + DecReadLockWaiterCount(ref GetLockCount(ref rwLock)); + } + + Assert.True(GetWriteLockCount(in rwLock) == 0); + Assert.True(rwLock.OwnerThread == 0); + } + + IncReadLockCount(ref GetLockCount(ref rwLock)); + } + + public static bool TryAcquireReadLockImpl(this OsStateImpl os, ref ReaderWriterLockType rwLock) + { + using ScopedLock lk = ScopedLock.Lock(ref GetLockCount(ref rwLock).Cs); + + // Acquire the lock if we already have write access + if (rwLock.OwnerThread == Environment.CurrentManagedThreadId) + { + Assert.True(GetWriteLocked(in GetLockCount(ref rwLock)) == 1); + + IncReadLockCount(ref GetLockCount(ref rwLock)); + return true; + } + + // Fail to acquire if there are any writers or waiting writers + if (GetWriteLocked(in GetLockCount(ref rwLock)) == 1 || + GetWriteLockWaiterCount(in GetLockCount(ref rwLock)) != 0) + { + return false; + } + + // Otherwise acquire the lock + Assert.True(GetWriteLockCount(in rwLock) == 0); + Assert.True(rwLock.OwnerThread == 0); + + IncReadLockCount(ref GetLockCount(ref rwLock)); + return true; + } + + public static void ReleaseReadLockImpl(this OsStateImpl os, ref ReaderWriterLockType rwLock) + { + using ScopedLock lk = ScopedLock.Lock(ref GetLockCount(ref rwLock).Cs); + + Assert.True(GetReadLockCount(in GetLockCount(ref rwLock)) > 0); + DecReadLockWaiterCount(ref GetLockCount(ref rwLock)); + + // If we own the lock, check if we need to release ownership and signal any waiting threads + if (rwLock.OwnerThread == Environment.CurrentManagedThreadId) + { + Assert.True(GetWriteLocked(in GetLockCount(ref rwLock)) == 1); + + // Return if we still hold any locks + if (GetWriteLockCount(in rwLock) != 0 || GetReadLockCount(in GetLockCount(ref rwLock)) != 0) + { + return; + } + + // We don't hold any more locks. Release our ownership of the lock + rwLock.OwnerThread = 0; + ClearWriteLocked(ref GetLockCount(ref rwLock)); + + // Signal the next writer if any are waiting + if (GetWriteLockWaiterCount(in GetLockCount(ref rwLock)) != 0) + { + rwLock.CvWriteLockWaiter.Signal(); + } + // Otherwise signal any waiting readers + else if (GetReadLockWaiterCount(in GetLockCount(ref rwLock)) != 0) + { + rwLock.CvReadLockWaiter.Broadcast(); + } + } + // Otherwise we need to signal the next writer if we were the only reader + else + { + Assert.True(GetWriteLockCount(in rwLock) == 0); + Assert.True(GetWriteLocked(in GetLockCount(ref rwLock)) == 0); + Assert.True(rwLock.OwnerThread == 0); + + // Signal the next writer if no readers are left + if (GetReadLockCount(in GetLockCount(ref rwLock)) == 0 && + GetWriteLockWaiterCount(in GetLockCount(ref rwLock)) != 0) + { + rwLock.CvWriteLockWaiter.Signal(); + } + } + } + + public static void AcquireWriteLockImpl(this OsStateImpl os, ref ReaderWriterLockType rwLock) + { + ref InternalCriticalSection cs = ref GetLockCount(ref rwLock).Cs; + using ScopedLock lk = ScopedLock.Lock(ref cs); + + int currentThread = Environment.CurrentManagedThreadId; + + // Increase the write lock count if we already own the lock + if (rwLock.OwnerThread == currentThread) + { + Assert.True(GetWriteLocked(in GetLockCount(ref rwLock)) == 1); + + IncWriteLockCount(ref rwLock); + return; + } + + // Otherwise wait until there aren't any readers or writers + while (GetReadLockCount(in GetLockCount(ref rwLock)) != 0 || + GetWriteLocked(in GetLockCount(ref rwLock)) == 1) + { + IncWriteLockWaiterCount(ref GetLockCount(ref rwLock)); + rwLock.CvWriteLockWaiter.Wait(ref cs); + DecWriteLockWaiterCount(ref GetLockCount(ref rwLock)); + } + + Assert.True(GetWriteLockCount(in rwLock) == 0); + Assert.True(rwLock.OwnerThread == 0); + + // Acquire the lock + IncWriteLockCount(ref rwLock); + SetWriteLocked(ref GetLockCount(ref rwLock)); + rwLock.OwnerThread = currentThread; + } + + public static bool TryAcquireWriteLockImpl(this OsStateImpl os, ref ReaderWriterLockType rwLock) + { + using ScopedLock lk = ScopedLock.Lock(ref GetLockCount(ref rwLock).Cs); + + int currentThread = Environment.CurrentManagedThreadId; + + // Acquire the lock if we already have write access + if (rwLock.OwnerThread == currentThread) + { + Assert.True(GetWriteLocked(in GetLockCount(ref rwLock)) == 1); + + IncWriteLockCount(ref rwLock); + return true; + } + + // Fail to acquire if there are any readers or writers + if (GetReadLockCount(in GetLockCount(ref rwLock)) != 0 || + GetWriteLocked(in GetLockCount(ref rwLock)) == 1) + { + return false; + } + + // Otherwise acquire the lock + Assert.True(GetWriteLockCount(in rwLock) == 0); + Assert.True(rwLock.OwnerThread == 0); + + IncWriteLockCount(ref rwLock); + SetWriteLocked(ref GetLockCount(ref rwLock)); + rwLock.OwnerThread = currentThread; + return true; + } + + public static void ReleaseWriteLockImpl(this OsStateImpl os, ref ReaderWriterLockType rwLock) + { + using ScopedLock lk = ScopedLock.Lock(ref GetLockCount(ref rwLock).Cs); + + Assert.True(GetWriteLockCount(in rwLock) > 0); + Assert.True(GetWriteLocked(in GetLockCount(ref rwLock)) != 0); + Assert.True(rwLock.OwnerThread == Environment.CurrentManagedThreadId); + + DecWriteLockCount(ref rwLock); + + // Return if we still hold any locks + if (GetWriteLockCount(in rwLock) != 0 || GetReadLockCount(in GetLockCountRo(in rwLock)) != 0) + { + return; + } + + // We don't hold any more locks. Release our ownership of the lock + rwLock.OwnerThread = 0; + ClearWriteLocked(ref GetLockCount(ref rwLock)); + + // Signal the next writer if any are waiting + if (GetWriteLockWaiterCount(in GetLockCount(ref rwLock)) != 0) + { + rwLock.CvWriteLockWaiter.Signal(); + } + // Otherwise signal any waiting readers + else if (GetReadLockWaiterCount(in GetLockCount(ref rwLock)) != 0) + { + rwLock.CvReadLockWaiter.Broadcast(); + } + } + } +} diff --git a/src/LibHac/Os/Impl/ReaderWriterLockImpl.cs b/src/LibHac/Os/Impl/ReaderWriterLockImpl.cs new file mode 100644 index 00000000..318ff295 --- /dev/null +++ b/src/LibHac/Os/Impl/ReaderWriterLockImpl.cs @@ -0,0 +1,128 @@ +using LibHac.Diag; + +namespace LibHac.Os.Impl +{ + internal static partial class ReaderWriterLockImpl + { + public static void ClearReadLockCount(ref ReaderWriterLockType.LockCountType lc) + { + lc.Counter.ReadLockCount = 0; + } + + public static void ClearWriteLocked(ref ReaderWriterLockType.LockCountType lc) + { + lc.Counter.WriteLocked = 0; + } + + public static void ClearReadLockWaiterCount(ref ReaderWriterLockType.LockCountType lc) + { + lc.Counter.ReadLockWaiterCount = 0; + } + + public static void ClearWriteLockWaiterCount(ref ReaderWriterLockType.LockCountType lc) + { + lc.Counter.WriteLockWaiterCount = 0; + } + + public static void ClearWriteLockCount(ref ReaderWriterLockType rwLock) + { + rwLock.LockCount.WriteLockCount = 0; + } + + public static ref ReaderWriterLockType.LockCountType GetLockCount(ref ReaderWriterLockType rwLock) + { + return ref rwLock.LockCount; + } + + public static ref readonly ReaderWriterLockType.LockCountType GetLockCountRo(in ReaderWriterLockType rwLock) + { + return ref rwLock.LockCount; + } + + public static uint GetReadLockCount(in ReaderWriterLockType.LockCountType lc) + { + return lc.Counter.ReadLockCount; + } + + public static uint GetWriteLocked(in ReaderWriterLockType.LockCountType lc) + { + return lc.Counter.WriteLocked; + } + + public static uint GetReadLockWaiterCount(in ReaderWriterLockType.LockCountType lc) + { + return lc.Counter.ReadLockWaiterCount; + } + + public static uint GetWriteLockWaiterCount(in ReaderWriterLockType.LockCountType lc) + { + return lc.Counter.WriteLockWaiterCount; + } + + public static uint GetWriteLockCount(in ReaderWriterLockType rwLock) + { + return rwLock.LockCount.WriteLockCount; + } + + public static void IncReadLockCount(ref ReaderWriterLockType.LockCountType lc) + { + uint readLockCount = lc.Counter.ReadLockCount; + Assert.True(readLockCount < ReaderWriterLock.ReaderWriterLockCountMax); + lc.Counter.ReadLockCount = readLockCount + 1; + } + + public static void DecReadLockCount(ref ReaderWriterLockType.LockCountType lc) + { + uint readLockCount = lc.Counter.ReadLockCount; + Assert.True(readLockCount > 0); + lc.Counter.ReadLockCount = readLockCount - 1; + } + + public static void IncReadLockWaiterCount(ref ReaderWriterLockType.LockCountType lc) + { + uint readLockWaiterCount = lc.Counter.ReadLockWaiterCount; + Assert.True(readLockWaiterCount < ReaderWriterLock.ReadWriteLockWaiterCountMax); + lc.Counter.ReadLockWaiterCount = readLockWaiterCount + 1; + } + + public static void DecReadLockWaiterCount(ref ReaderWriterLockType.LockCountType lc) + { + uint readLockWaiterCount = lc.Counter.ReadLockWaiterCount; + Assert.True(readLockWaiterCount > 0); + lc.Counter.ReadLockWaiterCount = readLockWaiterCount - 1; + } + + public static void IncWriteLockWaiterCount(ref ReaderWriterLockType.LockCountType lc) + { + uint writeLockWaiterCount = lc.Counter.WriteLockWaiterCount; + Assert.True(writeLockWaiterCount < ReaderWriterLock.ReadWriteLockWaiterCountMax); + lc.Counter.WriteLockWaiterCount = writeLockWaiterCount + 1; + } + + public static void DecWriteLockWaiterCount(ref ReaderWriterLockType.LockCountType lc) + { + uint writeLockWaiterCount = lc.Counter.WriteLockWaiterCount; + Assert.True(writeLockWaiterCount > 0); + lc.Counter.WriteLockWaiterCount = writeLockWaiterCount - 1; + } + + public static void IncWriteLockCount(ref ReaderWriterLockType rwLock) + { + uint writeLockCount = rwLock.LockCount.WriteLockCount; + Assert.True(writeLockCount < ReaderWriterLock.ReaderWriterLockCountMax); + rwLock.LockCount.WriteLockCount = writeLockCount + 1; + } + + public static void DecWriteLockCount(ref ReaderWriterLockType rwLock) + { + uint writeLockCount = rwLock.LockCount.WriteLockCount; + Assert.True(writeLockCount > 0); + rwLock.LockCount.WriteLockCount = writeLockCount - 1; + } + + public static void SetWriteLocked(ref ReaderWriterLockType.LockCountType lc) + { + lc.Counter.WriteLocked = 1; + } + } +} diff --git a/src/LibHac/Os/OsState.cs b/src/LibHac/Os/OsState.cs index 3b7466c9..d96fac4a 100644 --- a/src/LibHac/Os/OsState.cs +++ b/src/LibHac/Os/OsState.cs @@ -5,7 +5,8 @@ namespace LibHac.Os { public class OsState : IDisposable { - private HorizonClient Hos { get; } + public OsStateImpl Impl => new OsStateImpl(this); + internal HorizonClient Hos { get; } internal OsResourceManager ResourceManager { get; } // Todo: Use configuration object if/when more options are added @@ -25,4 +26,13 @@ namespace LibHac.Os ResourceManager.Dispose(); } } + + // Functions in the nn::os::detail namespace use this struct. + public readonly struct OsStateImpl + { + internal readonly OsState Os; + internal HorizonClient Hos => Os.Hos; + + internal OsStateImpl(OsState parent) => Os = parent; + } } diff --git a/src/LibHac/Os/ReaderWriterLock.cs b/src/LibHac/Os/ReaderWriterLock.cs new file mode 100644 index 00000000..c77a4067 --- /dev/null +++ b/src/LibHac/Os/ReaderWriterLock.cs @@ -0,0 +1,195 @@ +using System; +using LibHac.Diag; +using LibHac.Os.Impl; + +namespace LibHac.Os +{ + public static class ReaderWriterLockApi + { + public static void InitializeReaderWriterLock(this OsState os, ref ReaderWriterLockType rwLock) + { + // Create objects. + ReaderWriterLockImpl.GetLockCount(ref rwLock).Cs.Initialize(); + rwLock.CvReadLockWaiter.Initialize(); + rwLock.CvWriteLockWaiter.Initialize(); + + // Set member variables. + ReaderWriterLockImpl.ClearReadLockCount(ref ReaderWriterLockImpl.GetLockCount(ref rwLock)); + ReaderWriterLockImpl.ClearWriteLocked(ref ReaderWriterLockImpl.GetLockCount(ref rwLock)); + ReaderWriterLockImpl.ClearReadLockWaiterCount(ref ReaderWriterLockImpl.GetLockCount(ref rwLock)); + ReaderWriterLockImpl.ClearWriteLockWaiterCount(ref ReaderWriterLockImpl.GetLockCount(ref rwLock)); + ReaderWriterLockImpl.ClearWriteLockCount(ref rwLock); + rwLock.OwnerThread = 0; + + // Mark initialized. + rwLock.LockState = ReaderWriterLockType.State.Initialized; + } + + public static void FinalizeReaderWriterLock(this OsState os, ref ReaderWriterLockType rwLock) + { + Assert.True(rwLock.LockState == ReaderWriterLockType.State.Initialized); + + // Don't allow finalizing a locked lock. + Assert.True(ReaderWriterLockImpl.GetReadLockCount(in ReaderWriterLockImpl.GetLockCount(ref rwLock)) == 0); + Assert.True(ReaderWriterLockImpl.GetWriteLocked(in ReaderWriterLockImpl.GetLockCount(ref rwLock)) == 0); + + // Mark not initialized. + rwLock.LockState = ReaderWriterLockType.State.NotInitialized; + + // Destroy objects. + ReaderWriterLockImpl.GetLockCount(ref rwLock).Cs.FinalizeObject(); + } + + public static void AcquireReadLock(this OsState os, ref ReaderWriterLockType rwLock) + { + Assert.True(rwLock.LockState == ReaderWriterLockType.State.Initialized); + os.Impl.AcquireReadLockImpl(ref rwLock); + } + + public static bool TryAcquireReadLock(this OsState os, ref ReaderWriterLockType rwLock) + { + Assert.True(rwLock.LockState == ReaderWriterLockType.State.Initialized); + return os.Impl.TryAcquireReadLockImpl(ref rwLock); + } + + public static void ReleaseReadLock(this OsState os, ref ReaderWriterLockType rwLock) + { + Assert.True(rwLock.LockState == ReaderWriterLockType.State.Initialized); + os.Impl.ReleaseReadLockImpl(ref rwLock); + } + + public static void AcquireWriteLock(this OsState os, ref ReaderWriterLockType rwLock) + { + Assert.True(rwLock.LockState == ReaderWriterLockType.State.Initialized); + os.Impl.AcquireWriteLockImpl(ref rwLock); + } + + public static bool TryAcquireWriteLock(this OsState os, ref ReaderWriterLockType rwLock) + { + Assert.True(rwLock.LockState == ReaderWriterLockType.State.Initialized); + return os.Impl.TryAcquireWriteLockImpl(ref rwLock); + } + + public static void ReleaseWriteLock(this OsState os, ref ReaderWriterLockType rwLock) + { + Assert.True(rwLock.LockState == ReaderWriterLockType.State.Initialized); + os.Impl.ReleaseWriteLockImpl(ref rwLock); + } + + public static bool IsReadLockHeld(this OsState os, in ReaderWriterLockType rwLock) + { + Assert.True(rwLock.LockState == ReaderWriterLockType.State.Initialized); + return ReaderWriterLockImpl.GetReadLockCount(in ReaderWriterLockImpl.GetLockCountRo(in rwLock)) != 0; + + } + + // Todo: Use Horizon thread APIs + public static bool IsWriteLockHeldByCurrentThread(this OsState os, in ReaderWriterLockType rwLock) + { + Assert.True(rwLock.LockState == ReaderWriterLockType.State.Initialized); + return rwLock.OwnerThread == Environment.CurrentManagedThreadId && + ReaderWriterLockImpl.GetWriteLockCount(in rwLock) != 0; + } + + public static bool IsReaderWriterLockOwnerThread(this OsState os, in ReaderWriterLockType rwLock) + { + Assert.True(rwLock.LockState == ReaderWriterLockType.State.Initialized); + return rwLock.OwnerThread == Environment.CurrentManagedThreadId; + } + } + + public class ReaderWriterLock : ISharedMutex + { + public const int ReaderWriterLockCountMax = (1 << 15) - 1; + public const int ReadWriteLockWaiterCountMax = (1 << 8) - 1; + + private readonly OsState _os; + private ReaderWriterLockType _rwLock; + + public ReaderWriterLock(OsState os) + { + _os = os; + _os.InitializeReaderWriterLock(ref _rwLock); + } + + public void AcquireReadLock() + { + _os.AcquireReadLock(ref _rwLock); + } + + public bool TryAcquireReadLock() + { + return _os.TryAcquireReadLock(ref _rwLock); + } + + public void ReleaseReadLock() + { + _os.ReleaseReadLock(ref _rwLock); + } + + public void AcquireWriteLock() + { + _os.AcquireWriteLock(ref _rwLock); + } + + public bool TryAcquireWriteLock() + { + return _os.TryAcquireWriteLock(ref _rwLock); + } + + public void ReleaseWriteLock() + { + _os.ReleaseWriteLock(ref _rwLock); + } + + public bool IsReadLockHeld() + { + return _os.IsReadLockHeld(in _rwLock); + } + + public bool IsWriteLockHeldByCurrentThread() + { + return _os.IsWriteLockHeldByCurrentThread(in _rwLock); + } + + public bool IsLockOwner() + { + return _os.IsReaderWriterLockOwnerThread(in _rwLock); + } + + public void LockShared() + { + AcquireReadLock(); + } + + public bool TryLockShared() + { + return TryAcquireReadLock(); + } + + public void UnlockShared() + { + ReleaseReadLock(); + } + + public void Lock() + { + AcquireWriteLock(); + } + + public bool TryLock() + { + return TryAcquireWriteLock(); + } + + public void Unlock() + { + ReleaseWriteLock(); + } + + public ref ReaderWriterLockType GetBase() + { + return ref _rwLock; + } + } +} diff --git a/src/LibHac/Os/ReaderWriterLockTypes.cs b/src/LibHac/Os/ReaderWriterLockTypes.cs new file mode 100644 index 00000000..f3eba2e6 --- /dev/null +++ b/src/LibHac/Os/ReaderWriterLockTypes.cs @@ -0,0 +1,64 @@ +using System.Runtime.CompilerServices; +using LibHac.Os.Impl; + +namespace LibHac.Os +{ + public struct ReaderWriterLockType + { + internal LockCountType LockCount; + internal State LockState; + internal int OwnerThread; + internal InternalConditionVariable CvReadLockWaiter; + internal InternalConditionVariable CvWriteLockWaiter; + + public enum State + { + NotInitialized, + Initialized + } + + public struct LockCountType + { + public InternalCriticalSection Cs; + public ReaderWriterLockCounter Counter; + public uint WriteLockCount; + } + + public struct ReaderWriterLockCounter + { + private uint _counter; + + public uint ReadLockCount + { + readonly get => GetBitsValue(_counter, 0, 15); + set => _counter = SetBitsValue(value, 0, 15); + } + + public uint WriteLocked + { + readonly get => GetBitsValue(_counter, 15, 1); + set => _counter = SetBitsValue(value, 15, 1); + } + + public uint ReadLockWaiterCount + { + readonly get => GetBitsValue(_counter, 16, 8); + set => _counter = SetBitsValue(value, 16, 8); + } + + public uint WriteLockWaiterCount + { + readonly get => GetBitsValue(_counter, 24, 8); + set => _counter = SetBitsValue(value, 24, 8); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static uint GetBitsValue(uint value, int bitsOffset, int bitsCount) => + (value >> bitsOffset) & ~(~default(uint) << bitsCount); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static uint SetBitsValue(uint value, int bitsOffset, int bitsCount) => + (value & ~(~default(uint) << bitsCount)) << bitsOffset; + } + } +}