import lightning as L
import torch
from numpy.typing import NDArray
from torch import nn

from pdisvae import kl


class LitBTCVI(L.LightningModule):
    def __init__(
        self,
        encoder: nn.Module,
        decoder: nn.Module,
        kl_normal: kl.KLNormal,
        extra_beta: NDArray,
        learning_rate: float = 1e-3,
    ):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.kl_normal = kl_normal
        self.extra_beta = extra_beta
        self.learning_rate = learning_rate

        self.save_hyperparameters(ignore=["encoder", "decoder", "kl_normal"])

    def training_step(self, batch, batch_idx):
        x = batch[0]
        z_pred_mean, z_pred_log_std = self.encoder.forward(
            x
        )  # (batch_size, n_components)

        z = self.encoder.sample(z_pred_mean, z_pred_log_std)

        x_pred_mean = self.decoder.forward(z)
        reconstruction_loss = -self.decoder.log_prob(x_pred_mean, x).mean()

        if self.kl_normal.prior == "logcosh":
            index_code_mutual_information, partial_correlation, dimension_wise_kl = (
                self.kl_normal.decomposed(z_pred_mean, z, z_pred_log_std)
            )
            kl_loss = (
                index_code_mutual_information + partial_correlation + dimension_wise_kl
            )
        elif self.kl_normal.prior == "normal":
            partial_correlation = self.kl_normal.partial_correlation(
                z_pred_mean, z, z_pred_log_std
            )
            kl_loss = self.kl_normal.analytical(z_pred_mean, z_pred_log_std).mean()

        extra_partial_correlation_loss = (
            self.extra_beta[self.current_epoch] * partial_correlation
        )
        loss = reconstruction_loss + kl_loss + extra_partial_correlation_loss

        self.log_dict(
            {
                "reconstruction loss": reconstruction_loss.item(),
                "kl loss": kl_loss.item(),
                "partial correlation": partial_correlation.item(),
                "loss": loss.item(),
            },
        )
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer
