import numpy as np
import matplotlib.pyplot as plt
from sklearn.utils import check_random_state
from .genetic import SymbolicRegressor
import numpy as np
import pandas as pd
import torch
from pmlb import fetch_data
from sklearn.model_selection import train_test_split
import ast

import sympy as sp
import re

from ..constr2 import const


def make_hashable(obj):
    if isinstance(obj, list):
        return tuple(make_hashable(item) for item in obj)
    return obj


def in_str(preorder1, preorder, x_train_with_ones, y2_cache):
    if preorder.__len__() == 0:
        return False, y2_cache
    for preorder2 in preorder:
        hashable_preorder2 = make_hashable(preorder2)
        y1 = const(preorder1, x_train_with_ones)
        if hashable_preorder2 in y2_cache:
            y2 = y2_cache[hashable_preorder2]
        else:
            if not isinstance(preorder2, list):
                if isinstance(preorder2, int):
                    y2 = np.full_like(x_train_with_ones[:, 0], preorder2)
                else:
                    preorder2_l = [preorder2]
                    y2 = const(preorder2_l, x_train_with_ones)
            else:
                y2 = const(preorder2, x_train_with_ones)

            y2_cache[hashable_preorder2] = tuple(y2.flatten())

        arr = y1 - y2
        is_all_zero = np.all(arr == 0)  
        if is_all_zero:
            return True, y2_cache
    return False, y2_cache


def is_all_zeros_or_all_ones(tensor):

    is_all_zeros = torch.all(tensor == 0)

    is_all_ones = torch.all(tensor == 1)

    return is_all_zeros or is_all_ones

def find_subtrees(expr, min_height, max_height):
    

    def get_height(node):
        if isinstance(node, ast.Call):  
            return 1 + max(get_height(arg) for arg in node.args)
        elif isinstance(node, ast.UnaryOp):  
            return 1 + get_height(node.operand)
        else:  
            return 1

    def get_subtrees(node):
        subtrees = []
        height = get_height(node)
        if min_height <= height <= max_height:
            subtrees.append(ast_to_string(node))  
        if isinstance(node, ast.Call):
            for arg in node.args:
                subtrees.extend(get_subtrees(arg))  
        elif isinstance(node, ast.UnaryOp):
            subtrees.extend(get_subtrees(node.operand))  
        return subtrees

    def ast_to_string(node):

        if isinstance(node, ast.Call):  
            func_name = ast_to_string(node.func)
            args = ", ".join(ast_to_string(arg) for arg in node.args) 
            return f"{func_name}({args})"
        elif isinstance(node, ast.Name):  
            return node.id
        elif isinstance(node, ast.Constant):  
            return str(node.value)
        elif isinstance(node, ast.Num):  
            return str(node.n)
        elif isinstance(node, ast.UnaryOp):  
            if isinstance(node.op, ast.USub): 
                return f"-{ast_to_string(node.operand)}"
            else:
                raise ValueError(f"Unsupported unary operator: {type(node.op)}")
        else:
            raise ValueError(f"Unsupported node type: {type(node)}")


    tree = ast.parse(expr, mode='eval')
    return get_subtrees(tree.body)


EPSILON = 0.001
EXP_THRESHOLD = 80.
INF = 1e6


def _protected_division(x1, x2):
    with np.errstate(divide='ignore', invalid='ignore'):
        return np.where(np.abs(x2) > 0.001, np.divide(x1, x2), 1.)


def _protected_sqrt(x1):
    return np.sqrt(np.abs(x1))


def _protected_log(x1):
    with np.errstate(divide='ignore', invalid='ignore'):
        return np.where(np.abs(x1) > 0.001, np.log(np.abs(x1)), 0.)


def _sigmoid(x1):
    with np.errstate(over='ignore', under='ignore'):
        return 1 / (1 + np.exp(-x1))


def _id(x1):
    with np.errstate(over='ignore', under='ignore'):
        return x1


