import numpy as np
import numpy.typing as npt
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

# def decomposed_weighted_sum(
#     self,
#     z_pred_mean: torch.Tensor,
#     z: torch.Tensor,
#     z_pred_log_std: torch.Tensor,
#     annealing: float = 1,
# ) -> torch.Tensor:
#     """Decomposed weighted sum KL divergence between a normal distribution and the normal prior distribution.

#     Parameters
#     ----------
#     z_pred_mean : torch.Tensor of shape (batch_size, n_components)
#         The predicted mean of the latent variable.
#     z : torch.Tensor of shape (batch_size, n_components)
#         The sampled latent variable.
#     z_pred_log_std : torch.Tensor of shape (batch_size, n_components)
#         The predicted log standard deviation of the latent variable.
#     annealing : float, optional
#         Annealing factor, by default 1.

#     Returns
#     -------
#     torch.Tensor
#         The decomposed weighted sum KL divergence.
#     """
#     index_code_mutual_information, partial_correlation, dimension_wise_kl = (
#         self.decomposed(z_pred_mean, z, z_pred_log_std)
#     )
#     # if self.prior = "normal":
#     if 0:
#         return (
#             self.analytical(z_pred_mean, z_pred_log_std).mean()
#             + (self.alpha - 1) * index_code_mutual_information
#             + (annealing * self.beta - 1) * partial_correlation
#             + (self.gamma - 1) * dimension_wise_kl
#         )
#     else:
#         return (
#             self.alpha * index_code_mutual_information
#             + annealing * self.beta * partial_correlation
#             + self.gamma * dimension_wise_kl
#         )


def variational_inference(
    encoder: nn.Module,
    decoder: nn.Module,
    encoder_optimizer: torch.optim.Optimizer,
    decoder_optimizer: torch.optim.Optimizer,
    dataloader: DataLoader,
    kl_normal: KLNormal,
    n_epochs: int = 1000,
    annealing: npt.NDArray | None = None,
    writer: SummaryWriter | None = None,
):
    """Variational inference for the encoder and decoder.

    Parameters
    ----------
    encoder : nn.Module
        The encoder.
    decoder : nn.Module
        The decoder.
    encoder_optimizer : torch.optim.Optimizer
        The optimizer for the encoder.
    decoder_optimizer : torch.optim.Optimizer
        The optimizer for the decoder.
    dataloader : DataLoader
        The dataloader.
    kl_normal : KLNormal
        The KL divergence between a normal distribution and the prior distribution.
    n_epochs : int, optional
        Number of epochs, by default 1000.
    annealing : npt.NDArray | None, optional
        Annealing schedule, by default None.
    writer : SummaryWriter | None, optional
        The tensorboard writer, by default None.

    Raises
    ------
    ValueError
        If the length of annealing not equals to n_epochs.
    """

    if writer is None:
        writer = SummaryWriter()

    if annealing is None:
        annealing = np.ones(n_epochs)
    else:
        if len(annealing) != n_epochs:
            raise ValueError("The length of annealing not equals to n_epochs.")

    for epoch in range(n_epochs):
        epoch_reconstruction_loss = 0
        epoch_index_code_mutual_information = 0
        epoch_partial_correlation = 0
        epoch_dimension_wise_kl = 0
        epoch_loss = 0
        for batch, (x,) in enumerate(dataloader):
            batch_size = x.shape[0]

            z_pred_mean, z_pred_log_std = encoder.forward(
                x
            )  # (batch_size, n_components)

            z = encoder.sample(z_pred_mean, z_pred_log_std)

            x_pred_mean = decoder.forward(z)
            reconstruction_loss = -decoder.log_prob(x_pred_mean, x).mean()
            index_code_mutual_information, partial_correlation, dimension_wise_kl = (
                kl_normal.decomposed(z_pred_mean, z, z_pred_log_std)
            )
            kl_loss = (
                kl_normal.alpha * index_code_mutual_information
                + annealing[epoch] * kl_normal.beta * partial_correlation
                + kl_normal.gamma * dimension_wise_kl
            )
            loss = reconstruction_loss + kl_loss

            encoder_optimizer.zero_grad()
            decoder_optimizer.zero_grad()
            loss.backward()
            encoder_optimizer.step()
            decoder_optimizer.step()

            epoch_reconstruction_loss += reconstruction_loss.item()
            epoch_index_code_mutual_information += index_code_mutual_information.item()
            epoch_partial_correlation += partial_correlation.item()
            epoch_dimension_wise_kl += dimension_wise_kl.item()
            epoch_loss += loss.item()

        epoch_reconstruction_loss /= len(dataloader)
        epoch_index_code_mutual_information /= len(dataloader)
        epoch_partial_correlation /= len(dataloader)
        epoch_dimension_wise_kl /= len(dataloader)
        epoch_loss /= len(dataloader)
        writer.add_scalar("reconstruction loss", epoch_reconstruction_loss, epoch)
        writer.add_scalar(
            "index code mutual information", index_code_mutual_information, epoch
        )
        writer.add_scalar("partial correlation", partial_correlation, epoch)
        writer.add_scalar("dimension-wise KL divergence", dimension_wise_kl, epoch)
        writer.add_scalar("loss", epoch_loss, epoch)

    writer.flush()
    writer.close()
