Add optimized functions for decrypting a single AES block

This commit is contained in:
Alex Barney 2019-11-24 19:54:29 -06:00
parent abce62dd4f
commit 99522b748e
3 changed files with 156 additions and 16 deletions

View File

@ -1,4 +1,4 @@
#if NETCOREAPP #if HAS_INTRINSICS
using System; using System;
using System.Diagnostics; using System.Diagnostics;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
@ -18,6 +18,10 @@ namespace LibHac.Crypto.Detail
private Vector128<byte> _roundKeys; private Vector128<byte> _roundKeys;
// An Initialize method is used instead of a constructor because it prevents the runtime
// from zeroing out the structure's memory when creating it.
// When processing a single block, doing this can increase performance by 20-40%
// depending on the context.
public void Initialize(ReadOnlySpan<byte> key, bool isDecrypting) public void Initialize(ReadOnlySpan<byte> key, bool isDecrypting)
{ {
Debug.Assert(key.Length == Aes.KeySize128); Debug.Assert(key.Length == Aes.KeySize128);
@ -183,7 +187,8 @@ namespace LibHac.Crypto.Detail
// When inlining this function, RyuJIT will almost make the // When inlining this function, RyuJIT will almost make the
// generated code the same as if it were manually inlined // generated code the same as if it were manually inlined
[MethodImpl(MethodImplOptions.AggressiveInlining)] [MethodImpl(MethodImplOptions.AggressiveInlining)]
public readonly void EncryptBlocks8(Vector128<byte> in0, public readonly void EncryptBlocks8(
Vector128<byte> in0,
Vector128<byte> in1, Vector128<byte> in1,
Vector128<byte> in2, Vector128<byte> in2,
Vector128<byte> in3, Vector128<byte> in3,
@ -198,8 +203,7 @@ namespace LibHac.Crypto.Detail
out Vector128<byte> out4, out Vector128<byte> out4,
out Vector128<byte> out5, out Vector128<byte> out5,
out Vector128<byte> out6, out Vector128<byte> out6,
out Vector128<byte> out7 out Vector128<byte> out7)
)
{ {
ReadOnlySpan<Vector128<byte>> keys = RoundKeys; ReadOnlySpan<Vector128<byte>> keys = RoundKeys;
@ -331,8 +335,7 @@ namespace LibHac.Crypto.Detail
out Vector128<byte> out4, out Vector128<byte> out4,
out Vector128<byte> out5, out Vector128<byte> out5,
out Vector128<byte> out6, out Vector128<byte> out6,
out Vector128<byte> out7 out Vector128<byte> out7)
)
{ {
ReadOnlySpan<Vector128<byte>> keys = RoundKeys; ReadOnlySpan<Vector128<byte>> keys = RoundKeys;
@ -447,6 +450,71 @@ namespace LibHac.Crypto.Detail
out7 = AesNi.DecryptLast(b7, key); out7 = AesNi.DecryptLast(b7, key);
} }
public static Vector128<byte> EncryptBlock(Vector128<byte> input, Vector128<byte> key)
{
Vector128<byte> curKey = key;
Vector128<byte> b = Sse2.Xor(input, curKey);
curKey = KeyExpansion(curKey, AesNi.KeygenAssist(curKey, 0x01));
b = AesNi.Encrypt(b, curKey);
curKey = KeyExpansion(curKey, AesNi.KeygenAssist(curKey, 0x02));
b = AesNi.Encrypt(b, curKey);
curKey = KeyExpansion(curKey, AesNi.KeygenAssist(curKey, 0x04));
b = AesNi.Encrypt(b, curKey);
curKey = KeyExpansion(curKey, AesNi.KeygenAssist(curKey, 0x08));
b = AesNi.Encrypt(b, curKey);
curKey = KeyExpansion(curKey, AesNi.KeygenAssist(curKey, 0x10));
b = AesNi.Encrypt(b, curKey);
curKey = KeyExpansion(curKey, AesNi.KeygenAssist(curKey, 0x20));
b = AesNi.Encrypt(b, curKey);
curKey = KeyExpansion(curKey, AesNi.KeygenAssist(curKey, 0x40));
b = AesNi.Encrypt(b, curKey);
curKey = KeyExpansion(curKey, AesNi.KeygenAssist(curKey, 0x80));
b = AesNi.Encrypt(b, curKey);
curKey = KeyExpansion(curKey, AesNi.KeygenAssist(curKey, 0x1b));
b = AesNi.Encrypt(b, curKey);
curKey = KeyExpansion(curKey, AesNi.KeygenAssist(curKey, 0x36));
return AesNi.EncryptLast(b, curKey);
}
public static Vector128<byte> DecryptBlock(Vector128<byte> input, Vector128<byte> key)
{
Vector128<byte> key0 = key;
Vector128<byte> key1 = KeyExpansion(key0, AesNi.KeygenAssist(key0, 0x01));
Vector128<byte> key2 = KeyExpansion(key1, AesNi.KeygenAssist(key1, 0x02));
Vector128<byte> key3 = KeyExpansion(key2, AesNi.KeygenAssist(key2, 0x04));
Vector128<byte> key4 = KeyExpansion(key3, AesNi.KeygenAssist(key3, 0x08));
Vector128<byte> key5 = KeyExpansion(key4, AesNi.KeygenAssist(key4, 0x10));
Vector128<byte> key6 = KeyExpansion(key5, AesNi.KeygenAssist(key5, 0x20));
Vector128<byte> key7 = KeyExpansion(key6, AesNi.KeygenAssist(key6, 0x40));
Vector128<byte> key8 = KeyExpansion(key7, AesNi.KeygenAssist(key7, 0x80));
Vector128<byte> key9 = KeyExpansion(key8, AesNi.KeygenAssist(key8, 0x1b));
Vector128<byte> key10 = KeyExpansion(key9, AesNi.KeygenAssist(key9, 0x36));
Vector128<byte> b = input;
b = Sse2.Xor(b, key10);
b = AesNi.Decrypt(b, AesNi.InverseMixColumns(key9));
b = AesNi.Decrypt(b, AesNi.InverseMixColumns(key8));
b = AesNi.Decrypt(b, AesNi.InverseMixColumns(key7));
b = AesNi.Decrypt(b, AesNi.InverseMixColumns(key6));
b = AesNi.Decrypt(b, AesNi.InverseMixColumns(key5));
b = AesNi.Decrypt(b, AesNi.InverseMixColumns(key4));
b = AesNi.Decrypt(b, AesNi.InverseMixColumns(key3));
b = AesNi.Decrypt(b, AesNi.InverseMixColumns(key2));
b = AesNi.Decrypt(b, AesNi.InverseMixColumns(key1));
return AesNi.DecryptLast(b, key0);
}
private void KeyExpansion(ReadOnlySpan<byte> key, bool isDecrypting) private void KeyExpansion(ReadOnlySpan<byte> key, bool isDecrypting)
{ {
Span<Vector128<byte>> roundKeys = MemoryMarshal.CreateSpan(ref _roundKeys, RoundKeyCount); Span<Vector128<byte>> roundKeys = MemoryMarshal.CreateSpan(ref _roundKeys, RoundKeyCount);
@ -486,10 +554,15 @@ namespace LibHac.Crypto.Detail
if (isDecrypting) if (isDecrypting)
{ {
for (int i = 1; i < 10; i++) roundKeys[1] = AesNi.InverseMixColumns(roundKeys[1]);
{ roundKeys[2] = AesNi.InverseMixColumns(roundKeys[2]);
roundKeys[i] = AesNi.InverseMixColumns(roundKeys[i]); roundKeys[3] = AesNi.InverseMixColumns(roundKeys[3]);
} roundKeys[4] = AesNi.InverseMixColumns(roundKeys[4]);
roundKeys[5] = AesNi.InverseMixColumns(roundKeys[5]);
roundKeys[6] = AesNi.InverseMixColumns(roundKeys[6]);
roundKeys[7] = AesNi.InverseMixColumns(roundKeys[7]);
roundKeys[8] = AesNi.InverseMixColumns(roundKeys[8]);
roundKeys[9] = AesNi.InverseMixColumns(roundKeys[9]);
} }
} }

View File

@ -6,18 +6,19 @@ namespace hactoolnet
{ {
internal class MultiBenchmark internal class MultiBenchmark
{ {
public int RunsNeeded { get; set; } = 500; public int DefaultRunsNeeded { get; set; } = 500;
private List<BenchmarkItem> Benchmarks { get; } = new List<BenchmarkItem>(); private List<BenchmarkItem> Benchmarks { get; } = new List<BenchmarkItem>();
public void Register(string name, Action setupAction, Action runAction, Func<double, string> resultPrinter) public void Register(string name, Action setupAction, Action runAction, Func<double, string> resultPrinter, int runsNeeded = -1)
{ {
var benchmark = new BenchmarkItem var benchmark = new BenchmarkItem
{ {
Name = name, Name = name,
Setup = setupAction, Setup = setupAction,
Run = runAction, Run = runAction,
PrintResult = resultPrinter PrintResult = resultPrinter,
RunsNeeded = runsNeeded == -1 ? DefaultRunsNeeded : runsNeeded
}; };
Benchmarks.Add(benchmark); Benchmarks.Add(benchmark);
@ -40,7 +41,7 @@ namespace hactoolnet
int runsSinceLastBest = 0; int runsSinceLastBest = 0;
while (runsSinceLastBest < RunsNeeded) while (runsSinceLastBest < item.RunsNeeded)
{ {
runsSinceLastBest++; runsSinceLastBest++;
item.Setup(); item.Setup();
@ -64,6 +65,7 @@ namespace hactoolnet
private class BenchmarkItem private class BenchmarkItem
{ {
public string Name { get; set; } public string Name { get; set; }
public int RunsNeeded { get; set; }
public double Time { get; set; } public double Time { get; set; }
public string Result { get; set; } public string Result { get; set; }

View File

@ -6,6 +6,12 @@ using LibHac.Crypto;
using LibHac.Fs; using LibHac.Fs;
using LibHac.FsSystem; using LibHac.FsSystem;
#if NETCOREAPP
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.Intrinsics;
#endif
namespace hactoolnet namespace hactoolnet
{ {
internal static class ProcessBench internal static class ProcessBench
@ -16,6 +22,8 @@ namespace hactoolnet
private const int BlockSizeSeparate = 0x10; private const int BlockSizeSeparate = 0x10;
private const int BatchCipherBenchSize = 1024 * 1024; private const int BatchCipherBenchSize = 1024 * 1024;
// ReSharper disable once UnusedMember.Local
private const int SingleBlockCipherBenchSize = 1024 * 128;
private static void CopyBenchmark(IStorage src, IStorage dst, int iterations, string label, IProgressReport logger) private static void CopyBenchmark(IStorage src, IStorage dst, int iterations, string label, IProgressReport logger)
{ {
@ -173,7 +181,7 @@ namespace hactoolnet
logger.LogMessage($"{label}{averageRate}/s, fastest run: {fastestRate}/s, slowest run: {slowestRate}/s"); logger.LogMessage($"{label}{averageRate}/s, fastest run: {fastestRate}/s, slowest run: {slowestRate}/s");
} }
private static void RegisterAllCipherBenchmarks(MultiBenchmark bench) private static void RegisterAesSequentialBenchmarks(MultiBenchmark bench)
{ {
var input = new byte[BatchCipherBenchSize]; var input = new byte[BatchCipherBenchSize];
var output = new byte[BatchCipherBenchSize]; var output = new byte[BatchCipherBenchSize];
@ -220,6 +228,62 @@ namespace hactoolnet
} }
} }
// ReSharper disable once UnusedParameter.Local
private static void RegisterAesSingleBlockBenchmarks(MultiBenchmark bench)
{
#if NETCOREAPP
var input = new byte[SingleBlockCipherBenchSize];
var output = new byte[SingleBlockCipherBenchSize];
Func<double, string> resultPrinter = time => Util.GetBytesReadable((long)(SingleBlockCipherBenchSize / time)) + "/s";
bench.Register("AES single-block encrypt", () => { }, EncryptBlocks, resultPrinter);
bench.Register("AES single-block decrypt", () => { }, DecryptBlocks, resultPrinter);
bench.DefaultRunsNeeded = 1000;
void EncryptBlocks()
{
ref byte inBlock = ref MemoryMarshal.GetReference(input.AsSpan());
ref byte outBlock = ref MemoryMarshal.GetReference(output.AsSpan());
Vector128<byte> keyVec = Vector128<byte>.Zero;
ref byte end = ref Unsafe.Add(ref inBlock, input.Length);
while (Unsafe.IsAddressLessThan(ref inBlock, ref end))
{
var inputVec = Unsafe.ReadUnaligned<Vector128<byte>>(ref inBlock);
Vector128<byte> outputVec = LibHac.Crypto.Detail.AesCoreNi.EncryptBlock(inputVec, keyVec);
Unsafe.WriteUnaligned(ref outBlock, outputVec);
inBlock = ref Unsafe.Add(ref inBlock, Aes.BlockSize);
outBlock = ref Unsafe.Add(ref outBlock, Aes.BlockSize);
}
}
void DecryptBlocks()
{
ref byte inBlock = ref MemoryMarshal.GetReference(input.AsSpan());
ref byte outBlock = ref MemoryMarshal.GetReference(output.AsSpan());
Vector128<byte> keyVec = Vector128<byte>.Zero;
ref byte end = ref Unsafe.Add(ref inBlock, input.Length);
while (Unsafe.IsAddressLessThan(ref inBlock, ref end))
{
var inputVec = Unsafe.ReadUnaligned<Vector128<byte>>(ref inBlock);
Vector128<byte> outputVec = LibHac.Crypto.Detail.AesCoreNi.DecryptBlock(inputVec, keyVec);
Unsafe.WriteUnaligned(ref outBlock, outputVec);
inBlock = ref Unsafe.Add(ref inBlock, Aes.BlockSize);
outBlock = ref Unsafe.Add(ref outBlock, Aes.BlockSize);
}
}
#endif
}
private static void RunCipherBenchmark(Func<ICipher> cipherNet, Func<ICipher> cipherLibHac, private static void RunCipherBenchmark(Func<ICipher> cipherNet, Func<ICipher> cipherLibHac,
CipherTaskSeparate function, bool benchBlocked, string label, IProgressReport logger) CipherTaskSeparate function, bool benchBlocked, string label, IProgressReport logger)
{ {
@ -399,7 +463,8 @@ namespace hactoolnet
{ {
var bench = new MultiBenchmark(); var bench = new MultiBenchmark();
RegisterAllCipherBenchmarks(bench); RegisterAesSequentialBenchmarks(bench);
RegisterAesSingleBlockBenchmarks(bench);
bench.Run(); bench.Run();
break; break;