def split_expression(expression):

    expression = expression.replace(" ", "")

   
    tokens = re.split(r'[(),]\s*', expression)


    tokens = [token for token in tokens if token]

    return tokens


def trans(expression,X_feature):

    tokens = split_expression(expression)
    tokens_final = []
    stringlist_save = ['x' + str(i+1) for i in range(X_feature + 1)]
    for token in tokens:
        if token not in stringlist_save and token not in ['add', 'sub', 'mul', 'div', 'sin', 'cos', 'sig', 'log',
                                                             'sqrt', 'id']:
            tokens_final.append('x0')
        else:
            tokens_final.append(token)
    return tokens_final


def cul(expr, X):
    num_columns = X.shape[1]
    dict = {}

    for i in range(num_columns):
        globals()[f'x{i + 1}'] = X[:, i:i + 1]
        dict[f'x{i + 1}'] = globals()[f'x{i + 1}']
    dict.update({
        'add': np.add,
        'sub': np.subtract,
        'mul': np.multiply,
        'log': _protected_log,
        'sin': np.sin,
        'cos': np.cos,
        'sqrt': _protected_sqrt,
        'sig': _sigmoid,
        'div': _protected_division,
        'id': _id
    })
    result = eval(expr, dict)
    result_tensor = torch.from_numpy(result)
    result_tensor[torch.isnan(result_tensor)] = 1
    result_tensor[torch.isinf(result_tensor)] = 1
    return result_tensor

def expend_expr(X_train,y_train,list,most,sub_list,sub_list_value,max_height,list1,list2,x_train_with_ones,y2_cache ):
    X = X_train.numpy()
    y = y_train.squeeze(1).numpy()
    X_feature = X.shape[1]
    def replace_subtree(subtree, X_feature):
      for i in range(X_feature - 1, -1, -1):  
          old_str = f'X{i}'
          new_str = f'x{i + 1}'
          subtree = subtree.replace(old_str, new_str)
      return subtree
    stringlist_save = ['x' + str(i+1) for i in range(X_feature + 1)]
    sub_list_new = sub_list
    sub_list_new_value = sub_list_value
    est_gp = SymbolicRegressor(population_size=len(list),
                               generations=20, stopping_criteria=0.01,
                               p_crossover=0.5, p_subtree_mutation=0.1,
                               p_hoist_mutation=0.1, p_point_mutation=0.3,
                               verbose=0,
                               function_set=('add','sub','mul','div','sin','cos','sig','log','sqrt','id'),metric='rmse',exprlist = list)
    est_gp.fit(X, y)
    mapmap = est_gp.mapping
    sorted_mapping = dict(sorted(mapmap.items(), key=lambda item: item[1]))
    tmp_list = []


    count = 0
    for expr in sorted_mapping:
        if count == most:
            break
        min_height = 2


        subtrees = find_subtrees(expr, min_height, max_height)
        
        for subtree in subtrees:
            subtree = replace_subtree(subtree, X_feature)
            if count == most:
                break
            if any(item in subtree for item in stringlist_save):
                tran_expr = trans(subtree,X_feature)
                if tran_expr not in sub_list_new :
                    flag1, y2_cache = in_str(tran_expr.copy(), list1.copy(), x_train_with_ones,y2_cache)
                    if not flag1:
                        flag2, y2_cache = in_str(tran_expr.copy(),list2.copy(),x_train_with_ones, y2_cache)
                        if not flag2:
                            flag3, y2_cache = in_str(tran_expr.copy(), tmp_list.copy(), x_train_with_ones, y2_cache)
                            if not flag3:
                                v = cul(subtree,X)
                                if not is_all_zeros_or_all_ones(v):
                                    sub_list_new.append(tran_expr)
                                    sub_list_new_value.append(v)
                                    count = count + 1
                                    tmp_list.append(tran_expr)
    return sub_list_new, sub_list_new_value,y2_cache
