""""
Old version. See simulators/simulators for the actual simulators
""""

import gpytorch
import torch
import pickle
import numpy as np
from tqdm import tqdm
from torch.utils.data import TensorDataset, DataLoader

from utils.transformations import transform
from models.approximate_gp import ApproximateGPModel


def oracle_simulator(args, x):
    """
    Mimic an oracle as a simulator. Here the simulators are simply functions

    :param args: arguments
    :param x: input parameters
    :returns: sampled labels, mean and standard deviation
    """

    # Convert to np.array()
    x = np.array(x) if type(x) is not np.ndarray else x

    if args.simulator == "homo1":
        mean = np.sin(6 * x)
        sigma = 0.2 * np.ones_like(x)
    elif args.simulator == "homo2":
        mean = np.power(x, 3) - 6 * np.power(x, 2) + 4 * x + 12
        sigma = 1 * np.ones_like(x)
    elif args.simulator == "homo3":
        mean = np.power(x, 3) - 6 * np.power(x, 2) + 4 * x + 12 + 2 * np.sin(6 * x)
        sigma = 0.4 * np.ones_like(x)
    elif args.simulator == "homo4":
        mean = np.power(x, 3) - 6 * np.power(x, 2) + 4 * x + 12 + 0.5 * np.sin(24 * x)
        sigma = 0.02 * np.ones_like(x)
    elif args.simulator == "hetero1":
        mean = np.sin(6 * x)
        sigma = 0.2 * np.power(x / 3, 3)
    elif args.simulator == "hetero2":
        mean = np.power(x, 3) - 6 * np.power(x, 2) + 4 * x + 12
        sigma = 0.2 * np.power(x-2.5, 3)
    elif args.simulator == "hetero3":
        fun = np.power(x, 3) - 6 * np.power(x, 2) + 4 * x + 12
        mean = fun
        sigma = -fun * 0.2
    elif args.simulator == "hetero4":
        x = x + 0.5
        factor_eps = np.array(np.power(x, 9) - 24.3*np.power(x, 8) + 251.05*np.power(x, 7) - 1436.025*np.power(x, 6) +\
                     4957.375*np.power(x, 5) - 10535.325*np.power(x, 4) + 13402.575*np.power(x, 3) -\
                     9289.35*np.power(x, 2) + 2673*x)
        factor_eps[factor_eps > 0] = 0
        factor_eps = np.abs(factor_eps)
        mean = x
        sigma = factor_eps * 0.2
        # x ^ 9 - 24.3
        # x ^ 8 + 251.05
        # x ^ 7 - 1436.025
        # x ^ 6 + 4957.375
        # x ^ 5 - 10535.325
        # x ^ 4 + 13402.575
        # x ^ 3 - 9289.35
        # x ^ 2 + 2673
        # x
    elif args.simulator == "motorcycle":
        # Get mean and sigma values
        mean, sigma = motorcycle_simulator(x)
        #x = x.tolist()
        #x = x if isinstance(x, list) else [x]
        #x = [int(x) for x in x]
        #if isinstance(eps, int):
        #    y = [mean[tmp_x] + sigma[tmp_x] for tmp_x in x]
        #else:
        #    y = [mean[tmp_x] + sigma[tmp_x] * eps[idx] for idx, tmp_x in enumerate(x)]
    elif args.simulator == "autompg":
        # NB: This returns tensors. Convert to numpy for readability..
        mean, sigma = auto_mpg_simulator(x)
        mean = mean.numpy()
        sigma = sigma.numpy()
        eps = np.random.normal(loc=0, scale=1, size=mean.shape[0])
    elif args.simulator == "gramacy1d":
        mean, sigma = gramacy_and_lee_1d(x)
    elif args.simulator == "gramacy1d_wrong_gp":
        mean, sigma = gramacy_and_lee_1d_wrong_gp(x)
    elif args.simulator == "gramacy2d":
        mean, sigma = gramacy_and_lee_2d(x)
    elif args.simulator == "gramacy6d":
        mean, sigma = gramacy_and_lee_6d(x)
    elif args.simulator == "higdon1d":
        mean, sigma = higdon1d(x)
    elif args.simulator == "mbml_homo_0":
        mean, sigma = mbml_fun_homo(x, noise=0)
    elif args.simulator == "mbml_homo_01":
        mean, sigma = mbml_fun_homo(x, noise=0.01)
    elif args.simulator == "mbml_hetero":
        mean, sigma = mbml_fun_homo(x, noise=0)
    else:
        raise NotImplementedError(f'Simulator {args.simulator} is not defined.')

    #if args.simulator in ["motorcycle"]:
    #    return y, mean, sigma

    # Get noise
    size = 1 if mean.shape == () else mean.shape[0]
    eps = np.random.normal(loc=0, scale=1, size=size)

    y = mean + sigma * eps
    return y, mean, sigma


