diff --git a/src/backend/x64/emit_x64_vector_floating_point.cpp b/src/backend/x64/emit_x64_vector_floating_point.cpp index 9662f825..a40bb79e 100644 --- a/src/backend/x64/emit_x64_vector_floating_point.cpp +++ b/src/backend/x64/emit_x64_vector_floating_point.cpp @@ -337,8 +337,12 @@ void EmitTwoOpVectorOperation(BlockOfCode& code, EmitContext& ctx, IR::Inst* ins ctx.reg_alloc.DefineValue(inst, result); } +enum CheckInputNaN { + Yes, No, +}; + template class Indexer, typename Function> -void EmitThreeOpVectorOperation(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst, Function fn, typename NaNHandler::function_type nan_handler = NaNHandler::GetDefault()) { +void EmitThreeOpVectorOperation(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst, Function fn, CheckInputNaN check_input_nan = CheckInputNaN::No, typename NaNHandler::function_type nan_handler = NaNHandler::GetDefault()) { static_assert(fsize == 32 || fsize == 64, "fsize must be either 32 or 64"); auto args = ctx.reg_alloc.GetArgumentInfo(inst); @@ -371,15 +375,31 @@ void EmitThreeOpVectorOperation(BlockOfCode& code, EmitContext& ctx, IR::Inst* i const Xbyak::Xmm xmm_b = ctx.reg_alloc.UseXmm(args[1]); const Xbyak::Xmm nan_mask = ctx.reg_alloc.ScratchXmm(); - code.movaps(nan_mask, xmm_b); code.movaps(result, xmm_a); - FCODE(cmpunordp)(nan_mask, xmm_a); + + if (check_input_nan == CheckInputNaN::Yes) { + if (code.HasAVX()) { + FCODE(vcmpunordp)(nan_mask, xmm_a, xmm_b); + } else { + code.movaps(nan_mask, xmm_b); + FCODE(cmpunordp)(nan_mask, xmm_a); + } + } + if constexpr (std::is_member_function_pointer_v) { (code.*fn)(result, xmm_b); } else { fn(result, xmm_b); } - FCODE(cmpunordp)(nan_mask, result); + + if (check_input_nan == CheckInputNaN::Yes) { + FCODE(cmpunordp)(nan_mask, result); + } else if (code.HasAVX()) { + FCODE(vcmpunordp)(nan_mask, result, result); + } else { + code.movaps(nan_mask, result); + FCODE(cmpunordp)(nan_mask, nan_mask); + } HandleNaNs(code, ctx, fpcr_controlled, {result, xmm_a, xmm_b}, nan_mask, nan_handler); @@ -951,7 +971,7 @@ static void EmitFPVectorMinMax(BlockOfCode& code, EmitContext& ctx, IR::Inst* in code.andnps(mask, eq); code.orps(result, mask); } - }); + }, CheckInputNaN::Yes); } void EmitX64::EmitFPVectorMax32(EmitContext& ctx, IR::Inst* inst) {