"""The encodeer, decoder and the parameterization of the mixing matrix for the LINT model."""

import torch
from torch import nn
from torch.nn.utils import parametrize

from pdisvae.models import linear


class DecomposedMixing(nn.Module):
    """The parameterization of the mixing matrix."""

    def __init__(self, convolved_history: torch.Tensor):
        super().__init__()
        self.convolved_history = convolved_history

    def forward(self, right_weight: torch.Tensor) -> torch.Tensor:
        """The parameterization of the mixing matrix.

        Parameters
        ----------
        right_weight : torch.Tensor of shape (n_neurons, n_components)
            The V matrix in the decomposition of the neural connectivity matrix W = UV^T.

        Returns
        -------
        mixing : torch.Tensor of shape (n_time_bins, n_components)
            The mixing matrix.
        """
        return self.convolved_history @ right_weight

    def right_inverse(self, mixing: torch.Tensor) -> torch.Tensor:
        """The right inverse of the mixing matrix.

        Parameters
        ----------
        mixing : torch.Tensor of shape (n_time_bins, n_components)
            The mixing matrix.

        Returns
        -------
        right_weight : torch.Tensor of shape (n_neurons, n_components)
            The V matrix in the decomposition of the neural connectivity matrix W = UV^T.
        """
        return torch.linalg.lstsq(self.convolved_history, mixing).solution


class LintDecoder(linear.LinearDecoder):
    """The LINT decoder that decodes the latent variable z to the observation x.

    Think of a Gaussian observation GLM
    p(x_t^T|x_{1:t-1}^T) = N(x_t|phi_t^T W^T + b^T, sigma^2 I)

    But from the spatial perspective, we can also assume the relationship to be
    p(X|Phi) = N(X|Phi W^T + b 1^T, sigma^2 I)
    where X is R^{T x N}, Phi is R^{T x N}, W is R^{N x N}, b is R^{N x 1}, sigma is R^{N x 1}.
    W_{ij} is the weight from neuron j to neuron i.

    Now, if we decompose W as W = U V^T, we can rewrite the above equation as
    p(X|Phi) = N(X|Phi V U^T + b 1^T, sigma^2 I)
    where U is R^{N x K}, V is R^{N x K}, and K is the number of components.

    Treating Phi V in R^{T x K} as the mixing matrix, it becomes a ICA.
    Columns of U are K independent sources.

    See Section ?? of the ?? paper for more details.

    Parameters
    ----------
    convolved_history : torch.Tensor of shape (n_time_bins, n_neurons) = (obs_dim, n_total_samples)
        The convolved history of the neurons.
    n_components : int
        Number of components in the latent variable z.
        Aka the dimension of the latent variable z, i.e., latent_dim.
    """

    def __init__(self, convolved_history: torch.Tensor, n_components: int):
        obs_dim, self.n_total_samples = convolved_history.shape
        super().__init__(obs_dim, n_components)
        self.decomposed_mixing = DecomposedMixing(convolved_history)
        parametrize.register_parametrization(
            self.mixing_and_bias, "weight", self.decomposed_mixing
        )

    @property
    def n_neurons(self):
        """The number of neurons in the observation is n_total_samples, in terms of spatial ICA."""
        return self.n_total_samples

    @n_neurons.setter
    def n_neurons(self, value):
        self.n_total_samples = value
