import math
from utils.solis.helper import generate_func_set, generate_func_set_all, is_func

def try_search(args, orig_nums, fp_results, func_set=None):

    if func_set is None:
        if "multiarith" in args.dataset:
            op_set, x_set, func_set = generate_func_set(
                num_x=3,
                num_op=2,
                show_funx=False,
                op_sets=['+', '-', '*', '/'],
                rep=True,
            )
        elif "addsub" in args.dataset:
            op_set, x_set, func_set_3 = generate_func_set_all(
                num_x=3,
                num_op=2,
                show_funx=False,
                op_sets=['+', '-'],
                rep=True,
            )
            op_set, x_set, func_set_2 = generate_func_set(
                num_x=2,
                num_op=1,
                show_funx=False,
                op_sets=['+', '-'],
                rep=True,
            )
            if len(orig_nums) == 3:
                func_set = func_set_3
            elif len(orig_nums) == 2:
                func_set = func_set_2
            else:
                return None, None
    
    errors_cnt = [0] * len(func_set)
    losses_cnt = [0] * len(func_set)
    for k, expr in enumerate(func_set):

        for fp_result in fp_results:

            pred = fp_result["fp_z"]
            repl_numbers = fp_result["fp_nums"]

            flag_, loss_ = is_func(expr, [str(pred)], [repl_numbers], return_loss=True)
            errors_cnt[k] += (flag_ == False)
            losses_cnt[k] += abs(loss_)

    tmp_min = 10000000000000
    thresh_k = int(len(errors_cnt))
    expr_filter = ""
    for k, cnt in enumerate(errors_cnt):
        if cnt <= thresh_k and errors_cnt[k] < tmp_min:
            expr_filter = func_set[k]
            tmp_min = errors_cnt[k]

    tmp_min = losses_cnt[0]
    if expr_filter == "":
        for k, loss in enumerate(losses_cnt):
            if loss < tmp_min:
                expr_filter = func_set[k]
                tmp_min = loss

    # try calibration
    cali_pred = ""
    if expr_filter != "":
        var_dict = {}
        for i_, var in enumerate(orig_nums):
            var_dict.update({
                f"x{i_}": var,
            })
        try:
            cali_pred = eval(expr_filter, var_dict)
            if "multiarith" in args.dataset:
                cali_pred = round(cali_pred, 5)
                if int(cali_pred * 10 // 10) == cali_pred:
                    cali_pred = int(cali_pred)
                else:
                    cali_pred = math.ceil(cali_pred)
            elif "addsub" in args.dataset:
                bit_max = 0
                for number in orig_nums:
                    bit = str(number).split('.')
                    if len(bit) == 1:
                        bit = 0
                    else:
                        bit = len(bit[-1])
                        bit_max = max(bit, bit_max)
                cali_pred = round(cali_pred, bit_max)
        except:
            return None, None

        return expr_filter, cali_pred
    else:
        return None, None