import math
from abc import abstractmethod

import torch
import torch.nn as nn
import torch.nn.utils.parametrize as parametrize
from torch.nn.utils.parametrizations import _Orthogonal

from constants import ATTR_WEIGHT
from fast_hadamard_transform import hadamard_transform

MATRIX_INITS = ("identity", "orthogonal", "hadamard", "random")


def init_matrix(
    size: int, 
    init: str, 
    device: torch.device = None,
    dtype: torch.dtype = None,
):
    assert init in MATRIX_INITS, f"Invalid matrix initialization {init}"

    if init == "identity":
        m = torch.eye(size, device=device, dtype=dtype)
    elif init == "orthogonal":
        m = torch.empty(size, size, device=device, dtype=dtype)
        # m = torch.empty(size, size, device='cpu', dtype=torch.float32)  # TODO: need to revert this once moving to work on cuda
        # m = m.to('cpu')
        nn.init.orthogonal_(m)
        m = m.to(device=device, dtype=dtype)
    elif init == "hadamard":
        m = torch.eye(size, device=device, dtype=dtype)
        m = hadamard_transform(m, scale=1.0 / math.sqrt(size))

    elif init == "random":
        m = torch.randn(size, size, device=device, dtype=dtype) + torch.eye(size, device=device, dtype=dtype)
    return m


class BaseMatrix(nn.Module):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    @abstractmethod
    def forward(self) -> torch.Tensor:
        pass

    @abstractmethod
    def inv_t(self) -> torch.Tensor:
        pass

    @abstractmethod
    def remove_parametrizations(self) -> None:
        pass


class GeneralMatrix(BaseMatrix):

    def __init__(
        self, 
        size: int,
        init: str,
        device: torch.device = None,
        dtype: torch.dtype = None,
    ):
        super().__init__()
        self.weight = nn.Parameter(init_matrix(size, init, device, dtype))

    def forward(self) -> torch.Tensor:
        return self.weight

    def inv_t(self) -> torch.Tensor:
        return self.weight.pinverse().T
    
    def remove_parametrizations(self) -> None:
        pass


class OrthogonalMatrix(BaseMatrix):

    def __init__(
        self, 
        size: int,
        init: str,
        device: torch.device = None, 
        dtype: torch.dtype = None,
    ):
        super().__init__()
        self.weight = nn.Parameter(init_matrix(size, init, device, dtype))
        orth = _Orthogonal(self.weight, "cayley", use_trivialization=False)
        parametrize.register_parametrization(self, "weight", orth)

    def forward(self) -> torch.Tensor:
        return self.weight

    def inv_t(self) -> torch.Tensor:
        return self.weight
    
    def remove_parametrizations(self) -> None:
        parametrize.remove_parametrizations(self, ATTR_WEIGHT, leave_parametrized=True)


