import math

import numpy as np
import torch
import gpytorch
from src.models import TimeInvariantGP


def generate_2D_synthetic_data(parameters, seed, ):
    """Generates synthetic (within model) 2D objective functions following Assumption 2."""
    dims = parameters["dimension"]
    noise = parameters["noise"]
    forgetting_factor = parameters["forgetting_factor"]
    lengthscale = parameters["lengthscale"]
    time_horizon = parameters["time_horizon"]
    resolution = parameters["resolution"]
    compact_set = parameters["compact_set"]

    model = TimeInvariantGP(train_x=torch.ones(1, dims + 1),
                            train_y=torch.ones(1, 1), )
    model.likelihood.noise_covar.noise = noise
    model.spatial_kernel.lengthscale = lengthscale

    # create test_x
    spatial_grid = []
    for interval in compact_set:
        grid = torch.linspace(interval[0], interval[1], resolution)
        spatial_grid.append(grid)

    if dims > 1:
        spatial_grid = torch.stack(torch.meshgrid(spatial_grid), dim=2).reshape(-1, dims)
    else:
        spatial_grid = spatial_grid[0].reshape(-1, dims)

    # add temporal dimension
    test_t = torch.ones(spatial_grid.shape[0], dtype=torch.float).reshape(-1, 1)
    test_x = torch.cat((spatial_grid, test_t), dim=1)

    with gpytorch.settings.prior_mode(True):
        model.eval()
        model.likelihood.eval()
        torch.manual_seed(seed)
        samples = model(test_x).rsample(sample_shape=torch.Size([time_horizon])).detach()

        base1 = math.sqrt(1 - forgetting_factor)
        base2 = math.sqrt(forgetting_factor)

        objective_function = samples[0, :].reshape(-1, 1)
        for i in range(samples.shape[0] - 1):
            new_timestep = base1 * objective_function[:, i] + base2 * samples[i + 1, :]
            objective_function = torch.cat((objective_function, new_timestep.reshape(-1, 1)), dim=1)

        data = {'f_t(x)': objective_function.reshape(resolution, resolution, time_horizon),
                'x': spatial_grid}

    return data


def generate_2D_synthetic_data_jump_in_forgetting(parameters, seed, factor):
    """Generates synthetic (within model) 2D objective functions following Assumption 2."""
    dims = parameters["dimension"]
    noise = parameters["noise"]
    forgetting_factor = parameters["forgetting_factor"]
    lengthscale = parameters["lengthscale"]
    time_horizon = parameters["time_horizon"]
    resolution = parameters["resolution"]
    compact_set = parameters["compact_set"]

    model = TimeInvariantGP(train_x=torch.ones(1, dims + 1),
                            train_y=torch.ones(1, 1), )
    model.likelihood.noise_covar.noise = noise
    model.spatial_kernel.lengthscale = lengthscale

    # create test_x
    spatial_grid = []
    for interval in compact_set:
        grid = torch.linspace(interval[0], interval[1], resolution)
        spatial_grid.append(grid)

    if dims > 1:
        spatial_grid = torch.stack(torch.meshgrid(spatial_grid), dim=2).reshape(-1, dims)
    else:
        spatial_grid = spatial_grid[0].reshape(-1, dims)

    # add temporal dimension
    test_t = torch.ones(spatial_grid.shape[0], dtype=torch.float).reshape(-1, 1)
    test_x = torch.cat((spatial_grid, test_t), dim=1)

    with gpytorch.settings.prior_mode(True):
        model.eval()
        model.likelihood.eval()
        torch.manual_seed(seed)
        samples = model(test_x).rsample(sample_shape=torch.Size([time_horizon])).detach()

        objective_function = samples[0, :].reshape(-1, 1)
        for i in range(samples.shape[0] - 1):
            run_forgetting_factor = forgetting_factor if i < 150 else forgetting_factor * factor
            base1 = math.sqrt(1 - run_forgetting_factor)
            base2 = math.sqrt(run_forgetting_factor)
            new_timestep = base1 * objective_function[:, i] + base2 * samples[i + 1, :]
            objective_function = torch.cat((objective_function, new_timestep.reshape(-1, 1)), dim=1)

        data = {'f_t(x)': objective_function.reshape(resolution, resolution, time_horizon),
                'x': spatial_grid}

    return data


def generate_2D_synthetic_data_sudden_change(parameters, seed, time_of_jump):
    """Generates synthetic (within model) 2D objective functions following Assumption 2."""
    dims = parameters["dimension"]
    noise = parameters["noise"]
    lengthscale = parameters["lengthscale"]
    time_horizon = parameters["time_horizon"]
    resolution = parameters["resolution"]
    compact_set = parameters["compact_set"]

    model = TimeInvariantGP(train_x=torch.ones(1, dims + 1),
                            train_y=torch.ones(1, 1), )
    model.likelihood.noise_covar.noise = noise
    model.spatial_kernel.lengthscale = lengthscale

    # create test_x
    spatial_grid = []
    for interval in compact_set:
        grid = torch.linspace(interval[0], interval[1], resolution)
        spatial_grid.append(grid)

    if dims > 1:
        spatial_grid = torch.stack(torch.meshgrid(spatial_grid), dim=2).reshape(-1, dims)
    else:
        spatial_grid = spatial_grid[0].reshape(-1, dims)

    # add temporal dimension
    test_t = torch.ones(spatial_grid.shape[0], dtype=torch.float).reshape(-1, 1)
    test_x = torch.cat((spatial_grid, test_t), dim=1)

    with gpytorch.settings.prior_mode(True):
        model.eval()
        model.likelihood.eval()
        torch.manual_seed(seed)
        samples = model(test_x).rsample(sample_shape=torch.Size([2])).detach()

        objective_function = samples[0, :].reshape(-1, 1)
        for i in range(time_horizon - 1):
            new_sample = samples[0, :] if i < time_of_jump else samples[1, :]
            new_timestep = new_sample
            objective_function = torch.cat((objective_function, new_timestep.reshape(-1, 1)), dim=1)

        data = {'f_t(x)': objective_function.reshape(resolution, resolution, time_horizon),
                'x': spatial_grid}

    return data


