import re

import torch
from pmlb import fetch_data
from sklearn.model_selection import train_test_split
import numpy as np
from constrnew import const




def split_expression(expression):
    expression = expression.replace(" ", "")
    tokens = re.split(r'[(),]\s*', expression)
    tokens = [token for token in tokens if token]

    return tokens

def simplify_expression(expression):
    result = []
    i = 0
    print(expression)
    n = len(expression)

    while i < n:
        if expression[i:i + 3] == "id(":
            start = i + 3 
            bracket_count = 1  
            j = start

            while j < n:
                if expression[j] == "(":
                    bracket_count += 1
                elif expression[j] == ")":
                    bracket_count -= 1
                    if bracket_count == 0:
                        
                        inner_content = expression[start:j]  
                        result.append(inner_content)  
                        i = j + 1  
                        break
                j += 1
        else:
            result.append(expression[i])
            i += 1

    return "".join(result)
class Node:
    def __init__(self, val):
        self.val = val  
        self.children = []  

def build_tree(traversal):
    if not traversal:
        return None

    val = traversal.pop(0)
    node = Node(val)
    if val in ['add', 'sub', 'mul', 'div', 'sin', 'sig','sqrt','cos','log']:
        if val in ['add', 'sub', 'mul', 'div']:
            node.children.append(build_tree(traversal))
            node.children.append(build_tree(traversal))
        elif val in ['sin', 'sig','sqrt','cos','log']:
            node.children.append(build_tree(traversal))

    return node

def split_tree(root):
    subtrees = []
    opflag = []
    def _split(node):
        if node.val in ['add']:
            opflag.append(True)
            for child in node.children:
                _split(child)
        elif node.val in ['sub']:
            opflag.append(False)
            for child in node.children:
                _split(child)
        else:
            subtrees.append(node)

    _split(root)
    opflag.reverse()
    opflag.insert(0,True)
    return subtrees,opflag

def tree_to_prefix(node):
    if not node:
        return []
    result = [node.val]
    for child in node.children:
        result.extend(tree_to_prefix(child))
    return result

def compose(subtree_setset,opflag):
    sub_compose = []
    cnt = 0
    for sub in subtree_setset:
        if cnt != 0:
            sub_compose.insert(0, "add")

        sub_compose.append('mul')
        if opflag[cnt]:
            sub_compose.append('a')
        else:
            sub_compose.append('b')
        sub_compose += sub
        cnt = cnt + 1
    if cnt != 0:
        sub_compose.insert(0, "add")
    sub_compose.append('mul')
    sub_compose.append('a')
    sub_compose.append(1.0)
    x0_compose = []
    cnt = 0
    sub_compose_new = []
    for i in range(sub_compose.__len__()):
        if sub_compose[i] == 'a':
            x0_compose.append(cnt)
            cnt = cnt + 1
            sub_compose_new.append('x0')
        elif  sub_compose[i] == 'b':
            x0_compose.append(cnt)
            cnt = cnt + 1
            sub_compose_new += ['mul', 'x0', -1]
        elif sub_compose[i] == 'x0':
            cnt = cnt + 1
        else:
            sub_compose_new.append(sub_compose[i])
    return sub_compose_new, x0_compose

def compose2(subtree_setset,opflag):
    sub_compose = []
    cnt = 0
    for sub in subtree_setset:
        if cnt != 0:
            sub_compose.insert(0, "add")
        if opflag[cnt]:
            sub_compose += ['mul',1]
        else:
            sub_compose += ['mul', -1]
        cnt = cnt + 1
        sub_compose += sub
    return sub_compose
