From 4502595bc2518eecf934110e9393b11bf0c2f75a Mon Sep 17 00:00:00 2001
From: ReinUsesLisp <reinuseslisp@airmail.cc>
Date: Sun, 9 May 2021 18:03:01 -0300
Subject: [PATCH] glasm: Initial GLASM fp64 support

---
 .../backend/glasm/emit_context.h              |  7 ++
 .../backend/glasm/emit_glasm.cpp              | 17 +++--
 .../glasm/emit_glasm_bitwise_conversion.cpp   |  8 +++
 .../glasm/emit_glasm_floating_point.cpp       | 16 ++---
 .../backend/glasm/emit_glasm_instructions.h   | 12 ++--
 .../backend/glasm/emit_glasm_memory.cpp       | 10 +--
 .../glasm/emit_glasm_not_implemented.cpp      |  8 ---
 .../backend/glasm/reg_alloc.cpp               | 66 +++++++++++++------
 .../backend/glasm/reg_alloc.h                 | 63 ++++++++++++++++--
 9 files changed, 152 insertions(+), 55 deletions(-)

diff --git a/src/shader_recompiler/backend/glasm/emit_context.h b/src/shader_recompiler/backend/glasm/emit_context.h
index a59acbf6c..37663c1c8 100644
--- a/src/shader_recompiler/backend/glasm/emit_context.h
+++ b/src/shader_recompiler/backend/glasm/emit_context.h
@@ -29,6 +29,13 @@ public:
         code += '\n';
     }
 
