From d0b45f6150a61d6705bc2c53c372de57f1fb110f Mon Sep 17 00:00:00 2001
From: MerryMage <MerryMage@users.noreply.github.com>
Date: Sun, 17 May 2020 16:59:56 +0100
Subject: [PATCH] A32: Implement ARMv8 VST{1-4} (multiple)

---
 src/frontend/A32/decoder/asimd.inc            |   2 +-
 .../impl/asimd_load_store_structures.cpp      | 164 ++++++++++--------
 .../A32/translate/impl/translate_arm.h        |   1 +
 src/frontend/ir/ir_emitter.cpp                |  30 +++-
 src/frontend/ir/ir_emitter.h                  |   3 +-
 src/frontend/ir/opcodes.inc                   |   2 +-
 6 files changed, 122 insertions(+), 80 deletions(-)

diff --git a/src/frontend/A32/decoder/asimd.inc b/src/frontend/A32/decoder/asimd.inc
index 9a8dc46f..41031b7f 100644
--- a/src/frontend/A32/decoder/asimd.inc
+++ b/src/frontend/A32/decoder/asimd.inc
@@ -121,7 +121,7 @@ INST(asimd_VBIF,            "VBIF",                     "111100110D11nnnndddd000
 //INST(asimd_VMOV_imm,        "VMOV (immediate)",         "1111001a1-000bcd----11100-11efgh") // ASIMD
 
 // Advanced SIMD load/store structures
-//INST(v8_VST_multiple,       "VST{1-4} (multiple)",      "111101000D00nnnnddddxxxxzzaammmm") // v8
+INST(v8_VST_multiple,       "VST{1-4} (multiple)",      "111101000D00nnnnddddxxxxzzaammmm") // v8
 INST(v8_VLD_multiple,       "VLD{1-4} (multiple)",      "111101000D10nnnnddddxxxxzzaammmm") // v8
 INST(arm_UDF,               "UNALLOCATED",              "111101000--0--------1011--------") // v8
 INST(arm_UDF,               "UNALLOCATED",              "111101000--0--------11----------") // v8
diff --git a/src/frontend/A32/translate/impl/asimd_load_store_structures.cpp b/src/frontend/A32/translate/impl/asimd_load_store_structures.cpp
index 5019ef74..9f13d0b2 100644
--- a/src/frontend/A32/translate/impl/asimd_load_store_structures.cpp
+++ b/src/frontend/A32/translate/impl/asimd_load_store_structures.cpp
@@ -5,107 +5,129 @@
 
 #include "frontend/A32/translate/impl/translate_arm.h"
 
+#include <optional>
+#include <tuple>
 #include "common/bit_util.h"
 
 namespace Dynarmic::A32 {
 
-static ExtReg ToExtRegD(size_t base, bool bit) {
+namespace {
+ExtReg ToExtReg(size_t base, bool bit) {
     return ExtReg::D0 + (base + (bit ? 16 : 0));
 }
 
-bool ArmTranslatorVisitor::v8_VLD_multiple(bool D, Reg n, size_t Vd, Imm<4> type, size_t size, size_t align, Reg m) {
-    size_t nelem, regs, inc;
+std::optional<std::tuple<size_t, size_t, size_t>> DecodeType(Imm<4> type, size_t size, size_t align) {
     switch (type.ZeroExtend()) {
-    case 0b0111: // VLD1 A1
-        nelem = 1;
-        regs = 1;
-        inc = 0;
+    case 0b0111: // VST1 A1 / VLD1 A1
         if (Common::Bit<1>(align)) {
-            return UndefinedInstruction();
+            return std::nullopt;
         }
-        break;
-    case 0b1010: // VLD1 A2
-        nelem = 1;
-        regs = 2;
-        inc = 0;
+        return std::tuple<size_t, size_t, size_t>{1, 1, 0};
+    case 0b1010: // VST1 A2 / VLD1 A2
         if (align == 0b11) {
-            return UndefinedInstruction();
+            return std::nullopt;
         }
-        break;
-    case 0b0110: // VLD1 A3
-        nelem = 1;
-        regs = 3;
-        inc = 0;
+        return std::tuple<size_t, size_t, size_t>{1, 2, 0};
+    case 0b0110: // VST1 A3 / VLD1 A3
         if (Common::Bit<1>(align)) {
-            return UndefinedInstruction();
+            return std::nullopt;
         }
-        break;
-    case 0b0010: // VLD1 A4
-        nelem = 1;
-        regs = 4;
-        inc = 0;
-        break;
-    case 0b1000: // VLD2 A1
-        nelem = 2;
-        regs = 1;
-        inc = 1;
+        return std::tuple<size_t, size_t, size_t>{1, 3, 0};
+    case 0b0010: // VST1 A4 / VLD1 A4
+        return std::tuple<size_t, size_t, size_t>{1, 4, 0};
+    case 0b1000: // VST2 A1 / VLD2 A1
         if (size == 0b11 || align == 0b11) {
-            return UndefinedInstruction();
+            return std::nullopt;
         }
-        break;
-    case 0b1001: // VLD2 A1
-        nelem = 2;
-        regs = 1;
-        inc = 2;
+        return std::tuple<size_t, size_t, size_t>{2, 1, 1};
+    case 0b1001: // VST2 A1 / VLD2 A1
         if (size == 0b11 || align == 0b11) {
-            return UndefinedInstruction();
+            return std::nullopt;
         }
-        break;
-    case 0b0011: // VLD2 A2
-        nelem = 2;
-        regs = 2;
-        inc = 2;
+        return std::tuple<size_t, size_t, size_t>{2, 1, 2};
+    case 0b0011: // VST2 A2 / VLD2 A2
         if (size == 0b11) {
-            return UndefinedInstruction();
+            return std::nullopt;
         }
-        break;
-    case 0b0100: // VLD3
-        nelem = 3;
-        regs = 1;
-        inc = 1;
+        return std::tuple<size_t, size_t, size_t>{2, 2, 2};
+    case 0b0100: // VST3 / VLD3
         if (size == 0b11 || Common::Bit<1>(align)) {
-            return UndefinedInstruction();
+            return std::nullopt;
         }
-        break;
-    case 0b0101: // VLD3
-        nelem = 3;
-        regs = 1;
-        inc = 2;
+        return std::tuple<size_t, size_t, size_t>{3, 1, 1};
+    case 0b0101: // VST3 / VLD3
         if (size == 0b11 || Common::Bit<1>(align)) {
-            return UndefinedInstruction();
+            return std::nullopt;
         }
-        break;
-    case 0b0000: // VLD4
-        nelem = 4;
-        regs = 1;
-        inc = 1;
+        return std::tuple<size_t, size_t, size_t>{3, 1, 2};
+    case 0b0000: // VST4 / VLD4
         if (size == 0b11) {
-            return UndefinedInstruction();
+            return std::nullopt;
         }
-        break;
-    case 0b0001: // VLD4
-        nelem = 4;
-        regs = 1;
-        inc = 2;
+        return std::tuple<size_t, size_t, size_t>{4, 1, 1};
+    case 0b0001: // VST4 / VLD4
         if (size == 0b11) {
-            return UndefinedInstruction();
+            return std::nullopt;
         }
-        break;
-    default:
-        ASSERT_FALSE("Decode error");
+        return std::tuple<size_t, size_t, size_t>{4, 1, 2};
+    }
+    ASSERT_FALSE("Decode error");
+}
+} // anoynmous namespace
+
+bool ArmTranslatorVisitor::v8_VST_multiple(bool D, Reg n, size_t Vd, Imm<4> type, size_t size, size_t align, Reg m) {
+    const auto decoded_type = DecodeType(type, size, align);
+    if (!decoded_type) {
+        return UndefinedInstruction();
+    }
+    const auto [nelem, regs, inc] = *decoded_type;
+
+    const ExtReg d = ToExtReg(Vd, D);
+    const size_t d_last = RegNumber(d) + inc * (nelem - 1);
+    if (n == Reg::R15 || d_last + regs > 32) {
+        return UnpredictableInstruction();
     }
 
-    const ExtReg d = ToExtRegD(Vd, D);
+    [[maybe_unused]] const size_t alignment = align == 0 ? 1 : 4 << align;
+    const size_t ebytes = static_cast<size_t>(1) << size;
+    const size_t elements = 8 / ebytes;
+
+    const bool wback = m != Reg::R15;
+    const bool register_index = m != Reg::R15 && m != Reg::R13;
+
+    IR::U32 address = ir.GetRegister(n);
+    for (size_t r = 0; r < regs; r++) {
+        for (size_t e = 0; e < elements; e++) {
+            for (size_t i = 0; i < nelem; i++) {
+                const ExtReg ext_reg = d + i * inc + r;
+                const IR::U64 shifted_element = ir.LogicalShiftRight(ir.GetExtendedRegister(ext_reg), ir.Imm8(static_cast<u8>(e * ebytes * 8)));
+                const IR::UAny element = ir.LeastSignificant(8 * ebytes, shifted_element);
+                ir.WriteMemory(8 * ebytes, address, element);
+
+                address = ir.Add(address, ir.Imm32(static_cast<u32>(ebytes)));
+            }
+        }
+    }
+
+    if (wback) {
+        if (register_index) {
+            ir.SetRegister(n, ir.Add(ir.GetRegister(n), ir.GetRegister(m)));
+        } else {
+            ir.SetRegister(n, ir.Add(ir.GetRegister(n), ir.Imm32(static_cast<u32>(8 * nelem * regs))));
+        }
+    }
+
+    return true;
+}
+
+bool ArmTranslatorVisitor::v8_VLD_multiple(bool D, Reg n, size_t Vd, Imm<4> type, size_t size, size_t align, Reg m) {
+    const auto decoded_type = DecodeType(type, size, align);
+    if (!decoded_type) {
+        return UndefinedInstruction();
+    }
+    const auto [nelem, regs, inc] = *decoded_type;
+
+    const ExtReg d = ToExtReg(Vd, D);
     const size_t d_last = RegNumber(d) + inc * (nelem - 1);
     if (n == Reg::R15 || d_last + regs > 32) {
         return UnpredictableInstruction();
diff --git a/src/frontend/A32/translate/impl/translate_arm.h b/src/frontend/A32/translate/impl/translate_arm.h
index 48e89e14..64d053b9 100644
--- a/src/frontend/A32/translate/impl/translate_arm.h
+++ b/src/frontend/A32/translate/impl/translate_arm.h
@@ -440,6 +440,7 @@ struct ArmTranslatorVisitor final {
     bool asimd_VBIF(bool D, size_t Vn, size_t Vd, bool N, bool Q, bool M, size_t Vm);
 
     // Advanced SIMD load/store structures
+    bool v8_VST_multiple(bool D, Reg n, size_t Vd, Imm<4> type, size_t sz, size_t align, Reg m);
     bool v8_VLD_multiple(bool D, Reg n, size_t Vd, Imm<4> type, size_t sz, size_t align, Reg m);
 };
 
diff --git a/src/frontend/ir/ir_emitter.cpp b/src/frontend/ir/ir_emitter.cpp
index ec031998..5eff3492 100644
--- a/src/frontend/ir/ir_emitter.cpp
+++ b/src/frontend/ir/ir_emitter.cpp
@@ -41,14 +41,26 @@ U128 IREmitter::Pack2x64To1x128(const U64& lo, const U64& hi) {
     return Inst<U128>(Opcode::Pack2x64To1x128, lo, hi);
 }
 
-U32 IREmitter::LeastSignificantWord(const U64& value) {
-    return Inst<U32>(Opcode::LeastSignificantWord, value);
+UAny IREmitter::LeastSignificant(size_t bitsize, const U32U64& value) {
+    switch (bitsize) {
+    case 8:
+        return LeastSignificantByte(value);
+    case 16:
+        return LeastSignificantHalf(value);
+    case 32:
+        if (value.GetType() == Type::U32) {
+            return value;
+        }
+        return LeastSignificantWord(value);
+    case 64:
+        ASSERT(value.GetType() == Type::U64);
+        return value;
+    }
+    ASSERT_FALSE("Invalid bitsize");
 }
 
-ResultAndCarry<U32> IREmitter::MostSignificantWord(const U64& value) {
-    const auto result = Inst<U32>(Opcode::MostSignificantWord, value);
-    const auto carry_out = Inst<U1>(Opcode::GetCarryFromOp, result);
-    return {result, carry_out};
+U32 IREmitter::LeastSignificantWord(const U64& value) {
+    return Inst<U32>(Opcode::LeastSignificantWord, value);
 }
 
 U16 IREmitter::LeastSignificantHalf(U32U64 value) {
@@ -65,6 +77,12 @@ U8 IREmitter::LeastSignificantByte(U32U64 value) {
     return Inst<U8>(Opcode::LeastSignificantByte, value);
 }
 
+ResultAndCarry<U32> IREmitter::MostSignificantWord(const U64& value) {
+    const auto result = Inst<U32>(Opcode::MostSignificantWord, value);
+    const auto carry_out = Inst<U1>(Opcode::GetCarryFromOp, result);
+    return {result, carry_out};
+}
+
 U1 IREmitter::MostSignificantBit(const U32& value) {
     return Inst<U1>(Opcode::MostSignificantBit, value);
 }
diff --git a/src/frontend/ir/ir_emitter.h b/src/frontend/ir/ir_emitter.h
index 8bd12591..6b8eec0d 100644
--- a/src/frontend/ir/ir_emitter.h
+++ b/src/frontend/ir/ir_emitter.h
@@ -87,10 +87,11 @@ public:
 
     U64 Pack2x32To1x64(const U32& lo, const U32& hi);
     U128 Pack2x64To1x128(const U64& lo, const U64& hi);
+    UAny LeastSignificant(size_t bitsize, const U32U64& value);
     U32 LeastSignificantWord(const U64& value);
-    ResultAndCarry<U32> MostSignificantWord(const U64& value);
     U16 LeastSignificantHalf(U32U64 value);
     U8 LeastSignificantByte(U32U64 value);
+    ResultAndCarry<U32> MostSignificantWord(const U64& value);
     U1 MostSignificantBit(const U32& value);
     U1 IsZero(const U32& value);
     U1 IsZero(const U64& value);
diff --git a/src/frontend/ir/opcodes.inc b/src/frontend/ir/opcodes.inc
index d6fc7a48..4a5ca362 100644
--- a/src/frontend/ir/opcodes.inc
+++ b/src/frontend/ir/opcodes.inc
@@ -94,9 +94,9 @@ OPCODE(NZCVFromPackedFlags,                                 NZCV,           U32
 OPCODE(Pack2x32To1x64,                                      U64,            U32,            U32                                             )
 OPCODE(Pack2x64To1x128,                                     U128,           U64,            U64                                             )
 OPCODE(LeastSignificantWord,                                U32,            U64                                                             )
-OPCODE(MostSignificantWord,                                 U32,            U64                                                             )
 OPCODE(LeastSignificantHalf,                                U16,            U32                                                             )
 OPCODE(LeastSignificantByte,                                U8,             U32                                                             )
+OPCODE(MostSignificantWord,                                 U32,            U64                                                             )
 OPCODE(MostSignificantBit,                                  U1,             U32                                                             )
 OPCODE(IsZero32,                                            U1,             U32                                                             )
 OPCODE(IsZero64,                                            U1,             U64                                                             )