import torch
import numpy as np
import re
from functools import partial
import numpy as np
from joblib import wrap_non_picklable_objects

__all__ = ['make_function']

from scipy.optimize import minimize


class _Function(object):
    def __init__(self, function, name, arity):
        self.function = function
        self.name = name
        self.arity = arity

    def __call__(self, *args):
        return self.function(*args)


def make_function(*, function, name, arity, wrap=True):
    if not isinstance(arity, int):
        raise ValueError('arity must be an int, got %s' % type(arity))
    if not isinstance(function, np.ufunc):
        if function.__code__.co_argcount != arity:
            raise ValueError('arity %d does not match required number of '
                             'function arguments of %d.'
                             % (arity, function.__code__.co_argcount))
    if not isinstance(name, str):
        raise ValueError('name must be a string, got %s' % type(name))
    if not isinstance(wrap, bool):
        raise ValueError('wrap must be an bool, got %s' % type(wrap))

    args = [np.ones(10) for _ in range(arity)]
    try:
        function(*args)
    except (ValueError, TypeError):
        raise ValueError('supplied function %s does not support arity of %d.'
                         % (name, arity))
    if not hasattr(function(*args), 'shape'):
        raise ValueError('supplied function %s does not return a numpy array.'
                         % name)
    if function(*args).shape != (10,):
        raise ValueError('supplied function %s does not return same shape as '
                         'input vectors.' % name)

    args = [np.zeros(10) for _ in range(arity)]
    if not np.all(np.isfinite(function(*args))):
        raise ValueError('supplied function %s does not have closure against '
                         'zeros in argument vectors.' % name)
    args = [-1 * np.ones(10) for _ in range(arity)]
    if not np.all(np.isfinite(function(*args))):
        raise ValueError('supplied function %s does not have closure against '
                         'negatives in argument vectors.' % name)

    if wrap:
        return _Function(function=wrap_non_picklable_objects(function),
                         name=name,
                         arity=arity)
    return _Function(function=function,
                     name=name,
                     arity=arity)

def _protected_division(x1, x2):
    with np.errstate(divide='ignore', invalid='ignore'):
        return np.where(np.abs(x2) > 0.001, np.divide(x1, x2), 1.)


def _protected_sqrt(x1):
    return np.sqrt(np.abs(x1))


def _protected_log(x1):
    with np.errstate(divide='ignore', invalid='ignore'):
        return np.where(np.abs(x1) > 0.001, np.log(np.abs(x1)), 0.)


def _protected_inverse(x1):
    with np.errstate(divide='ignore', invalid='ignore'):
        return np.where(np.abs(x1) > 0.001, 1. / x1, 0.)


def _sigmoid(x1):
    with np.errstate(over='ignore', under='ignore'):
        return 1 / (1 + np.exp(-x1))




add2 = _Function(function=np.add, name='add', arity=2)
sub2 = _Function(function=np.subtract, name='sub', arity=2)
mul2 = _Function(function=np.multiply, name='mul', arity=2)
div2 = _Function(function=_protected_division, name='div', arity=2)
sqrt1 = _Function(function=_protected_sqrt, name='sqrt', arity=1)
log1 = _Function(function=_protected_log, name='log', arity=1)
neg1 = _Function(function=np.negative, name='neg', arity=1)
inv1 = _Function(function=_protected_inverse, name='inv', arity=1)
abs1 = _Function(function=np.abs, name='abs', arity=1)
max2 = _Function(function=np.maximum, name='max', arity=2)
min2 = _Function(function=np.minimum, name='min', arity=2)
sin1 = _Function(function=np.sin, name='sin', arity=1)
cos1 = _Function(function=np.cos, name='cos', arity=1)
tan1 = _Function(function=np.tan, name='tan', arity=1)
sig1 = _Function(function=_sigmoid, name='sig', arity=1)
id1 = _Function(function=np.copy, name='id', arity=1)

_function_map = {'add': add2,
                 'sub': sub2,
                 'mul': mul2,
                 'div': div2,
                 'sqrt': sqrt1,
                 'log': log1,
                 'abs': abs1,
                 'neg': neg1,
                 'inv': inv1,
                 'max': max2,
                 'min': min2,
                 'sin': sin1,
                 'cos': cos1,
                 'tan': tan1,
                 'sig': sig1,
                 'id': id1}


class ConstOptimizer(object):

    def __init__(self, **kwargs):
        self.kwargs = kwargs

    def __call__(self, f, x0):
        raise NotImplementedError


class Dummy(ConstOptimizer):
    def __init__(self, **kwargs):
        super(Dummy, self).__init__(**kwargs)

    def __call__(self, f, x0):
        return x0


class ScipyMinimize(ConstOptimizer):
    def __init__(self, **kwargs):
        super(ScipyMinimize, self).__init__(**kwargs)

    def __call__(self, f, x0):
        with np.errstate(divide='ignore'):
            opt_result = partial(minimize, **self.kwargs)(f, x0)
        x = opt_result
        return x


