from sympy import simplify


def generate_func_set(num_x, num_op, show_funx=False, op_sets=['+', '-', '*', '/'], rep=True):
    """
    e.g. numx = 3, num_op = 2, must return 3 lists:
    [0, 1, 2] => var order
    ['+', '-'] => op order
    [0, 1] => barcket start,end pos
    """
    def dfs_op(tmp_op_list, max_num):
        if len(tmp_op_list) == max_num:
            if tmp_op_list not in op_lists:
                op_lists.append(tmp_op_list)
            return
        
        for op in op_sets:
            dfs_op(tmp_op_list + [op], max_num)
        return
    
    def dfs_x(tmp_x_list, max_var_num, max_num):
        if len(tmp_x_list) == max_num:
            if len(set(tmp_x_list)) == max_var_num:
                x_lists.append(tmp_x_list)
            return
        
        for x in x_sets:
            if rep or (not rep and x not in tmp_x_list):
                dfs_x(tmp_x_list + [x], max_var_num, max_num)
        return

    def dfs_parenthesis(expr_list, left, right):
        if left > right:
            return []
        ans = []  
        for i in range(left, right + 1):
            if 'x' in expr_list[i]:
                continue
            ans_left = dfs_parenthesis(expr_list, left, i - 1)
            ans_right = dfs_parenthesis(expr_list, i + 1, right)

            for a in ans_left:
                for b in ans_right:
                    if a.count('x') > 1:
                        a = '(' + a + ')'
                    if b.count('x') > 1:
                        b = '(' + b + ')'
                    cur = ""
                    if expr_list[i] == "+":
                        cur = a + '+' + b
                    elif expr_list[i] == "-":
                        cur = a + '-' + b
                    elif expr_list[i] == '*':
                        cur = a + '*' + b
                    elif expr_list[i] == '/':
                        cur = a + '/' + b
                    ans.append(cur)
        if not ans:
            ans.append(expr_list[left])
        return ans

    x_sets = [_ for _ in range(num_x)]
    n_op = len(op_sets)

    op_lists = []
    x_lists = []

    # Generate all OP combinations, all X combinations
    dfs_op([], max_num=num_op)
    dfs_x([], max_var_num=num_x, max_num=num_op+1)

    exprs = []
    symps = []

    # Make up all operations, add parentthesis, remove repeated cases
    for op_list in op_lists:
        for x_list in x_lists:
            if len(x_list) != len(op_list) + 1:
                continue
            expr = ""
            for i in range(len(op_list)):
                expr = expr + "x" + str(x_list[i])# + ' '
                expr = expr + op_list[i]# + ' '
            # Make up expression without ()
            expr = expr + "x" + str(x_list[-1])
            
            # Make up expression with ()
            e_list = expr.replace(' ', '').replace("+", " + ").replace("-", " - ").replace("*", " * ").replace("/", " / ").split(' ')

            if rep:
                exprs_parenth = dfs_parenthesis(e_list, 0, len(e_list) - 1)
                exprs_parenth.insert(0, expr)
            else:
                exprs_parenth = [expr]

            # remove repeat
            for expr in exprs_parenth:
                symp = simplify(expr)
                if all(op not in str(symp) for op in op_sets):  # remove out of func space
                    continue
                if symp not in symps:
                    symps.append(symp)

                    e_list = str(symp).replace(' ', '').replace("+", " + ").replace("-", " - ").replace("*", " * ").replace("/", " / ").replace("(", " ( ").replace(")", " ) ").replace("**", " ** ").split(' ')

                    if any(e.isnumeric() for e in e_list):  # do not allow const in a expression, only to add manually
                        expr_simple = expr
                    else:
                        expr_simple = str(symp)

                    x_sets_ = list(set([e if 'x' in e else '' for e in e_list]))
                    x_sets_.remove('')
                    if x_sets_ is None or len(x_sets_) < len(x_sets):
                        continue

                    exprs.append(expr_simple.replace(' ',''))

    return op_lists, x_lists, exprs


