diff --git a/src/backend/x64/emit_x64_floating_point.cpp b/src/backend/x64/emit_x64_floating_point.cpp index 8543bc94..8386338d 100644 --- a/src/backend/x64/emit_x64_floating_point.cpp +++ b/src/backend/x64/emit_x64_floating_point.cpp @@ -946,52 +946,54 @@ template static void EmitFPRSqrtStepFused(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst) { using FPT = mp::unsigned_integer_of_size; - if (code.DoesCpuSupport(Xbyak::util::Cpu::tFMA) && code.DoesCpuSupport(Xbyak::util::Cpu::tAVX)) { - auto args = ctx.reg_alloc.GetArgumentInfo(inst); + if constexpr (fsize != 16) { + if (code.DoesCpuSupport(Xbyak::util::Cpu::tFMA) && code.DoesCpuSupport(Xbyak::util::Cpu::tAVX)) { + auto args = ctx.reg_alloc.GetArgumentInfo(inst); - Xbyak::Label end, fallback; + Xbyak::Label end, fallback; - const Xbyak::Xmm operand1 = ctx.reg_alloc.UseXmm(args[0]); - const Xbyak::Xmm operand2 = ctx.reg_alloc.UseXmm(args[1]); - const Xbyak::Xmm result = ctx.reg_alloc.ScratchXmm(); + const Xbyak::Xmm operand1 = ctx.reg_alloc.UseXmm(args[0]); + const Xbyak::Xmm operand2 = ctx.reg_alloc.UseXmm(args[1]); + const Xbyak::Xmm result = ctx.reg_alloc.ScratchXmm(); - code.vmovaps(result, code.MConst(xword, FP::FPValue())); - FCODE(vfnmadd231s)(result, operand1, operand2); + code.vmovaps(result, code.MConst(xword, FP::FPValue())); + FCODE(vfnmadd231s)(result, operand1, operand2); - // Detect if the intermediate result is infinity or NaN or nearly an infinity. - // Why do we need to care about infinities? This is because x86 doesn't allow us - // to fuse the divide-by-two with the rest of the FMA operation. Therefore the - // intermediate value may overflow and we would like to handle this case. - const Xbyak::Reg32 tmp = ctx.reg_alloc.ScratchGpr().cvt32(); - code.vpextrw(tmp, result, fsize == 32 ? 1 : 3); - code.and_(tmp.cvt16(), fsize == 32 ? 0x7f80 : 0x7ff0); - code.cmp(tmp.cvt16(), fsize == 32 ? 0x7f00 : 0x7fe0); - ctx.reg_alloc.Release(tmp); + // Detect if the intermediate result is infinity or NaN or nearly an infinity. + // Why do we need to care about infinities? This is because x86 doesn't allow us + // to fuse the divide-by-two with the rest of the FMA operation. Therefore the + // intermediate value may overflow and we would like to handle this case. + const Xbyak::Reg32 tmp = ctx.reg_alloc.ScratchGpr().cvt32(); + code.vpextrw(tmp, result, fsize == 32 ? 1 : 3); + code.and_(tmp.cvt16(), fsize == 32 ? 0x7f80 : 0x7ff0); + code.cmp(tmp.cvt16(), fsize == 32 ? 0x7f00 : 0x7fe0); + ctx.reg_alloc.Release(tmp); - code.jae(fallback, code.T_NEAR); + code.jae(fallback, code.T_NEAR); - FCODE(vmuls)(result, result, code.MConst(xword, FP::FPValue())); - code.L(end); + FCODE(vmuls)(result, result, code.MConst(xword, FP::FPValue())); + code.L(end); - code.SwitchToFarCode(); - code.L(fallback); + code.SwitchToFarCode(); + code.L(fallback); - code.sub(rsp, 8); - ABI_PushCallerSaveRegistersAndAdjustStackExcept(code, HostLocXmmIdx(result.getIdx())); - code.movq(code.ABI_PARAM1, operand1); - code.movq(code.ABI_PARAM2, operand2); - code.mov(code.ABI_PARAM3.cvt32(), ctx.FPCR().Value()); - code.lea(code.ABI_PARAM4, code.ptr[code.r15 + code.GetJitStateInfo().offsetof_fpsr_exc]); - code.CallFunction(&FP::FPRSqrtStepFused); - code.movq(result, code.ABI_RETURN); - ABI_PopCallerSaveRegistersAndAdjustStackExcept(code, HostLocXmmIdx(result.getIdx())); - code.add(rsp, 8); + code.sub(rsp, 8); + ABI_PushCallerSaveRegistersAndAdjustStackExcept(code, HostLocXmmIdx(result.getIdx())); + code.movq(code.ABI_PARAM1, operand1); + code.movq(code.ABI_PARAM2, operand2); + code.mov(code.ABI_PARAM3.cvt32(), ctx.FPCR().Value()); + code.lea(code.ABI_PARAM4, code.ptr[code.r15 + code.GetJitStateInfo().offsetof_fpsr_exc]); + code.CallFunction(&FP::FPRSqrtStepFused); + code.movq(result, code.ABI_RETURN); + ABI_PopCallerSaveRegistersAndAdjustStackExcept(code, HostLocXmmIdx(result.getIdx())); + code.add(rsp, 8); - code.jmp(end, code.T_NEAR); - code.SwitchToNearCode(); + code.jmp(end, code.T_NEAR); + code.SwitchToNearCode(); - ctx.reg_alloc.DefineValue(inst, result); - return; + ctx.reg_alloc.DefineValue(inst, result); + return; + } } auto args = ctx.reg_alloc.GetArgumentInfo(inst); @@ -1001,6 +1003,10 @@ static void EmitFPRSqrtStepFused(BlockOfCode& code, EmitContext& ctx, IR::Inst* code.CallFunction(&FP::FPRSqrtStepFused); } +void EmitX64::EmitFPRSqrtStepFused16(EmitContext& ctx, IR::Inst* inst) { + EmitFPRSqrtStepFused<16>(code, ctx, inst); +} + void EmitX64::EmitFPRSqrtStepFused32(EmitContext& ctx, IR::Inst* inst) { EmitFPRSqrtStepFused<32>(code, ctx, inst); } diff --git a/src/backend/x64/emit_x64_vector_floating_point.cpp b/src/backend/x64/emit_x64_vector_floating_point.cpp index 7023ce35..5d31418b 100644 --- a/src/backend/x64/emit_x64_vector_floating_point.cpp +++ b/src/backend/x64/emit_x64_vector_floating_point.cpp @@ -1273,51 +1273,57 @@ static void EmitRSqrtStepFused(BlockOfCode& code, EmitContext& ctx, IR::Inst* in } }; - if (code.DoesCpuSupport(Xbyak::util::Cpu::tFMA) && code.DoesCpuSupport(Xbyak::util::Cpu::tAVX)) { - auto args = ctx.reg_alloc.GetArgumentInfo(inst); + if constexpr (fsize != 16) { + if (code.DoesCpuSupport(Xbyak::util::Cpu::tFMA) && code.DoesCpuSupport(Xbyak::util::Cpu::tAVX)) { + auto args = ctx.reg_alloc.GetArgumentInfo(inst); - const Xbyak::Xmm result = ctx.reg_alloc.ScratchXmm(); - const Xbyak::Xmm operand1 = ctx.reg_alloc.UseXmm(args[0]); - const Xbyak::Xmm operand2 = ctx.reg_alloc.UseXmm(args[1]); - const Xbyak::Xmm tmp = ctx.reg_alloc.ScratchXmm(); - const Xbyak::Xmm mask = ctx.reg_alloc.ScratchXmm(); + const Xbyak::Xmm result = ctx.reg_alloc.ScratchXmm(); + const Xbyak::Xmm operand1 = ctx.reg_alloc.UseXmm(args[0]); + const Xbyak::Xmm operand2 = ctx.reg_alloc.UseXmm(args[1]); + const Xbyak::Xmm tmp = ctx.reg_alloc.ScratchXmm(); + const Xbyak::Xmm mask = ctx.reg_alloc.ScratchXmm(); - Xbyak::Label end, fallback; + Xbyak::Label end, fallback; - code.vmovaps(result, GetVectorOf(code)); - FCODE(vfnmadd231p)(result, operand1, operand2); + code.vmovaps(result, GetVectorOf(code)); + FCODE(vfnmadd231p)(result, operand1, operand2); - // An explanation for this is given in EmitFPRSqrtStepFused. - code.vmovaps(mask, GetVectorOf(code)); - FCODE(vandp)(tmp, result, mask); - if constexpr (fsize == 32) { - code.vpcmpeqd(tmp, tmp, mask); - } else { - code.vpcmpeqq(tmp, tmp, mask); + // An explanation for this is given in EmitFPRSqrtStepFused. + code.vmovaps(mask, GetVectorOf(code)); + FCODE(vandp)(tmp, result, mask); + if constexpr (fsize == 32) { + code.vpcmpeqd(tmp, tmp, mask); + } else { + code.vpcmpeqq(tmp, tmp, mask); + } + code.ptest(tmp, tmp); + code.jnz(fallback, code.T_NEAR); + + FCODE(vmulp)(result, result, GetVectorOf(code)); + code.L(end); + + code.SwitchToFarCode(); + code.L(fallback); + code.sub(rsp, 8); + ABI_PushCallerSaveRegistersAndAdjustStackExcept(code, HostLocXmmIdx(result.getIdx())); + EmitThreeOpFallbackWithoutRegAlloc(code, ctx, result, operand1, operand2, fallback_fn); + ABI_PopCallerSaveRegistersAndAdjustStackExcept(code, HostLocXmmIdx(result.getIdx())); + code.add(rsp, 8); + code.jmp(end, code.T_NEAR); + code.SwitchToNearCode(); + + ctx.reg_alloc.DefineValue(inst, result); + return; } - code.ptest(tmp, tmp); - code.jnz(fallback, code.T_NEAR); - - FCODE(vmulp)(result, result, GetVectorOf(code)); - code.L(end); - - code.SwitchToFarCode(); - code.L(fallback); - code.sub(rsp, 8); - ABI_PushCallerSaveRegistersAndAdjustStackExcept(code, HostLocXmmIdx(result.getIdx())); - EmitThreeOpFallbackWithoutRegAlloc(code, ctx, result, operand1, operand2, fallback_fn); - ABI_PopCallerSaveRegistersAndAdjustStackExcept(code, HostLocXmmIdx(result.getIdx())); - code.add(rsp, 8); - code.jmp(end, code.T_NEAR); - code.SwitchToNearCode(); - - ctx.reg_alloc.DefineValue(inst, result); - return; } EmitThreeOpFallback(code, ctx, inst, fallback_fn); } +void EmitX64::EmitFPVectorRSqrtStepFused16(EmitContext& ctx, IR::Inst* inst) { + EmitRSqrtStepFused<16>(code, ctx, inst); +} + void EmitX64::EmitFPVectorRSqrtStepFused32(EmitContext& ctx, IR::Inst* inst) { EmitRSqrtStepFused<32>(code, ctx, inst); } diff --git a/src/common/fp/op/FPRSqrtStepFused.cpp b/src/common/fp/op/FPRSqrtStepFused.cpp index e3ecf2ae..84a193b4 100644 --- a/src/common/fp/op/FPRSqrtStepFused.cpp +++ b/src/common/fp/op/FPRSqrtStepFused.cpp @@ -19,9 +19,9 @@ template FPT FPRSqrtStepFused(FPT op1, FPT op2, FPCR fpcr, FPSR& fpsr) { op1 = FPNeg(op1); - const auto [type1, sign1, value1] = FPUnpack(op1, fpcr, fpsr); - const auto [type2, sign2, value2] = FPUnpack(op2, fpcr, fpsr); - + const auto [type1, sign1, value1] = FPUnpack(op1, fpcr, fpsr); + const auto [type2, sign2, value2] = FPUnpack(op2, fpcr, fpsr); + if (const auto maybe_nan = FPProcessNaNs(type1, type2, op1, op2, fpcr, fpsr)) { return *maybe_nan; } @@ -37,7 +37,7 @@ FPT FPRSqrtStepFused(FPT op1, FPT op2, FPCR fpcr, FPSR& fpsr) { } if (inf1 || inf2) { - return FPInfo::Infinity(sign1 != sign2); + return FPT(FPInfo::Infinity(sign1 != sign2)); } // result_value = (3.0 + (value1 * value2)) / 2.0 @@ -45,11 +45,12 @@ FPT FPRSqrtStepFused(FPT op1, FPT op2, FPCR fpcr, FPSR& fpsr) { result_value.exponent--; if (result_value.mantissa == 0) { - return FPInfo::Zero(fpcr.RMode() == RoundingMode::TowardsMinusInfinity); + return FPT(FPInfo::Zero(fpcr.RMode() == RoundingMode::TowardsMinusInfinity)); } return FPRound(result_value, fpcr, fpsr); } +template u16 FPRSqrtStepFused(u16 op1, u16 op2, FPCR fpcr, FPSR& fpsr); template u32 FPRSqrtStepFused(u32 op1, u32 op2, FPCR fpcr, FPSR& fpsr); template u64 FPRSqrtStepFused(u64 op1, u64 op2, FPCR fpcr, FPSR& fpsr); diff --git a/src/frontend/A64/decoder/a64.inc b/src/frontend/A64/decoder/a64.inc index de1aa01e..a58a2e92 100644 --- a/src/frontend/A64/decoder/a64.inc +++ b/src/frontend/A64/decoder/a64.inc @@ -386,7 +386,7 @@ INST(FMULX_vec_2, "FMULX", "01011 INST(FCMEQ_reg_2, "FCMEQ (register)", "010111100z1mmmmm111001nnnnnddddd") INST(FRECPS_1, "FRECPS", "01011110010mmmmm001111nnnnnddddd") INST(FRECPS_2, "FRECPS", "010111100z1mmmmm111111nnnnnddddd") -//INST(FRSQRTS_1, "FRSQRTS", "01011110110mmmmm001111nnnnnddddd") +INST(FRSQRTS_1, "FRSQRTS", "01011110110mmmmm001111nnnnnddddd") INST(FRSQRTS_2, "FRSQRTS", "010111101z1mmmmm111111nnnnnddddd") //INST(FCMGE_reg_1, "FCMGE (register)", "01111110010mmmmm001001nnnnnddddd") INST(FCMGE_reg_2, "FCMGE (register)", "011111100z1mmmmm111001nnnnnddddd") @@ -576,7 +576,7 @@ INST(INS_elt, "INS (element)", "01101 //INST(FMULX_vec_3, "FMULX", "0Q001110010mmmmm000111nnnnnddddd") //INST(FCMEQ_reg_3, "FCMEQ (register)", "0Q001110010mmmmm001001nnnnnddddd") INST(FRECPS_3, "FRECPS", "0Q001110010mmmmm001111nnnnnddddd") -//INST(FRSQRTS_3, "FRSQRTS", "0Q001110110mmmmm001111nnnnnddddd") +INST(FRSQRTS_3, "FRSQRTS", "0Q001110110mmmmm001111nnnnnddddd") //INST(FCMGE_reg_3, "FCMGE (register)", "0Q101110010mmmmm001001nnnnnddddd") //INST(FACGE_3, "FACGE", "0Q101110010mmmmm001011nnnnnddddd") //INST(FABD_3, "FABD", "0Q101110110mmmmm000101nnnnnddddd") diff --git a/src/frontend/A64/translate/impl/simd_scalar_three_same.cpp b/src/frontend/A64/translate/impl/simd_scalar_three_same.cpp index f4eaa6e1..68c15735 100644 --- a/src/frontend/A64/translate/impl/simd_scalar_three_same.cpp +++ b/src/frontend/A64/translate/impl/simd_scalar_three_same.cpp @@ -316,6 +316,17 @@ bool TranslatorVisitor::FRECPS_2(bool sz, Vec Vm, Vec Vn, Vec Vd) { return true; } +bool TranslatorVisitor::FRSQRTS_1(Vec Vm, Vec Vn, Vec Vd) { + const size_t esize = 16; + + const IR::U16 operand1 = V_scalar(esize, Vn); + const IR::U16 operand2 = V_scalar(esize, Vm); + const IR::U16 result = ir.FPRSqrtStepFused(operand1, operand2); + + V_scalar(esize, Vd, result); + return true; +} + bool TranslatorVisitor::FRSQRTS_2(bool sz, Vec Vm, Vec Vn, Vec Vd) { const size_t esize = sz ? 64 : 32; diff --git a/src/frontend/A64/translate/impl/simd_three_same.cpp b/src/frontend/A64/translate/impl/simd_three_same.cpp index da8939d5..a22eaa6d 100644 --- a/src/frontend/A64/translate/impl/simd_three_same.cpp +++ b/src/frontend/A64/translate/impl/simd_three_same.cpp @@ -965,6 +965,18 @@ bool TranslatorVisitor::FRECPS_4(bool Q, bool sz, Vec Vm, Vec Vn, Vec Vd) { return true; } +bool TranslatorVisitor::FRSQRTS_3(bool Q, Vec Vm, Vec Vn, Vec Vd) { + const size_t esize = 16; + const size_t datasize = Q ? 128 : 64; + + const IR::U128 operand1 = V(datasize, Vn); + const IR::U128 operand2 = V(datasize, Vm); + const IR::U128 result = ir.FPVectorRSqrtStepFused(esize, operand1, operand2); + + V(datasize, Vd, result); + return true; +} + bool TranslatorVisitor::FRSQRTS_4(bool Q, bool sz, Vec Vm, Vec Vn, Vec Vd) { if (sz && !Q) { return ReservedValue(); diff --git a/src/frontend/ir/ir_emitter.cpp b/src/frontend/ir/ir_emitter.cpp index e9c655ab..6f9b8715 100644 --- a/src/frontend/ir/ir_emitter.cpp +++ b/src/frontend/ir/ir_emitter.cpp @@ -1997,11 +1997,20 @@ U16U32U64 IREmitter::FPRSqrtEstimate(const U16U32U64& a) { } } -U32U64 IREmitter::FPRSqrtStepFused(const U32U64& a, const U32U64& b) { - if (a.GetType() == Type::U32) { +U16U32U64 IREmitter::FPRSqrtStepFused(const U16U32U64& a, const U16U32U64& b) { + ASSERT(a.GetType() == b.GetType()); + + switch (a.GetType()) { + case Type::U16: + return Inst(Opcode::FPRSqrtStepFused16, a, b); + case Type::U32: return Inst(Opcode::FPRSqrtStepFused32, a, b); + case Type::U64: + return Inst(Opcode::FPRSqrtStepFused64, a, b); + default: + UNREACHABLE(); + return U16U32U64{}; } - return Inst(Opcode::FPRSqrtStepFused64, a, b); } U32U64 IREmitter::FPSqrt(const U32U64& a) { @@ -2335,6 +2344,8 @@ U128 IREmitter::FPVectorRSqrtEstimate(size_t esize, const U128& a) { U128 IREmitter::FPVectorRSqrtStepFused(size_t esize, const U128& a, const U128& b) { switch (esize) { + case 16: + return Inst(Opcode::FPVectorRSqrtStepFused16, a, b); case 32: return Inst(Opcode::FPVectorRSqrtStepFused32, a, b); case 64: diff --git a/src/frontend/ir/ir_emitter.h b/src/frontend/ir/ir_emitter.h index 6fe44dab..0b80d924 100644 --- a/src/frontend/ir/ir_emitter.h +++ b/src/frontend/ir/ir_emitter.h @@ -310,7 +310,7 @@ public: U16U32U64 FPRecipStepFused(const U16U32U64& a, const U16U32U64& b); U16U32U64 FPRoundInt(const U16U32U64& a, FP::RoundingMode rounding, bool exact); U16U32U64 FPRSqrtEstimate(const U16U32U64& a); - U32U64 FPRSqrtStepFused(const U32U64& a, const U32U64& b); + U16U32U64 FPRSqrtStepFused(const U16U32U64& a, const U16U32U64& b); U32U64 FPSqrt(const U32U64& a); U32U64 FPSub(const U32U64& a, const U32U64& b, bool fpcr_controlled); U16 FPDoubleToHalf(const U64& a, FP::RoundingMode rounding); diff --git a/src/frontend/ir/microinstruction.cpp b/src/frontend/ir/microinstruction.cpp index 3577fee7..a6b6bee7 100644 --- a/src/frontend/ir/microinstruction.cpp +++ b/src/frontend/ir/microinstruction.cpp @@ -287,6 +287,7 @@ bool Inst::ReadsFromAndWritesToFPSRCumulativeExceptionBits() const { case Opcode::FPRSqrtEstimate16: case Opcode::FPRSqrtEstimate32: case Opcode::FPRSqrtEstimate64: + case Opcode::FPRSqrtStepFused16: case Opcode::FPRSqrtStepFused32: case Opcode::FPRSqrtStepFused64: case Opcode::FPSqrt32: @@ -350,6 +351,7 @@ bool Inst::ReadsFromAndWritesToFPSRCumulativeExceptionBits() const { case Opcode::FPVectorRSqrtEstimate16: case Opcode::FPVectorRSqrtEstimate32: case Opcode::FPVectorRSqrtEstimate64: + case Opcode::FPVectorRSqrtStepFused16: case Opcode::FPVectorRSqrtStepFused32: case Opcode::FPVectorRSqrtStepFused64: case Opcode::FPVectorSqrt32: diff --git a/src/frontend/ir/opcodes.inc b/src/frontend/ir/opcodes.inc index 0adab017..fbe6c303 100644 --- a/src/frontend/ir/opcodes.inc +++ b/src/frontend/ir/opcodes.inc @@ -506,6 +506,7 @@ OPCODE(FPRoundInt64, U64, U64, OPCODE(FPRSqrtEstimate16, U16, U16 ) OPCODE(FPRSqrtEstimate32, U32, U32 ) OPCODE(FPRSqrtEstimate64, U64, U64 ) +OPCODE(FPRSqrtStepFused16, U16, U16, U16 ) OPCODE(FPRSqrtStepFused32, U32, U32, U32 ) OPCODE(FPRSqrtStepFused64, U64, U64, U64 ) OPCODE(FPSqrt32, U32, U32 ) @@ -585,6 +586,7 @@ OPCODE(FPVectorRoundInt64, U128, U128 OPCODE(FPVectorRSqrtEstimate16, U128, U128 ) OPCODE(FPVectorRSqrtEstimate32, U128, U128 ) OPCODE(FPVectorRSqrtEstimate64, U128, U128 ) +OPCODE(FPVectorRSqrtStepFused16, U128, U128, U128 ) OPCODE(FPVectorRSqrtStepFused32, U128, U128, U128 ) OPCODE(FPVectorRSqrtStepFused64, U128, U128, U128 ) OPCODE(FPVectorSqrt32, U128, U128 )