import torch
import numpy as np

from gpytorch.kernels import RBFKernel
from gpytorch.likelihoods import GaussianLikelihood

from botorch.utils.gp_sampling import get_gp_samples
from botorch.models import SingleTaskGP

from utils.postprocessing_utils import pickle2data
from scipy.interpolate import RegularGridInterpolator


def within_model_wrapper_2D_from_file(name, return_max=False):
    results = pickle2data(name)
    model_x = np.linspace(0, 1, 100)
    model_y = np.asarray(results['f_t(x)'])

    def within_model(x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        y_vec = []
        for xi, ti in zip(x, t):
            y_slice = model_y[:, :, int(ti.item()) - 1]
            func = RegularGridInterpolator((model_x, model_x), y_slice, method='linear')
            fxt = func(xi.numpy().reshape(-1))
            y_vec.append(fxt)
        y_vec = np.asarray(y_vec)
        return torch.tensor(y_vec, dtype=torch.float)

    if return_max:
        return np.max(model_y.reshape(-1, model_y.shape[-1]), axis=0)

    return within_model


def within_model_wrapper_2D(data, return_max=False):
    model_x = np.linspace(0, 1, 100)
    model_y = np.asarray(data['f_t(x)'])

    def within_model(x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        y_vec = []
        for xi, ti in zip(x, t):
            y_slice = model_y[:, :, int(ti.item()) - 1]
            func = RegularGridInterpolator((model_x, model_x), y_slice, method='linear')
            fxt = func(xi.numpy().reshape(-1))
            y_vec.append(fxt)
        y_vec = np.asarray(y_vec)
        return torch.tensor(y_vec, dtype=torch.float)

    if return_max:
        return np.max(model_y.reshape(-1, model_y.shape[-1]), axis=0)

    return within_model


def within_model_wrapper_from_data(data, resolution, return_max=False):
    model_x = resolution
    model_y = np.asarray(data['f_t(x)'])

    def within_model(x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        y_vec = []
        for xi, ti in zip(x, t):
            y_slice = model_y[:, :, int(ti.item()) - 1]
            func = RegularGridInterpolator((model_x, model_x), y_slice, method='linear')
            fxt = func(xi.numpy().reshape(-1))
            y_vec.append(fxt)
        y_vec = np.asarray(y_vec)
        return torch.tensor(y_vec, dtype=torch.float)

    if return_max:
        return np.max(model_y.reshape(-1, model_y.shape[-1]), axis=0)

    return within_model


def find_nearest(array, value):
    array = np.asarray(array)
    idx = (np.linalg.norm(array - value, axis=1)).argmin()
    return idx, array[idx]


def wrapper_temperature_data(data, sensor_coords, time_horizon=288, return_max=False):
    sensor_ids = sensor_coords[0]
    feasible_x = sensor_coords[1]
    feasible_y = sensor_coords[2]
    discrete_sample_set = np.stack((feasible_x, feasible_y), axis=1)

    def temperature_model(x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        y_vec = []
        for xi, ti in zip(x, t):
            proposed_x = xi.numpy()
            idx, coords = find_nearest(discrete_sample_set, proposed_x)
            sensor_id = sensor_ids[idx]
            fxt = np.asarray(data[str(sensor_id)]['temperature'])[int(ti.item())]
            y_vec.append(fxt)
        y_vec = np.asarray(y_vec)
        return torch.tensor(y_vec, dtype=torch.float)

    if return_max:
        temperatures = np.empty((0, time_horizon))
        for sensor in data.keys():
            sensor_temp = np.asarray(data[sensor]["temperature"][:time_horizon])
            temperatures = np.concatenate((temperatures, sensor_temp.reshape(1, -1)))

        max_temperatures = np.max(temperatures, axis=0)
        max_temperatures_sensor_id = np.argmax(temperatures, axis=0)
        return max_temperatures, max_temperatures_sensor_id

    return temperature_model

def within_model_wrapper_temperature_data_empirical_kernel(data, sensor_coords, name=None, time_horizon=288,
                                                           return_max=False):
    sensor_ids = sensor_coords[0]
    feasible_x = sensor_coords[1]
    feasible_y = sensor_coords[2]
    discrete_sample_set = np.stack((feasible_x, feasible_y), axis=1)

    def temperature_model(x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        y_vec = []
        for xi, ti in zip(x, t):
            proposed_x = xi.numpy()
            sensor_id = sensor_ids[int(proposed_x)]
            fxt = np.asarray(data[str(sensor_id)]['temperature'])[int(ti.item())]
            y_vec.append(fxt)
        y_vec = np.asarray(y_vec)
        return torch.tensor(y_vec, dtype=torch.float)

    if return_max:
        temperatures = np.empty((0, time_horizon))
        for sensor in data.keys():
            sensor_temp = np.asarray(data[sensor]["temperature"][:time_horizon])
            temperatures = np.concatenate((temperatures, sensor_temp.reshape(1, -1)))

        max_temperatures = np.max(temperatures, axis=0)
        max_temperatures_sensor_id = np.argmax(temperatures, axis=0)
        return max_temperatures, max_temperatures_sensor_id

    if name:
        results = pickle2data(name)
        hypers = np.asarray([ls.numpy() for ls in results['hyperparameter']])
        return temperature_model, hypers
    else:
        return temperature_model


def within_model_wrapper_ND(params, seed):
    epsilon = params["forgetting_factor"]
    dims = params["dimension"]
    noise = params["noise"]
    lengthscale = params["lengthscale"]
    time_horizon = params["time_horizon"]

    kernel = RBFKernel()
    likelihood = GaussianLikelihood()

    # Initialize the model without training data
    model = SingleTaskGP(train_X=torch.empty((0, dims)), train_Y=torch.empty((0, 1)),
                         covar_module=kernel, likelihood=likelihood)

    model.likelihood.noise_covar.noise = noise
    model.covar_module.lengthscale = lengthscale

    num_outputs = 1  # For a single-output model
    num_rff_features = params["num_rff_features"]  # Number of random Fourier features

    alpha = np.sqrt(1 - epsilon)
    beta = np.sqrt(epsilon)

    torch.manual_seed(seed)

    # Get GP samples using the provided function from your code
    gp_samples = get_gp_samples(
        model=model,
        num_outputs=num_outputs,
        n_samples=time_horizon,
        num_rff_features=num_rff_features
    )

    def within_model_ND(x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        y_vec = []
        for xi, ti in zip(x, t):
            ti = int(ti.item())  # Convert tensor to integer time step
            xi = xi.unsqueeze(0)  # Ensure xi has batch dimension

            # Initialize f_t(xi) for t = 1
            pred_samples = gp_samples(xi).squeeze(-1)
            f_t_xi = pred_samples[0, :]

            # Accumulate the sum for t >= 2
            if ti >= 2:
                for k in range(2, 1 + ti):
                    f_t_xi *= alpha
                    f_t_xi += pred_samples[k - 1, :] * beta

            y_vec.append(f_t_xi.item())
        return torch.tensor(y_vec, dtype=torch.float)

    return within_model_ND