"""The linear encoder and decoder models."""

import torch
import torch.nn.functional as F
from sklearn.decomposition import FastICA
from torch import nn

from pdisvae.models.model import Encoder


class LinearDecoder(nn.Module):
    """The generating model p(x|z) that decodes the latent variable z to the observation x.

    The generative model is a linear model with Gaussian noise.
    I.e., p(x|z) = N(x|Wz + b, \\sigma^2 I).

    Parameters
    ----------
    obs_dim : int
        Dimension of the observation.
    n_components : int
        Number of components in the latent variable z.
        Aka the dimension of the latent variable z, i.e., latent_dim.
    """

    def __init__(self, obs_dim: int, n_components: int):
        super().__init__()

        self.obs_dim: int = obs_dim
        self.n_components: int = n_components

        self.log_std = nn.Parameter(-torch.ones((self.obs_dim)))
        self.mixing_and_bias = nn.Linear(self.n_components, self.obs_dim)

    @property
    def n_time_bins(self):
        """The number of time bins in the observation is the obs_dim, in terms of spatial ICA."""
        return self.obs_dim

    @n_time_bins.setter
    def n_time_bins(self, value):
        self.obs_dim = value

    @property
    def std(self) -> torch.Tensor:
        """The standard deviation of the Gaussian noise.

        Returns
        -------
        torch.Tensor of shape (obs_dim,)
            The standard deviation of the Gaussian noise.
        """
        return self.log_std.exp()

    @property
    def var(self) -> torch.Tensor:
        """The variance of the Gaussian noise.

        Returns
        -------
        torch.Tensor of shape (obs_dim,)
            The variance of the Gaussian noise.
        """
        return self.std**2

    def initialize(
        self,
        method: str | None = None,
        x: torch.Tensor | None = None,
        mixing: torch.Tensor | None = None,
        bias: torch.Tensor | None = None,
    ) -> None:
        """Initialize the parameters of the generative model.

        Parameters
        ----------
        method : str | None, optional
            Initialize by a method in ["ICA" | ], by default None
        x : torch.Tensor of shape (n_samples, obs_dim) | None, optional
            The provided dataset of shape (n_samples, obs_dim), by default None
        mixing : torch.Tensor of shape (obs_dim, n_components) | None, optional
            The initialization of the mixing, by default None
        bias : torch.Tensor of shape (obs_dim,) | None, optional
            The initialization of the bias, by default None
        """
        if method == "ICA":
            if x is None:
                raise ValueError("x must be provided when initializing by ICA.")
            fastica = FastICA(n_components=self.n_components)
            fastica.fit(x.detach().numpy())
            self.mixing_and_bias.weight.data = torch.from_numpy(fastica.mixing_).to(
                torch.float32
            )
            self.mixing_and_bias.bias.data = torch.from_numpy(fastica.mean_).to(
                torch.float32
            )
        else:
            if mixing is not None:
                self.mixing_and_bias.weight.data = mixing.detach().clone()
            else:
                self.mixing_and_bias.weight.data = torch.zeros(
                    (self.obs_dim, self.n_components)
                )
            if bias is not None:
                self.mixing_and_bias.bias.data = bias.detach().clone()
            else:
                self.mixing_and_bias.bias.data = torch.zeros(self.obs_dim)

    def forward(self, z: torch.Tensor) -> torch.Tensor:
        """The forward pass of the generative model.

        Parameters
        ----------
        z : torch.Tensor of shape (*, n_components)
            The latent variable.

        Returns
        -------
        x_pred_mean : torch.Tensor of shape (*, obs_dim)
            The mean of the predicted observation.
        """
        return self.mixing_and_bias(z)

    def log_prob(self, x_pred_mean: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        """The log probability of the observation given the predicted observation.

        Parameters
        ----------
        x_pred_mean : torch.Tensor of shape (*, obs_dim)
            The mean of the predicted observation.
        x : torch.Tensor of shape (*, obs_dim)
            The observation.

        Returns
        -------
        torch.Tensor of shape (*,)
            The log probability of the observation given the predicted observation.
        """
        return -F.gaussian_nll_loss(
            x_pred_mean, x, self.var.expand_as(x), full=True, reduction="none"
        ).sum(dim=-1)


class LinearEncoder(Encoder):
    """The inference model q(z|x) that infers the latent variable z from the observation x.

    The inference model is a linear model with Gaussian noise.
    I.e., q(z|x) = N(z|Wx + b, \\sigma^2 I).

    Here, we need to fix the log standard deviation of the Gaussian noise to be a constant.
    So that, the partial/total correlation between datasets are can be fairly compared.

    Parameters
    ----------
    obs_dim : int
        Dimension of the observation.
    """

    def __init__(
        self,
        obs_dim: int,
        n_components: int,
        n_total_samples: int | None = None,
    ):
        super().__init__(n_components, n_total_samples)

        self.obs_dim = obs_dim

        self.fc_mean = nn.Linear(obs_dim, self.n_components)

    @property
    def n_time_bins(self):
        """The number of time bins in the observation is the obs_dim, in terms of spatial ICA."""
        return self.obs_dim

    @n_time_bins.setter
    def n_time_bins(self, value):
        self.obs_dim = value

    def initialize(
        self,
        method: str | None = None,
        x: torch.Tensor | None = None,
        fc_mean_weight: torch.Tensor | None = None,
        fc_mean_bias: torch.Tensor | None = None,
    ) -> None:
        """Initialize the parameters of the inference model.

        Parameters
        ----------
        method : str | None, optional
            Initialize by a method in ["ICA" | ], by default None.
        x : torch.Tensor of shape (n_samples, obs_dim) | None, optional
            The provided dataset of shape (n_samples, obs_dim), by default None.
        fc_mean_weight : torch.Tensor of shape (n_components, obs_dim) | None, optional
            The initialization of the mean weight, by default None.
        fc_mean_bias : torch.Tensor of shape (n_components,) | None, optional
            The initialization of the mean bias, by default None.
        """
        if method == "ICA":
            if x is None:
                raise ValueError("x must be provided when initializing by ICA.")
            fastica = FastICA(n_components=self.n_components)
            fastica.fit(x.detach().numpy())
            self.fc_mean.weight.data = torch.from_numpy(fastica.components_).to(
                torch.float32
            )
            self.fc_mean.bias.data = torch.from_numpy(
                fastica.components_ @ fastica.mean_
            ).to(torch.float32)
        else:
            if fc_mean_weight is not None:
                self.fc_mean.weight.data = fc_mean_weight.detach().clone()
            else:
                self.fc_mean.weight.data = torch.zeros(
                    (self.n_components, self.obs_dim)
                )
            if fc_mean_bias is not None:
                self.fc_mean.bias.data = fc_mean_bias.detach().clone()
            else:
                self.fc_mean.bias.data = torch.zeros(self.n_components)

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """Forward pass of the encoder.

        Parameters
        ----------
        x : torch.Tensor of shape (*, obs_dim)
            The observation.

        Returns
        -------
        z_pred_mean : torch.Tensor of shape (*, n_components)
            The mean of the predicted latent variable.
        z_pred_log_std : torch.Tensor of shape (*, n_components)
            The log standard deviation of the predicted latent variable.
        """
        # return self.fc_mean(x), self.fc_log_std(torch.tanh(x))
        z_pred_mean = self.fc_mean(x)
        # return z_pred_mean, self.compute_log_std(z_pred_mean)
        return z_pred_mean, self.log_std[None, :].expand_as(z_pred_mean)