def generate_2D_synthetic_data_jump_in_forgetting_with_jump(parameters, seed, factor):
    """Generates synthetic (within model) 2D objective functions following Assumption 2."""
    dims = parameters["dimension"]
    noise = parameters["noise"]
    forgetting_factor = parameters["forgetting_factor"]
    lengthscale = parameters["lengthscale"]
    time_horizon = parameters["time_horizon"]
    resolution = parameters["resolution"]
    compact_set = parameters["compact_set"]

    model = TimeInvariantGP(train_x=torch.ones(1, dims + 1),
                            train_y=torch.ones(1, 1), )
    model.likelihood.noise_covar.noise = noise
    model.spatial_kernel.lengthscale = lengthscale

    # create test_x
    spatial_grid = []
    for interval in compact_set:
        grid = torch.linspace(interval[0], interval[1], resolution)
        spatial_grid.append(grid)

    if dims > 1:
        spatial_grid = torch.stack(torch.meshgrid(spatial_grid), dim=2).reshape(-1, dims)
    else:
        spatial_grid = spatial_grid[0].reshape(-1, dims)

    # add temporal dimension
    test_t = torch.ones(spatial_grid.shape[0], dtype=torch.float).reshape(-1, 1)
    test_x = torch.cat((spatial_grid, test_t), dim=1)

    with gpytorch.settings.prior_mode(True):
        model.eval()
        model.likelihood.eval()
        torch.manual_seed(seed)
        samples = model(test_x).rsample(sample_shape=torch.Size([time_horizon + 1])).detach()

        objective_function = samples[0, :].reshape(-1, 1)
        for i in range(time_horizon - 1):
            new_sample = samples[i + 1, :] if i < 150 else samples[-1, :]
            run_forgetting_factor = forgetting_factor if i < 150 else forgetting_factor * factor
            base1 = math.sqrt(1 - run_forgetting_factor)
            base2 = math.sqrt(run_forgetting_factor)
            new_timestep = base1 * objective_function[:, i] + base2 * new_sample
            objective_function = torch.cat((objective_function, new_timestep.reshape(-1, 1)), dim=1)

        data = {'f_t(x)': objective_function.reshape(resolution, resolution, time_horizon),
                'x': spatial_grid}

    return data


def generate_2D_synthetic_data_change_in_forgetting(parameters, seed, increase):
    """Generates synthetic (within model) 2D objective functions following Assumption 2."""
    dims = parameters["dimension"]
    noise = parameters["noise"]
    forgetting_factor = parameters["forgetting_factor"]
    lengthscale = parameters["lengthscale"]
    time_horizon = parameters["time_horizon"]
    resolution = parameters["resolution"]
    compact_set = parameters["compact_set"]

    model = TimeInvariantGP(train_x=torch.ones(1, dims + 1),
                            train_y=torch.ones(1, 1), )
    model.likelihood.noise_covar.noise = noise
    model.spatial_kernel.lengthscale = lengthscale

    # create test_x
    spatial_grid = []
    for interval in compact_set:
        grid = torch.linspace(interval[0], interval[1], resolution)
        spatial_grid.append(grid)

    if dims > 1:
        spatial_grid = torch.stack(torch.meshgrid(spatial_grid), dim=2).reshape(-1, dims)
    else:
        spatial_grid = spatial_grid[0].reshape(-1, dims)

    # add temporal dimension
    test_t = torch.ones(spatial_grid.shape[0], dtype=torch.float).reshape(-1, 1)
    test_x = torch.cat((spatial_grid, test_t), dim=1)

    array_forgetting_factors = np.linspace(forgetting_factor, forgetting_factor * increase, time_horizon)

    with gpytorch.settings.prior_mode(True):
        model.eval()
        model.likelihood.eval()
        torch.manual_seed(seed)
        samples = model(test_x).rsample(sample_shape=torch.Size([time_horizon])).detach()

        objective_function = samples[0, :].reshape(-1, 1)
        for i in range(samples.shape[0] - 1):
            run_forgetting_factor = array_forgetting_factors[i]
            base1 = math.sqrt(1 - run_forgetting_factor)
            base2 = math.sqrt(run_forgetting_factor)
            new_timestep = base1 * objective_function[:, i] + base2 * samples[i + 1, :]
            objective_function = torch.cat((objective_function, new_timestep.reshape(-1, 1)), dim=1)

        data = {'f_t(x)': objective_function.reshape(resolution, resolution, time_horizon),
                'x': spatial_grid}

    return data