import torch

from log_mobius_scan import log_mobius_scan
from scan import mamba_scan as linear_scan



class KLA(torch.nn.Module):
    def __init__(
        self,
        input_dim: int,
        qk_dim: int,
        v_dim: int,
        output_dim: int | None = None,
    ):
        """
        :param input_dim: Dimension of the input features
        :type input_dim: int
        :param qk_dim: Dimension of the query and key features
        :type qk_dim: int
        :param v_dim: Dimension of the value features (in ssm this is the state dimension)
        :type v_dim: int
        :param output_dim: Dimension of the output features. If None, defaults to input_dim
        :type output_dim: int | None = None
        """
        super(KLA, self).__init__()
        self.input_dim = input_dim
        self.qk_dim = qk_dim
        self.v_dim = v_dim
        self.output_dim = output_dim if output_dim is not None else input_dim

        # Input projection to q, k, v, p
        self.qkvp_proj = torch.nn.Linear(
            self.input_dim,
            2 * self.qk_dim + 2 * self.v_dim,
        )

        # A_log (v_dim, qk_dim) ssm gating parameter
        # NOTE: This is mamba initialization
        A_init = (
            torch.arange(1, self.qk_dim + 1, dtype=torch.float32)
            .unsqueeze(0)
            .repeat(self.v_dim, 1)
            .contiguous()
        )
        self.A_log = torch.nn.Parameter(torch.log(A_init))
        self.A_log._no_weight_decay = True

        # Sigma_log (v_dim, qk_dim) ssm process noise parameter
        # TODO: tune initialization
        Sigma_init = 0.1 * torch.ones((self.v_dim, self.qk_dim), dtype=torch.float32)
        self.Sigma_log = torch.nn.Parameter(torch.log(Sigma_init))
        self.Sigma_log._no_weight_decay = True

        # Output projection
        self.out_proj = torch.nn.Linear(
            self.v_dim,
            self.output_dim,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: (batch_size, seq_len, input_dim)
        Returns:
            torch.Tensor: (batch_size, seq_len, output_dim)
        """
        # General variables creation

        # q: (batch_size, seq_len, qk_dim) Queries
        # k: (batch_size, seq_len, qk_dim) Keys
        # v: (batch_size, seq_len, v_dim) Values
        # p: (batch_size, seq_len, v_dim) Predicted observation precision of Values
        q, k, v, p = torch.split(
            self.qkvp_proj(x),
            [self.qk_dim, self.qk_dim, self.v_dim, self.v_dim],
            dim=-1,
        )

        # Enforce positivity on precision
        # In notes:
        # p_t = 1 / exp(W_p x_t)
        # which is equivalent to:
        # p_t = exp(-W_p x_t)
        # which considering the sign does not matter as long as it's learned, we can do:
        p = torch.exp(p)

        # A: (v_dim, qk_dim) SSM gating parameter (positive)
        A = torch.exp(self.A_log)

        #################################################################################
        # Precision SSM variables creation

        # Sigma: (v_dim, qk_dim) SSM process noise parameter (positive)
        Sigma = torch.exp(self.Sigma_log)

        # phi: (batch_size, seq_len, v_dim, qk_dim)
        # TODO: verify this is using tensor cores
        phi = torch.square(k).unsqueeze(-2) * p.unsqueeze(-1)

        # A_squared: (v_dim, qk_dim)
        A_squared = torch.square(A)

        # alpha: (batch_size, seq_len, v_dim, qk_dim) Precision SSM fractional numerator gating
        alpha = 1 + Sigma * phi

        # beta: (batch_size, seq_len, v_dim, qk_dim) Precision SSM fractional numerator input
        beta = A_squared * phi

        # gamma: (batch_size, seq_len, v_dim, qk_dim) Precision SSM fractional denominator gating
        gamma = Sigma

        # delta: (batch_size, seq_len, v_dim, qk_dim) Precision SSM fractional denominator input
        delta = A_squared

        #################################################################################
        # Information SSM variables creation

        # G: (v_dim, qk_dim) Information SSM gating
        G = 1.0 / A

        # pv: (batch_size, seq_len, v_dim) Precision-weighted values
        pv = p * v

        #kpv: (batch_size, seq_len, v_dim, qk_dim)
        kpv = k.unsqueeze(-2) * pv.unsqueeze(-1)

        #################################################################################
        # Precision SSM computation

        # Lambda: (batch_size, seq_len, v_dim, qk_dim) Precision SSM hidden state
        Lambda = log_mobius_scan(alpha.log(), beta.log(), gamma.log(), delta.log())

        #################################################################################
        # Information SSM computation

        # H: (batch_size, seq_len, v_dim, qk_dim) Information SSM hidden state
        H = linear_scan(G, kpv)

        ##############################################################################
        # Output computation

        # y: (batch_size, seq_len, v_dim) Information SSM output
        y = (q.unsqueeze(-2) * (Lambda * H)).sum(-1)

        # out: (batch_size, seq_len, output_dim) Final output
        out = self.out_proj(y)

        return out
