import numpy as np
import torch
import gpytorch
from matplotlib import pyplot as plt

from src.methods.models.layers.scalers.standardizer import Standardizer, DeStandardizer
from src.methods.models.layers.scalers.normalizer import Normalizer, DeNormalizer

from src.methods.models.layers.surrogates.baselines import Baseline, NullBaseline

from src.utils.strings import *
from src.utils.probabilities import estimate_kernel


class GaussianProcess(gpytorch.models.ExactGP):
    def __init__(self, likelihood: gpytorch.likelihoods.FixedNoiseGaussianLikelihood,
                 train_inputs: np.ndarray | torch.Tensor, train_targets: np.ndarray | torch.Tensor,
                 auto_training_delta: int | None = None, length_scale_prior: gpytorch.priors.Prior | None = None,
                 training_multiplier: float = 1.0, reset_kernel: bool = True, default_train_epochs: int = 50,
                 epochs_decay: float = 1.0, lr: float = 0.1, standardize: bool = True, apply_smoothing: bool = True,
                 input_min: np.ndarray | None = None, input_max: np.ndarray | None = None,
                 target_mean: float | None = None, target_std: float | None = None, pool_size: int | None = None,
                 scale_correlation_by_noise: bool = True, device: torch.device = torch.device("cpu"),
                 verbose: bool = False, name: str = "default"):

        if type(train_inputs) is np.ndarray:
            train_inputs = torch.Tensor(train_inputs)

        if type(train_targets) is np.ndarray:
            train_targets = torch.Tensor(train_targets)

        super().__init__(train_inputs, train_targets, likelihood)

        self._prior_noises: list[float] = []
        self._target_noises: list[float] = []

        self._auto_training_delta = auto_training_delta
        self._length_scale_prior = length_scale_prior
        self._training_multiplier = training_multiplier
        self._reset_kernel = reset_kernel
        self._default_train_epochs = default_train_epochs
        self._epochs_decay = epochs_decay
        self._lr = lr
        self._standardize = standardize
        self._smooth = apply_smoothing
        self._pool_size = pool_size
        self._scale_correlation_by_noise = scale_correlation_by_noise
        self._device = device
        self._verbose = verbose
        self._name = name

        self._input_min = input_min
        self._input_max = input_max
        self._target_mean = target_mean
        self._target_std = target_std

        self._trained = False
        self._untrained_samples_count = 0

        self._input_normalizer = None
        self._input_denormalizer = None
        self._output_standardizer = None
        self._output_destandardizer = None

        self._distance_noise_matrix = None

        if apply_smoothing:
            self._sampled_inputs = []
            self._sampled_targets = []
            self._distributions = []
            self._tmp_new_x = [train_inputs]
        else:
            self._sampled_inputs = None
            self._sampled_targets = None
            self._distributions = None
            self._tmp_new_x = None

        self._build_modules()

    @property
    def is_empty(self) -> bool:
        return len(self.train_inputs[0]) == 0

    @property
    def is_trained(self) -> bool:
        return self._trained

    @property
    def is_trainable(self) -> bool:
        return self._auto_training_delta is not None and self._untrained_samples_count >= self._auto_training_delta

    def set_target_mean(self, target_mean: float) -> None:
        self._target_mean = target_mean

    def set_target_std(self, target_std: float) -> None:
        self._target_std = target_std

    def set_input_min(self, input_min: np.ndarray) -> None:
        self._input_min = input_min

    def set_input_max(self, input_max: np.ndarray) -> None:
        self._input_max = input_max

    def set_distance_noise_matrix(self, distance_noise_matrix: torch.Tensor) -> None:
        self._distance_noise_matrix = distance_noise_matrix

    def add_noise(self, noise: float) -> None:
        self._prior_noises.append(noise)

    def auto_train(self) -> bool:

        if self.is_trainable:

            self.launch_train(self._default_train_epochs)
            self._untrained_samples_count = 0
            self._auto_training_delta = int(self._auto_training_delta * self._training_multiplier)
            self._default_train_epochs = max(int(self._default_train_epochs * self._epochs_decay), 10)

            return True

        return False

    def add_samples(self, x: torch.Tensor | np.ndarray, y: torch.Tensor | np.ndarray, noises: list[float]) -> None:

        assert len(x) == len(y) == len(noises)

        if type(x) is np.ndarray:
            x = torch.Tensor(x)

        if type(y) is np.ndarray:
            y = torch.Tensor(y)

        inputs = self.train_inputs[0]
        targets = self.train_targets

        if self._input_normalizer:
            x = self._input_normalizer(x)
        if self._output_standardizer:
            y = self._output_standardizer(y)

        new_x = torch.concat([inputs, x])
        new_y = torch.concat([targets, y])

        if self._pool_size is not None:
            if len(new_x) > self._pool_size:
                new_x = new_x[-self._pool_size:]
                new_y = new_y[-self._pool_size:]

        self.set_train_data(new_x, new_y, strict=False)

        self._untrained_samples_count += len(x)

        self._prior_noises.extend(noises)

    def add_smoothed_samples(self, x: torch.Tensor | np.ndarray, y: torch.Tensor | np.ndarray,
                             sample_x: list[torch.Tensor | np.ndarray],
                             distributions: list[torch.distributions.Distribution],
                             noises: list[float]) -> None:

        assert self._smooth
        assert len(x) == len(y) == len(sample_x) == len(distributions) == len(noises)

        if type(x) is np.ndarray:
            x = torch.Tensor(x)

        if type(y) is np.ndarray:
            y = torch.Tensor(y)

        self._sampled_inputs.extend(sample_x)
        self._distributions.extend(distributions)

        for out in y:
            self._sampled_targets.append(out)

        for inp in x:
            self._tmp_new_x.append(torch.unsqueeze(inp, dim=0))

        self._untrained_samples_count += len(x)

        self._prior_noises.extend(noises)

        if self._pool_size is not None:
            if len(self._sampled_inputs) > self._pool_size:
                self._sampled_inputs = self._sampled_inputs[-self._pool_size:]
                self._distributions = self._distributions[-self._pool_size:]
                self._sampled_targets = self._sampled_targets[-self._pool_size:]
                self._tmp_new_x = [self._tmp_new_x[0]] + self._tmp_new_x[-self._pool_size:]
                self._prior_noises = [self._prior_noises[0]] + self._prior_noises[-self._pool_size:]

    def forward(self, x):

        mean_x = self._mean_module(x)
        covar_x = self._covar_module(x)

        if self._scale_correlation_by_noise:

            train_samples = len(self.train_inputs[0])
            test_samples = x.shape[0] - train_samples

            if not self.training and test_samples > 0:

                priors = self._prior_noises[0:train_samples]
                prior_noise_row = torch.Tensor([priors])
                prior_noise_col = torch.Tensor([priors + [0.0]]).transpose(0, 1)

                distance_noise_matrix = torch.cat([self._distance_noise_matrix, prior_noise_row], dim=0)
                distance_noise_matrix = torch.cat([distance_noise_matrix, prior_noise_col], dim=1)

            else:
                distance_noise_matrix = self._distance_noise_matrix

            noise_correlation = 1.0 + (self._noise_correlation_scale.exp() * distance_noise_matrix)
            covar_x = covar_x / noise_correlation

        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

    def launch_train(self, epochs: int) -> None:

        if self._reset_kernel:
            self._build_modules()

        self.likelihood.noise = torch.Tensor(self._prior_noises)

        self.train(True)
        self.likelihood.train(True)
        if self._scale_correlation_by_noise:
            self._noise_correlation_scale.requires_grad = True

        mll_loss = gpytorch.mlls.ExactMarginalLogLikelihood(self.likelihood, self)
        optimizer = torch.optim.Adam(self.parameters(), lr=self._lr)

        with torch.no_grad():
            if self._smooth:
                self._smooth_targets()
                if self._scale_correlation_by_noise or len(self._target_noises) == 0:
                    self.likelihood.noise = torch.zeros(len(self.train_inputs[0]))
                else:
                    self.likelihood.noise = torch.Tensor(self._target_noises)
            if self._standardize:
                self._standardize_train_data()

        for epoch in range(epochs):

            _x = self.train_inputs[0].to(self._device)
            _y = self.train_targets.to(self._device)

            optimizer.zero_grad()

            output = self(_x)
            loss = - mll_loss(output, _y)

            loss.backward()

            optimizer.step()

            if self._verbose:
                print('Iter %d/%d - Loss: %.3f   lengthscale: %.3f   noise: %.3f' % (
                    epoch + 1, epochs, loss.item(),
                    self._covar_module.base_kernel.lengthscale.item(),
                    self.likelihood.noise[-1].item()
                ))

        self._trained = True

    def predict(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

        test_noises = torch.zeros(len(x))

        self.eval()
        self.likelihood.eval()
        if self._scale_correlation_by_noise:
            self._noise_correlation_scale.requires_grad = False

        if self._input_normalizer:
            x = self._input_normalizer(x)

        with gpytorch.settings.fast_pred_var():
            distribution = self.likelihood(self(x), noise=test_noises)

        output = distribution.mean
        variance = distribution.variance
        std = torch.sqrt(variance)

        if self._output_destandardizer:
            output = self._output_destandardizer(output)
            de_std = self._output_destandardizer(std)
        else:
            de_std = std

        return output, std, de_std

    def plot(self, regrets_params: dict | None = None, baseline: Baseline | None = None) -> None:

        if self.train_inputs[0][0].numpy().shape[0] != 1:
            return

        self.eval()
        self.likelihood.eval()

        test_x_norm = torch.linspace(0, 1, 100)
        test_x = self._input_denormalizer(test_x_norm) if self._input_denormalizer else test_x_norm

        train_x_norm = self.train_inputs[0]
        train_y_std = self.train_targets

        train_x = self._input_denormalizer(train_x_norm) if self._input_denormalizer else train_x_norm
        train_y = self._output_destandardizer(train_y_std) if self._output_destandardizer else train_y_std

        with torch.no_grad(), gpytorch.settings.fast_pred_var():
            pred_distribution = self.likelihood(self(test_x_norm))

        predictions_std = pred_distribution.mean
        predictions = self._output_destandardizer(predictions_std) if self._output_destandardizer else predictions_std

        if baseline is not None and type(baseline) is not NullBaseline:
            with torch.no_grad():

                y_true_test = torch.Tensor(np.ones(shape=(len(test_x,))) * float(train_x[0]))
                priors_test = []
                for inp, ytt in zip(test_x, y_true_test):
                    prior_test = baseline(inp, ytt).numpy()
                    priors_test.append(prior_test)
                priors_test = torch.Tensor(np.array(priors_test))
                predictions = predictions + priors_test

                y_true_train = torch.Tensor(np.ones(shape=(len(train_x,))) * float(train_x[0]))
                priors_train = []
                for inp, ytt in zip(train_x, y_true_train):
                    prior_train = baseline(inp, ytt).numpy()
                    priors_train.append(prior_train)
                priors_train = torch.Tensor(np.array(priors_train))
                train_y = train_y + priors_train
        else:
            priors_test = None

        if regrets_params is not None:
            x_i = regrets_params["x"]
            y_i = regrets_params["y"]
            optimal_cost_i = regrets_params["cost"]
            solver = regrets_params["solver"]
            problem_params = regrets_params["params"]
            mm = regrets_params["mm"]
            regrets = []
            for inp in test_x:
                sol_hat_i, _ = solver.solve(x_i, inp.detach().numpy(), problem_params)
                metrics_i = solver.compute_metrics(y_i, sol_hat_i, problem_params)
                cost_hat_i = metrics_i[TOTAL_COST]
                regret_i = mm * (cost_hat_i - optimal_cost_i)
                regrets.append(regret_i)
        else:
            regrets = None

        f, ax = plt.subplots(1, 1, figsize=(4, 3))

        # Get upper and lower confidence bounds
        lower, upper = pred_distribution.confidence_region()
        if self._output_destandardizer:
            lower = self._output_destandardizer(lower)
            upper = self._output_destandardizer(upper)
        if priors_test is not None:
            lower = lower + priors_test
            upper = upper + priors_test
        # Plot training data as black stars
        ax.plot(train_x.numpy(), train_y.numpy(), 'k*')
        # Plot predictive means as blue line
        ax.plot(test_x.numpy(), predictions.numpy(), 'b')
        # Shade between the lower and upper confidence bounds
        ax.fill_between(test_x.numpy(), lower.numpy(), upper.numpy(), alpha=0.5)

        if regrets is not None:
            ax.plot(test_x.numpy(), np.array(regrets), 'r')

        ax.legend(['Observed Data', 'Mean', 'Confidence'])

        plt.show()

    def _build_modules(self) -> None:

        self._mean_module = gpytorch.means.ConstantMean()
        # self._mean_module = gpytorch.means.ZeroMean()
        self._covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(lengthscale_prior=self._length_scale_prior))
        # self._covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.PiecewisePolynomialKernel(lengthscale_prior=self._length_scale_prior))

        if self._scale_correlation_by_noise:
            self._noise_correlation_scale = torch.nn.Parameter(torch.Tensor([1.0]), requires_grad=True)
            self.register_parameter("noise correlation scale", self._noise_correlation_scale)
        else:
            self._noise_correlation_scale = None

    def _smooth_targets(self) -> None:

        targets = [0.0]
        self._target_noises = [0.0]

        sampled_inputs = torch.cat(self._sampled_inputs)
        sampled_targets = torch.Tensor(self._sampled_targets)
        prior_noises = torch.Tensor(self._prior_noises[1:])

        g_x_global = estimate_kernel(self._distributions, sampled_inputs, reduce_mean=False)
        g_x_sum = torch.sum(g_x_global, dim=0)
        g_x_count = len(self._distributions) - 1
        # weights_prior_probs = torch.clamp(1.0 - prior_noises * 10, min=0.0, max=1.0)

        for distribution, g_x_diff in zip(self._distributions, g_x_global):
            f_x = torch.exp(distribution.log_prob(sampled_inputs))
            g_x = (g_x_sum - g_x_diff) / g_x_count
            # weights = (f_x / g_x) * weights_prior_probs
            weights = f_x / g_x
            norm_weights = weights / torch.sum(weights)
            estimated_target = torch.matmul(norm_weights, sampled_targets)
            estimated_noise = torch.matmul(norm_weights, prior_noises)
            targets.append(estimated_target)
            self._target_noises.append(estimated_noise)

        new_x = torch.cat(self._tmp_new_x)
        targets = torch.Tensor(targets)

        self.set_train_data(new_x, targets, strict=False)

    def _standardize_train_data(self) -> None:

        if not self._standardize:
            return

        inputs = self.train_inputs[0]
        targets = self.train_targets

        if self._input_denormalizer and not self._smooth:
            inputs = self._input_denormalizer(inputs)
        if self._output_destandardizer and not self._smooth:
            targets = self._output_destandardizer(targets)

        x_min = torch.min(inputs, dim=0)[0] if self._input_min is None else self._input_min
        x_max = torch.max(inputs, dim=0)[0] if self._input_max is None else self._input_max
        self._input_normalizer = Normalizer(x_min, x_max)
        self._input_denormalizer = DeNormalizer(x_min, x_max)
        new_x = self._input_normalizer(inputs)

        y_mean = torch.mean(targets, dim=0) if self._target_mean is None else self._target_mean
        y_std = torch.std(targets, dim=0) if self._target_std is None else self._target_std
        if float(y_std) != 0.0:
            self._output_standardizer = Standardizer(y_mean, y_std)
            self._output_destandardizer = DeStandardizer(y_mean, y_std)
            new_y = self._output_standardizer(targets)
        else:
            new_y = targets

        self.set_train_data(new_x, new_y, strict=False)
