"""The base model class for all models."""

from abc import abstractmethod

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class Encoder(nn.Module):
    """Base encoder class.

    Parameters
    ----------
    n_components : int
        Number of components in the latent variable z.
        Aka the dimension of the latent variable z, i.e., latent_dim.
    n_total_samples : int | None, optional
        Number of samples of the whole dataset, by default None.
    """

    def __init__(self, n_components: int, n_total_samples: int | None = None) -> None:
        super().__init__()
        self.n_components = n_components
        # self.log_std = log_std
        # A fixed log_std can be (0.3789 ** 0.5).log()
        # see iclr2025/synthetic_linear_weak/data/data_generator.py
        # see https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.gaussian_kde.html#scipy.stats.gaussian_kde
        # If not using the above fixed log_std, there are two options:
        # 1. The log_std (bandwidth) can be computed via Scotts factor on the whole dataset (n_total_samples)
        # 1 * n_total_samples ** (-1 / (d + 4)), which assumes the projected latent has unit data standard deviation
        # 2. The log_std (bandwidth) can be computed via Scotts factor on the batch (batch_size)
        # diag(var(z_pred_mean, axis=0)) ** 0.5 * batch_size ** (-1 / (d + 4))
        # then the reweights in inference.py::KLNormal::forward should be removed
        # We will just compute within batch partial/total correlation
        # The log_std can be dynamically determined by then encoder output z_pred_mean
        # Using the Scotts factor, the log_std (bandwidth) will be
        self.n_total_samples = n_total_samples
        if self.n_total_samples is None:
            self.log_std = nn.Parameter(torch.ones(n_components))
        else:
            self.log_std = nn.Parameter(
                np.log(self.n_total_samples ** (-1 / (n_components + 4)))
                * torch.ones(n_components)
            )

    @abstractmethod
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass of the encoder.

        Parameters
        ----------
        x : torch.Tensor
            The input data.

        Returns
        -------
        torch.Tensor
            The latent variable.
        """

    @torch.no_grad()
    def log_scaled_scott_factor(self, z_pred_mean: torch.Tensor) -> torch.Tensor:
        """Compute the Scotts factor for the log standard deviation.

        Parameters
        ----------
        z_pred_mean : torch.Tensor of shape (batch_size, n_components)
            The latent variable.

        Returns
        -------
        torch.Tensor of shape (n_components,)
            The log standard deviation.
        """
        scott_factor = z_pred_mean.shape[0] ** (-1 / (z_pred_mean.shape[1] + 4))
        return (
            z_pred_mean.std(dim=0) * scott_factor * torch.ones(self.n_components)
        ).log()

    @torch.no_grad()
    def compute_log_std(self, z_pred_mean: torch.Tensor) -> torch.Tensor:
        if self.log_std is None:
            return self.log_scaled_scott_factor(z_pred_mean)[None, :].expand_as(
                z_pred_mean
            )
        else:
            return self.log_std * torch.ones_like(z_pred_mean)

    def log_prob(
        self, z_pred_mean: torch.Tensor, z: torch.Tensor, z_pred_log_std: torch.Tensor
    ) -> torch.Tensor:
        """The log probability of the latent variable given the predicted latent variable.

        Parameters
        ----------
        z_pred_mean : torch.Tensor of shape (*, n_components)
            The mean of the predicted latent variable.
        z : torch.Tensor of shape (*, n_components)
            The latent variable.
        z_pred_log_std : torch.Tensor of shape (*, n_components)
            The log standard deviation of the predicted latent variable.

        Returns
        -------
        torch.Tensor of shape (*,)
            The log probability of the latent variable given the predicted latent variable.
        """
        return -F.gaussian_nll_loss(
            z_pred_mean, z, (z_pred_log_std.exp() ** 2), full=True, reduction="none"
        ).sum(dim=-1)

    def sample(
        self, z_pred_mean: torch.Tensor, z_pred_log_std: torch.Tensor
    ) -> torch.Tensor:
        """Sample from the predicted latent variable mean and log standard deviation.

        Parameters
        ----------
        z_pred_mean : torch.Tensor of shape (*, n_components)
            The predicted mean of the latent variable.
        z_pred_log_std : torch.Tensor of shape (*, n_components)
            The predicted log standard deviation of the latent variable.

        Returns
        -------
        z : torch.Tensor of shape (*, n_components)
            The sampled latent variable.
        """
        return (
            torch.randn_like(z_pred_mean, device=z_pred_mean.device)
            * z_pred_log_std.exp()
            + z_pred_mean
        )
