import torch
import numpy as np
import re


import numpy as np
from joblib import wrap_non_picklable_objects
from functools import partial
__all__ = ['make_function']

from scipy.optimize import minimize


class _Function(object):
    def __init__(self, function, name, arity):
        self.function = function
        self.name = name
        self.arity = arity

    def __call__(self, *args):
        return self.function(*args)


def make_function(*, function, name, arity, wrap=True):
    if not isinstance(arity, int):
        raise ValueError('arity must be an int, got %s' % type(arity))
    if not isinstance(function, np.ufunc):
        if function.__code__.co_argcount != arity:
            raise ValueError('arity %d does not match required number of '
                             'function arguments of %d.'
                             % (arity, function.__code__.co_argcount))
    if not isinstance(name, str):
        raise ValueError('name must be a string, got %s' % type(name))
    if not isinstance(wrap, bool):
        raise ValueError('wrap must be an bool, got %s' % type(wrap))


    args = [np.ones(10) for _ in range(arity)]
    try:
        function(*args)
    except (ValueError, TypeError):
        raise ValueError('supplied function %s does not support arity of %d.'
                         % (name, arity))
    if not hasattr(function(*args), 'shape'):
        raise ValueError('supplied function %s does not return a numpy array.'
                         % name)
    if function(*args).shape != (10,):
        raise ValueError('supplied function %s does not return same shape as '
                         'input vectors.' % name)
    args = [np.zeros(10) for _ in range(arity)]
    if not np.all(np.isfinite(function(*args))):
        raise ValueError('supplied function %s does not have closure against '
                         'zeros in argument vectors.' % name)
    args = [-1 * np.ones(10) for _ in range(arity)]
    if not np.all(np.isfinite(function(*args))):
        raise ValueError('supplied function %s does not have closure against '
                         'negatives in argument vectors.' % name)

    if wrap:
        return _Function(function=wrap_non_picklable_objects(function),
                         name=name,
                         arity=arity)
    return _Function(function=function,
                     name=name,
                     arity=arity)


def _protected_division(x1, x2):
    with np.errstate(divide='ignore', invalid='ignore'):
        return np.where(np.abs(x2) > 0.001, np.divide(x1, x2), 1.)


def _protected_sqrt(x1):
    return np.sqrt(np.abs(x1))


def _protected_log(x1):
    with np.errstate(divide='ignore', invalid='ignore'):
        return np.where(np.abs(x1) > 0.001, np.log(np.abs(x1)), 0.)


def _protected_inverse(x1):
    with np.errstate(divide='ignore', invalid='ignore'):
        return np.where(np.abs(x1) > 0.001, 1. / x1, 0.)


def _sigmoid(x1):
    with np.errstate(over='ignore', under='ignore'):
        return 1 / (1 + np.exp(-x1))


add2 = _Function(function=np.add, name='add', arity=2)
sub2 = _Function(function=np.subtract, name='sub', arity=2)
mul2 = _Function(function=np.multiply, name='mul', arity=2)
div2 = _Function(function=_protected_division, name='div', arity=2)
sqrt1 = _Function(function=_protected_sqrt, name='sqrt', arity=1)
log1 = _Function(function=_protected_log, name='log', arity=1)
neg1 = _Function(function=np.negative, name='neg', arity=1)
inv1 = _Function(function=_protected_inverse, name='inv', arity=1)
abs1 = _Function(function=np.abs, name='abs', arity=1)
max2 = _Function(function=np.maximum, name='max', arity=2)
min2 = _Function(function=np.minimum, name='min', arity=2)
sin1 = _Function(function=np.sin, name='sin', arity=1)
cos1 = _Function(function=np.cos, name='cos', arity=1)
tan1 = _Function(function=np.tan, name='tan', arity=1)
sig1 = _Function(function=_sigmoid, name='sig', arity=1)
id1 = _Function(function=np.copy, name='id', arity=1)


_function_map = {'add': add2,
                 'sub': sub2,
                 'mul': mul2,
                 'div': div2,
                 'sqrt': sqrt1,
                 'log': log1,
                 'abs': abs1,
                 'neg': neg1,
                 'inv': inv1,
                 'max': max2,
                 'min': min2,
                 'sin': sin1,
                 'cos': cos1,
                 'tan': tan1,
                 'sig': sig1,
                 'id': id1}

class ConstOptimizer(object):

    def __init__(self, **kwargs):
        self.kwargs = kwargs

    def __call__(self, f, x0):
        raise NotImplementedError


class Dummy(ConstOptimizer):
    def __init__(self, **kwargs):
        super(Dummy, self).__init__(**kwargs)

    def __call__(self, f, x0):
        return x0


class ScipyMinimize(ConstOptimizer):
    def __init__(self, **kwargs):
        super(ScipyMinimize, self).__init__(**kwargs)

    def __call__(self, f, x0):
        with np.errstate(divide='ignore'):
            opt_result = partial(minimize, **self.kwargs)(f, x0)
        x = opt_result
        return x
