From a94af8ea62abb481b356813be2a3dd7aabf69c7f Mon Sep 17 00:00:00 2001
From: Wunk <wunkolo@gmail.com>
Date: Tue, 11 Jul 2023 09:21:37 -0700
Subject: [PATCH] shader_jit: Add optimizations up to `x86-64-v4` (#6668)

---
 .../shader/shader_jit_x64_compiler.cpp        | 220 +++++++++++++-----
 1 file changed, 157 insertions(+), 63 deletions(-)

diff --git a/src/video_core/shader/shader_jit_x64_compiler.cpp b/src/video_core/shader/shader_jit_x64_compiler.cpp
index 85681ab83..6a30d4e23 100644
--- a/src/video_core/shader/shader_jit_x64_compiler.cpp
+++ b/src/video_core/shader/shader_jit_x64_compiler.cpp
@@ -338,15 +338,39 @@ void JitShader::Compile_SanitizedMul(Xmm src1, Xmm src2, Xmm scratch) {
     // where neither source was, this NaN was generated by a 0 * inf multiplication, and so the
     // result should be transformed to 0 to match PICA fp rules.
 
+    if (host_caps.has(Cpu::tAVX512F | Cpu::tAVX512VL | Cpu::tAVX512DQ)) {
+        vmulps(scratch, src1, src2);
+
+        // Mask of any NaN values found in the result
+        const Xbyak::Opmask zero_mask = k1;
+        vcmpunordps(zero_mask, scratch, scratch);
+
+        // Mask of any non-NaN inputs producing NaN results
+        vcmpordps(zero_mask | zero_mask, src1, src2);
+
+        knotb(zero_mask, zero_mask);
+        vmovaps(src1 | zero_mask | T_z, scratch);
+
+        return;
+    }
+
     // Set scratch to mask of (src1 != NaN and src2 != NaN)
-    movaps(scratch, src1);
-    cmpordps(scratch, src2);
+    if (host_caps.has(Cpu::tAVX)) {
+        vcmpordps(scratch, src1, src2);
+    } else {
+        movaps(scratch, src1);
+        cmpordps(scratch, src2);
+    }
 
     mulps(src1, src2);
 
     // Set src2 to mask of (result == NaN)
-    movaps(src2, src1);
-    cmpunordps(src2, src2);
+    if (host_caps.has(Cpu::tAVX)) {
+        vcmpunordps(src2, src2, src1);
+    } else {
+        movaps(src2, src1);
+        cmpunordps(src2, src2);
+    }
 
     // Clear components where scratch != src2 (i.e. if result is NaN where neither source was NaN)
     xorps(scratch, src2);
@@ -406,13 +430,20 @@ void JitShader::Compile_DP3(Instruction instr) {
 
     Compile_SanitizedMul(SRC1, SRC2, SCRATCH);
 
-    movaps(SRC2, SRC1);
-    shufps(SRC2, SRC2, _MM_SHUFFLE(1, 1, 1, 1));
+    if (host_caps.has(Cpu::tAVX)) {
+        vshufps(SRC3, SRC1, SRC1, _MM_SHUFFLE(2, 2, 2, 2));
+        vshufps(SRC2, SRC1, SRC1, _MM_SHUFFLE(1, 1, 1, 1));
+        vshufps(SRC1, SRC1, SRC1, _MM_SHUFFLE(0, 0, 0, 0));
+    } else {
+        movaps(SRC2, SRC1);
+        shufps(SRC2, SRC2, _MM_SHUFFLE(1, 1, 1, 1));
 
-    movaps(SRC3, SRC1);
-    shufps(SRC3, SRC3, _MM_SHUFFLE(2, 2, 2, 2));
+        movaps(SRC3, SRC1);
+        shufps(SRC3, SRC3, _MM_SHUFFLE(2, 2, 2, 2));
+
+        shufps(SRC1, SRC1, _MM_SHUFFLE(0, 0, 0, 0));
+    }
 
-    shufps(SRC1, SRC1, _MM_SHUFFLE(0, 0, 0, 0));
     addps(SRC1, SRC2);
     addps(SRC1, SRC3);
 
@@ -589,9 +620,15 @@ void JitShader::Compile_MOV(Instruction instr) {
 void JitShader::Compile_RCP(Instruction instr) {
     Compile_SwizzleSrc(instr, 1, instr.common.src1, SRC1);
 
-    // TODO(bunnei): RCPSS is a pretty rough approximation, this might cause problems if Pica
-    // performs this operation more accurately. This should be checked on hardware.
-    rcpss(SRC1, SRC1);
+    if (host_caps.has(Cpu::tAVX512F | Cpu::tAVX512VL)) {
+        // Accurate to 14 bits of precisions rather than 12 bits of rcpss
+        vrcp14ss(SRC1, SRC1, SRC1);
+    } else {
+        // TODO(bunnei): RCPSS is a pretty rough approximation, this might cause problems if Pica
+        // performs this operation more accurately. This should be checked on hardware.
+        rcpss(SRC1, SRC1);
+    }
+
     shufps(SRC1, SRC1, _MM_SHUFFLE(0, 0, 0, 0)); // XYWZ -> XXXX
 
     Compile_DestEnable(instr, SRC1);
@@ -600,9 +637,15 @@ void JitShader::Compile_RCP(Instruction instr) {
 void JitShader::Compile_RSQ(Instruction instr) {
     Compile_SwizzleSrc(instr, 1, instr.common.src1, SRC1);
 
-    // TODO(bunnei): RSQRTSS is a pretty rough approximation, this might cause problems if Pica
-    // performs this operation more accurately. This should be checked on hardware.
-    rsqrtss(SRC1, SRC1);
+    if (host_caps.has(Cpu::tAVX512F | Cpu::tAVX512VL)) {
+        // Accurate to 14 bits of precisions rather than 12 bits of rsqrtss
+        vrsqrt14ss(SRC1, SRC1, SRC1);
+    } else {
+        // TODO(bunnei): RSQRTSS is a pretty rough approximation, this might cause problems if Pica
+        // performs this operation more accurately. This should be checked on hardware.
+        rsqrtss(SRC1, SRC1);
+    }
+
     shufps(SRC1, SRC1, _MM_SHUFFLE(0, 0, 0, 0)); // XYWZ -> XXXX
 
     Compile_DestEnable(instr, SRC1);
@@ -1050,32 +1093,47 @@ Xbyak::Label JitShader::CompilePrelude_Log2() {
     jp(input_is_nan);
     jae(input_out_of_range);
 
-    // Split input
-    movd(eax, SRC1);
-    mov(edx, eax);
-    and_(eax, 0x7f800000);
-    and_(edx, 0x007fffff);
-    movss(SCRATCH, xword[rip + c0]); // Preload c0.
-    or_(edx, 0x3f800000);
-    movd(SRC1, edx);
-    // SRC1 now contains the mantissa of the input.
-    mulss(SCRATCH, SRC1);
-    shr(eax, 23);
-    sub(eax, 0x7f);
-    cvtsi2ss(SCRATCH2, eax);
-    // SCRATCH2 now contains the exponent of the input.
+    // Split input: SRC1=MANT[1,2) SCRATCH2=Exponent
+    if (host_caps.has(Cpu::tAVX512F | Cpu::tAVX512VL)) {
+        vgetexpss(SCRATCH2, SRC1, SRC1);
+        vgetmantss(SRC1, SRC1, SRC1, 0x0'0);
+    } else {
+        movd(eax, SRC1);
+        mov(edx, eax);
+        and_(eax, 0x7f800000);
+        and_(edx, 0x007fffff);
+        or_(edx, 0x3f800000);
+        movd(SRC1, edx);
+        // SRC1 now contains the mantissa of the input.
+        shr(eax, 23);
+        sub(eax, 0x7f);
+        cvtsi2ss(SCRATCH2, eax);
+        // SCRATCH2 now contains the exponent of the input.
+    }
+
+    movss(SCRATCH, xword[rip + c0]);
 
     // Complete computation of polynomial
-    addss(SCRATCH, xword[rip + c1]);
-    mulss(SCRATCH, SRC1);
-    addss(SCRATCH, xword[rip + c2]);
-    mulss(SCRATCH, SRC1);
-    addss(SCRATCH, xword[rip + c3]);
-    mulss(SCRATCH, SRC1);
-    subss(SRC1, ONE);
-    addss(SCRATCH, xword[rip + c4]);
-    mulss(SCRATCH, SRC1);
-    addss(SCRATCH2, SCRATCH);
+    if (host_caps.has(Cpu::tFMA)) {
+        vfmadd213ss(SCRATCH, SRC1, xword[rip + c1]);
+        vfmadd213ss(SCRATCH, SRC1, xword[rip + c2]);
+        vfmadd213ss(SCRATCH, SRC1, xword[rip + c3]);
+        vfmadd213ss(SCRATCH, SRC1, xword[rip + c4]);
+        subss(SRC1, ONE);
+        vfmadd231ss(SCRATCH2, SCRATCH, SRC1);
+    } else {
+        mulss(SCRATCH, SRC1);
+        addss(SCRATCH, xword[rip + c1]);
+        mulss(SCRATCH, SRC1);
+        addss(SCRATCH, xword[rip + c2]);
+        mulss(SCRATCH, SRC1);
+        addss(SCRATCH, xword[rip + c3]);
+        mulss(SCRATCH, SRC1);
+        subss(SRC1, ONE);
+        addss(SCRATCH, xword[rip + c4]);
+        mulss(SCRATCH, SRC1);
+        addss(SCRATCH2, SCRATCH);
+    }
 
     // Duplicate result across vector
     xorps(SRC1, SRC1); // break dependency chain
@@ -1122,33 +1180,69 @@ Xbyak::Label JitShader::CompilePrelude_Exp2() {
     // Handle edge cases
     ucomiss(SRC1, SRC1);
     jp(ret_label);
-    // Clamp to maximum range since we shift the value directly into the exponent.
-    minss(SRC1, xword[rip + input_max]);
-    maxss(SRC1, xword[rip + input_min]);
 
-    // Decompose input
-    movss(SCRATCH, SRC1);
-    movss(SCRATCH2, xword[rip + c0]); // Preload c0.
-    subss(SCRATCH, xword[rip + half]);
-    cvtss2si(eax, SCRATCH);
-    cvtsi2ss(SCRATCH, eax);
-    // SCRATCH now contains input rounded to the nearest integer.
-    add(eax, 0x7f);
-    subss(SRC1, SCRATCH);
-    // SRC1 contains input - round(input), which is in [-0.5, 0.5).
-    mulss(SCRATCH2, SRC1);
-    shl(eax, 23);
-    movd(SCRATCH, eax);
-    // SCRATCH contains 2^(round(input)).
+    // Decompose input:
+    // SCRATCH=2^round(input)
+    // SRC1=input-round(input) [-0.5, 0.5)
+    if (host_caps.has(Cpu::tAVX512F | Cpu::tAVX512VL)) {
+        // input - 0.5
+        vsubss(SCRATCH, SRC1, xword[rip + half]);
+
+        // trunc(input - 0.5)
+        vrndscaless(SCRATCH2, SCRATCH, SCRATCH, _MM_FROUND_TRUNC);
+
+        // SCRATCH = 1 * 2^(trunc(input - 0.5))
+        vscalefss(SCRATCH, ONE, SCRATCH2);
+
+        // SRC1 = input-trunc(input - 0.5)
+        vsubss(SRC1, SRC1, SCRATCH2);
+    } else {
+        // Clamp to maximum range since we shift the value directly into the exponent.
+        minss(SRC1, xword[rip + input_max]);
+        maxss(SRC1, xword[rip + input_min]);
+
+        if (host_caps.has(Cpu::tAVX)) {
+            vsubss(SCRATCH, SRC1, xword[rip + half]);
+        } else {
+            movss(SCRATCH, SRC1);
+            subss(SCRATCH, xword[rip + half]);
+        }
+
+        if (host_caps.has(Cpu::tSSE41)) {
+            roundss(SCRATCH, SCRATCH, _MM_FROUND_TRUNC);
+            cvtss2si(eax, SCRATCH);
+        } else {
+            cvtss2si(eax, SCRATCH);
+            cvtsi2ss(SCRATCH, eax);
+        }
+        // SCRATCH now contains input rounded to the nearest integer.
+        add(eax, 0x7f);
+        subss(SRC1, SCRATCH);
+        // SRC1 contains input - round(input), which is in [-0.5, 0.5).
+        shl(eax, 23);
+        movd(SCRATCH, eax);
+        // SCRATCH contains 2^(round(input)).
+    }
 
     // Complete computation of polynomial.
-    addss(SCRATCH2, xword[rip + c1]);
-    mulss(SCRATCH2, SRC1);
-    addss(SCRATCH2, xword[rip + c2]);
-    mulss(SCRATCH2, SRC1);
-    addss(SCRATCH2, xword[rip + c3]);
-    mulss(SRC1, SCRATCH2);
-    addss(SRC1, xword[rip + c4]);
+    movss(SCRATCH2, xword[rip + c0]);
+
+    if (host_caps.has(Cpu::tFMA)) {
+        vfmadd213ss(SCRATCH2, SRC1, xword[rip + c1]);
+        vfmadd213ss(SCRATCH2, SRC1, xword[rip + c2]);
+        vfmadd213ss(SCRATCH2, SRC1, xword[rip + c3]);
+        vfmadd213ss(SRC1, SCRATCH2, xword[rip + c4]);
+    } else {
+        mulss(SCRATCH2, SRC1);
+        addss(SCRATCH2, xword[rip + c1]);
+        mulss(SCRATCH2, SRC1);
+        addss(SCRATCH2, xword[rip + c2]);
+        mulss(SCRATCH2, SRC1);
+        addss(SCRATCH2, xword[rip + c3]);
+        mulss(SRC1, SCRATCH2);
+        addss(SRC1, xword[rip + c4]);
+    }
+
     mulss(SRC1, SCRATCH);
 
     // Duplicate result across vector