class LearnableInvMatrix(BaseMatrix):
    """
    A learnable invertible matrix using LU decomposition.

    Decomposes the matrix as P @ L @ diag(sign_s * exp(log_s)) @ (I + U)
    where P is a permutation matrix (fixed), L is lower triangular with ones on diagonal,
    U is strictly upper triangular, and s are the diagonal scaling factors.
    """

    def __init__(
        self,
        size: int,
        init: str,
        device: torch.device,
        dtype: torch.dtype,
        init_matrix_tensor: torch.Tensor = None
    ):
        super().__init__()
        # Save the forward dtype (bf16 / fp16 / fp32) used by the model
        self._forward_dtype = dtype

        if init_matrix_tensor is not None:
            transform_matrix = init_matrix_tensor.to(device=device, dtype=torch.float32)
        else:
            transform_matrix = init_matrix(size, init, device, torch.float32)

        # LU decomposition in float32 for numerical stability
        transform_matrix_f32 = transform_matrix.to(dtype=torch.float32)
        LU, pivots = torch.lu(transform_matrix_f32)
        P, L, U = torch.lu_unpack(LU, pivots)

        s = torch.diag(U)
        # Extract diagonal from U
        sign_s = torch.sign(s)
        log_s = torch.log(torch.abs(s) + 1e-12)

        U = torch.triu(U, diagonal=1)
        # Make U strictly upper triangular (remove diagonal)

        w_shape = transform_matrix.shape
        l_mask = torch.tril(
            torch.ones(w_shape, dtype=torch.float32, device=device),
            diagonal=-1,
        )
        eye = torch.eye(size, dtype=torch.float32, device=device)

        # Buffers stored in forward dtype to match model params
        self.register_buffer("p", P.to(device=device, dtype=dtype))
        self.register_buffer("sign_s", sign_s.to(device=device, dtype=dtype))
        self.register_buffer("l_mask", l_mask.to(device=device, dtype=dtype))
        self.register_buffer("eye", eye.to(device=device, dtype=dtype))

        # Parameters in forward dtype (train_transform_matrix also enforces fp32)
        self.log_s = nn.Parameter(log_s.to(device=device, dtype=dtype))
        self.l = nn.Parameter(L.to(device=device, dtype=dtype))
        self.u = nn.Parameter(U.to(device=device, dtype=dtype))

        self._cached_weight = None
        self._cached_inv_weight = None

    def _compose_weight(self) -> torch.Tensor:
        """Compose W in float32 for stability, then cast to forward dtype."""
        device = self.log_s.device

        # Promote to float32 for matrix algebra
        P = self.p.to(device=device, dtype=torch.float32)
        sign_s = self.sign_s.to(device=device, dtype=torch.float32)
        l_mask = self.l_mask.to(device=device, dtype=torch.float32)
        eye = self.eye.to(device=device, dtype=torch.float32)

        L = self.l.to(device=device, dtype=torch.float32) * l_mask + eye

        diag_u = sign_s * torch.exp(self.log_s.to(device=device, dtype=torch.float32))
        U_strict = self.u.to(device=device, dtype=torch.float32) * l_mask.t()
        U = U_strict + torch.diag(diag_u)

        W = P @ (L @ U)
        return W.to(dtype=self._forward_dtype)

    def forward(self) -> torch.Tensor:
        if self._cached_weight is None:
            self._cached_weight = self._compose_weight()
        return self._cached_weight

    def inv_t(self) -> torch.Tensor:
        if self._cached_inv_weight is None:
            W = self._cached_weight
            if W is None:
                W = self._compose_weight()
                self._cached_weight = W

            W_f32 = W.to(dtype=torch.float32)
            W_inv = torch.linalg.inv(W_f32)
            self._cached_inv_weight = W_inv.to(dtype=self._forward_dtype)

        return self._cached_inv_weight.T

    def reset_cache(self):
        self._cached_weight = None
        self._cached_inv_weight = None

    def remove_parametrizations(self) -> None:
        pass