def generate_func_set_all(num_x, num_op, show_funx=False, op_sets=['+', '-', '*', '/'], rep=True):
    """
    e.g. numx = 3, num_op = 2, must return 3 lists:
    [0, 1, 2] => var order
    ['+', '-'] => op order
    [0, 1] => barcket start,end pos
    """
    def dfs_op(tmp_op_list, max_num):
        if len(tmp_op_list) == max_num:
            if tmp_op_list not in op_lists:
                op_lists.append(tmp_op_list)
            return
        
        elif len(tmp_op_list) < max_num and len(tmp_op_list) > 0:
            if tmp_op_list not in op_lists:
                op_lists.append(tmp_op_list)
        
        for op in op_sets:
            dfs_op(tmp_op_list + [op], max_num)
        return
    
    def dfs_x(tmp_x_list, max_var_num, max_num):
        if len(tmp_x_list) == max_num:
            if len(set(tmp_x_list)) == max_var_num:
                x_lists.append(tmp_x_list)
            return
        
        elif len(tmp_x_list) < max_num and len(tmp_x_list) > 0:
            if tmp_x_list not in x_lists:
                x_lists.append(tmp_x_list)
        
        for x in x_sets:
            if rep or (not rep and x not in tmp_x_list):
                dfs_x(tmp_x_list + [x], max_var_num, max_num)
        return

    def dfs_parenthesis(expr_list, left, right):
        if left > right:
            return []
        ans = []  
        for i in range(left, right + 1):
            if 'x' in expr_list[i]:
                continue
            ans_left = dfs_parenthesis(expr_list, left, i - 1)
            ans_right = dfs_parenthesis(expr_list, i + 1, right)

            for a in ans_left:
                for b in ans_right:
                    if a.count('x') > 1:
                        a = '(' + a + ')'
                    if b.count('x') > 1:
                        b = '(' + b + ')'
                    cur = ""
                    if expr_list[i] == "+":
                        cur = a + '+' + b
                    elif expr_list[i] == "-":
                        cur = a + '-' + b
                    elif expr_list[i] == '*':
                        cur = a + '*' + b
                    elif expr_list[i] == '/':
                        cur = a + '/' + b
                    ans.append(cur)
        if not ans:
            ans.append(expr_list[left])
        return ans

    x_sets = [_ for _ in range(num_x)]
    n_op = len(op_sets)

    op_lists = []
    x_lists = []

    # Generate all OP combinations, all X combinations
    dfs_op([], max_num=num_op)
    dfs_x([], max_var_num=num_x, max_num=num_op+1)

    exprs = []
    symps = []

    # Make up all operations, add parentthesis, remove repeated cases
    for op_list in op_lists:
        for x_list in x_lists:
            if len(x_list) != len(op_list) + 1:
                continue
            expr = ""
            for i in range(len(op_list)):
                expr = expr + "x" + str(x_list[i])# + ' '
                expr = expr + op_list[i]# + ' '
            # Make up expression without ()
            expr = expr + "x" + str(x_list[-1])
            
            # Make up expression with ()
            e_list = expr.replace(' ', '').replace("+", " + ").replace("-", " - ").replace("*", " * ").replace("/", " / ").split(' ')

            if rep:
                exprs_parenth = dfs_parenthesis(e_list, 0, len(e_list) - 1)
                exprs_parenth.insert(0, expr)
            else:
                exprs_parenth = [expr]

            # remove repeat
            for expr in exprs_parenth:
                symp = simplify(expr)
                if all(op not in str(symp) for op in op_sets):  # remove out of func space
                    continue
                if symp not in symps:
                    symps.append(symp)

                    e_list = str(symp).replace(' ', '').replace("+", " + ").replace("-", " - ").replace("*", " * ").replace("/", " / ").replace("(", " ( ").replace(")", " ) ").replace("**", " ** ").split(' ')

                    if any(e.isnumeric() for e in e_list):  # do not allow const in a expression, only to add manually
                        expr_simple = expr
                    else:
                        expr_simple = str(symp)

                    x_sets_ = list(set([e if 'x' in e else '' for e in e_list]))
                    x_sets_.remove('')
                    if x_sets_ is None:
                        continue

                    exprs.append(expr_simple.replace(' ',''))
    return op_lists, x_lists, exprs


def is_func(str_1, rets, vars_sets, prec=1e-1, return_loss=False):
    var_dict = {}
    for _, var_set in enumerate(vars_sets):
        for i, var in enumerate(var_set):
            var_dict.update({
                f"x{i}": var,
            })
        try:
            ret_1 = eval(str_1, var_dict)
            rets[_] = eval(rets[_])
            if not isinstance(rets[_], int) and not isinstance(rets[_], float):
                if not return_loss:
                    return False
                else:
                    return False, abs(ret_1)
        except:
            if not return_loss:
                return False
            else:
                return False, 1e20
        if abs(ret_1 - rets[_]) > prec:
            if not return_loss:
                return False
            else:
                return False, abs(ret_1 - rets[_])
    if not return_loss:
        return True
    else:
        return True, abs(ret_1 - rets[_])