From 047e238d09589c2f47376cab3c91bb8806c811cc Mon Sep 17 00:00:00 2001
From: SachinVin <sachinvinayak2000@gmail.com>
Date: Sun, 24 Apr 2022 00:30:36 +0530
Subject: [PATCH] shader_jit: Compile nested loops

and use `T_NEAR` instead of the default in Compile_BREAKC
---
 .../shader/shader_jit_x64_compiler.cpp        | 148 ++++++++++++++++--
 .../shader/shader_jit_x64_compiler.cpp        |  30 ++--
 .../shader/shader_jit_x64_compiler.h          |   8 +-
 3 files changed, 155 insertions(+), 31 deletions(-)

diff --git a/src/tests/video_core/shader/shader_jit_x64_compiler.cpp b/src/tests/video_core/shader/shader_jit_x64_compiler.cpp
index d4576acac..85eab4428 100644
--- a/src/tests/video_core/shader/shader_jit_x64_compiler.cpp
+++ b/src/tests/video_core/shader/shader_jit_x64_compiler.cpp
@@ -7,28 +7,28 @@
 #include <memory>
 #include <catch2/catch.hpp>
 #include <nihstro/inline_assembly.h>
+#include "video_core/shader/shader_interpreter.h"
 #include "video_core/shader/shader_jit_x64_compiler.h"
 
 using float24 = Pica::float24;
 using JitShader = Pica::Shader::JitShader;
+using ShaderInterpreter = Pica::Shader::InterpreterEngine;
 
 using DestRegister = nihstro::DestRegister;
 using OpCode = nihstro::OpCode;
 using SourceRegister = nihstro::SourceRegister;
+using Type = nihstro::InlineAsm::Type;
 
