from __future__ import annotations

import warnings

warnings.filterwarnings("ignore")
import math
from typing import Any, Callable, Optional, Union
import gpytorch
import torch
import numpy as np
from botorch.models.model import Model
from botorch.acquisition.objective import PosteriorTransform
from botorch.posteriors import TorchPosterior
from torch import Tensor
from torch.distributions import MultivariateNormal, Normal
from torch.nn import Linear, Module, ReLU, Sequential
from torch.utils.data import DataLoader, TensorDataset
import torch.utils.data as data_utils
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.kernels import ProductKernel, AdditiveKernel, ScaleKernel

from utils.configs import FSPLaplaceConfig

from bayesopt.surrogates.fsplaplace_utils.prior import (
    GPModel, 
    optimize_prior_parameters
)

from bayesopt.surrogates.fsplaplace_utils.fsplaplace_model import (
    FSPLaplace, 
    fit_fsplaplace, 
    MAPRandomDataset
)

from bayesopt.surrogates.utils import DataScaler


class FixedFeatFSPLaplace(Model):  # type: ignore[misc]

    def __init__(
        self,
        train_X: Tensor,
        train_Y: Tensor,
        candidate_X: Tensor,
        mean_fn: gpytorch.means.Mean,
        kernel_fn: gpytorch.kernels.Kernel,
        initialize_nn: Optional[Callable[[], Module]] = None,
        laplace: Optional[FSPLaplace] = None,
        laplace_config: Optional[FSPLaplaceConfig] = None,
        device: torch.device | str = "cpu",
        dtype: str = "float64",
        normalize_x: bool = False,
    ):
        super().__init__()
        if device not in ["cpu", "cuda"]:
            raise ValueError("Only 'cpu' and 'cuda' devices are supported.")
        
        self.train_X: Tensor = train_X.to(device).to(dtype)
        self.train_Y: Tensor = train_Y.to(device).to(dtype)
        self.candidate_X: Tensor = candidate_X.to(dtype).to(device) 
        self.device: str = device
        self.dtype = torch.float32 if dtype == "float32" else torch.float64
        self.normalize_x = normalize_x  
        self.data_scaler = DataScaler(train_X, train_Y)

        self.laplace_config: FSPLaplaceConfig
        if laplace_config is None:
            self.laplace_config = FSPLaplaceConfig()
        else:
            self.laplace_config = laplace_config

        self.noise_var = self.laplace_config.noise_var 

        # Initialize the GP prior
        if self.normalize_x:
            train_X = self.data_scaler.transform_x(train_X)
        train_Y = self.data_scaler.transform_y(train_Y)

        likelihood = GaussianLikelihood().to(device).to(dtype)
        likelihood.noise = self.noise_var 
        # Initialize the kernel of the model
        self.mean_fn = mean_fn.to(device).to(dtype)
        self.kernel_fn = kernel_fn.to(device).to(dtype)
        feature_dim = train_X.shape[-1]
        self.init_kernel_params(self.kernel_fn, feature_dim)
        self.prior = GPModel(
            self.mean_fn, 
            self.kernel_fn,
            likelihood,
            train_X,
            train_Y
        ).to(self.device).to(self.dtype)

        self.initialize_nn: Callable[[], Module]
        if initialize_nn is None:
            self.initialize_nn = lambda: Sequential(
                Linear(train_X.shape[-1], 50),
                ReLU(),
                Linear(50, 50),
                ReLU(),
                Linear(50, train_Y.shape[-1]),
            )
        else:
            self.initialize_nn = initialize_nn

        # Initialize Laplace
        self.laplace: FSPLaplace
        if laplace is None:
            self._train_model(self._get_train_loader(train_X, train_Y))
        else:
            self.laplace = laplace


    def posterior(
        self,
        X: Tensor,
        output_indices: Optional[list[int]] = None,
        observation_noise: Union[bool, Tensor] = False,
        posterior_transform: Optional[PosteriorTransform] = None,
        **kwargs: Any,
    ) -> TorchPosterior:
        if self.normalize_x:
            X = self.data_scaler.transform_x(X)

        X = X.to(self.dtype).to(self.device)
        mean_y, var_y = self.get_prediction(X, use_test_loader=False, joint=False)
        mean_y, var_y = self.data_scaler.inverse_transform_y(mean_y, var_y)
        
        if len(var_y.shape) != 3:  # Single objective
            post_pred = TorchPosterior(Normal(mean_y, var_y.squeeze(-1)))  # type: ignore
        else:  # Multi objective
            post_pred = TorchPosterior(MultivariateNormal(mean_y, var_y))  # type: ignore

        if hasattr(self, "outcome_transform"):
            post_pred = self.outcome_transform.untransform_posterior(post_pred)

        if posterior_transform is not None:
            post_pred = posterior_transform(post_pred)

        return post_pred

    def condition_on_observations(
        self, X: Tensor, Y: Tensor, **kwargs: Any
    ) -> FixedFeatFSPLaplace:
        # Append new observation to the current data
        self.train_X = torch.cat([self.train_X, X], dim=0).to(self.dtype).to(self.device)
        self.train_Y = torch.cat([self.train_Y, Y], dim=0).to(self.dtype).to(self.device)

        return FixedFeatFSPLaplace(
            # Replace the dataset & retrained BNN
            train_X=self.train_X,  # Important!
            train_Y=self.train_Y,  # Important!
            candidate_X=self.candidate_X,
            mean_fn=self.mean_fn,
            kernel_fn=self.kernel_fn,
            initialize_nn=self.initialize_nn,
            laplace_config=self.laplace_config,
            device=self.device,
            dtype=self.dtype,
            normalize_x=self.normalize_x
        )

    def get_prediction(
        self, test_X: Tensor, joint: bool = True, use_test_loader: bool = False
    ) -> tuple[Tensor, Tensor]:
        if self.normalize_x:
            test_X = self.data_scaler.transform_x(test_X)

        if self.laplace is None:
            raise AttributeError("Train your model first before making prediction!")
        
        test_X = test_X.to(self.dtype)
        if not use_test_loader:
            mean_y, cov_y = self.laplace(
                test_X.to(self.device), joint=joint
            )
            mean_y, cov_y = self.data_scaler.inverse_transform_y(mean_y, cov_y)
        else:
            test_loader = DataLoader(
                TensorDataset(test_X, torch.zeros_like(test_X)),
                batch_size=self.laplace_config.batch_size,
            )

            mean_y, cov_y = [], []

            for X_batch, _ in test_loader:
                X_batch = X_batch.to(self.device).to(self.dtype)
                _mean_y, _cov_y = self.laplace(
                    X_batch, joint=joint
                )
                _mean_y, _cov_y = self.data_scaler.inverse_transform_y(_mean_y, _cov_y)
                mean_y.append(_mean_y)
                cov_y.append(_cov_y)

            mean_y = torch.cat(mean_y, dim=0).squeeze()
            cov_y = torch.cat(cov_y, dim=0).squeeze()

        return mean_y, cov_y

    @property
    def num_outputs(self) -> int:
        """The number of outputs of the model."""
        return self.train_Y.shape[-1]

    def _train_model(
        self, 
        train_loader: DataLoader[tuple[Tensor, ...]]
    ) -> None:
        cfg = self.laplace_config
        
        train_X, train_Y = train_loader.dataset.tensors
        
        noise_var = self.noise_var
        if isinstance(self.prior, GPModel):
            self.prior, noise_var = optimize_prior_parameters(
                self.prior, 
                self.prior.likelihood, 
                train_X, 
                train_Y, 
                n_steps=cfg.prior_n_steps, 
                val_frequency=cfg.prior_val_frequency,
                verbose=False
            )
            self.prior.eval()
            self.prior.likelihood.eval()
        
        model = self.initialize_nn().to(self.dtype).to(self.device)
        context_loader = self._get_context_loader("map")
        self.noise_var = fit_fsplaplace(
            model, 
            self.prior, 
            likelihood="regression",
            train_loader=train_loader, 
            context_loader=context_loader, 
            device=self.device, 
            dtype=self.dtype, 
            n_outputs=self.num_outputs, 
            noise_var=noise_var, #cfg.noise_var,
            lr=cfg.lr,
            n_epochs=cfg.n_epochs,
            val_frequency=cfg.val_frequency,
            early_stopping_patience=cfg.early_stopping_patience,
            jitter=cfg.jitter,
        )

        context_loader = self._get_context_loader("covariance")
        sigma_noise = self.noise_var**0.5 * torch.ones(1, device=self.device, dtype=self.dtype)
        self.laplace = FSPLaplace(
            model,
            self.prior,
            likelihood="regression",
            sigma_noise=sigma_noise,
            params_sketch=cfg.params_sketch,
            params_sketch_dim=cfg.params_sketch_dim,
            max_rank=cfg.max_rank, 
            dtype=self.dtype
        )
        self.laplace.fit(
            train_loader,
            context_loader,
            n_chunks=cfg.n_chunks,
        )


    def _get_train_loader(
        self,
        train_X: Tensor,
        train_Y: Tensor,
    ) -> DataLoader[tuple[Tensor, ...]]:
        return DataLoader(
            TensorDataset(train_X, train_Y),
            batch_size=self.laplace_config.batch_size,
            shuffle=True,
        )
    
    def _get_context_loader(self, mode: str):
        cfg = self.laplace_config
        candidate_X = self.candidate_X
        if self.normalize_x:
            candidate_X = self.data_scaler.transform_x(self.candidate_X)
        full_X = torch.cat([self.train_X, candidate_X], dim=0)
        input_size = self.train_X.shape[1:]
        x_mins = full_X.min(dim=0)[0]
        x_maxs = full_X.max(dim=0)[0]
        min_val = x_mins - 0.5 * (x_maxs - x_mins)
        max_val = x_maxs + 0.5 * (x_maxs - x_mins)
        if mode == "map":
            if cfg.map_context_points in ["uniform", "uniform_bitstring"]:
                dataset = MAPRandomDataset(
                    batch_size=self.laplace_config.context_points_batch_size,
                    input_size=input_size,
                    min_val=min_val,
                    max_val=max_val,
                    device=self.device,
                    dtype=self.dtype,
                    distribution=cfg.map_context_points,
                    generator_seed=cfg.generator_seed,
                )
                return DataLoader(
                    dataset,
                    batch_size=None,  # Batch size is handled by the dataset
                )
            elif cfg.map_context_points == "bo_candidates":
                return data_utils.DataLoader(
                    data_utils.TensorDataset(candidate_X),
                    batch_size=self.laplace_config.context_points_batch_size,
                    shuffle=True
                )
            else:
                raise ValueError(f"Invalid context point type: {cfg.map_context_points}")
        elif mode == "covariance":
            if cfg.cov_context_points == "sobol":
                qmc = torch.quasirandom.SobolEngine(dimension=np.prod(input_size))
                context_X = qmc.draw(cfg.n_context_points_cov).reshape(-1, *input_size).to(self.device).to(self.dtype)
                if x_mins.max() == 0. and x_maxs.min() == 1.:
                    context_X = (context_X > 0.5).to(self.dtype)
                else:
                    context_X = context_X * (max_val - min_val) + min_val
            elif cfg.cov_context_points == "bo_candidates":
                context_X = candidate_X.to(self.device).to(self.dtype)
            else:  
                raise ValueError(f"Invalid context point type: {cfg.cov_context_points}")
            return data_utils.DataLoader(
                data_utils.TensorDataset(context_X),
                batch_size=self.laplace_config.batch_size,
                shuffle=True
            )
        else:
            raise ValueError(f"Invalid mode: {mode}")
        

    def init_kernel_params(self, kernel, feature_dim):
        # Check if the kernel is a composition kernel
        if isinstance(kernel, (ProductKernel, AdditiveKernel)):
            # Recursively initialize sub-kernels
            for sub_kernel in kernel.kernels:
                self.init_kernel_params(sub_kernel, feature_dim)
        elif isinstance(kernel, ScaleKernel):
            self.init_kernel_params(kernel.base_kernel, feature_dim)
            
        # For kernels with length scale, set it to sqrt(input_dim)
        if kernel.has_lengthscale:
            kernel.lengthscale = math.sqrt(feature_dim)

        # For kernels with outputscale, set it to 1
        if hasattr(kernel, 'raw_outputscale'):
            kernel.outputscale = 1.0
        
        # For kernels with variance, set it to 1
        if hasattr(kernel, 'raw_variance'):
            kernel.variance = 1.0