+    template <typename... Args>
+    void LongAdd(const char* format_str, IR::Inst& inst, Args&&... args) {
+        code += fmt::format(format_str, reg_alloc.LongDefine(inst), std::forward<Args>(args)...);
+        // TODO: Remove this
+        code += '\n';
+    }
+
     template <typename... Args>
     void Add(const char* format_str, Args&&... args) {
         code += fmt::format(format_str, std::forward<Args>(args)...);
diff --git a/src/shader_recompiler/backend/glasm/emit_glasm.cpp b/src/shader_recompiler/backend/glasm/emit_glasm.cpp
index 842ec157d..9db6eb4a0 100644
--- a/src/shader_recompiler/backend/glasm/emit_glasm.cpp
+++ b/src/shader_recompiler/backend/glasm/emit_glasm.cpp
@@ -42,7 +42,11 @@ template <bool scalar>
 struct RegWrapper {
     RegWrapper(EmitContext& ctx, Value value)
         : reg_alloc{ctx.reg_alloc}, allocated{value.type != Type::Register} {
-        reg = allocated ? reg_alloc.AllocReg() : Register{value};
+        if (allocated) {
+            reg = value.type == Type::F64 ? reg_alloc.AllocLongReg() : reg_alloc.AllocReg();
+        } else {
+            reg = Register{value};
+        }
         switch (value.type) {
         case Type::Register:
             break;
@@ -55,6 +59,9 @@ struct RegWrapper {
         case Type::F32:
             ctx.Add("MOV.F {}.x,{};", reg, value.imm_f32);
             break;
+        case Type::F64:
+            ctx.Add("MOV.F64 {}.x,{};", reg, value.imm_f64);
+            break;
         }
     }
     ~RegWrapper() {
@@ -162,10 +169,12 @@ std::string EmitGLASM(const Profile&, IR::Program& program, Bindings&) {
     for (size_t index = 0; index < ctx.reg_alloc.NumUsedRegisters(); ++index) {
         header += fmt::format("R{},", index);
     }
-    header += "RC;";
-    if (!program.info.storage_buffers_descriptors.empty()) {
-        header += "LONG TEMP LC;";
+    header += "RC;"
+              "LONG TEMP ";
+    for (size_t index = 0; index < ctx.reg_alloc.NumUsedLongRegisters(); ++index) {
+        header += fmt::format("D{},", index);
     }
+    header += "DC;";
     ctx.code.insert(0, header);
     ctx.code += "END";
     return ctx.code;
diff --git a/src/shader_recompiler/backend/glasm/emit_glasm_bitwise_conversion.cpp b/src/shader_recompiler/backend/glasm/emit_glasm_bitwise_conversion.cpp
index 918d82375..eb6140954 100644
--- a/src/shader_recompiler/backend/glasm/emit_glasm_bitwise_conversion.cpp
+++ b/src/shader_recompiler/backend/glasm/emit_glasm_bitwise_conversion.cpp
@@ -72,4 +72,12 @@ void EmitUnpackHalf2x16(EmitContext& ctx, IR::Inst& inst, Register value) {
     ctx.Add("UP2H {}.xy,{}.x;", inst, value);
 }
 
+void EmitPackDouble2x32(EmitContext& ctx, IR::Inst& inst, Register value) {
+    ctx.LongAdd("PK64 {}.x,{};", inst, value);
+}
+
+void EmitUnpackDouble2x32(EmitContext& ctx, IR::Inst& inst, Register value) {
+    ctx.Add("UP64 {}.xy,{}.x;", inst, value);
+}
+
 } // namespace Shader::Backend::GLASM
diff --git a/src/shader_recompiler/backend/glasm/emit_glasm_floating_point.cpp b/src/shader_recompiler/backend/glasm/emit_glasm_floating_point.cpp
index fed6503c6..2b9a210aa 100644
--- a/src/shader_recompiler/backend/glasm/emit_glasm_floating_point.cpp
+++ b/src/shader_recompiler/backend/glasm/emit_glasm_floating_point.cpp
@@ -10,7 +10,8 @@
 
 namespace Shader::Backend::GLASM {
 
-void EmitFPAbs16([[maybe_unused]] EmitContext& ctx, [[maybe_unused]] Register value) {
+void EmitFPAbs16([[maybe_unused]] EmitContext& ctx, [[maybe_unused]] IR::Inst& inst,
+                 [[maybe_unused]] Register value) {
     throw NotImplementedException("GLASM instruction");
 }
 
@@ -18,8 +19,8 @@ void EmitFPAbs32(EmitContext& ctx, IR::Inst& inst, ScalarF32 value) {
     ctx.Add("MOV.F {}.x,|{}|;", inst, value);
 }
 
-void EmitFPAbs64([[maybe_unused]] EmitContext& ctx, [[maybe_unused]] Register value) {
-    throw NotImplementedException("GLASM instruction");
+void EmitFPAbs64(EmitContext& ctx, IR::Inst& inst, ScalarF64 value) {
+    ctx.LongAdd("MOV.F64 {}.x,|{}|;", inst, value);
 }
 
 void EmitFPAdd16([[maybe_unused]] EmitContext& ctx, [[maybe_unused]] IR::Inst& inst,
@@ -31,9 +32,8 @@ void EmitFPAdd32(EmitContext& ctx, IR::Inst& inst, ScalarF32 a, ScalarF32 b) {
     ctx.Add("ADD.F {}.x,{},{};", inst, a, b);
 }
 
-void EmitFPAdd64([[maybe_unused]] EmitContext& ctx, [[maybe_unused]] IR::Inst& inst,
-                 [[maybe_unused]] Register a, [[maybe_unused]] Register b) {
-    throw NotImplementedException("GLASM instruction");
+void EmitFPAdd64(EmitContext& ctx, IR::Inst& inst, ScalarF64 a, ScalarF64 b) {
+    ctx.LongAdd("ADD.F64 {}.x,{},{};", inst, a, b);
 }
 
 void EmitFPFma16([[maybe_unused]] EmitContext& ctx, [[maybe_unused]] IR::Inst& inst,
@@ -94,8 +94,8 @@ void EmitFPNeg32(EmitContext& ctx, IR::Inst& inst, ScalarRegister value) {
     ctx.Add("MOV.F {}.x,-{};", inst, value);
 }
 
-void EmitFPNeg64([[maybe_unused]] EmitContext& ctx, [[maybe_unused]] Register value) {
-    throw NotImplementedException("GLASM instruction");
+void EmitFPNeg64(EmitContext& ctx, IR::Inst& inst, Register value) {
+    ctx.LongAdd("MOV.F64 {}.x,-{};", inst, value);
 }
 
 void EmitFPSin([[maybe_unused]] EmitContext& ctx, [[maybe_unused]] ScalarF32 value) {
diff --git a/src/shader_recompiler/backend/glasm/emit_glasm_instructions.h b/src/shader_recompiler/backend/glasm/emit_glasm_instructions.h
index cb1067dc9..ab1e08215 100644
--- a/src/shader_recompiler/backend/glasm/emit_glasm_instructions.h
+++ b/src/shader_recompiler/backend/glasm/emit_glasm_instructions.h
@@ -202,20 +202,20 @@ void EmitPackFloat2x16(EmitContext& ctx, Register value);
 void EmitUnpackFloat2x16(EmitContext& ctx, Register value);
 void EmitPackHalf2x16(EmitContext& ctx, IR::Inst& inst, Register value);
 void EmitUnpackHalf2x16(EmitContext& ctx, IR::Inst& inst, Register value);
-void EmitPackDouble2x32(EmitContext& ctx, Register value);
-void EmitUnpackDouble2x32(EmitContext& ctx, Register value);
+void EmitPackDouble2x32(EmitContext& ctx, IR::Inst& inst, Register value);
+void EmitUnpackDouble2x32(EmitContext& ctx, IR::Inst& inst, Register value);
 void EmitGetZeroFromOp(EmitContext& ctx);
 void EmitGetSignFromOp(EmitContext& ctx);
 void EmitGetCarryFromOp(EmitContext& ctx);
 void EmitGetOverflowFromOp(EmitContext& ctx);
 void EmitGetSparseFromOp(EmitContext& ctx);
 void EmitGetInBoundsFromOp(EmitContext& ctx);
-void EmitFPAbs16(EmitContext& ctx, Register value);
+void EmitFPAbs16(EmitContext& ctx, IR::Inst& inst, Register value);
 void EmitFPAbs32(EmitContext& ctx, IR::Inst& inst, ScalarF32 value);
-void EmitFPAbs64(EmitContext& ctx, Register value);
+void EmitFPAbs64(EmitContext& ctx, IR::Inst& inst, ScalarF64 value);
 void EmitFPAdd16(EmitContext& ctx, IR::Inst& inst, Register a, Register b);
 void EmitFPAdd32(EmitContext& ctx, IR::Inst& inst, ScalarF32 a, ScalarF32 b);
-void EmitFPAdd64(EmitContext& ctx, IR::Inst& inst, Register a, Register b);
+void EmitFPAdd64(EmitContext& ctx, IR::Inst& inst, ScalarF64 a, ScalarF64 b);
 void EmitFPFma16(EmitContext& ctx, IR::Inst& inst, Register a, Register b, Register c);
 void EmitFPFma32(EmitContext& ctx, IR::Inst& inst, ScalarF32 a, ScalarF32 b, ScalarF32 c);
 void EmitFPFma64(EmitContext& ctx, IR::Inst& inst, Register a, Register b, Register c);
@@ -228,7 +228,7 @@ void EmitFPMul32(EmitContext& ctx, IR::Inst& inst, ScalarF32 a, ScalarF32 b);
 void EmitFPMul64(EmitContext& ctx, IR::Inst& inst, Register a, Register b);
 void EmitFPNeg16(EmitContext& ctx, Register value);
 void EmitFPNeg32(EmitContext& ctx, IR::Inst& inst, ScalarRegister value);
-void EmitFPNeg64(EmitContext& ctx, Register value);
+void EmitFPNeg64(EmitContext& ctx, IR::Inst& inst, Register value);
 void EmitFPSin(EmitContext& ctx, ScalarF32 value);
 void EmitFPCos(EmitContext& ctx, ScalarF32 value);
 void EmitFPExp2(EmitContext& ctx, ScalarF32 value);
diff --git a/src/shader_recompiler/backend/glasm/emit_glasm_memory.cpp b/src/shader_recompiler/backend/glasm/emit_glasm_memory.cpp
index 8ef0f7c17..0c6a6e1c8 100644
--- a/src/shader_recompiler/backend/glasm/emit_glasm_memory.cpp
+++ b/src/shader_recompiler/backend/glasm/emit_glasm_memory.cpp
@@ -17,9 +17,9 @@ void StorageOp(EmitContext& ctx, const IR::Value& binding, ScalarU32 offset,
     // address = c[binding].xy
     // length  = c[binding].z
     const u32 sb_binding{binding.U32()};
-    ctx.Add("PK64.U LC,c[{}];"           // pointer = address
-            "CVT.U64.U32 LC.z,{};"       // offset = uint64_t(offset)
-            "ADD.U64 LC.x,LC.x,LC.z;"    // pointer += offset
+    ctx.Add("PK64.U DC,c[{}];"           // pointer = address
+            "CVT.U64.U32 DC.z,{};"       // offset = uint64_t(offset)
+            "ADD.U64 DC.x,DC.x,DC.z;"    // pointer += offset
             "SLT.U.CC RC.x,{},c[{}].z;", // cc = offset < length
             sb_binding, offset, offset, sb_binding);
     if (else_expr.empty()) {
@@ -32,13 +32,13 @@ void StorageOp(EmitContext& ctx, const IR::Value& binding, ScalarU32 offset,
 template <typename ValueType>
 void Store(EmitContext& ctx, const IR::Value& binding, ScalarU32 offset, ValueType value,
            std::string_view size) {
-    StorageOp(ctx, binding, offset, fmt::format("STORE.{} {},LC.x;", size, value));
+    StorageOp(ctx, binding, offset, fmt::format("STORE.{} {},DC.x;", size, value));
 }
 
 void Load(EmitContext& ctx, IR::Inst& inst, const IR::Value& binding, ScalarU32 offset,
           std::string_view size) {
     const Register ret{ctx.reg_alloc.Define(inst)};
-    StorageOp(ctx, binding, offset, fmt::format("STORE.{} {},LC.x;", size, ret),
+    StorageOp(ctx, binding, offset, fmt::format("STORE.{} {},DC.x;", size, ret),
               fmt::format("MOV.U {},{{0,0,0,0}};", ret));
 }
 } // Anonymous namespace
diff --git a/src/shader_recompiler/backend/glasm/emit_glasm_not_implemented.cpp b/src/shader_recompiler/backend/glasm/emit_glasm_not_implemented.cpp
index 03464524e..f3baf33af 100644
--- a/src/shader_recompiler/backend/glasm/emit_glasm_not_implemented.cpp
+++ b/src/shader_recompiler/backend/glasm/emit_glasm_not_implemented.cpp
@@ -281,14 +281,6 @@ void EmitSelectF64(EmitContext& ctx, ScalarS32 cond, Register true_value, Regist
     NotImplemented();
 }
 
-void EmitPackDouble2x32(EmitContext& ctx, Register value) {
-    NotImplemented();
-}
-
-void EmitUnpackDouble2x32(EmitContext& ctx, Register value) {
-    NotImplemented();
-}
-
 void EmitGetZeroFromOp(EmitContext& ctx) {
     NotImplemented();
 }
diff --git a/src/shader_recompiler/backend/glasm/reg_alloc.cpp b/src/shader_recompiler/backend/glasm/reg_alloc.cpp
index 030b48d83..82b627500 100644
--- a/src/shader_recompiler/backend/glasm/reg_alloc.cpp
+++ b/src/shader_recompiler/backend/glasm/reg_alloc.cpp
@@ -14,12 +14,11 @@
 namespace Shader::Backend::GLASM {
 
 Register RegAlloc::Define(IR::Inst& inst) {
-    const Id id{Alloc()};
-    inst.SetDefinition<Id>(id);
-    Register ret;
-    ret.type = Type::Register;
-    ret.id = id;
-    return ret;
+    return Define(inst, false);
+}
+
+Register RegAlloc::LongDefine(IR::Inst& inst) {
+    return Define(inst, true);
 }
 
 Value RegAlloc::Consume(const IR::Value& value) {
@@ -40,6 +39,10 @@ Value RegAlloc::Consume(const IR::Value& value) {
         ret.type = Type::F32;
         ret.imm_f32 = value.F32();
         break;
+    case IR::Type::F64:
+        ret.type = Type::F64;
+        ret.imm_f64 = value.F64();
+        break;
     default:
         throw NotImplementedException("Immediate type {}", value.Type());
     }
@@ -49,7 +52,14 @@ Value RegAlloc::Consume(const IR::Value& value) {
 Register RegAlloc::AllocReg() {
     Register ret;
     ret.type = Type::Register;
-    ret.id = Alloc();
+    ret.id = Alloc(false);
+    return ret;
+}
+
+Register RegAlloc::AllocLongReg() {
+    Register ret;
+    ret.type = Type::Register;
+    ret.id = Alloc(true);
     return ret;
 }
 
@@ -57,6 +67,15 @@ void RegAlloc::FreeReg(Register reg) {
     Free(reg.id);
 }
 
+Register RegAlloc::Define(IR::Inst& inst, bool is_long) {
+    const Id id{Alloc(is_long)};
+    inst.SetDefinition<Id>(id);
+    Register ret;
+    ret.type = Type::Register;
+    ret.id = id;
+    return ret;
+}
+
 Value RegAlloc::Consume(IR::Inst& inst) {
     const Id id{inst.Definition<Id>()};
     inst.DestructiveRemoveUsage();
@@ -69,18 +88,23 @@ Value RegAlloc::Consume(IR::Inst& inst) {
     return ret;
 }
 
-Id RegAlloc::Alloc() {
-    for (size_t reg = 0; reg < NUM_REGS; ++reg) {
-        if (register_use[reg]) {
-            continue;
+Id RegAlloc::Alloc(bool is_long) {
+    size_t& num_regs{is_long ? num_used_long_registers : num_used_registers};
+    std::bitset<NUM_REGS>& use{is_long ? long_register_use : register_use};
+    if (num_used_registers + num_used_long_registers < NUM_REGS) {
+        for (size_t reg = 0; reg < NUM_REGS; ++reg) {
+            if (use[reg]) {
+                continue;
+            }
+            num_regs = std::max(num_regs, reg + 1);
+            use[reg] = true;
+            Id ret{};
+            ret.index.Assign(static_cast<u32>(reg));
+            ret.is_long.Assign(is_long ? 1 : 0);
+            ret.is_spill.Assign(0);
+            ret.is_condition_code.Assign(0);
+            return ret;
         }
-        num_used_registers = std::max(num_used_registers, reg + 1);
-        register_use[reg] = true;
-        Id ret{};
-        ret.index.Assign(static_cast<u32>(reg));
-        ret.is_spill.Assign(0);
-        ret.is_condition_code.Assign(0);
-        return ret;
     }
     throw NotImplementedException("Register spilling");
 }
@@ -89,7 +113,11 @@ void RegAlloc::Free(Id id) {
     if (id.is_spill != 0) {
         throw NotImplementedException("Free spill");
     }
-    register_use[id.index] = false;
+    if (id.is_long != 0) {
+        long_register_use[id.index] = false;
+    } else {
+        register_use[id.index] = false;
+    }
 }
 
 } // namespace Shader::Backend::GLASM
diff --git a/src/shader_recompiler/backend/glasm/reg_alloc.h b/src/shader_recompiler/backend/glasm/reg_alloc.h
index 6a238afa9..f1899eae1 100644
--- a/src/shader_recompiler/backend/glasm/reg_alloc.h
+++ b/src/shader_recompiler/backend/glasm/reg_alloc.h
@@ -27,12 +27,14 @@ enum class Type : u32 {
     U32,
     S32,
     F32,
+    F64,
 };
 
 struct Id {
     union {
         u32 raw;
-        BitField<0, 30, u32> index;
+        BitField<0, 29, u32> index;
+        BitField<29, 1, u32> is_long;
         BitField<30, 1, u32> is_spill;
         BitField<31, 1, u32> is_condition_code;
     };
@@ -53,6 +55,7 @@ struct Value {
         u32 imm_u32;
         s32 imm_s32;
         f32 imm_f32;
+        f64 imm_f64;
     };
 
     bool operator==(const Value& rhs) const noexcept {
@@ -68,6 +71,8 @@ struct Value {
             return imm_s32 == rhs.imm_s32;
         case Type::F32:
             return Common::BitCast<u32>(imm_f32) == Common::BitCast<u32>(rhs.imm_f32);
+        case Type::F64:
+            return Common::BitCast<u64>(imm_f64) == Common::BitCast<u64>(rhs.imm_f64);
         }
         return false;
     }
@@ -80,6 +85,7 @@ struct ScalarRegister : Value {};
 struct ScalarU32 : Value {};
 struct ScalarS32 : Value {};
 struct ScalarF32 : Value {};
+struct ScalarF64 : Value {};
 
 class RegAlloc {
 public:
@@ -87,9 +93,13 @@ public:
 
     Register Define(IR::Inst& inst);
 
+    Register LongDefine(IR::Inst& inst);
+
     Value Consume(const IR::Value& value);
 
-    Register AllocReg();
+    [[nodiscard]] Register AllocReg();
+
+    [[nodiscard]] Register AllocLongReg();
 
     void FreeReg(Register reg);
 
@@ -97,19 +107,27 @@ public:
         return num_used_registers;
     }
 
+    [[nodiscard]] size_t NumUsedLongRegisters() const noexcept {
+        return num_used_long_registers;
+    }
+
 private:
     static constexpr size_t NUM_REGS = 4096;
     static constexpr size_t NUM_ELEMENTS = 4;
 
+    Register Define(IR::Inst& inst, bool is_long);
+
     Value Consume(IR::Inst& inst);
 
-    Id Alloc();
+    Id Alloc(bool is_long);
 
     void Free(Id id);
 
     EmitContext& ctx;
     size_t num_used_registers{};
+    size_t num_used_long_registers{};
     std::bitset<NUM_REGS> register_use{};
+    std::bitset<NUM_REGS> long_register_use{};
 };
 
 template <bool scalar, typename FormatContext>
@@ -121,9 +139,17 @@ auto FormatTo(FormatContext& ctx, Id id) {
         throw NotImplementedException("Spill emission");
     }
     if constexpr (scalar) {
-        return fmt::format_to(ctx.out(), "R{}.x", id.index.Value());
+        if (id.is_long != 0) {
+            return fmt::format_to(ctx.out(), "D{}.x", id.index.Value());
+        } else {
+            return fmt::format_to(ctx.out(), "R{}.x", id.index.Value());
+        }
     } else {
-        return fmt::format_to(ctx.out(), "R{}", id.index.Value());
+        if (id.is_long != 0) {
+            return fmt::format_to(ctx.out(), "D{}", id.index.Value());
+        } else {
+            return fmt::format_to(ctx.out(), "R{}", id.index.Value());
+        }
     }
 }
 
@@ -184,6 +210,8 @@ struct fmt::formatter<Shader::Backend::GLASM::ScalarU32> {
             return fmt::format_to(ctx.out(), "{}", static_cast<u32>(value.imm_s32));
         case Shader::Backend::GLASM::Type::F32:
             return fmt::format_to(ctx.out(), "{}", Common::BitCast<u32>(value.imm_f32));
+        case Shader::Backend::GLASM::Type::F64:
+            break;
         }
         throw Shader::InvalidArgument("Invalid value type {}", value.type);
     }
@@ -205,6 +233,8 @@ struct fmt::formatter<Shader::Backend::GLASM::ScalarS32> {
             return fmt::format_to(ctx.out(), "{}", value.imm_s32);
         case Shader::Backend::GLASM::Type::F32:
             return fmt::format_to(ctx.out(), "{}", Common::BitCast<s32>(value.imm_f32));
+        case Shader::Backend::GLASM::Type::F64:
+            break;
         }
         throw Shader::InvalidArgument("Invalid value type {}", value.type);
     }
@@ -226,6 +256,29 @@ struct fmt::formatter<Shader::Backend::GLASM::ScalarF32> {
             return fmt::format_to(ctx.out(), "{}", Common::BitCast<s32>(value.imm_s32));
         case Shader::Backend::GLASM::Type::F32:
             return fmt::format_to(ctx.out(), "{}", value.imm_f32);
+        case Shader::Backend::GLASM::Type::F64:
+            break;
+        }
+        throw Shader::InvalidArgument("Invalid value type {}", value.type);
+    }
+};
+
+template <>
+struct fmt::formatter<Shader::Backend::GLASM::ScalarF64> {
+    constexpr auto parse(format_parse_context& ctx) {
+        return ctx.begin();
+    }
+    template <typename FormatContext>
+    auto format(const Shader::Backend::GLASM::ScalarF64& value, FormatContext& ctx) {
+        switch (value.type) {
+        case Shader::Backend::GLASM::Type::Register:
+            return Shader::Backend::GLASM::FormatTo<true>(ctx, value.id);
+        case Shader::Backend::GLASM::Type::U32:
+        case Shader::Backend::GLASM::Type::S32:
+        case Shader::Backend::GLASM::Type::F32:
+            break;
+        case Shader::Backend::GLASM::Type::F64:
+            return format_to(ctx.out(), "{}", value.imm_f64);
         }
         throw Shader::InvalidArgument("Invalid value type {}", value.type);
     }