from __future__ import annotations

import torch 
import gpytorch
import copy 

from math import pi
from gpytorch.means import Mean
from gpytorch.kernels import Kernel
from gpytorch.likelihoods import (
    Likelihood,
    GaussianLikelihood, 
    DirichletClassificationLikelihood, 
    MultitaskGaussianLikelihood
)
from gpytorch.distributions import (
    MultitaskMultivariateNormal, 
    MultivariateNormal
)


class GPModel(gpytorch.models.ExactGP):
    def __init__(
        self, 
        mean_fn: Mean,
        kernel_fn: Kernel,
        likelihood: Likelihood,
        train_x: torch.Tensor, 
        train_y: torch.Tensor
    ):
        super().__init__(train_x, train_y, likelihood)
        self.mean_fn = mean_fn
        self.kernel_fn = kernel_fn
        self.likelihood = likelihood

    def forward(
        self, 
        x: torch.Tensor
    ):
        mean_x = self.mean_fn(x) 
        covar_x = self.kernel_fn(x)

        if isinstance(self.likelihood, MultitaskGaussianLikelihood):
            return MultitaskMultivariateNormal.from_batch_mvn(
                MultivariateNormal(mean_x, covar_x) # type: ignore
            )
        else:
            return MultivariateNormal(mean_x, covar_x) # type: ignore
    
    @torch.no_grad()
    def prior_predictive(
        self,
        x: torch.Tensor, 
        jitter: float = 1e-10
    ) -> tuple[torch.Tensor, torch.Tensor]:
        # Used in MAP estimation
        mean_x = self.mean_fn(x) # (1, n_samples)
        covar_x = self.kernel_fn(x, diag=False).add_jitter(jitter).to_dense() # type: ignore
        return mean_x, covar_x # type: ignore
    
    @torch.no_grad()
    def marginal_variance(
        self, 
        x: torch.Tensor
    ):
        covar_x = self.kernel_fn(x, diag=True).to_dense()
        return covar_x 

    @torch.no_grad()
    def kernel_linop(
        self, 
        x: torch.Tensor, 
        output_idx: int = -1,
        diag: bool = False
    ) -> gpytorch.lazy.LazyTensor:
        # Used in hessian computation
        if output_idx >= 0:
            return self.kernel_fn(x, diag=diag)[output_idx,:,:]
        else:
            return self.kernel_fn(x, diag=diag)


def optimize_prior_parameters(
    model: GPModel, 
    likelihood: Likelihood,
    train_x: torch.Tensor,
    train_y: torch.Tensor,
    n_steps: int = 1,
    val_frequency: int = 10,
    verbose: bool = False,
    progress_bar: bool = False,
) -> GPModel:
    """
    Optimize the GP prior parameters of the model using the marginal likelihood.
    """
    if not isinstance(likelihood, DirichletClassificationLikelihood):
        train_y = train_y.reshape(-1)
    
    model.train()
    likelihood.train()

    optimizer = torch.optim.LBFGS(model.parameters(), line_search_fn="strong_wolfe", tolerance_change=1e-7)
    
    # Loss function for the GP
    mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)
    
    def closure():
        if torch.is_grad_enabled():
            optimizer.zero_grad()
        output = model(train_x)
        loss = -mll(output, train_y).sum() # type: ignore
        if loss.requires_grad:
            loss.backward()
        return loss
        
    # Run the optimization
    for i in range(n_steps):
        loss = optimizer.step(closure)
        if i % val_frequency == 0:
            print(f"{i}/{n_steps} - {loss.detach()}")

    for param_name, param in model.named_parameters():
        constraint = model.constraint_for_parameter_name(param_name)
        
        if constraint is not None:
            # Transform the parameter to constrained space
            constrained_value = constraint.transform(param.data)
            
            # Handle different parameter shapes
            if constrained_value.numel() == 1:
                print(f"{param_name::<40} {constrained_value.item():.4f} (constrained)", flush=True)
        else:
            # If no constraint found, print raw parameter
            if param.numel() == 1:
                print(f"{param_name::<40} {param.item():.4f} (unconstrained)", flush=True)
    
    noise_var = model.likelihood.noise_covar.noise.item()

    return model, noise_var

    
