From 191b3d41f6028c46a779aff7cf5880b4cebabc8e Mon Sep 17 00:00:00 2001 From: Alex Barney Date: Tue, 19 Nov 2019 21:05:33 -0500 Subject: [PATCH] Ensure crypto works when the input and output buffers are the same --- src/LibHac/Crypto/Aes.cs | 4 +++- src/LibHac/Crypto/Detail/AesCbcModeNi.cs | 5 +++-- src/LibHac/Crypto/Detail/AesXtsMode.cs | 12 +++++++----- src/LibHac/Crypto/Detail/AesXtsModeNi.cs | 4 ++-- tests/LibHac.Tests/CryptoTests/Common.cs | 10 ++++++---- 5 files changed, 21 insertions(+), 14 deletions(-) diff --git a/src/LibHac/Crypto/Aes.cs b/src/LibHac/Crypto/Aes.cs index 8286f255..52df2d8e 100644 --- a/src/LibHac/Crypto/Aes.cs +++ b/src/LibHac/Crypto/Aes.cs @@ -1,7 +1,9 @@ // ReSharper disable AssignmentIsFullyDiscarded using System; -using LibHac.Crypto.Detail; + #if HAS_INTRINSICS +using LibHac.Crypto.Detail; + using AesNi = System.Runtime.Intrinsics.X86.Aes; #endif diff --git a/src/LibHac/Crypto/Detail/AesCbcModeNi.cs b/src/LibHac/Crypto/Detail/AesCbcModeNi.cs index fd997d0f..8160991d 100644 --- a/src/LibHac/Crypto/Detail/AesCbcModeNi.cs +++ b/src/LibHac/Crypto/Detail/AesCbcModeNi.cs @@ -58,10 +58,11 @@ namespace LibHac.Crypto.Detail for (int i = 0; i < blockCount; i++) { - Vector128 decBeforeIv = _aesCore.DecryptBlock(inBlock); + Vector128 currentBlock = inBlock; + Vector128 decBeforeIv = _aesCore.DecryptBlock(currentBlock); outBlock = Sse2.Xor(decBeforeIv, iv); - iv = inBlock; + iv = currentBlock; inBlock = ref Unsafe.Add(ref inBlock, 1); outBlock = ref Unsafe.Add(ref outBlock, 1); diff --git a/src/LibHac/Crypto/Detail/AesXtsMode.cs b/src/LibHac/Crypto/Detail/AesXtsMode.cs index ccfd37db..2fad3d11 100644 --- a/src/LibHac/Crypto/Detail/AesXtsMode.cs +++ b/src/LibHac/Crypto/Detail/AesXtsMode.cs @@ -49,8 +49,8 @@ namespace LibHac.Crypto.Detail if (leftover != 0) { - ref Buffer16 inBlock = - ref Unsafe.Add(ref Unsafe.As(ref MemoryMarshal.GetReference(input)), blockCount); + Buffer16 inBlock = + Unsafe.Add(ref Unsafe.As(ref MemoryMarshal.GetReference(input)), blockCount); ref Buffer16 outBlock = ref Unsafe.Add(ref Unsafe.As(ref MemoryMarshal.GetReference(output)), blockCount); @@ -107,8 +107,8 @@ namespace LibHac.Crypto.Detail Buffer16 finalTweak = tweak; Gf128Mul(ref finalTweak); - ref Buffer16 inBlock = - ref Unsafe.Add(ref Unsafe.As(ref MemoryMarshal.GetReference(input)), blockCount); + Buffer16 inBlock = + Unsafe.Add(ref Unsafe.As(ref MemoryMarshal.GetReference(input)), blockCount); ref Buffer16 outBlock = ref Unsafe.Add(ref Unsafe.As(ref MemoryMarshal.GetReference(output)), blockCount); @@ -120,7 +120,9 @@ namespace LibHac.Crypto.Detail XorBuffer(ref outBlock, ref tmp, ref finalTweak); ref Buffer16 finalOutBlock = ref Unsafe.Add(ref outBlock, 1); - ref Buffer16 finalInBlock = ref Unsafe.Add(ref inBlock, 1); + + Buffer16 finalInBlock = Unsafe.Add(ref Unsafe.As(ref MemoryMarshal.GetReference(input)), + blockCount + 1); for (int i = 0; i < leftover; i++) { diff --git a/src/LibHac/Crypto/Detail/AesXtsModeNi.cs b/src/LibHac/Crypto/Detail/AesXtsModeNi.cs index 914d9d56..fcfca434 100644 --- a/src/LibHac/Crypto/Detail/AesXtsModeNi.cs +++ b/src/LibHac/Crypto/Detail/AesXtsModeNi.cs @@ -108,7 +108,7 @@ namespace LibHac.Crypto.Detail var x = new Buffer16(); ref Buffer16 outBuf = ref Unsafe.As, Buffer16>(ref output); - ref Buffer16 nextInBuf = ref Unsafe.As, Buffer16>(ref Unsafe.Add(ref input, 1)); + Buffer16 nextInBuf = Unsafe.As, Buffer16>(ref Unsafe.Add(ref input, 1)); ref Buffer16 nextOutBuf = ref Unsafe.As, Buffer16>(ref Unsafe.Add(ref output, 1)); for (int i = 0; i < finalBlockLength; i++) @@ -134,7 +134,7 @@ namespace LibHac.Crypto.Detail var x = new Buffer16(); ref Buffer16 outBuf = ref Unsafe.As, Buffer16>(ref output); - ref Buffer16 inBuf = ref Unsafe.As, Buffer16>(ref input); + Buffer16 inBuf = Unsafe.As, Buffer16>(ref input); ref Buffer16 prevOutBuf = ref Unsafe.As, Buffer16>(ref prevOutBlock); for (int i = 0; i < finalBlockLength; i++) diff --git a/tests/LibHac.Tests/CryptoTests/Common.cs b/tests/LibHac.Tests/CryptoTests/Common.cs index 3c7f915c..318ee8e8 100644 --- a/tests/LibHac.Tests/CryptoTests/Common.cs +++ b/tests/LibHac.Tests/CryptoTests/Common.cs @@ -1,4 +1,5 @@ -using LibHac.Crypto; +using System; +using LibHac.Crypto; using Xunit; namespace LibHac.Tests.CryptoTests @@ -7,11 +8,12 @@ namespace LibHac.Tests.CryptoTests { internal static void CipherTestCore(byte[] inputData, byte[] expected, ICipher cipher) { - var outputData = new byte[expected.Length]; + var transformBuffer = new byte[inputData.Length]; + Buffer.BlockCopy(inputData, 0, transformBuffer, 0, inputData.Length); - cipher.Transform(inputData, outputData); + cipher.Transform(transformBuffer, transformBuffer); - Assert.Equal(expected, outputData); + Assert.Equal(expected, transformBuffer); } } }