-static std::unique_ptr<JitShader> CompileShader(std::initializer_list<nihstro::InlineAsm> code) {
+static std::unique_ptr<Pica::Shader::ShaderSetup> CompileShaderSetup(
+    std::initializer_list<nihstro::InlineAsm> code) {
     const auto shbin = nihstro::InlineAsm::CompileToRawBinary(code);
 
-    std::array<u32, Pica::Shader::MAX_PROGRAM_CODE_LENGTH> program_code{};
-    std::array<u32, Pica::Shader::MAX_SWIZZLE_DATA_LENGTH> swizzle_data{};
+    auto shader = std::make_unique<Pica::Shader::ShaderSetup>();
 
-    std::transform(shbin.program.begin(), shbin.program.end(), program_code.begin(),
+    std::transform(shbin.program.begin(), shbin.program.end(), shader->program_code.begin(),
                    [](const auto& x) { return x.hex; });
-    std::transform(shbin.swizzle_table.begin(), shbin.swizzle_table.end(), swizzle_data.begin(),
-                   [](const auto& x) { return x.hex; });
-
-    auto shader = std::make_unique<JitShader>();
-    shader->Compile(&program_code, &swizzle_data);
+    std::transform(shbin.swizzle_table.begin(), shbin.swizzle_table.end(),
+                   shader->swizzle_data.begin(), [](const auto& x) { return x.hex; });
 
     return shader;
 }
@@ -36,19 +36,32 @@ static std::unique_ptr<JitShader> CompileShader(std::initializer_list<nihstro::I
 class ShaderTest {
 public:
     explicit ShaderTest(std::initializer_list<nihstro::InlineAsm> code)
-        : shader(CompileShader(code)) {}
+        : shader_setup(CompileShaderSetup(code)) {
+        shader_jit.Compile(&shader_setup->program_code, &shader_setup->swizzle_data);
+    }
 
     float Run(float input) {
-        Pica::Shader::ShaderSetup shader_setup;
         Pica::Shader::UnitState shader_unit;
-
-        shader_unit.registers.input[0].x = float24::FromFloat32(input);
-        shader->Run(shader_setup, shader_unit, 0);
+        RunJit(shader_unit, input);
         return shader_unit.registers.output[0].x.ToFloat32();
     }
 
+    void RunJit(Pica::Shader::UnitState& shader_unit, float input) {
+        shader_unit.registers.input[0].x = float24::FromFloat32(input);
+        shader_unit.registers.temporary[0].x = float24::FromFloat32(0);
+        shader_jit.Run(*shader_setup, shader_unit, 0);
+    }
+
+    void RunInterpreter(Pica::Shader::UnitState& shader_unit, float input) {
+        shader_unit.registers.input[0].x = float24::FromFloat32(input);
+        shader_unit.registers.temporary[0].x = float24::FromFloat32(0);
+        shader_interpreter.Run(*shader_setup, shader_unit);
+    }
+
 public:
-    std::unique_ptr<JitShader> shader;
+    JitShader shader_jit;
+    ShaderInterpreter shader_interpreter;
+    std::unique_ptr<Pica::Shader::ShaderSetup> shader_setup;
 };
 
 TEST_CASE("LG2", "[video_core][shader][shader_jit]") {
@@ -89,3 +102,108 @@ TEST_CASE("EX2", "[video_core][shader][shader_jit]") {
     REQUIRE(shader.Run(79.7262742773f) == Approx(1.e24f));
     REQUIRE(std::isinf(shader.Run(800.f)));
 }
+
+TEST_CASE("Nested Loop", "[video_core][shader][shader_jit]") {
+    const auto sh_input = SourceRegister::MakeInput(0);
+    const auto sh_temp = SourceRegister::MakeTemporary(0);
+    const auto sh_output = DestRegister::MakeOutput(0);
+
+    std::array<Common::Vec4<u8>, 2> loop_parms{Common::Vec4<u8>{4, 0, 1, 0},
+                                               Common::Vec4<u8>{4, 0, 1, 0}};
+
+    auto shader_test = ShaderTest({
+        // clang-format off
+        {OpCode::Id::LOOP, 0},
+            {OpCode::Id::LOOP, 1},
+                {OpCode::Id::ADD, sh_temp, sh_temp, sh_input},
+            {Type::EndLoop},
+        {Type::EndLoop},
+        {OpCode::Id::MOV, sh_output, sh_temp},
+        {OpCode::Id::END},
+        // clang-format on
+    });
+
+    shader_test.shader_setup->uniforms.i[0] = loop_parms[0];
+    shader_test.shader_setup->uniforms.i[1] = loop_parms[0];
+
+    const auto run_test_helper = [&shader_test](float input) {
+        Pica::Shader::UnitState shader_unit_jit;
+        Pica::Shader::UnitState shader_unit_inter;
+        shader_test.RunJit(shader_unit_jit, input);
+        shader_test.RunInterpreter(shader_unit_inter, input);
+
+        REQUIRE(shader_unit_jit.registers.output[0].x.ToFloat32() ==
+                Approx(shader_unit_inter.registers.output[0].x.ToFloat32()));
+        REQUIRE(shader_unit_jit.address_registers[2] == shader_unit_inter.address_registers[2]);
+    };
+    {
+        // Sanity check
+        Pica::Shader::UnitState shader_unit_jit;
+        shader_test.RunJit(shader_unit_jit, 1.0f);
+        REQUIRE(shader_unit_jit.address_registers[2] == 6);
+        REQUIRE(shader_unit_jit.registers.output[0].x.ToFloat32() == Approx(25.0f));
+
+        Pica::Shader::UnitState shader_unit_inter;
+        shader_test.RunInterpreter(shader_unit_inter, 2.0f);
+        REQUIRE(shader_unit_inter.address_registers[2] == 6);
+        REQUIRE(shader_unit_inter.registers.output[0].x.ToFloat32() == Approx(50.0f));
+    }
+    run_test_helper(-5.f);
+    run_test_helper(0.f);
+    run_test_helper(2.f);
+    run_test_helper(6.f);
+    run_test_helper(79.7262742773f);
+}
+
+TEST_CASE("Nested Loop Randomized", "[video_core][shader][shader_jit]") {
+    const auto sh_input = SourceRegister::MakeInput(0);
+    const auto sh_temp = SourceRegister::MakeTemporary(0);
+    const auto sh_output = DestRegister::MakeOutput(0);
+
+    auto shader_test = ShaderTest({
+        // clang-format off
+        {OpCode::Id::LOOP, 0},
+            {OpCode::Id::LOOP, 1},
+                 {OpCode::Id::LOOP, 2},
+                    {OpCode::Id::LOOP, 3},
+                        {OpCode::Id::ADD, sh_temp, sh_temp, sh_input},
+                    {Type::EndLoop},
+                {Type::EndLoop},
+            {Type::EndLoop},
+        {Type::EndLoop},
+
+        {OpCode::Id::MOV, sh_output, sh_temp},
+        {OpCode::Id::END},
+        // clang-format on
+    });
+
+    const auto generate_loop_parms = [] {
+        u8 iterations = 1 + rand();
+        u8 initial = 1 + rand();
+        u8 increment = 1 + rand();
+
+        Common::Vec4<u8> loop_parm{iterations, initial, increment, 0};
+        return Common::Vec4<u8>{iterations, initial, increment, 0};
+    };
+
+    const auto run_test_helper = [&shader_test](float input) {
+        Pica::Shader::UnitState shader_unit_jit;
+        Pica::Shader::UnitState shader_unit_inter;
+        shader_test.RunJit(shader_unit_jit, input);
+        shader_test.RunInterpreter(shader_unit_inter, input);
+
+        REQUIRE(shader_unit_jit.registers.output[0].x.ToFloat32() ==
+                Approx(shader_unit_inter.registers.output[0].x.ToFloat32()));
+        REQUIRE(shader_unit_jit.address_registers[2] == shader_unit_inter.address_registers[2]);
+    };
+
+    srand(time(0));
+    for (int i = 0; i < 10; i++) {
+        shader_test.shader_setup->uniforms.i[0] = generate_loop_parms();
+        shader_test.shader_setup->uniforms.i[1] = generate_loop_parms();
+        shader_test.shader_setup->uniforms.i[2] = generate_loop_parms();
+        shader_test.shader_setup->uniforms.i[3] = generate_loop_parms();
+        float input = -(RAND_MAX / 2) + rand();
+        run_test_helper(input);
+    }
+}
diff --git a/src/video_core/shader/shader_jit_x64_compiler.cpp b/src/video_core/shader/shader_jit_x64_compiler.cpp
index 606762788..5753187e9 100644
--- a/src/video_core/shader/shader_jit_x64_compiler.cpp
+++ b/src/video_core/shader/shader_jit_x64_compiler.cpp
@@ -595,11 +595,11 @@ void JitShader::Compile_END(Instruction instr) {
 }
 
 void JitShader::Compile_BREAKC(Instruction instr) {
-    Compile_Assert(looping, "BREAKC must be inside a LOOP");
-    if (looping) {
+    Compile_Assert(loop_depth, "BREAKC must be inside a LOOP");
+    if (loop_depth) {
         Compile_EvaluateCondition(instr);
-        ASSERT(loop_break_label);
-        jnz(*loop_break_label);
+        ASSERT(!loop_break_labels.empty());
+        jnz(loop_break_labels.back(), T_NEAR);
     }
 }
 
@@ -725,9 +725,11 @@ void JitShader::Compile_IF(Instruction instr) {
 void JitShader::Compile_LOOP(Instruction instr) {
     Compile_Assert(instr.flow_control.dest_offset >= program_counter,
                    "Backwards loops not supported");
-    Compile_Assert(!looping, "Nested loops not supported");
-
-    looping = true;
+    if (loop_depth++) {
+        // LOOPCOUNT_REG is a "global", so we don't save it here.
+        push(LOOPINC.cvt64());
+        push(LOOPCOUNT.cvt64());
+    }
 
     // This decodes the fields from the integer uniform at index instr.flow_control.int_uniform_id.
     // The Y (LOOPCOUNT_REG) and Z (LOOPINC) component are kept multiplied by 16 (Left shifted by
@@ -746,16 +748,20 @@ void JitShader::Compile_LOOP(Instruction instr) {
     Label l_loop_start;
     L(l_loop_start);
 
-    loop_break_label = Xbyak::Label();
+    loop_break_labels.emplace_back(Xbyak::Label());
     Compile_Block(instr.flow_control.dest_offset + 1);
 
     add(LOOPCOUNT_REG, LOOPINC); // Increment LOOPCOUNT_REG by Z-component
     sub(LOOPCOUNT, 1);           // Increment loop count by 1
     jnz(l_loop_start);           // Loop if not equal
-    L(*loop_break_label);
-    loop_break_label.reset();
 
-    looping = false;
+    L(loop_break_labels.back());
+    loop_break_labels.pop_back();
+
+    if (--loop_depth) {
+        pop(LOOPCOUNT.cvt64());
+        pop(LOOPINC.cvt64());
+    }
 }
 
 void JitShader::Compile_JMP(Instruction instr) {
@@ -892,7 +898,7 @@ void JitShader::Compile(const std::array<u32, MAX_PROGRAM_CODE_LENGTH>* program_
     // Reset flow control state
     program = (CompiledShader*)getCurr();
     program_counter = 0;
-    looping = false;
+    loop_depth = 0;
     instruction_labels.fill(Xbyak::Label());
 
     // Find all `CALL` instructions and identify return locations
diff --git a/src/video_core/shader/shader_jit_x64_compiler.h b/src/video_core/shader/shader_jit_x64_compiler.h
index 507cd0ff3..573bdf8d3 100644
--- a/src/video_core/shader/shader_jit_x64_compiler.h
+++ b/src/video_core/shader/shader_jit_x64_compiler.h
@@ -120,15 +120,15 @@ private:
     /// Mapping of Pica VS instructions to pointers in the emitted code
     std::array<Xbyak::Label, MAX_PROGRAM_CODE_LENGTH> instruction_labels;
 
-    /// Label pointing to the end of the current LOOP block. Used by the BREAKC instruction to break
-    /// out of the loop.
-    std::optional<Xbyak::Label> loop_break_label;
+    /// Labels pointing to the end of each nested LOOP block. Used by the BREAKC instruction to
+    /// break out of a loop.
+    std::vector<Xbyak::Label> loop_break_labels;
 
     /// Offsets in code where a return needs to be inserted
     std::vector<unsigned> return_offsets;
 
     unsigned program_counter = 0; ///< Offset of the next instruction to decode
-    bool looping = false;         ///< True if compiling a loop, used to check for nested loops
+    u8 loop_depth = 0;            ///< Depth of the (nested) loops currently compiled
 
     using CompiledShader = void(const void* setup, void* state, const u8* start_addr);
     CompiledShader* program = nullptr;