def motorcycle_simulator(x):
    """
    Mean and stddev list are fitted with an heteroscedastic GP (inducing points)

    :param x: non-transformed data points on the interval [0, 100]
    :returns: the mean and the standard deviation for each input data point (in original space)
    """
    gp_mean = [0.5107, 0.5032, 0.4939, 0.4843, 0.4765, 0.4718, 0.4713, 0.4748, 0.4805, 0.4856,
            0.4870, 0.4824, 0.4724, 0.4604, 0.4523, 0.4543, 0.4695, 0.4953, 0.5216, 0.5319,
            0.5067, 0.4284, 0.2868, 0.0825, -0.1722, -0.4559, -0.7433, -1.0112, -1.2435, -1.4339,
            -1.5851, -1.7049, -1.8015, -1.8788, -1.9333, -1.9549, -1.9298, -1.8453, -1.6947, -1.4804,
            -1.2140, -0.9140, -0.6013, -0.2950, -0.0080, 0.2536, 0.4900, 0.7046, 0.9006, 1.0778,
            1.2317, 1.3542, 1.4362, 1.4710, 1.4573, 1.4012, 1.3147, 1.2139, 1.1143, 1.0277,
            0.9589, 0.9056, 0.8605, 0.8146, 0.7605, 0.6959, 0.6241, 0.5533, 0.4935, 0.4538,
            0.4395, 0.4501, 0.4795, 0.5173, 0.5507, 0.5678, 0.5600, 0.5241, 0.4632, 0.3869,
            0.3092, 0.2456, 0.2098, 0.2103, 0.2482, 0.3164, 0.4015, 0.4868, 0.5568, 0.6006,
            0.6147, 0.6033, 0.5769, 0.5487, 0.5311, 0.5324, 0.5549, 0.5950, 0.6441, 0.6919,
            0.7285]
    gp_stddev = [0.0477, 0.0400, 0.0463, 0.0514, 0.0529, 0.0502, 0.0438, 0.0366, 0.0330, 0.0328, 0.0327,
             0.0318, 0.0315, 0.0331, 0.0363, 0.0433, 0.0560, 0.0702, 0.0780, 0.0733, 0.0615, 0.0799,
             0.1468, 0.2393, 0.3408, 0.4382, 0.5196, 0.5747, 0.5959, 0.5800, 0.5295, 0.4540, 0.3712,
             0.3067, 0.2822, 0.2915, 0.3074, 0.3135, 0.3156, 0.3360, 0.3899, 0.4658, 0.5385, 0.5873,
             0.6045, 0.5975, 0.5860, 0.5901, 0.6133, 0.6377, 0.6380, 0.5951, 0.5022, 0.3669, 0.2227,
             0.1980, 0.3483, 0.5319, 0.6956, 0.8186, 0.8886, 0.8987, 0.8470, 0.7373, 0.5789, 0.3876,
             0.1898, 0.1067, 0.2472, 0.3799, 0.4620, 0.4847, 0.4492, 0.3643, 0.2463, 0.1253, 0.1068,
             0.2006, 0.2817, 0.3230, 0.3195, 0.2769, 0.2092, 0.1377, 0.0911, 0.0856, 0.0919, 0.0940,
             0.1116, 0.1568, 0.2117, 0.2548, 0.2730, 0.2623, 0.2282, 0.1857, 0.1545, 0.1445, 0.1445,
             0.1437, 0.1542]

    x = [x] if x.shape == () else x
    lst_x = [int(x) for x in x]
    mean = [gp_mean[x] for x in lst_x]
    stddev = [gp_stddev[x] for x in lst_x]

    return np.array(mean), np.array(stddev)


def auto_mpg_simulator(x):
    """
    Heteroscedastic GP w/ inducing points and a separate length scale for each input dimension fitted to all the data
    in the UCI Auto-MPG data set.

    :param x: non-transformed data points
    :returns: the mean and the standard deviation for each input data point (in original space)
    """
    # Load state dict and info
    state_dict = torch.load('state_dicts/auto_mpg-model_state.pth')
    with open('state_dicts/auto_mpg-data.pkl', 'rb') as fp:
        info = pickle.load(fp)

    # Make the model
    model = ApproximateGPModel(inducing_points=info['train_x'],
                               kernel=gpytorch.kernels.RBFKernel(ard_num_dims=7),
                               likelihood=gpytorch.likelihoods.GaussianLikelihood())
    model.load_state_dict(state_dict)

    # Transform data to be suitable with the learned model
    x, _, _ = transform(torch.tensor(x, dtype=torch.float32), info['mu_x'], info['sigma_x'], method='standardize')

    # Make sure to have a batch dimension
    x = x.view(1, -1) if len(x.shape) == 1 else x

    # Sample from the model (aka the simulator)
    model.eval()

    with torch.no_grad(), gpytorch.settings.fast_pred_var():
        batch_size = 1000 if x.shape[0] > 1000 else x.shape[0]
        """
        if x.shape[0] < 1000:
            tmp = model(x)
            pred = model.likelihood(tmp)
            means = pred.mean
            stddevs = pred.stddev
        else:
        """
        dataset = TensorDataset(x, x)
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
        means = []
        stddevs = []
        # real pytorch test loader
        for x_batch, _ in dataloader:
            output = model(x_batch)
            predictions = model.likelihood(output)
            means.extend(predictions.mean)
            stddevs.extend(predictions.stddev)
        means = torch.Tensor(means)
        stddevs = torch.Tensor(stddevs)

    # Sample a single value for each x
    #print("Get samples")
    #new_samples = pred.to_data_independent_dist().sample([1])

    # Transform predictions back
    #new_samples = transform(new_samples, info['mu_y'], info['sigma_y'], method='standardize', inverse=True)
    mean = transform(means, info['mu_y'], info['sigma_y'], method='standardize', inverse=True)
    stddev = transform(stddevs, 0, info['sigma_y'], method='standardize', inverse=True)

    return mean, stddev


