from __future__ import annotations

import warnings

warnings.filterwarnings("ignore")

import math
from typing import Any, Callable, Optional, Union

import torch
from botorch.models.model import Model
from botorch.acquisition.objective import PosteriorTransform
from botorch.posteriors import TorchPosterior
from laplace import BaseLaplace, Laplace, Likelihood, PredType, PriorStructure
from laplace.marglik_training import marglik_training
from torch import Tensor
from torch.distributions import MultivariateNormal, Normal
from torch.nn import Linear, Module, MSELoss, ReLU, Sequential
from torch.optim.adam import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader, TensorDataset

from utils.configs import LaplaceConfig
from bayesopt.surrogates.utils import DataScaler

class FixedFeatLaplace(Model):  # type: ignore[misc]
    
    def __init__(
        self,
        train_X: Tensor,
        train_Y: Tensor,
        initialize_nn: Optional[Callable[[], Module]] = None,
        hess_factorization: str = "kron",
        laplace: Optional[BaseLaplace] = None,
        laplace_config: Optional[LaplaceConfig] = None,
        device: str = "cpu",
        dtype: str = "float64",
        normalize_x: bool = False,
    ):
        super().__init__()

        if device not in ["cpu", "cuda"]:
            raise ValueError("Only 'cpu' and 'cuda' devices are supported.")

        self.train_X: Tensor = train_X
        self.train_Y: Tensor = train_Y
        self.device: str = device
        self.dtype: str = dtype
        self.normalize_x: bool = normalize_x
        self.data_scaler = DataScaler(train_X, train_Y)

        self.initialize_nn: Callable[[], Module]
        if initialize_nn is None:
            self.initialize_nn = lambda: Sequential(
                Linear(train_X.shape[-1], 50),
                ReLU(),
                Linear(50, 50),
                ReLU(),
                Linear(50, train_Y.shape[-1]),
            )
        else:
            self.initialize_nn = initialize_nn

        self.laplace_config: LaplaceConfig
        self.hess_factorization = hess_factorization
        if laplace_config is None:
            self.laplace_config = LaplaceConfig()
        else:
            self.laplace_config = laplace_config

        self.noise_var: Optional[float] = self.laplace_config.noise_var

        # Initialize Laplace
        self.laplace: BaseLaplace
        if laplace is None:
            if normalize_x:
                train_X = self.data_scaler.transform_x(train_X)
            train_Y = self.data_scaler.transform_y(train_Y)
            self._train_model(self._get_train_loader(train_X, train_Y))
        else:
            self.laplace = laplace

    def posterior(
        self,
        X: Tensor,
        output_indices: Optional[list[int]] = None,
        observation_noise: Union[bool, Tensor] = False,
        posterior_transform: Optional[PosteriorTransform] = None,
        **kwargs: Any,
    ) -> TorchPosterior:
        if self.normalize_x:
            X = self.data_scaler.transform_x(X)

        mean_y, var_y = self.get_prediction(X, use_test_loader=False, joint=False)

        mean_y, var_y = self.data_scaler.inverse_transform_y(mean_y, var_y)

        if len(var_y.shape) != 3:  # Single objective
            post_pred = TorchPosterior(Normal(mean_y, var_y.squeeze(-1)))  # type: ignore
        else:  # Multi objective
            post_pred = TorchPosterior(MultivariateNormal(mean_y, var_y))  # type: ignore

        if hasattr(self, "outcome_transform"):
            post_pred = self.outcome_transform.untransform_posterior(post_pred)

        if posterior_transform is not None:
            post_pred = posterior_transform(post_pred)

        return post_pred

    def condition_on_observations(
        self, X: Tensor, Y: Tensor, **kwargs: Any
    ) -> FixedFeatLaplace:
        # Append new observation to the current data
        self.train_X = torch.cat([self.train_X, X], dim=0)
        self.train_Y = torch.cat([self.train_Y, Y], dim=0)

        self.data_scaler = DataScaler(self.train_X, self.train_Y)
        train_Y = self.data_scaler.transform_y(self.train_Y)
        if self.normalize_x:
            train_X = self.data_scaler.transform_x(self.train_X)
        else:
            train_X = self.train_X
        
        # Update Laplace with the updated data
        train_loader = self._get_train_loader(train_X, train_Y)
        self._train_model(train_loader)

        return FixedFeatLaplace(
            # Replace the dataset & retrained BNN
            train_X=self.train_X,  # Important!
            train_Y=self.train_Y,  # Important!
            initialize_nn=self.initialize_nn,
            hess_factorization=self.hess_factorization,
            laplace=self.laplace,  # Important!
            laplace_config=self.laplace_config,
            device=self.device,
            dtype=self.dtype,
            normalize_x=self.normalize_x,
        )

    def get_prediction(
        self, test_X: Tensor, joint: bool = True, use_test_loader: bool = False
    ) -> tuple[Tensor, Tensor]:
        """
        Batched Laplace prediction.

        Args:
        -----
        test_X:
            Array of shape `(batch_shape, feature_dim)`.

        joint:
            Whether to do joint predictions (like in GP).

        use_test_loader:
            Set to True if your test_X is large.


        Returns:
        --------
        mean_y:
            Tensor of shape `(batch_shape, num_tasks)`.

        cov_y:
            Tensor of shape `(batch_shape*num_tasks, batch_shape*num_tasks)`
            if joint is True. Otherwise, `(batch_shape, num_tasks, num_tasks)`.
        """
        if self.normalize_x:
            test_X = self.data_scaler.transform_x(test_X)

        if self.laplace is None:
            raise AttributeError("Train your model first before making prediction!")
        test_X = test_X.to(self.dtype)
        if not use_test_loader:
            mean_y, cov_y = self.laplace( # type: ignore
                test_X.to(self.device), 
                pred_type=PredType.GLM, 
                joint=joint # type: ignore
            )
            mean_y, cov_y = self.data_scaler.inverse_transform_y(mean_y, cov_y)
        else:
            test_loader = DataLoader(
                TensorDataset(test_X, torch.zeros_like(test_X)),
                batch_size=self.laplace_config.batch_size,
            )

            mean_y, cov_y = [], []

            for X_batch, _ in test_loader:
                X_batch = X_batch.to(self.device).to(self.dtype)
                _mean_y, _cov_y = self.laplace( # type: ignore
                    X_batch, 
                    pred_type=PredType.GLM, 
                    joint=joint # type: ignore
                )
                mean_y, cov_y = self.data_scaler.inverse_transform_y(mean_y, cov_y)
                mean_y.append(_mean_y)
                cov_y.append(_cov_y)

            mean_y = torch.cat(mean_y, dim=0).squeeze()
            cov_y = torch.cat(cov_y, dim=0).squeeze()

        return mean_y, cov_y

    @property
    def num_outputs(self) -> int:
        """The number of outputs of the model."""
        return self.train_Y.shape[-1]

    def _train_model(self, train_loader: DataLoader[tuple[Tensor, ...]]) -> None:
        cfg: LaplaceConfig = self.laplace_config

        if cfg.marglik_mode == "posthoc":
            self._posthoc_laplace(train_loader)
        else:
            # Online
            la, model, _, _ = marglik_training(
                # Ensure that the base net is re-initialized
                model=self.initialize_nn(),
                train_loader=train_loader,
                likelihood=Likelihood.REGRESSION,
                hessian_structure=self.hess_factorization,
                prior_structure=cfg.prior_prec_structure,
                n_epochs=cfg.n_epochs,
                backend=cfg.hessian_backend,
                optimizer_kwargs={"lr": self.lr},
                scheduler_cls=CosineAnnealingLR,
                scheduler_kwargs={"T_max": self.n_epochs * len(train_loader)},
                marglik_frequency=cfg.online_marglik_freq,
            )
            self.laplace = la

    def _posthoc_laplace(self, train_loader: DataLoader[tuple[Tensor, ...]]) -> None:
        cfg = self.laplace_config

        net = self.initialize_nn()  # Ensure that the base net is re-initialized
        optimizer = Adam(net.parameters(), lr=cfg.lr, weight_decay=cfg.wd)
        scheduler = CosineAnnealingLR(optimizer, cfg.n_epochs * len(train_loader))
        loss_func = MSELoss()

        for i in range(cfg.n_epochs):
            batch_loss = 0
            for x, y in train_loader:
                x, y = x.to(self.device), y.to(self.device)
                optimizer.zero_grad()
                output = net(x)
                loss = loss_func(output, y)
                loss.backward()
                optimizer.step()
                scheduler.step()
                batch_loss += loss.item()
            if i % cfg.val_frequency == 0:
                print(f"Epoch {i+1}/{cfg.n_epochs} -- Loss: {batch_loss/len(train_loader)}", flush=True)

        net.eval()
        self.laplace = Laplace(
            net,
            Likelihood.REGRESSION,
            subset_of_weights=cfg.subset_of_weights,
            hessian_structure=self.hess_factorization,
            backend=cfg.hessian_backend,
        )
        self.laplace.fit(train_loader)

        PRIOR_PREC_SHAPES: dict[str, int] = {
            PriorStructure.SCALAR: 1,
            PriorStructure.LAYERWISE: self.laplace.n_layers,
            PriorStructure.DIAG: self.laplace.n_params,
        }

        # Tune prior precision and observation noise
        log_prior = torch.ones(
            PRIOR_PREC_SHAPES[cfg.prior_prec_structure],
            requires_grad=True,
            device=self.device,
        )
        log_sigma = torch.tensor([0.5*math.log(cfg.noise_var)], requires_grad=True, device=self.device)
        hyper_optimizer = Adam([log_prior, log_sigma], lr=1e-2)

        for _ in range(cfg.posthoc_marglik_iters):
            hyper_optimizer.zero_grad()
            neg_marglik = -self.laplace.log_marginal_likelihood(
                log_prior.exp(), log_sigma.exp()
            )
            neg_marglik.backward()
            hyper_optimizer.step()

        self.laplace.prior_precision = log_prior.detach().exp()
        self.laplace.sigma_noise = log_sigma.detach().exp()

        print("Prior precision", self.laplace.prior_precision)
        print("Sigma noise", self.laplace.sigma_noise)


    def _get_train_loader(
        self,
        train_X: Tensor,
        train_Y: Tensor,
    ) -> DataLoader[tuple[Tensor, ...]]:
        return DataLoader(
            TensorDataset(train_X, train_Y),
            batch_size=self.laplace_config.batch_size,
            shuffle=True,
        )
