from typing import Union

import torch
import torch.nn as nn
from torch import Tensor
from torch.distributions.multivariate_normal import MultivariateNormal

from benchmarks import Benchmark


class PriorMean(nn.Module):
    def __init__(self, prior: torch.distributions.Distribution, log_space: bool, device: str) -> None:
        super().__init__()
        self.prior = prior
        self.device = device
        self.log_space = log_space

    def forward(self, theta: Tensor, x: Tensor) -> Tensor:
        prior_log_prob = self.prior.log_prob(theta.cpu()).to(self.device)

        if self.log_space:
            return prior_log_prob
        else:
            return prior_log_prob.exp()


class RBFKernel(nn.Module):
    def __init__(self, lengthscale: Union[float, Tensor], variance: float, per_parameter_lengthscale=False) -> None:
        super().__init__()
        self.lengthscale = lengthscale
        self.variance = variance
        self.per_parameter_lengthscale = per_parameter_lengthscale

    def forward(self, x: Tensor) -> Tensor:
        x = x.view(x.shape[0], -1)
        
        if self.per_parameter_lengthscale:
            lengthscale = self.lengthscale.to(x.device).view(-1)
            dist = ((x.unsqueeze(0) - x.unsqueeze(1))**2/(2*lengthscale**2).unsqueeze(0).unsqueeze(0)).mean(dim=2)

        else:
            xs = (x**2).sum(1)
            dist = -2 * torch.matmul(x, x.t())
            dist += xs.view(-1, 1) + xs.view(1, -1)
            dist = dist/(2*self.lengthscale**2)

        return self.variance * torch.exp(dist)

class AdditiveKernel(nn.Module):
    def __init__(self, theta_kernel: nn.Module, x_kernel: nn.Module, theta_coef:float, x_coef:float) -> None:
        """The constructor.

        Parameters
        ----------
        theta_kernel : nn.Module
            The kernel applied to the simulator's parameters.
        x_kernel : nn.Module
            The kernel appliedto the observations.
        theta_coef : float
            The importance of the theta kernel.
        x_coef : float
            The importance of the x kernel.
        """
        super().__init__()
        self.theta_kernel = theta_kernel
        self.x_kernel = x_kernel
        self.theta_coef = theta_coef
        self.x_coef = x_coef

    def forward(self, theta: Tensor, x: Tensor) -> Tensor:
        return self.theta_coef * self.theta_kernel(theta) + self.x_coef * self.x_kernel(x)
    
class MultiplicativeKernel(nn.Module):
    def __init__(self, theta_kernel: nn.Module, x_kernel: nn.Module, theta_coef:float, x_coef:float) -> None:
        """The constructor.

        Parameters
        ----------
        theta_kernel : nn.Module
            The kernel applied to the simulator's parameters.
        x_kernel : nn.Module
            The kernel applied to the observations.
        theta_coef : float
            The importance of the theta kernel.
        x_coef : float
            The importance of the x kernel.
        """
        super().__init__()
        self.theta_kernel = theta_kernel
        self.x_kernel = x_kernel
        self.theta_coef = theta_coef
        self.x_coef = x_coef

    def forward(self, theta: Tensor, x: Tensor) -> Tensor:
        return self.theta_kernel(theta)**self.theta_coef * self.x_kernel(x)**self.x_coef