def fun(filename,random_state,infix,expression):
    X, y = fetch_data(filename, return_X_y=True, local_cache_dir="./datasets")
    x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.9, random_state=random_state)
    x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=random_state)
    X_train = torch.tensor(x_train, dtype=torch.float32)
    y_train = torch.tensor(y_train, dtype=torch.float32).unsqueeze(1)
    X_test = torch.tensor(x_test, dtype=torch.float32)
    y_test = torch.tensor(y_test, dtype=torch.float32).unsqueeze(1)
    mu = 0
    std = 1
    const0 = []

    for ll in expression:
        if isinstance(ll, float):
            const0.append(ll)
    simplified_expression = simplify_expression(infix)
    expr = split_expression(simplified_expression)
    cnt = 0
    for i in range(len(expr)):
        if expr[i] == 'x0':
            expr[i] = const0[cnt]
            cnt = cnt + 1
    c, r2 = const(expr, X_train, y_train ,mu = mu,std = std)
    best_r2 = r2
    best_len = expr.__len__()
    best_expr = expr
    root = build_tree(expr.copy())
    subtrees, opflag = split_tree(root)

    subtree_set = []
    for i, subtree in enumerate(subtrees):
        if tree_to_prefix(subtree) not in subtree_set:
            subtree_set.append(tree_to_prefix(subtree))
    tol = 0.01
    def split_data_randomly_with_overlap(data, num_splits=10):
        total_samples = data.shape[0]
        subsets = []
        subset_size = (int)(total_samples / num_splits)
        for _ in range(num_splits):
            indices = np.random.choice(total_samples, size=subset_size, replace=True)
            subset = data[indices, :]
            subsets.append(subset)

        return subsets

    yy = y_train.reshape(-1, 1)  

    Xyy = np.hstack((X_train, yy))
    datasets = split_data_randomly_with_overlap(Xyy, num_splits=10)
    lenlen = subtree_set.__len__()
    m = x_train.shape[1]
    for i in range(lenlen-1):
        module_cv = []
        for j in range(len(subtree_set)):
            new_list = [subtree_set[k] for k in range(len(subtree_set)) if k != j]
            new_list2 = [opflag[k] for k in range(len(opflag)) if k != j]
            coef = [[0 for _ in range(len(new_list)+1)] for _ in range(len(datasets))]
            new_compose, new_x0 = compose(new_list,new_list2)
            for datacnt in range(len(datasets)):
                dataset = datasets[datacnt]
                xxx = torch.tensor(dataset[:,:m], dtype=torch.float32)
                yyy = torch.tensor(dataset[:,m], dtype=torch.float32).unsqueeze(1)
                optimized_constants, r2 = const(new_compose, xxx, yyy,mu = mu,std = std)
                num = 0
                for cnt in range(len(optimized_constants)):
                    if cnt in new_x0:
                        coef[datacnt][num] = optimized_constants[cnt]
                        num = num + 1
            array = np.array(coef)
            mean = np.mean(array, axis=0)
            variance = np.var(array, axis=0)
            cv = np.abs(variance / mean)
            cv_mean = np.mean(cv)
            module_cv.append(cv_mean)

        module_cv_array = np.array(module_cv)
        delete_id = module_cv_array.argsort().copy()[0]
        new_list = [subtree_set[k] for k in range(len(subtree_set)) if k != delete_id]
        new_list2 = [opflag[k] for k in range(len(opflag)) if k != delete_id]
        expr = compose2(new_list,new_list2)

        optimized_constants, r2 = const(expr, X_train, y_train,mu = mu,std = std)
        if best_r2 - r2 < tol:
            best_r2 = r2
            best_len = expr.__len__()
            best_expr = expr
            subtree_set = new_list
            opflag = new_list2
        else:
            break
    optimized_constants, r2_test = const(best_expr, X_test, y_test, mu=mu, std=std)
    return best_expr, r2_test

best_expr, r2_test = fun(filename= "feynman_II_15_5",random_state=14423,infix="add(sub(sub(cos(x0),mul(x2,mul(x1,cos(x3)))),div(x0,x3)),div(x0,add(mul(x2,cos(x3)),mul(x1,cos(x3)))))"
,expression=['add', 'sub', 'sub', 'cos', 1.5737473488907576, 'mul', 'x2', 'mul', 'x1', 'cos', 'x3', 'div', -0.03794409024877454, 'x3', 'div', 0.0012612701211745549, 'add', 'mul', 'x2', 'cos', 'x3', 'mul', 'x1', 'cos', 'x3']
)
print(best_expr)
print(r2_test)