class LearnableQRMatrix(BaseMatrix):
    """
    A matrix parametrized via QR decomposition where Q is obtained from
    the exponential map of a skew-symmetric matrix.

    W = Q @ R
    Q = exp(A), where A is skew-symmetric (A^T = -A)
    R is upper triangular with positive diagonal elements
    """

    def __init__(
            self,
            size: int,
            init: str,
            device: torch.device = None,
            dtype: torch.dtype = None,
            init_matrix_tensor: torch.Tensor = None
    ):
        super().__init__()

        self._forward_dtype = dtype

        # Initialize with a matrix and decompose it
        if init_matrix_tensor is not None:
            init_matrix_f32 = init_matrix_tensor.to(device=device, dtype=torch.float32)
        else:
            init_matrix_f32 = init_matrix(size, init, device, torch.float32)

        # Perform QR decomposition
        Q, R = torch.linalg.qr(init_matrix_f32)

        # Ensure R has positive diagonal (standard form)
        d = torch.diag(R)
        sign_d = torch.sign(d)
        sign_d[sign_d == 0] = 1  # Handle zeros
        Q = Q * sign_d[None, :]
        R = R * sign_d[:, None]

        # Extract skew-symmetric matrix from Q using matrix logarithm
        # For orthogonal matrices: log(Q) is skew-symmetric
        A = self._log_orthogonal(Q)

        # Extract log of R diagonal and upper triangular part
        log_diag_r = torch.log(torch.diag(R) + 1e-12)
        U_strict = torch.triu(R, diagonal=1)

        # Create mask for upper triangular part (excluding diagonal)
        u_mask = torch.triu(torch.ones(size, size, dtype=torch.float32, device=device), diagonal=1)

        # Register buffers
        self.register_buffer("u_mask", u_mask.to(dtype=dtype))

        # Learnable parameters
        self.skew_params = nn.Parameter(A.to(device=device, dtype=dtype))
        self.log_s = nn.Parameter(log_diag_r.to(device=device, dtype=dtype))
        self.upper_r = nn.Parameter(U_strict.to(device=device, dtype=dtype))

        self._cached_weight = None
        self._cached_inv_weight = None

    def _log_orthogonal(self, Q: torch.Tensor) -> torch.Tensor:
        """
        Compute the matrix logarithm of an orthogonal matrix Q.
        Returns a skew-symmetric matrix A such that Q ≈ exp(A).
        """
        # Use the formula: log(Q) = (Q - Q^T) / 2 for small rotations
        # For better accuracy, use torch.linalg.matrix_exp inverse
        A = (Q - Q.T) / 2
        return A

    def _exp_skew_symmetric(self, A: torch.Tensor) -> torch.Tensor:
        """
        Compute the matrix exponential of a skew-symmetric matrix A.
        Returns an orthogonal matrix Q = exp(A).
        """
        # Ensure A is skew-symmetric
        A_skew = (A - A.T) / 2

        # Compute matrix exponential
        Q = torch.linalg.matrix_exp(A_skew)
        return Q

    def _compose_weight(self) -> torch.Tensor:
        """Compose W = Q @ R in float32 for stability, then cast to forward dtype."""
        device = self.skew_params.device

        # Convert to float32 for numerical stability
        A = self.skew_params.to(dtype=torch.float32)
        log_diag_r = self.log_s.to(dtype=torch.float32)
        upper_r = self.upper_r.to(dtype=torch.float32)
        u_mask = self.u_mask.to(dtype=torch.float32)

        # Compute Q from exponential map of skew-symmetric matrix
        Q = self._exp_skew_symmetric(A)

        # Compose R with positive diagonal
        diag_r = torch.exp(log_diag_r)
        R = upper_r * u_mask + torch.diag(diag_r)

        # Compute W = Q @ R
        W = Q @ R

        return W.to(dtype=self._forward_dtype)

    def forward(self) -> torch.Tensor:
        if self._cached_weight is None:
            self._cached_weight = self._compose_weight()
        return self._cached_weight

    def inv_t(self) -> torch.Tensor:
        if self._cached_inv_weight is None:
            W = self._cached_weight
            if W is None:
                W = self._compose_weight()
                self._cached_weight = W

            # For QR decomposition: W^{-1} = R^{-1} @ Q^T
            # Compute in float32 for stability
            W_f32 = W.to(dtype=torch.float32)
            W_inv = torch.linalg.inv(W_f32)
            self._cached_inv_weight = W_inv.to(dtype=self._forward_dtype)

        return self._cached_inv_weight.T

    def reset_cache(self):
        """Call this after parameter updates to invalidate cached values."""
        self._cached_weight = None
        self._cached_inv_weight = None

    def remove_parametrizations(self) -> None:
        pass


