from functools import partial

from pmlb import fetch_data
from sklearn.model_selection import train_test_split
import numpy as np
def const(expr,x_train):
    unary_operators = {'sin', 'cos', 'sig', 'log', 'sqrt', 'id'} 
    binary_operators = {'add', 'sub', 'mul', 'div'} 

    class _Function(object):
        def __init__(self, function, name, arity):
            self.function = function
            self.name = name
            self.arity = arity

        def __call__(self, *args):
            return self.function(*args)

    def _protected_division(x1, x2):
        with np.errstate(divide='ignore', invalid='ignore'):
            return np.where(np.abs(x2) > 0.001, np.divide(x1, x2), 1.)

    def _protected_sqrt(x1):
        return np.sqrt(np.abs(x1))

    def _protected_log(x1):
        with np.errstate(divide='ignore', invalid='ignore'):
            return np.where(np.abs(x1) > 0.001, np.log(np.abs(x1)), 0.)

    def _protected_inverse(x1):
        with np.errstate(divide='ignore', invalid='ignore'):
            return np.where(np.abs(x1) > 0.001, 1. / x1, 0.)

    def _sigmoid(x1):
        with np.errstate(over='ignore', under='ignore'):
            return 1 / (1 + np.exp(-x1))

    def _id(x1):
        return x1

    def root(*args):
        list = args[0]
        for i in range(1, len(args)):
            list = np.add(list, args[i])
        return list

    add2 = _Function(function=np.add, name='add', arity=2)
    sub2 = _Function(function=np.subtract, name='sub', arity=2)
    mul2 = _Function(function=np.multiply, name='mul', arity=2)
    div2 = _Function(function=_protected_division, name='div', arity=2)
    sqrt1 = _Function(function=_protected_sqrt, name='sqrt', arity=1)
    log1 = _Function(function=_protected_log, name='log', arity=1)
    neg1 = _Function(function=np.negative, name='neg', arity=1)
    inv1 = _Function(function=_protected_inverse, name='inv', arity=1)
    abs1 = _Function(function=np.abs, name='abs', arity=1)
    max2 = _Function(function=np.maximum, name='max', arity=2)
    min2 = _Function(function=np.minimum, name='min', arity=2)
    sin1 = _Function(function=np.sin, name='sin', arity=1)
    cos1 = _Function(function=np.cos, name='cos', arity=1)
    tan1 = _Function(function=np.tan, name='tan', arity=1)
    sig1 = _Function(function=_sigmoid, name='sig', arity=1)
    id1 = _Function(function=_id, name='id', arity=1)
    root = _Function(function=root, name='root', arity=None)

    _function_map = {'add': add2,
                     'sub': sub2,
                     'mul': mul2,
                     'div': div2,
                     'sqrt': sqrt1,
                     'log': log1,
                     'abs': abs1,
                     'neg': neg1,
                     'inv': inv1,
                     'max': max2,
                     'min': min2,
                     'sin': sin1,
                     'cos': cos1,
                     'tan': tan1,
                     'sig': sig1,
                     'id': id1,
                     'root': root}



    X = x_train


    def check_arity(str):
        if str in binary_operators:
            return 2
        elif str in unary_operators:
            return 1

    def g():
        node = expr[0]

        if node.startswith('x'):
                try:
                  index = int(node[1:]) 
                  return X[:, index]
                except (ValueError, IndexError):
                  print("error")
        apply_stack = []
        i = -1
        for node in expr:
            i += 1
            if node in binary_operators or node in unary_operators:
                apply_stack.append([i, node])
            else:
                apply_stack[-1].append(node)

            while len(apply_stack[-1]) - 1 == 1 + check_arity(apply_stack[-1][1]):
                function = _function_map[apply_stack[-1][1]]

                terminals = []
                for t in apply_stack[-1][2:]:
                    if isinstance(t, str):
                        
                        if t.startswith('x'):
                                try:
                                    index = int(t[1:]) 
                                    terminals.append(X[:, index])
                                except (ValueError, IndexError):
                                    print("error")
                    else:
                        terminals.append(t)
                intermediate_result = function(*terminals)
                if len(apply_stack) != 1:
                    list = []
                    apply_stack.pop()
                    for j in apply_stack[::-1]:
                        list.append(j[0])
                    apply_stack[-1].append(intermediate_result)
                else:
                    result = intermediate_result
                    return result


    yy = g()
    return yy