def gramacy_and_lee_1d(x):
    """
    Gramacy and Lee 2012 function: https: // www.sfu.ca / ~ssurjano / grlee12.html
    The original function has no noise. Here we use N(0,0.05^2)

    :param x: non-transformed data points on the interval [0.5, 2.5]
    :returns: the mean and the standard deviation for each input data point (in original space)
    """
    mean = np.sin(10*np.pi*x) / (2*x) + np.power(x - 1, 4)
    stddev = 0.1 * np.ones_like(x)

    return mean, stddev


def gramacy_and_lee_1d_wrong_gp(x):
    """
    A GP wrongly fitted to two 200 data points from the Gramacy and Lee 1d function.
    The GP only catches the overall trend and not the sinus curve.
    """
    import torch
    from gpytorch.kernels import RBFKernel
    from utils.transformations import transform
    from utils.gp_utils import get_model, get_likelihood


    with open(path + 'gramacy1d_wrong_gp.pkl', 'rb') as fp:
        p_dct = pickle.load(fp)

    args = p_dct['args']
    train_x = torch.Tensor(p_dct['train_x'])
    train_y = torch.Tensor(p_dct['train_y'])
    train_x_trans, mu_x, sigma_x = transform(train_x, method=args.transformation_x)
    train_y_trans, mu_y, sigma_y = transform(train_y, method=args.transformation_y)
    model = get_model(args, train_x_trans, train_y_trans, RBFKernel(), likelihood=get_likelihood(args))
    model.load_state_dict(p_dct['model_state_dict'])
    x, _, _ = transform(torch.Tensor(x), mu_x, sigma_x, method=args.transformation_x)
    x = torch.Tensor([x]) if x.shape == () else x
    predict_output = model.predict((x, None))

    return predict_output['mean'].numpy(), predict_output['stddev'].detach().numpy()


def gramacy_and_lee_2d(x):
    """
    Gramacy and Lee 2008 function: https://www.sfu.ca/~ssurjano/grlee08.html
    The original function has no noise. Here we use N(0,0.05^2)

    :param x: non-transformed data points on the interval [-2, 6]
    :returns: the mean and the standard deviation for each input data point (in original space)
    """

    # Make sure to have a batch dimension
    x = x.reshape(-1, 2) if len(x.shape) == 1 else x

    mean = x[:, 0] * np.exp(-np.power(x[:, 0], 2) - np.power(x[:, 1], 2))
    stddev = 0.05 * np.ones_like(mean)

    return mean, stddev


def gramacy_and_lee_6d(x):
    """
    Gramacy and Lee 2009 function: https://www.sfu.ca/~ssurjano/grlee09.html
    The original function noise: N(0,0.05^2)
    The last two variables x5 and x6 are inactive

    :param x: non-transformed data points on the interval [0, 1]
    :returns: the mean and the standard deviation for each input data point (in original space)
    """

    # Make sure to have a batch dimension
    x = x.view(1, -1) if len(x.shape) == 1 else x

    mean = np.exp(np.sin(np.power(0.9*(x[:, 0]+0.48), 10))) + x[:, 1] * x[:, 2] + x[:, 3]
    stddev = 0.05 * np.ones_like(x)

    return mean, stddev


def higdon1d(x):
    """
    Higdon (2002), Gramacy and Lee (2008) function: https://www.sfu.ca/~ssurjano/hig02grlee08.html
    The original function noise: N(0,0.1^2)

    :param x: non-transformed data points on the interval [0, 20]
    :returns: the mean and the standard deviation for each input data point (in original space)
    """

    mean = np.piecewise(x,
                        [x < 10, x >= 10],
                        [lambda x: np.sin(np.pi * x / 5) + 0.2 * np.cos(4 * np.pi * x / 5),
                         lambda x: x / 10 - 1])
    stddev = 0.1 * np.ones_like(x)

    return mean, stddev


def mbml_fun_homo(x, noise=0):
    """
    Function with homoscedastic noise for MBML course
    """
    mean = np.sin(4.0*x)
    stddev = noise * np.ones_like(x)
    return mean, stddev


def mbml_fun_hetero(x, noise=0.1):
    """
    Function with homoscedastic noise for MBML course
    """
    mean = np.sin(4.0*x)
    stddev = noise * np.ones_like(x) * x**3 * np.random.randn(len(x))
    return mean, stddev
