import copy
from typing import Tuple

import numpy as np
import torch
from torch import Tensor
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader

from psgld import pSGLD
from loss import GammaNLLLoss, NLLLoss
from utils import set_seed
from network import BayesMLP, GammaVarMLP
from standard_trainer import StandardTrainer


class CUQDNNInitBayesTrainer:
    """A wrapper class for the standard RNNs training and evaluating its own
    performance.
    """

    def __init__(
        self,
        mean_net: BayesMLP,
        var_net: GammaVarMLP,
        device: torch.device = torch.device("cpu"),
        seed: int = 0,
    ) -> None:
        """initialize the StandardRNN class, with RNN architecture, device and
        seed.

        Parameters
        ----------
        net : MDNBayesGRU | MDNBayesMLP
            RNN architecture, could be RNN, LSTM or GRU
        device : torch.device, optional
            device with cpu or gpu, by default torch.device("cpu")
        seed : int, optional
            seed number, by default 0
        """

        # set seed for all components
        self.seed = seed
        set_seed(seed=seed)

        # Model architecture of mean and variance networks
        self.un_trained_mean_net = mean_net
        self.un_trained_var_net = var_net
        print("Mean Net: ", mean_net)
        print("Var Net: ", var_net)

        # set device
        self.device = device

    def configure_var_optimizer(
        self,
        var_net: GammaVarMLP,
        optimizer_name: str = "Adam",
        lr: float = 1e-3,
        weight_decay: float = 1e-6,
    ) -> None:
        """define optimizer of the network

        Parameters
        ----------
        optimizer_name : str, optional
            name of the optimizer, by default "Adam"
        lr : float, optional
            learning rate, by default 1e-3
        weight_decay : float, optional
            weight decay, by default 1e-4

        Raises
        ------
        ValueError
            Undefined optimizer
        """
        # take a copy from the untrained network
        if var_net is not None:
            self.var_net = copy.deepcopy(var_net)
        else:
            self.var_net = copy.deepcopy(self.un_trained_var_net)

        # define optimizer
        if optimizer_name == "Adam":
            self.optimizer = torch.optim.Adam(
                self.var_net.parameters(),
                lr=lr,
                weight_decay=weight_decay
            )
        elif optimizer_name == "SGD":
            self.optimizer = torch.optim.SGD(
                self.var_net.parameters(),
                lr=lr,
                weight_decay=weight_decay
            )
        else:
            raise ValueError("Undefined optimizer")

    def configure_bayes_sampler(
        self,
        mean_net: BayesMLP = None,
        inference_method: str = "pSGLD",
        lr: float = 1e-3,
        gamma: float = 0.9999,
    ) -> None:
        """define optimizer and learning rate scheduler

        Parameters
        ----------
        lr : float
            learning rate, by default 1e-3
        gamma : float, optional
            learning rate decay, by default 0.9999
        """
        if mean_net is not None:
            self.mean_net = copy.deepcopy(mean_net)
        else:
            self.mean_net = copy.deepcopy(self.un_trained_mean_net)
        # define optimizer
        if inference_method == "pSGLD":
            self.sampler = pSGLD(
                self.mean_net.parameters(),
                lr=lr
            )
        else:
            raise ValueError("Undefined inference method")

        # define learning rate scheduler
        self.schedular = ExponentialLR(self.sampler, gamma=gamma)

    def sample_posterior(
        self,
        x: torch.Tensor,
        y: torch.Tensor,
        var_best: torch.Tensor,
        num_epochs: int,
        mix_epochs: int,
        burn_in_epochs: int,
        batch_size: int = None,
        verbose: bool = True,
        print_iter: int = 10,
    ) -> None:

        if batch_size is None:
            batch_size = x.size(0)

        dataset = torch.utils.data.TensorDataset(x, y, var_best)
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

        self.mean_nets = []  # Store state_dicts instead of full models
        self.log_likelihood = []
        self.kl_values = []

        for epoch in range(num_epochs):
            nll_loss_collection = 0.0
            kl_loss_collection = 0.0

            for X_batch, y_batch, var_batch in dataloader:
                X_batch, y_batch, var_batch = (
                    X_batch.to(self.device),
                    y_batch.to(self.device),
                    var_batch.to(self.device),
                )
                self.sampler.zero_grad()
                pred, prior_loss = self.mean_net(X_batch, Train=True)
                nll_loss = NLLLoss()(pred, y_batch, var_batch, len(dataloader))
                loss = nll_loss + prior_loss
                loss.backward()
                self.sampler.step()

                nll_loss_collection += nll_loss.detach()
                kl_loss_collection += prior_loss.detach()

            self.schedular.step()

            if verbose and (epoch + 1) % print_iter == 0:
                print(
                    f"Epoch {epoch+1}/{num_epochs}, NLL Loss: {nll_loss_collection.item():.3e}, Log prior loss: {kl_loss_collection.item():.3e}"
                )

            if epoch >= burn_in_epochs and (epoch % mix_epochs == 0):
                self.mean_nets.append(
                    copy.deepcopy(self.mean_net.state_dict()))
                self.log_likelihood.append(-nll_loss_collection)
                self.kl_values.append(-kl_loss_collection)

    def var_train(
        self,
        x_train: Tensor,
        y_train: Tensor,
        num_epochs: int,
        batch_size: int,
        factor: float = 0.0,
        verbose: bool = False,
        print_iter: int = 100,
        penalty: float = 1.0,
        early_stopping: bool = False,
        early_stopping_iter: float = 100,
        early_stopping_tol: float = 1e-4,
        initialization: bool = False,
        initialized_model: StandardTrainer = None
    ):
        """
        variance: Tensor, optional
        variance of the data, by default None, if refine_var is True
        then the variance is epistemic uncertainty; if refine mean is True
        then the variance is aleatoric uncertainty"""

        min_loss = torch.inf
        # check is we have self.nets or not
        if initialization:
            if isinstance(initialized_model, StandardTrainer):
                train_mean = initialized_model.predict(x_train)
                train_epistemic = 0.0
            else:
                raise ValueError(
                    "Initialized model must be an instance of StandardTrainer")

        else:
            if not hasattr(self, "mean_nets"):
                print("No posterior samples available. Run `sample_posterior` first.")
                print("Training the network from scratch using deterministic training")
            else:
                # get the ppd responses
                train_mean, train_epistemic = self.bayes_predict(
                    x_train, save_ppd=True)

        # get the residuals
        residuals = (y_train - train_mean)**2 - factor * train_epistemic

        # split the data into batches
        if batch_size == -1:
            x_train = torch.split(x_train, x_train.size(0))
            residuals = torch.split(residuals, residuals.size(0))
            y_train = torch.split(y_train, y_train.size(0))
        else:
            x_train = torch.split(x_train, batch_size)
            residuals = torch.split(residuals, batch_size)
            y_train = torch.split(y_train, batch_size)

        # number of batches
        num_batches = len(x_train)

        self.train_loss_collection = torch.zeros(num_epochs)
        self.val_loss_collection = torch.zeros(num_epochs)
        self.nlog_mgks = torch.zeros(num_epochs)
        # count the number of epochs with no improvement
        if early_stopping:
            no_improvement = 0
        # begin the training process
        for epoch in range(num_epochs):
            # set the network to training mode
            self.var_net.train()
            log_mgks_batch = 0.0
            num_sample_count = 0
            for i in range(num_batches):
                self.optimizer.zero_grad()
                alpha, beta, prior_loss = self.var_net.forward(
                    x_train[i], Train=True)
                # get the mean and variance of the prediction and add penalty
                nll_loss = GammaNLLLoss(reduction="sum")(
                    residuals=residuals[i], alpha=alpha, beta=beta, num_scale=num_batches)
                loss = nll_loss + prior_loss + \
                    + penalty * \
                    self.residual_regularizer(
                        var_pred=alpha/beta, residuals=residuals[i])
                # update used samples
                num_batch_sample = x_train[i].shape[0]
                if not initialization:
                    # update the log_margin_likelihood for this batch
                    log_marginal_likelihood = self.log_marginal_likelihood(
                        y_train[i],
                        ppd_responses=self.responses[:, num_sample_count:(
                            num_sample_count+num_batch_sample), ...],
                        refinement="var", var_best=alpha/beta)

                    # sum over the mini-batch
                    log_mgks_batch += log_marginal_likelihood*num_batch_sample
                    num_sample_count += num_batch_sample

                loss.backward()
                self.optimizer.step()

            self.train_loss_collection[epoch] = loss.item()
            if not initialization:
                self.nlog_mgks[epoch] = -log_mgks_batch

                # identify the best log marginal likelihood
                if epoch > 0:
                    relative_improvement = np.abs(
                        self.nlog_mgks[epoch] - min_loss)/np.abs(min_loss)
                else:
                    relative_improvement = 1.0

                if relative_improvement > early_stopping_tol and self.nlog_mgks[epoch] < min_loss:
                    # count the number of epochs with no improvement
                    no_improvement = 0
                    min_loss = self.nlog_mgks[epoch]
                    self.best_var_epoch = epoch
                    self.best_var_net = copy.deepcopy(self.var_net)
                else:
                    no_improvement += 1
                    if no_improvement == early_stopping_iter:
                        break
                if verbose and epoch % print_iter == 0:
                    print("Epoch/Total: %d/%d, Train Gamma NLL Loss: %.3e, NLogMarginal: %.3e, prior loss:%.3e" %
                          (epoch, num_epochs, loss.item(), self.nlog_mgks[epoch], prior_loss.item()))
            else:
                self.best_var_epoch = epoch
                self.best_var_net = copy.deepcopy(self.var_net)
                if verbose and epoch % print_iter == 0:
                    print("Epoch/Total: %d/%d, Train Gamma NLL Loss: %.3e, prior loss:%.3e" %
                          (epoch, num_epochs, loss.item(),  prior_loss.item()))

        if not initialization:
            del self.responses
            print("Training is done, delete the temporary PPD responses to free memory")

        return self.best_var_net, self.best_var_epoch

    def residual_regularizer(self, var_pred, residuals, c=1.0):
        """Residual-based regularization to stabilize aleatoric variance."""
        residual_ratio = residuals**0.5 / torch.sqrt(var_pred + 1e-8)
        penalty = torch.clamp(residual_ratio - c, min=0) ** 2
        return penalty.mean()

    def log_marginal_likelihood(self,
                                y: torch.Tensor,
                                var_best: torch.Tensor = None,
                                ppd_responses: torch.Tensor = None,
                                refinement: str = "mean",
                                ) -> float:
        """
        Evaluation of log marginal likelihood

        Parameters
        ----------
        y : torch.Tensor
            Target values (ground truth).
        var_best : torch.Tensor
            Precomputed aleatoric variance from the best model.
        refinement : str, optional
            Refinement method for log marginal likelihood, by default "mean".

        Returns
        -------
        float
            Log marginal likelihood.
        """
        if refinement == "var":

            # make sure the ppd responses are available
            if ppd_responses is None:
                raise ValueError("PPD responses are not available")

            # conver the ppd responses to tensor
            # ppd_responses = torch.stack(ppd_responses)
            # Compute negative log-likelihood (NLL) for all samples in parallel
            # Shape: (num_samples, batch_size, output_dim)
            residuals = (ppd_responses - y.unsqueeze(0))
            nlls = 0.5 * torch.sum(residuals**2 / var_best.unsqueeze(0) +
                                   torch.log(var_best.unsqueeze(0)), dim=-1)

            # get the kl divergence
            # kl_values = torch.stack(self.kl_values)
            log_likelihoods = - torch.sum(nlls, dim=-1)  # + kl_values
            max_log = torch.max(log_likelihoods)
            log_marginal_likelihood = max_log + \
                torch.log(torch.mean(torch.exp(log_likelihoods - max_log)))

        elif refinement == "mean":
            # Combine log-likelihoods and log-priors (KL values)
            log_posterior_values = torch.stack(
                self.log_likelihood) + torch.stack(self.kl_values)
            # Use log-sum-exp trick for numerical stability
            max_log = torch.max(log_posterior_values)
            log_marginal_likelihood = max_log + torch.log(
                torch.mean(torch.exp(log_posterior_values - max_log))
            )

        return log_marginal_likelihood.item()

    def bayes_predict(
        self,
        x: torch.Tensor,
        save_ppd: bool = False,
    ) -> Tuple[Tensor, Tensor]:
        """Predict the mean and variance of the output at the scaled data.

        Parameters
        ----------
        x : torch.Tensor
            Test data points.
        save_ppd : bool, optional
            Whether to save posterior predictive distributions (default is False).

        Returns
        -------
        Tuple[Tensor, Tensor]
            Predicted mean and variance at the scaled space.
        """
        self.mean_net.eval()  # Ensure the model is in evaluation mode
        responses = []

        for state_dict in self.mean_nets:
            # Create a fresh model instance
            temp_model = copy.deepcopy(self.mean_net)
            temp_model.load_state_dict(state_dict)  # Load the state dictionary
            temp_model.to(self.device)
            temp_model.eval()  # Set to evaluation mode
            with torch.no_grad():  # Disable gradient computation
                y_pred = temp_model.forward(x, Train=False)
                responses.append(y_pred)

        # Stack the predictions and calculate the mean and variance
        # Shape: (num_samples, batch_size, output_dim)
        responses = torch.stack(responses)
        y_pred_mean = torch.mean(responses, dim=0)  # Mean across samples
        y_pred_var = torch.var(responses, dim=0)  # Variance across samples

        if save_ppd:
            self.responses = responses  # Save posterior predictive distributions
            print("PPD responses are saved with shape:", self.responses.shape)

        return y_pred_mean.detach(), y_pred_var.detach()