if __name__ == "__main__":

    dtype = torch.float32
    # Data
    x1 = torch.linspace(-1, -0.5, 50).reshape(-1, 1).to(dtype)
    x2 = torch.linspace(0.5, 1, 50).reshape(-1, 1).to(dtype)
    
    # Train data 
    train_X = torch.cat([x1, x2], dim=0).to(dtype)
    train_Y = torch.sin(2 * pi * train_X) + torch.normal(0, 0.1, (100, 1)).to(dtype)

    #################################################################################
    # Test GP prior
    #################################################################################
    print("Standard GP prior...")
    
    # Prior
    kernel = gpytorch.kernels.ScaleKernel(
        gpytorch.kernels.RBFKernel(
            ard_num_dims=1, 
            active_dims=(0,),
            lengthscale_constraint=gpytorch.constraints.Interval(1e-3, 1e3)
        )
    )
    mean = gpytorch.means.ConstantMean()
    
    # Initialize GP model
    likelihood = GaussianLikelihood()
    likelihood.noise = 0.0001
    prior = GPModel(mean, kernel, likelihood, train_X, train_Y)#.to(self._device)

    # Prior maximum likelihood estimation
    prior, _ = optimize_prior_parameters(prior, likelihood, train_X, train_Y, verbose=True)
    prior.eval()
    likelihood.eval()
    
    mean_x, cov_x = prior.prior_predictive(train_X)
    print("mean_x.shape, cov_x.shape", mean_x.shape, cov_x.shape)
    kernel_linop = prior.kernel_linop(train_X)
    print("kernel_linop.shape", kernel_linop.shape)

    #################################################################################
    # Test Multitask GP prior
    #################################################################################
    print("Multitask GP model")
    # Train data 
    train_X = torch.cat([x1, x2], dim=0).to(dtype)
    train_Y = torch.sin(2 * pi * train_X) + torch.normal(0, 0.1, (100, 2)).to(dtype)
    
    # Prior
    kernel = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(ard_num_dims=1, active_dims=(0,), batch_shape=torch.Size([2])))
    mean = gpytorch.means.ConstantMean(batch_shape=torch.Size([2]))

    # Initialize GP model   
    likelihood = MultitaskGaussianLikelihood(num_tasks=2, has_global_noise=False)
    prior = GPModel(mean, kernel, likelihood, train_X, train_Y)

    # Prior maximum likelihood estimation
    prior = optimize_prior_parameters(prior, likelihood, train_X, train_Y, verbose=True)
    prior.eval()
    likelihood.eval()

    mean_x, cov_x = prior.prior_predictive(train_X)
    print("mean_x.shape, cov_x.shape", mean_x.shape, cov_x.shape)
    kernel_linop = prior.kernel_linop(train_X, output_idx=0)
    print("kernel_linop.shape output_idx=0", kernel_linop.shape)
    #kernel_linop = prior.kernel_linop(train_X, output_idx=1)  # This will raise an error bug in the library
    #print("kernel_linop.shape output_idx=1", kernel_linop.shape)
    kernel_linop = prior.kernel_linop(train_X)
    print("kernel_linop.shape", kernel_linop.shape)

    #################################################################################
    # Test Multitask GP prior
    #################################################################################
    print("Multitask Classification GP model")
    # Train data 
    train_X = torch.cat([x1, x2], dim=0).to(dtype)
    train_Y = torch.round(torch.sin(2 * pi * train_X) + torch.normal(0, 0.1, (100,1)).to(dtype)).long()
    train_Y = train_Y.reshape(-1)
    
    # Prior
    kernel = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(ard_num_dims=1, active_dims=(0,), batch_shape=torch.Size([2])))
    mean = gpytorch.means.ConstantMean(batch_shape=torch.Size([2]))

    # Initialize GP model
    likelihood = DirichletClassificationLikelihood(train_Y, learn_additional_noise=True)
    prior = GPModel(mean, kernel, likelihood, train_X, likelihood.transformed_targets)

    # Prior maximum likelihood estimation
    prior = optimize_prior_parameters(prior, likelihood, train_X, likelihood.transformed_targets, verbose=True)
    prior.eval()
    likelihood.eval()

    mean_x, cov_x = prior.prior_predictive(train_X)
    print("mean_x.shape, cov_x.shape", mean_x.shape, cov_x.shape)
    kernel_linop = prior.kernel_linop(train_X, output_idx=0)
    print("kernel_linop.shape output_idx=0", kernel_linop.shape)
    #kernel_linop = prior.kernel_linop(train_X, output_idx=1) # This will raise an error bug in the library
    #print("kernel_linop.shape output_idx=1", kernel_linop.shape)
    kernel_linop = prior.kernel_linop(train_X)
    print("kernel_linop.shape", kernel_linop.shape)