class GPPrior(nn.Module):
    def __init__(
        self,
        benchmark: Benchmark,
        config: dict,
        device: str,
        measurement_generator=None,
        
    ) -> None:
        """The constructor.

        Parameters
        ----------
        benchmark : Benchmark
            The benchmark for which the prior is created.
        config : dict
            The config.
        device : str
            The device on which to run.
        """
        super().__init__()

        self.benchmark = benchmark
        self.config = config
        self.device = device

        x_variance = self.config["gp_x_variance"]
        theta_variance = self.config["gp_theta_variance"]
        x_coef = self.config["x_coef"]
        theta_coef = self.config["theta_coef"]

        if "automatic_kernel" in self.config.keys() and self.config["automatic_kernel"]:
            automatic_kernel_quantile = self.config["automatic_kernel_quantile"]
            
            theta_list, x_list = measurement_generator.get(self.config["automatic_kernel_nb_samples"])

            x_squared_distances = []
            for i, x_1 in enumerate(x_list[:-1, ...]):
                for x_2 in x_list[i+1:, ...]:
                    x_squared_distances.append((x_1 - x_2)**2)

            x_squared_distances_quantile = torch.quantile(torch.stack(x_squared_distances, dim=0), automatic_kernel_quantile, dim=0)
            print("x_squared_distances_quantile = {}".format(x_squared_distances_quantile))
            x_length_scale = torch.sqrt(x_squared_distances_quantile/2)
            x_length_scale[x_length_scale <= 1e-8] = 1.
            print("x_length_scale = {}".format(x_length_scale))

            theta_squared_distances = []
            for i, theta_1 in enumerate(theta_list[:-1, ...]):
                for theta_2 in theta_list[i+1:, ...]:
                    theta_squared_distances.append((theta_1 - theta_2)**2)

            theta_squared_distances_quantile = torch.quantile(torch.stack(theta_squared_distances, dim=0), automatic_kernel_quantile, dim=0)
            print("theta_squared_distances_quantile = {}".format(theta_squared_distances_quantile))
            theta_length_scale = torch.sqrt(theta_squared_distances_quantile/2)
            theta_length_scale[theta_length_scale <= 1e-8] = 1.
            print("theta_length_scale = {}".format(theta_length_scale))

            per_parameter_lengthscale = True

            

        else:
            x_length_scale = self.config["x_length_scale"]
            theta_length_scale = self.config["theta_length_scale"]

            per_parameter_lengthscale = False

        self.log_space = self.config["gp_log_space"]

        self.mean = PriorMean(self.benchmark.get_prior(), self.log_space, device)

        if "kernel_type" in config.keys():
            if config["kernel_type"] == "additive":
                self.kernel = AdditiveKernel(
                    RBFKernel(theta_length_scale, theta_variance, per_parameter_lengthscale=per_parameter_lengthscale),
                    RBFKernel(x_length_scale, x_variance, per_parameter_lengthscale=per_parameter_lengthscale),
                    theta_coef,
                    x_coef
                )
            elif config["kernel_type"] == "multiplicative":
                self.kernel = MultiplicativeKernel(
                    RBFKernel(theta_length_scale, theta_variance, per_parameter_lengthscale=per_parameter_lengthscale),
                    RBFKernel(x_length_scale, x_variance, per_parameter_lengthscale=per_parameter_lengthscale),
                    theta_coef,
                    x_coef
                )
            else:
                raise NotImplementedError("Kernel type not implemented.")
            
        else:
            self.kernel = AdditiveKernel(
                RBFKernel(theta_length_scale, theta_variance), RBFKernel(x_length_scale, x_variance), theta_coef, x_coef
            )

    def create_distribution(
        self, theta: Tensor, x: Tensor
    ) -> torch.distributions.Distribution:
        """Create a distribution over functions evaluated at points (theta, x)

        Parameters
        ----------
        theta : Tensor
            The simulator's parameters associated to the data points.
        x : Tensor
            The observations associated to the data points.

        Returns
        -------
        torch.distributions.Distribution
            A distribution over function outputs at the specified datapoints.
        """
        theta = theta.float()
        x = x.float()
        mean_vector = self.mean(theta, x)
        cov_matrix = self.kernel(theta, x)

        multiplier = 1.
        jitter = torch.eye(x.size(0), dtype=x.dtype, device=x.device) * 1e-6

        cov_matrix = cov_matrix + jitter
        
        while True:
            try:
                L = torch.linalg.cholesky(cov_matrix + multiplier*jitter, upper=False)
                break
            except RuntimeError as err:
                multiplier *= 2.
                if float(multiplier) == float("inf"):
                    raise RuntimeError("Increase to inf jitter")

        distribution = MultivariateNormal(mean_vector, scale_tril=L)
        return distribution

    def sample_functions(self, theta: Tensor, x: Tensor, n_samples: int) -> Tensor:
        """Sample function ouputs.

        Parameters
        ----------
        theta : Tensor
            The simulator's parameters associated to the data points for which to sample functions.
        x : Tensor
            The observations associated to the data points for which to sample functions.
        n_samples : int
            The number of functions to sample.

        Returns
        -------
        Tensor
            probabilities or log probabilities of shape [n_data, n_func, 1]
        """
        distribution = self.create_distribution(theta, x)
        outputs = distribution.rsample((n_samples,)).t()
        outputs = outputs.unsqueeze(dim=2)

        return outputs

    def functions_log_prob(self, theta: Tensor, x: Tensor, outputs: Tensor) -> Tensor:
        """Compute the log probability of functions 

        Parameters
        ----------
        theta : Tensor[n_data, parameter_dim]
            The simulator's parameters associated to the data points at which the functions are evaluated.
        x : Tensor[n_data, observation_dim]
            The observations associated to the data points at which the functions are evaluated.
        outputs : Tensor[n_data, n_func]
            The output probabilities or log probabilities for each data point 

        Returns
        -------
        Tensor[n_func]
            The log probabilities associated to each function.
        """
        distribution = self.create_distribution(theta, x)

        return distribution.log_prob(outputs.t())