def make_const_optimizer(name, **kwargs):
    const_optimizers = {
        "scipy": ScipyMinimize,
    }

    return const_optimizers[name](**kwargs)
    
class Program(object):
    def __init__(self,):
        self.program = []

        self.raw_fitness_ = None
        self.fitness_ = None
        self.infix_str = ""
        self.reward = 0
        self.actual_value = None
        self.r2 = 0
        self.consts = []
        self.actual_program = []
        kwargs = {'method': 'L-BFGS-B', 'options': {'gtol': 0.001}}
        self.cc = make_const_optimizer("scipy",**kwargs)
    def build_program(self,op,var1,var2=None):
        if not self.program:
            self.program.append(op)
            if isinstance(var1, list):
                self.program += var1
            else:
                self.program.append(var1)
            if var2:
                if isinstance(var2, list):
                    self.program += var2
                else:
                    self.program.append(var2)
        else:
            self.program.insert(0, "add")
            self.program.append(op)
            if isinstance(var1, list):
                self.program += var1
            else:
                self.program.append(var1)
            if var2:
                if isinstance(var2, list):
                    self.program += var2
                else:
                    self.program.append(var2)

    def str(self):
        return self.infix_str

    def get_infix_str(self):
        n_tokens = len(self.program)
        curr_stack = []
        start = n_tokens - 1
        unary_operators = {'sin', 'cos', 'sig', 'log', 'sqrt', 'id'}  
        binary_operators = {'add', 'sub', 'mul', 'div'}  

        for i in range(start, -1, -1):
            token = self.program[i]
            arity = 0
            if token in binary_operators: 
                arity = 2
            elif token in unary_operators: 
                arity = 1
            args = curr_stack[-arity:][::-1]
            if arity == 0:
                res = token
            elif arity == 1:
                res = "%s(%s)" % (token, args[0])
            elif arity == 2:
                res = "%s(%s,%s)" % (token, args[0], args[1])
            if arity > 0:
                curr_stack = curr_stack[:-arity]
            curr_stack.append(res)
        self.infix_str = curr_stack[0]

    def const_optimize(self, X_train, y_train,mu,std):
        self.actual_program = []
        X = X_train.numpy()
        y = y_train.squeeze(1).numpy()
        unary_operators = {'sin', 'cos', 'sig', 'log', 'sqrt', 'id'}  
        binary_operators = {'add', 'sub', 'mul', 'div'}  
        def rmse(x, y):
            err = (x - y) ** 2
            loss = np.mean(err)
            rmse = np.sqrt(loss)
            return rmse

        def check_arity(str):
            if str in binary_operators:
                return 2
            elif str in unary_operators:
                return 1
        def r2_score(output, target):
            y_mean = torch.mean(target)

            r2 = 1 - torch.sum((output - target) ** 2) / torch.sum((target - y_mean) ** 2)
            if torch.isnan(r2) or torch.isinf(r2) or r2 <= 0:
                return 0
            return r2.item()
        def g(exp,x0):
            expr = exp.copy()
            apply_stack = []
            i = -1
            cont = 0
            for tmp in range(len(expr)):
                if expr[tmp] == 'x0':
                        expr[tmp] = x0[cont]
                        cont = cont + 1        
            for node in expr:
                i += 1
                if node in binary_operators or node in unary_operators:
                    apply_stack.append([i, node])
                else:
                    apply_stack[-1].append(node)
                    

                while len(apply_stack[-1]) - 1 == 1 + check_arity(apply_stack[-1][1]):
                    function = _function_map[apply_stack[-1][1]]

                    terminals = []
                    for t in apply_stack[-1][2:]:
                        if isinstance(t, str):
                             if t.startswith('x'):
                                try:
                                    index = int(t[1:]) - 1
                                    terminals.append(X[:, index])
                                except (ValueError, IndexError):
                                    print("error")
                        elif isinstance(t, float):
                            terminals.append(np.repeat(t, X.shape[0]))
                        else:
                            terminals.append(t)
                    intermediate_result = function(*terminals)
                    if len(apply_stack) != 1:
                        list = []
                        apply_stack.pop()
                        for j in apply_stack[::-1]:
                            list.append(j[0])
                        apply_stack[-1].append(intermediate_result)
                    else:
                        result = intermediate_result
                        return result
        const_sum = 0
        list = []
        for i in range(len(self.program)):
            if self.program[i] == 'x0':
                const_sum += 1
                list.append(i)
        x0 = [1.0 for _ in range(const_sum)]
        for cnt in range(len(self.consts)):
            x0[cnt] = self.consts[cnt]
        if const_sum == 0:
            y_predict = g(self.program,x0)
            y_predict_tensor = torch.tensor(y_predict, dtype=torch.float32).unsqueeze(1)
            sigma_targ = y_train.std()
            mse = torch.mean((y_predict_tensor*std+mu - y_train) ** 2)
            RMSE = torch.sqrt(mse)
            NRMSE = (1 / sigma_targ) * RMSE
            self.reward = 1 / (1 + NRMSE)
            self.actual_value = y_predict_tensor
            self.r2 = r2_score(self.actual_value*std+mu, y_train)
            for program_tmp in self.program:
                self.actual_program.append(program_tmp)
            if self.r2 == 1.0:
                return self.actual_program,self.r2,self.infix_str
            return None,None,None


        def f(x0):
            node = self.program[0]
            if node == 'x0':
                return np.repeat(1, X.shape[0])
            if node.startswith('x'):
                try:
                  index = int(node[1:]) - 1
                  return X[:, index]
                except (ValueError, IndexError):
                  print("error")
            expr = self.program.copy()
            cont = 0
            for tmp in range(len(expr)):
                if expr[tmp] == 'x0':
                        expr[tmp] = x0[cont]
                        cont = cont + 1
            apply_stack = []
            i = -1
            pos = 0
            for node in expr:
                i += 1
                if node in binary_operators or node in unary_operators:
                    apply_stack.append([i, node])
                else:
                    apply_stack[-1].append(node)

                while len(apply_stack[-1]) - 1 == 1 + check_arity(apply_stack[-1][1]):
                    function = _function_map[apply_stack[-1][1]]

                    terminals = []
                    for t in apply_stack[-1][2:]:
                        if isinstance(t, str):
                             if t.startswith('x'):
                                try:
                                    index = int(t[1:]) - 1
                                    terminals.append(X[:, index])
                                except (ValueError, IndexError):
                                    print("error")
                        elif isinstance(t, float):
                            terminals.append(np.repeat(t, X.shape[0]))
                        else:
                            terminals.append(t)
                    intermediate_result = function(*terminals)
                    if len(apply_stack) != 1:
                        list = []
                        apply_stack.pop()
                        for j in apply_stack[::-1]:
                            list.append(j[0])
                        apply_stack[-1].append(intermediate_result)
                    else:
                        result = rmse(y, intermediate_result)
                        return result

        result = self.cc(f, x0)
        self.consts = result.x
        y_predict = g(self.program,self.consts)
        

        y_predict_tensor = torch.tensor(y_predict, dtype=torch.float32).unsqueeze(1)
        sigma_targ = y_train.std()
        mse = torch.mean((y_predict_tensor*std+mu - y_train) ** 2)
        RMSE = torch.sqrt(mse)
        NRMSE = (1 / sigma_targ) * RMSE
        self.reward = 1 / (1 + NRMSE)
        self.actual_value = y_predict_tensor
        self.r2 = r2_score(self.actual_value*std+mu, y_train)
       
        cont = 0
        for program_tmp in self.program:
            self.actual_program.append(program_tmp)
        for tmp in range(len(self.actual_program)):
            if self.actual_program[tmp] == 'x0':
                self.actual_program[tmp] = self.consts[cont]
                cont = cont + 1
        if self.r2 == 1.0:
            return self.actual_program,self.r2,self.infix_str
        return None,None,None
          
    def const_sub(self, X_train, mm=6):
        X = X_train.numpy()
        unary_operators = {'sin', 'cos', 'sig', 'log', 'sqrt', 'id'}  
        binary_operators = {'add', 'sub', 'mul', 'div'} 
        results = []

        def check_arity(str):
            if str in binary_operators:
                return 2
            elif str in unary_operators:
                return 1

        list = self.actual_program.copy()
        new_list = list[mm - 1:]
        while new_list.__len__() != 0:
            apply_stack = []
            i = -1
            new_list_copy = new_list.copy()
            for node_index in range(len(new_list_copy)):
                node = new_list_copy[node_index]
                i += 1
                if node in binary_operators or node in unary_operators:
                    apply_stack.append([i, node])
                else:
                    apply_stack[-1].append(node)

                while len(apply_stack[-1]) - 1 == 1 + check_arity(apply_stack[-1][1]):
                    function = _function_map[apply_stack[-1][1]]

                    terminals = []
                    for t in apply_stack[-1][2:]:
                        if isinstance(t, str):
                            if t.startswith('x'):
                                try:
                                    index = int(t[1:]) - 1
                                    terminals.append(X[:, index])
                                except (ValueError, IndexError):
                                    print("error")
                        elif isinstance(t, float):
                            terminals.append(np.repeat(t, X.shape[0]))
                        else:
                            terminals.append(t)
                    intermediate_result = function(*terminals)

                    if len(apply_stack) != 1:
                        list = []
                        apply_stack.pop()
                        for j in apply_stack[::-1]:
                            list.append(j[0])
                        apply_stack[-1].append(intermediate_result)
                    else:
                        result = intermediate_result
                        new_list = new_list[node_index + 1:]

                        result_tensor = torch.tensor(result, dtype=torch.float32).unsqueeze(1)
                        results.append(result_tensor)
                        break
        return results





