From bd82513199c4b17d8467b6171f0dc1e487e5acff Mon Sep 17 00:00:00 2001
From: Lioncash <mathew1800@gmail.com>
Date: Sat, 13 Apr 2019 00:12:25 -0400
Subject: [PATCH] frontend/ir_emitter: Add half-precision opcode for FPMulAdd

---
 src/backend/x64/emit_x64_floating_point.cpp   | 80 ++++++++++---------
 .../impl/simd_scalar_x_indexed_element.cpp    |  2 +-
 src/frontend/ir/ir_emitter.cpp                | 13 ++-
 src/frontend/ir/ir_emitter.h                  |  2 +-
 src/frontend/ir/microinstruction.cpp          |  1 +
 src/frontend/ir/opcodes.inc                   |  1 +
 6 files changed, 57 insertions(+), 42 deletions(-)

diff --git a/src/backend/x64/emit_x64_floating_point.cpp b/src/backend/x64/emit_x64_floating_point.cpp
index 5df30955..3fd10ced 100644
--- a/src/backend/x64/emit_x64_floating_point.cpp
+++ b/src/backend/x64/emit_x64_floating_point.cpp
@@ -608,54 +608,56 @@ template<size_t fsize>
 static void EmitFPMulAdd(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst) {
     using FPT = mp::unsigned_integer_of_size<fsize>;
 
-    if (code.DoesCpuSupport(Xbyak::util::Cpu::tFMA)) {
-        auto args = ctx.reg_alloc.GetArgumentInfo(inst);
+    if constexpr (fsize != 16) {
+        if (code.DoesCpuSupport(Xbyak::util::Cpu::tFMA)) {
+            auto args = ctx.reg_alloc.GetArgumentInfo(inst);
 
-        Xbyak::Label end, fallback;
+            Xbyak::Label end, fallback;
 
-        const Xbyak::Xmm operand1 = ctx.reg_alloc.UseXmm(args[0]);
-        const Xbyak::Xmm operand2 = ctx.reg_alloc.UseXmm(args[1]);
-        const Xbyak::Xmm operand3 = ctx.reg_alloc.UseXmm(args[2]);
-        const Xbyak::Xmm result = ctx.reg_alloc.ScratchXmm();
-        const Xbyak::Xmm tmp = ctx.reg_alloc.ScratchXmm();
+            const Xbyak::Xmm operand1 = ctx.reg_alloc.UseXmm(args[0]);
+            const Xbyak::Xmm operand2 = ctx.reg_alloc.UseXmm(args[1]);
+            const Xbyak::Xmm operand3 = ctx.reg_alloc.UseXmm(args[2]);
+            const Xbyak::Xmm result = ctx.reg_alloc.ScratchXmm();
+            const Xbyak::Xmm tmp = ctx.reg_alloc.ScratchXmm();
 
-        code.movaps(result, operand1);
-        FCODE(vfmadd231s)(result, operand2, operand3);
+            code.movaps(result, operand1);
+            FCODE(vfmadd231s)(result, operand2, operand3);
 
-        code.movaps(tmp, code.MConst(xword, fsize == 32 ? f32_non_sign_mask : f64_non_sign_mask));
-        code.andps(tmp, result);
-        FCODE(ucomis)(tmp, code.MConst(xword, fsize == 32 ? f32_smallest_normal : f64_smallest_normal));
-        code.jz(fallback, code.T_NEAR);
-        code.L(end);
+            code.movaps(tmp, code.MConst(xword, fsize == 32 ? f32_non_sign_mask : f64_non_sign_mask));
+            code.andps(tmp, result);
+            FCODE(ucomis)(tmp, code.MConst(xword, fsize == 32 ? f32_smallest_normal : f64_smallest_normal));
+            code.jz(fallback, code.T_NEAR);
+            code.L(end);
 
-        code.SwitchToFarCode();
-        code.L(fallback);
+            code.SwitchToFarCode();
+            code.L(fallback);
 
-        code.sub(rsp, 8);
-        ABI_PushCallerSaveRegistersAndAdjustStackExcept(code, HostLocXmmIdx(result.getIdx()));
-        code.movq(code.ABI_PARAM1, operand1);
-        code.movq(code.ABI_PARAM2, operand2);
-        code.movq(code.ABI_PARAM3, operand3);
-        code.mov(code.ABI_PARAM4.cvt32(), ctx.FPCR().Value());
+            code.sub(rsp, 8);
+            ABI_PushCallerSaveRegistersAndAdjustStackExcept(code, HostLocXmmIdx(result.getIdx()));
+            code.movq(code.ABI_PARAM1, operand1);
+            code.movq(code.ABI_PARAM2, operand2);
+            code.movq(code.ABI_PARAM3, operand3);
+            code.mov(code.ABI_PARAM4.cvt32(), ctx.FPCR().Value());
 #ifdef _WIN32
-        code.sub(rsp, 16 + ABI_SHADOW_SPACE);
-        code.lea(rax, code.ptr[code.r15 + code.GetJitStateInfo().offsetof_fpsr_exc]);
-        code.mov(qword[rsp + ABI_SHADOW_SPACE], rax);
-        code.CallFunction(&FP::FPMulAdd<FPT>);
-        code.add(rsp, 16 + ABI_SHADOW_SPACE);
+            code.sub(rsp, 16 + ABI_SHADOW_SPACE);
+            code.lea(rax, code.ptr[code.r15 + code.GetJitStateInfo().offsetof_fpsr_exc]);
+            code.mov(qword[rsp + ABI_SHADOW_SPACE], rax);
+            code.CallFunction(&FP::FPMulAdd<FPT>);
+            code.add(rsp, 16 + ABI_SHADOW_SPACE);
 #else
-        code.lea(code.ABI_PARAM5, code.ptr[code.r15 + code.GetJitStateInfo().offsetof_fpsr_exc]);
-        code.CallFunction(&FP::FPMulAdd<FPT>);
+            code.lea(code.ABI_PARAM5, code.ptr[code.r15 + code.GetJitStateInfo().offsetof_fpsr_exc]);
+            code.CallFunction(&FP::FPMulAdd<FPT>);
 #endif
-        code.movq(result, code.ABI_RETURN);
-        ABI_PopCallerSaveRegistersAndAdjustStackExcept(code, HostLocXmmIdx(result.getIdx()));
-        code.add(rsp, 8);
+            code.movq(result, code.ABI_RETURN);
+            ABI_PopCallerSaveRegistersAndAdjustStackExcept(code, HostLocXmmIdx(result.getIdx()));
+            code.add(rsp, 8);
 
-        code.jmp(end, code.T_NEAR);
-        code.SwitchToNearCode();
+            code.jmp(end, code.T_NEAR);
+            code.SwitchToNearCode();
 
-        ctx.reg_alloc.DefineValue(inst, result);
-        return;
+            ctx.reg_alloc.DefineValue(inst, result);
+            return;
+        }
     }
 
     auto args = ctx.reg_alloc.GetArgumentInfo(inst);
@@ -673,6 +675,10 @@ static void EmitFPMulAdd(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst) {
 #endif
 }
 
+void EmitX64::EmitFPMulAdd16(EmitContext& ctx, IR::Inst* inst) {
+    EmitFPMulAdd<16>(code, ctx, inst);
+}
+
 void EmitX64::EmitFPMulAdd32(EmitContext& ctx, IR::Inst* inst) {
     EmitFPMulAdd<32>(code, ctx, inst);
 }
diff --git a/src/frontend/A64/translate/impl/simd_scalar_x_indexed_element.cpp b/src/frontend/A64/translate/impl/simd_scalar_x_indexed_element.cpp
index 693cf3d3..0d475c40 100644
--- a/src/frontend/A64/translate/impl/simd_scalar_x_indexed_element.cpp
+++ b/src/frontend/A64/translate/impl/simd_scalar_x_indexed_element.cpp
@@ -36,7 +36,7 @@ bool MultiplyByElement(TranslatorVisitor& v, bool sz, Imm<1> L, Imm<1> M, Imm<4>
     const size_t esize = sz ? 64 : 32;
 
     const IR::U32U64 element = v.ir.VectorGetElement(esize, v.V(idxdsize, Vm), index);
-    const IR::U32U64 result = [&] {
+    const IR::U32U64 result = [&]() -> IR::U32U64 {
         IR::U32U64 operand1 = v.V_scalar(esize, Vn);
 
         if (extra_behavior == ExtraBehavior::None) {
diff --git a/src/frontend/ir/ir_emitter.cpp b/src/frontend/ir/ir_emitter.cpp
index 53b546c0..6d333b91 100644
--- a/src/frontend/ir/ir_emitter.cpp
+++ b/src/frontend/ir/ir_emitter.cpp
@@ -1867,13 +1867,20 @@ U32U64 IREmitter::FPMul(const U32U64& a, const U32U64& b, bool fpcr_controlled)
     }
 }
 
-U32U64 IREmitter::FPMulAdd(const U32U64& a, const U32U64& b, const U32U64& c, bool fpcr_controlled) {
+U16U32U64 IREmitter::FPMulAdd(const U16U32U64& a, const U16U32U64& b, const U16U32U64& c, bool fpcr_controlled) {
     ASSERT(fpcr_controlled);
     ASSERT(a.GetType() == b.GetType());
-    if (a.GetType() == Type::U32) {
+
+    switch (a.GetType()) {
+    case Type::U16:
+        return Inst<U16>(Opcode::FPMulAdd16, a, b, c);
+    case Type::U32:
         return Inst<U32>(Opcode::FPMulAdd32, a, b, c);
-    } else {
+    case Type::U64:
         return Inst<U64>(Opcode::FPMulAdd64, a, b, c);
+    default:
+        UNREACHABLE();
+        return U16U32U64{};
     }
 }
 
diff --git a/src/frontend/ir/ir_emitter.h b/src/frontend/ir/ir_emitter.h
index 291712ee..3ef31516 100644
--- a/src/frontend/ir/ir_emitter.h
+++ b/src/frontend/ir/ir_emitter.h
@@ -301,7 +301,7 @@ public:
     U32U64 FPMin(const U32U64& a, const U32U64& b, bool fpcr_controlled);
     U32U64 FPMinNumeric(const U32U64& a, const U32U64& b, bool fpcr_controlled);
     U32U64 FPMul(const U32U64& a, const U32U64& b, bool fpcr_controlled);
-    U32U64 FPMulAdd(const U32U64& addend, const U32U64& op1, const U32U64& op2, bool fpcr_controlled);
+    U16U32U64 FPMulAdd(const U16U32U64& addend, const U16U32U64& op1, const U16U32U64& op2, bool fpcr_controlled);
     U32U64 FPMulX(const U32U64& a, const U32U64& b);
     U16U32U64 FPNeg(const U16U32U64& a);
     U32U64 FPRecipEstimate(const U32U64& a);
diff --git a/src/frontend/ir/microinstruction.cpp b/src/frontend/ir/microinstruction.cpp
index 46d5fb87..d746c9e0 100644
--- a/src/frontend/ir/microinstruction.cpp
+++ b/src/frontend/ir/microinstruction.cpp
@@ -269,6 +269,7 @@ bool Inst::ReadsFromAndWritesToFPSRCumulativeExceptionBits() const {
     case Opcode::FPMinNumeric64:
     case Opcode::FPMul32:
     case Opcode::FPMul64:
+    case Opcode::FPMulAdd16:
     case Opcode::FPMulAdd32:
     case Opcode::FPMulAdd64:
     case Opcode::FPRecipEstimate32:
diff --git a/src/frontend/ir/opcodes.inc b/src/frontend/ir/opcodes.inc
index b5e8f691..3eaa41ea 100644
--- a/src/frontend/ir/opcodes.inc
+++ b/src/frontend/ir/opcodes.inc
@@ -479,6 +479,7 @@ OPCODE(FPMinNumeric32,                                      U32,            U32,
 OPCODE(FPMinNumeric64,                                      U64,            U64,            U64                                             )
 OPCODE(FPMul32,                                             U32,            U32,            U32                                             )
 OPCODE(FPMul64,                                             U64,            U64,            U64                                             )
+OPCODE(FPMulAdd16,                                          U16,            U16,            U16,            U16                             )
 OPCODE(FPMulAdd32,                                          U32,            U32,            U32,            U32                             )
 OPCODE(FPMulAdd64,                                          U64,            U64,            U64,            U64                             )
 OPCODE(FPMulX32,                                            U32,            U32,            U32                                             )