class LearnableKroneckerMatrix(BaseMatrix):
    """
    A learnable invertible matrix using LU decomposition.

    Decomposes the matrix as P @ L @ diag(sign_s * exp(log_s)) @ (I + U)
    where P is a permutation matrix (fixed), L is lower triangular with ones on diagonal,
    U is strictly upper triangular, and s are the diagonal scaling factors.
    """

    def __init__(
        self,
        size: int,
        init: str,
        device: torch.device,
        dtype: torch.dtype,
    ):
        super().__init__()
        # Save the forward dtype (bf16 / fp16 / fp32) used by the model
        self._forward_dtype = dtype
        n_left, n_right = self.get_decompose_dim(size)

        transform_matrix_left = init_matrix(n_left, init, device, torch.float32)
        transform_matrix_right =init_matrix(n_right, init, device, torch.float32)

        transform_matrix_left_f32 = transform_matrix_left.to(dtype=torch.float32)
        transform_matrix_right_f32 = transform_matrix_right.to(dtype=torch.float32)
        diag_scale = torch.ones(size).to(dtype=torch.float32)

        # Parameters in forward dtype (train_transform_matrix also enforces fp32)
        self.log_s = nn.Parameter(diag_scale.to(device=device, dtype=dtype))
        self.l = nn.Parameter(transform_matrix_left_f32.to(device=device, dtype=dtype))
        self.r = nn.Parameter(transform_matrix_right_f32.to(device=device, dtype=dtype))

        self._cached_weight = None
        self._cached_inv_weight = None


    def get_decompose_dim(self, n):
        a = int(math.sqrt(n))
        if a * a < n:
            a += 1
        while True:
            tmp = a * a - n
            b = int(math.sqrt(tmp))
            if b * b == tmp:
                break
            a += 1
        return a - b, a + b

    def _compose_weight(self) -> torch.Tensor:
        """Compose W in float32 for stability, then cast to forward dtype."""
        device = self.log_s.device

        # Promote to float32 for matrix algebra
        c = self.log_s.to(device=device, dtype=torch.float32)
        left_mat = self.l.to(device=device, dtype=torch.float32)
        right_mat = self.r.to(device=device, dtype=torch.float32)

        W = torch.kron(left_mat, right_mat) @ torch.diag(c)
        return W.to(dtype=self._forward_dtype)

    def forward(self) -> torch.Tensor:
        if self._cached_weight is None:
            self._cached_weight = self._compose_weight()
        return self._cached_weight

    def inv_t(self) -> torch.Tensor:
        if self._cached_inv_weight is None:
            W = self._cached_weight
            if W is None:
                W = self._compose_weight()
                self._cached_weight = W

            W_f32 = W.to(dtype=torch.float32)
            W_inv = torch.linalg.inv(W_f32)
            self._cached_inv_weight = W_inv.to(dtype=self._forward_dtype)

        return self._cached_inv_weight.T

    def reset_cache(self):
        self._cached_weight = None
        self._cached_inv_weight = None

    def remove_parametrizations(self) -> None:
        pass

def l2norm_along_axis1(X: torch.Tensor) -> torch.Tensor:
        return torch.norm(X, p=2, dim=1)

def sample_chi(d, rng=None, device='cpu'):
    """
    Samples from a Chi distribution with `d` degrees of freedom.
    
    Args:
        d (int): The degrees of freedom for the Chi distribution. Also determines the shape of the matrix.
        rng (np.random.RandomState or np.random.Generator, optional): 
            A NumPy random number generator for seeding. If None, uses PyTorch's default RNG.
        device (str or torch.device): The device on which to perform computation ('cpu', 'cuda' or 'xpu').

    Returns:
        torch.Tensor: A 1D tensor of length `d`, where each entry is a sample from Chi(d).
    """
    
    if rng is None:
        # Case 1: No external RNG provided → use PyTorch's default RNG
        # Generate a (d x d) matrix of standard normal samples
        normal_samples = torch.randn((d, d), device=device)
    else:
        # Case 2: A NumPy RNG is provided
        # Create a PyTorch generator and seed it using a random 32-bit integer from the NumPy RNG
        g = torch.Generator(device=device)
        g.manual_seed(rng.randint(0, 2**32 - 1))
        
        # Generate a (d x d) matrix of standard normal samples using the seeded generator
        normal_samples = torch.randn((d, d), generator=g, device=device)

    # Compute the L2 norm (Euclidean norm) of each row
    # This gives `d` samples from a Chi distribution with `d` degrees of freedom
    chi_samples = torch.norm(normal_samples, dim=1)

    return chi_samples


if __name__ == '__main__':
    qr_matrix = LearnableQRMatrix(size=128, init='orthogonal', device='cuda', dtype=torch.bfloat16)

    # Forward pass returns W = Q @ R
    W = qr_matrix()

    # Get inverse transpose
    W_inv_T = qr_matrix.inv_t()

    # After optimizer step, reset cache
    qr_matrix.reset_cache()