def make_const_optimizer(name, **kwargs):
    const_optimizers = {
        "scipy": ScipyMinimize,
    }

    return const_optimizers[name](**kwargs)


def const(expr, X_train, y_train, mu, std):
    X = X_train.numpy()
    y = y_train.squeeze(1).numpy()
    unary_operators = {'sin', 'cos', 'sig', 'log', 'sqrt', 'id'}  
    binary_operators = {'add', 'sub', 'mul', 'div'}  

    def check_arity(str):
        if str in binary_operators:
            return 2
        elif str in unary_operators:
            return 1

    def r2_score(output, target):
        y_mean = torch.mean(target)

        r2 = 1 - torch.sum((output - target) ** 2) / torch.sum((target - y_mean) ** 2)
        return r2.item()

    def g(exp, x0):
        expr = exp.copy()
        node = expr[0]

        if node == 'x1':
            return X[:, 0]
        if node == 'x2':
            return X[:, 1]
        if node == 'x3':
            return X[:, 2]
        if node == 'x4':
            return X[:, 3]
        if node == 'x5':
            return X[:, 4]
        if node == 'x6':
            return X[:, 5]
        if isinstance(node, float):
            y_predict = np.repeat(node, X.shape[0])
            return y_predict
        apply_stack = []
        i = -1
        cont = 0
        for tmp in range(len(expr)):
            if expr[tmp] == 'x0':
                expr[tmp] = x0[cont]
                cont = cont + 1
        for node in expr:
            i += 1
            if node in binary_operators or node in unary_operators:
                apply_stack.append([i, node])
            else:
                apply_stack[-1].append(node)

            while len(apply_stack[-1]) - 1 == 1 + check_arity(apply_stack[-1][1]):
                function = _function_map[apply_stack[-1][1]]

                terminals = []
                for t in apply_stack[-1][2:]:
                    if isinstance(t, str):
                        if t.startswith('x'):
                            try:
                                index = int(t[1:]) - 1
                                terminals.append(X[:, index])
                            except (ValueError, IndexError):
                                print("error")
                    elif isinstance(t, float):
                        terminals.append(np.repeat(t, X.shape[0]))
                    else:
                        terminals.append(t)
                intermediate_result = function(*terminals)
                if len(apply_stack) != 1:
                    list = []
                    apply_stack.pop()
                    for j in apply_stack[::-1]:
                        list.append(j[0])

                    apply_stack[-1].append(intermediate_result)
                else:
                    result = intermediate_result
                    return result
    const_sum = 0
    for i in range(len(expr)):
        if expr[i] == 'x0':
            const_sum += 1

    if const_sum == 0:
        x0 = [1.0 for _ in range(const_sum)]
        y_predict = g(expr, x0)

        y_predict_tensor = torch.tensor(y_predict, dtype=torch.float32).unsqueeze(1)
        r2 = r2_score(y_predict_tensor * std + mu, y_train)
        return x0, r2

    def rmse(x, y):
        err = (x - y) ** 2
        loss = np.mean(err)
        rmse = np.sqrt(loss)
        return rmse

    def f(x0):
        node = expr
        exprr = expr.copy()
        cont = 0
        for tmp in range(len(exprr)):
            if exprr[tmp] == 'x0':
                exprr[tmp] = x0[cont]
                cont = cont + 1
        apply_stack = []
        i = -1
        pos = 0
        for node in exprr:
            i += 1
            if node in binary_operators or node in unary_operators:
                apply_stack.append([i, node])
            else:
                apply_stack[-1].append(node)

            while len(apply_stack[-1]) - 1 == 1 + check_arity(apply_stack[-1][1]):
                function = _function_map[apply_stack[-1][1]]

                terminals = []
                for t in apply_stack[-1][2:]:
                    if isinstance(t, str):
                        if t.startswith('x'):
                            try:
                                index = int(t[1:]) - 1
                                terminals.append(X[:, index])
                            except (ValueError, IndexError):
                                print("error")
                    elif isinstance(t, float):
                        terminals.append(np.repeat(t, X.shape[0]))
                    else:
                        terminals.append(t)
                intermediate_result = function(*terminals)
                if len(apply_stack) != 1:
                    list = []
                    apply_stack.pop()
                    for j in apply_stack[::-1]:
                        list.append(j[0])
                    apply_stack[-1].append(intermediate_result)
                else:
                    result = rmse(y, intermediate_result)
                    return result

    kwargs = {'method': 'L-BFGS-B', 'options': {'gtol': 0.001}}
    cc = make_const_optimizer("scipy", **kwargs)
    x0 = [1.0 for _ in range(const_sum)]
    result = cc(f, x0)
    consts = result.x
    y_predict = g(expr, consts)
    y_predict_tensor = torch.tensor(y_predict, dtype=torch.float32).unsqueeze(1)
    r2 = r2_score(y_predict_tensor*std+mu, y_train)
    return consts,r2