from functools import partial

from pmlb import fetch_data
from sklearn.model_selection import train_test_split
import numpy as np

from scipy.optimize import minimize

class ConstOptimizer(object):

    def __init__(self, **kwargs):
        self.kwargs = kwargs

    def __call__(self, f, x0):
        raise NotImplementedError


class Dummy(ConstOptimizer):


    def __init__(self, **kwargs):
        super(Dummy, self).__init__(**kwargs)

    def __call__(self, f, x0):
        return x0


class ScipyMinimize(ConstOptimizer):


    def __init__(self, **kwargs):
        super(ScipyMinimize, self).__init__(**kwargs)

    def __call__(self, f, x0):
        with np.errstate(divide='ignore'):
            opt_result = partial(minimize, **self.kwargs)(f, x0)
        x = opt_result
        return x
def make_const_optimizer(name, **kwargs):


    const_optimizers = {
        "scipy": ScipyMinimize,
    }

    return const_optimizers[name](**kwargs)
def const(expr,x_train,y_train,mu,std):
    unary_operators = {'sin', 'cos', 'sig', 'log', 'sqrt', 'id'}  
    binary_operators = {'add', 'sub', 'mul', 'div'} 

    class _Function(object):
        def __init__(self, function, name, arity):
            self.function = function
            self.name = name
            self.arity = arity

        def __call__(self, *args):
            return self.function(*args)

    def make_function(*, function, name, arity, wrap=True):
        if not isinstance(arity, int):
            raise ValueError('arity must be an int, got %s' % type(arity))
        if not isinstance(function, np.ufunc):
            if function.__code__.co_argcount != arity:
                raise ValueError('arity %d does not match required number of '
                                 'function arguments of %d.'
                                 % (arity, function.__code__.co_argcount))
        if not isinstance(name, str):
            raise ValueError('name must be a string, got %s' % type(name))
        if not isinstance(wrap, bool):
            raise ValueError('wrap must be an bool, got %s' % type(wrap))


        args = [np.ones(10) for _ in range(arity)]
        try:
            function(*args)
        except (ValueError, TypeError):
            raise ValueError('supplied function %s does not support arity of %d.'
                             % (name, arity))
        if not hasattr(function(*args), 'shape'):
            raise ValueError('supplied function %s does not return a numpy array.'
                             % name)
        if function(*args).shape != (10,):
            raise ValueError('supplied function %s does not return same shape as '
                             'input vectors.' % name)

        args = [np.zeros(10) for _ in range(arity)]
        if not np.all(np.isfinite(function(*args))):
            raise ValueError('supplied function %s does not have closure against '
                             'zeros in argument vectors.' % name)
        args = [-1 * np.ones(10) for _ in range(arity)]
        if not np.all(np.isfinite(function(*args))):
            raise ValueError('supplied function %s does not have closure against '
                             'negatives in argument vectors.' % name)

        if wrap:
            return _Function(function=wrap_non_picklable_objects(function),
                             name=name,
                             arity=arity)
        return _Function(function=function,
                         name=name,
                         arity=arity)

    def _protected_division(x1, x2):
        with np.errstate(divide='ignore', invalid='ignore'):
            return np.where(np.abs(x2) > 0.001, np.divide(x1, x2), 1.)

    def _protected_sqrt(x1):
        return np.sqrt(np.abs(x1))

    def _protected_log(x1):
        with np.errstate(divide='ignore', invalid='ignore'):
            return np.where(np.abs(x1) > 0.001, np.log(np.abs(x1)), 0.)

    def _protected_inverse(x1):
        with np.errstate(divide='ignore', invalid='ignore'):
            return np.where(np.abs(x1) > 0.001, 1. / x1, 0.)

    def _sigmoid(x1):
        with np.errstate(over='ignore', under='ignore'):
            return 1 / (1 + np.exp(-x1))

    def _id(x1):
        return x1

    def root(*args):
        list = args[0]
        for i in range(1, len(args)):
            list = np.add(list, args[i])
        return list

    add2 = _Function(function=np.add, name='add', arity=2)
    sub2 = _Function(function=np.subtract, name='sub', arity=2)
    mul2 = _Function(function=np.multiply, name='mul', arity=2)
    div2 = _Function(function=_protected_division, name='div', arity=2)
    sqrt1 = _Function(function=_protected_sqrt, name='sqrt', arity=1)
    log1 = _Function(function=_protected_log, name='log', arity=1)
    neg1 = _Function(function=np.negative, name='neg', arity=1)
    inv1 = _Function(function=_protected_inverse, name='inv', arity=1)
    abs1 = _Function(function=np.abs, name='abs', arity=1)
    max2 = _Function(function=np.maximum, name='max', arity=2)
    min2 = _Function(function=np.minimum, name='min', arity=2)
    sin1 = _Function(function=np.sin, name='sin', arity=1)
    cos1 = _Function(function=np.cos, name='cos', arity=1)
    tan1 = _Function(function=np.tan, name='tan', arity=1)
    sig1 = _Function(function=_sigmoid, name='sig', arity=1)
    id1 = _Function(function=_id, name='id', arity=1)
    root = _Function(function=root, name='root', arity=None)

    _function_map = {'add': add2,
                     'sub': sub2,
                     'mul': mul2,
                     'div': div2,
                     'sqrt': sqrt1,
                     'log': log1,
                     'abs': abs1,
                     'neg': neg1,
                     'inv': inv1,
                     'max': max2,
                     'min': min2,
                     'sin': sin1,
                     'cos': cos1,
                     'tan': tan1,
                     'sig': sig1,
                     'id': id1,
                     'root': root}



    X = x_train
    y = y_train


    def rmse(y_pred, y):
        return np.sqrt(np.average(((y_pred - y) ** 2)))

    def check_arity(str):
        if str in binary_operators:
            return 2
        elif str in unary_operators:
            return 1

    const_sum = 0
    list = []
    for i in range(len(expr)):

        if expr[i] == 'x0':
            const_sum += 1
            list.append(i)

    def f(x0):
        node = expr[0]
        if node == 'x0':
            return np.repeat(1, X.shape[0])
        
        if node.startswith('x'):
              try:
                index = int(node[1:]) - 1
                return X[:, index]
              except (ValueError, IndexError):
                print("error")
        apply_stack = []
        i = -1
        father = []
        pos = 0
        for node in expr:
            i += 1
            if node in binary_operators or node in unary_operators:
                apply_stack.append([i, node])
            else:
                apply_stack[-1].append(node)
                father.append([i, apply_stack[-1][0]])

            while len(apply_stack[-1]) - 1 == 1 + check_arity(apply_stack[-1][1]):
                function = _function_map[apply_stack[-1][1]]

                terminals = []
                for t in apply_stack[-1][2:]:
                    if isinstance(t, str):
                        if t == 'x0':
                            terminals.append(np.repeat(x0[pos], X.shape[0]))
                            pos += 1

                        
                        elif t.startswith('x'):
                                try:
                                    index = int(t[1:]) - 1
                                    terminals.append(X[:, index])
                                except (ValueError, IndexError):
                                    print("error")
                    else:
                        terminals.append(t)
                intermediate_result = function(*terminals)
                if len(apply_stack) != 1:
                    list = []
                    apply_stack.pop()
                    for j in apply_stack[::-1]:
                        list.append(j[0])
                    apply_stack[-1].append(intermediate_result)
                else:
                    result = rmse(y, intermediate_result)
                    return result



    def g(x0):
        node = expr[0]
        if node == 'x0':
            return np.repeat(1, X.shape[0])
      
        if node.startswith('x'):
                try:
                  index = int(node[1:]) - 1
                  return X[:, index]
                except (ValueError, IndexError):
                  print("error")
        apply_stack = []
        i = -1
        father = []
        pos = 0
        for node in expr:
            i += 1
            if node in binary_operators or node in unary_operators:
                apply_stack.append([i, node])
            else:
                apply_stack[-1].append(node)
                father.append([i, apply_stack[-1][0]])

            while len(apply_stack[-1]) - 1 == 1 + check_arity(apply_stack[-1][1]):
                function = _function_map[apply_stack[-1][1]]

                terminals = []
                for t in apply_stack[-1][2:]:
                    if isinstance(t, str):
                        if t == 'x0':
                            terminals.append(np.repeat(x0[pos], X.shape[0]))
                            pos += 1

                        elif t.startswith('x'):
                                try:
                                    index = int(t[1:]) - 1
                                    terminals.append(X[:, index])
                                except (ValueError, IndexError):
                                    print("error")
                    else:
                        terminals.append(t)
                intermediate_result = function(*terminals)
                if len(apply_stack) != 1:
                    list = []
                    apply_stack.pop()
                    for j in apply_stack[::-1]:
                        list.append(j[0])
                    apply_stack[-1].append(intermediate_result)
                else:
                    result = intermediate_result
                    return result


    if const_sum == 0:
        x0 = [1.0 for _ in range(const_sum)]
        yy = g(x0)*std+mu
        yy = yy.reshape(-1, 1)
        y = y.reshape(-1, 1)
        y_mean = np.mean(yy)
        r2 = 1 - np.sum((y - yy) ** 2) / np.sum((y - y_mean) ** 2)
        return x0, r2
    x0 = [1.0 for _ in range(const_sum)]
    kwargs = {'method': 'L-BFGS-B', 'options': {'gtol': 0.001}}
    cc = make_const_optimizer("scipy", **kwargs)
    result = cc(f, x0)
    optimized_constants = result.x
    yy = g(optimized_constants)*std+mu
    yy = yy.reshape(-1, 1)
    y = y.reshape(-1, 1)
    y_mean = np.mean(yy)
    eps = 1e-15
    if np.sum((yy - y_mean) ** 2)<eps:
        r2 = 1 - np.sum((y - yy) ** 2) / eps
    else:
        r2 = 1 - np.sum((y - yy) ** 2) / np.sum((y - y_mean) ** 2)
    return optimized_constants, r2


