import math
from abc import abstractmethod

import torch
import torch.nn as nn

from quantization.transforms.hadamard import matmul_hadUt, matmul_hadU

from .matrix import (
    GeneralMatrix,
    OrthogonalMatrix,
    LearnableInvMatrix,
    LearnableQRMatrix,
    LearnableKroneckerMatrix,
    init_matrix
)
from quantization.utils.common_utils import filter_kwarg_dict

from fast_hadamard_transform import hadamard_transform


MATRIX_PARAMETRIZATIONS = {
    "general": GeneralMatrix,
    "orthogonal": OrthogonalMatrix,
    "learnable_inv": LearnableInvMatrix,
    "learnable_qr": LearnableQRMatrix,
    "learnable_kronecker": LearnableKroneckerMatrix,
}


class BaseTransform(nn.Module):

    def __init__(self, *args, **kwargs):
        super().__init__()

    @abstractmethod
    def forward(self, x: torch.Tensor, inv_t: bool = False, dim: int = -1):
        pass

    @abstractmethod
    def remove_parametrizations(self) -> None:
        pass


class IdentityTransform(BaseTransform):

    def __init__(self, *args, **kwargs):
        super().__init__()

    def forward(self, x: torch.Tensor, inv_t: bool = False, dim: int = -1):
        return x
    
    def remove_parametrizations(self) -> None:
        pass


class LearnedTransform(BaseTransform):
    """
    A learnable invertible transformation matrix.

    This class represents a full tensor-wise transformation where the transformation
    matrix is learnable (i.e., its parameters have requires_grad=True and can be
    updated during training).

    Args:
        block_size: Size of each block matrix
        num_blocks: Number of transformation blocks. Total output size will be num_blocks * block_size.
        init: Initialization strategy for the matrix ('orthogonal', 'identity', 'hadamard', 'xavier_normal')
        parametrization: Type of matrix parametrization to use ('general', 'orthogonal', 'svd', 'learnable_inv')
        device: Device to place the matrix on
        dtype: Data type of the matrix
    """

    def __init__(
        self,
        block_size: int,
        num_blocks: int,
        init: str = "orthogonal",
        parametrization: str = "general",
        block_diag_init = False,
        divide_num_blocks = 1,
        divide_block_size = 32,
        device: torch.device = None,
        dtype: torch.dtype = None,
        add_rand_noise: bool = False,
    ):
        super().__init__()

        self.block_size = block_size
        self.num_blocks = num_blocks
        self.parametrization = parametrization

        if num_blocks > 1:
            # Multi-block mode
            # Create multiple matrix blocks
            self.blocks = nn.ModuleList([
                MATRIX_PARAMETRIZATIONS[parametrization](block_size, init, device, dtype)
                for _ in range(num_blocks)
            ])
            self.matrix = None
        else:
            if block_diag_init:
                # Single block mode
                block_matrices = [
                    init_matrix(divide_block_size, init, device, dtype=torch.float32)
                    for _ in range(divide_num_blocks)
                ]
                with torch.no_grad():
                    # block_matrices = [block() for block in blocks]
                    # Create block diagonal matrix
                    if add_rand_noise:
                        mat = torch.block_diag(*block_matrices)
                        # mask for block entries (True on the blocks, False off-block)
                        mask = torch.block_diag(*[torch.ones_like(b, dtype=torch.bool) for b in block_matrices])
                        sigma = 1e-3  # << small compared to typical block values
                        noise = torch.randn_like(mat) * sigma
                        mat_init = mat + noise.masked_fill(mask, 0)  # only off-block gets Gaussian noise
                    else:
                        mat_init = torch.block_diag(*block_matrices)
                self.matrix = MATRIX_PARAMETRIZATIONS[parametrization](block_size, init, device, dtype, init_matrix_tensor=mat_init)
            else:
                self.matrix = MATRIX_PARAMETRIZATIONS[parametrization](block_size, init, device, dtype)
                self.blocks = None


    def _get_block_diagonal_matrix(self, inv_t: bool = False) -> torch.Tensor:
        """Construct block diagonal matrix from multiple blocks."""
        if self.num_blocks == 1:
            return self.matrix() if not inv_t else self.matrix.inv_t()

        # Get all block matrices
        block_matrices = []
        for block in self.blocks:
            block_mat = block() if not inv_t else block.inv_t()
            block_matrices.append(block_mat)

        # Create block diagonal matrix
        return torch.block_diag(*block_matrices)

    def forward(self, x: torch.Tensor, inv_t: bool = False, dim: int = -1):
        """Apply the learned transformation."""
        t = self._get_block_diagonal_matrix(inv_t)
        out = torch.tensordot(x, t, dims=((dim,), (0,)))
        return out

    def _set_requires_grad(self, requires_grad: bool = True) -> None:
        """Set requires_grad for all parameters in the transformation matrix."""
        if self.num_blocks == 1:
            for param in self.matrix.parameters():
                param.requires_grad_(requires_grad)
        else:
            for block in self.blocks:
                for param in block.parameters():
                    param.requires_grad_(requires_grad)

    def freeze(self) -> None:
        """Freeze the transformation (disable gradient computation)."""
        self._set_requires_grad(False)

    def unfreeze(self) -> None:
        """Unfreeze the transformation (enable gradient computation)."""
        self._set_requires_grad(True)

    def reset_cache(self) -> None:
        """Clear cached matrices for all blocks. Call at the start of each forward pass."""
        if self.num_blocks == 1:
            if hasattr(self.matrix, 'reset_cache'):
                self.matrix.reset_cache()
        else:
            for block in self.blocks:
                if hasattr(block, 'reset_cache'):
                    block.reset_cache()

    def construct_weight(self) -> None:
        """Pre-compose all matrices. Call at the start of training iteration."""
        if self.num_blocks == 1:
            if hasattr(self.matrix, 'forward'):
                _ = self.matrix()
        else:
            for block in self.blocks:
                if hasattr(block, 'forward'):
                    _ = block()

    def remove_parametrizations(self) -> None:
        """Remove parametrizations from all matrices."""
        if self.num_blocks == 1:
            self.matrix.remove_parametrizations()
        else:
            for block in self.blocks:
                block.remove_parametrizations()


class LearnedAffineTransform(LearnedTransform):
    """
    A learnable invertible affine transformation with bias.
    """

    def __init__(
        self,
        block_size: int,
        num_blocks: int,
        init: str = "orthogonal",
        parametrization: str = "general",
        block_diag_init = False,
        divide_num_blocks = 1,
        divide_block_size = 32,
        device: torch.device = None,
        dtype: torch.dtype = None,
        add_rand_noise: bool = False,
    ):
        super().__init__(
            block_size=block_size,
            num_blocks=num_blocks,
            init=init,
            parametrization=parametrization,
            block_diag_init=block_diag_init,
            divide_num_blocks=divide_num_blocks,
            divide_block_size=divide_block_size,
            device=device,
            dtype=dtype,
            add_rand_noise=add_rand_noise,
        )

        self.bias = nn.Parameter(torch.zeros(block_size * num_blocks, dtype=dtype, device=device))
        self._forward_dtype = dtype

    def _set_requires_grad(self, requires_grad: bool = True) -> None:
        """Set requires_grad for all parameters including bias."""
        super()._set_requires_grad(requires_grad)
        self.bias.requires_grad_(requires_grad)

    def freeze(self) -> None:
        """Freeze the transformation (disable gradient computation)."""
        self._set_requires_grad(False)

    def unfreeze(self) -> None:
        """Unfreeze the transformation (enable gradient computation)."""
        self._set_requires_grad(True)


class HadamardTransform(BaseTransform):

    def __init__(self, group_size: int = 128):
        super().__init__()
        self.group_size = group_size
        self.scale = 1 / math.sqrt(self.group_size)

    def forward(self, x: torch.Tensor, inv_t: bool = False, dim: int = -1):
        # Hadamard transform is it own inverse
        x_shape = x.shape
        return hadamard_transform(x.view(-1, self.group_size), scale=self.scale).view(x_shape)

    
    def remove_parametrizations(self) -> None:
        pass


class FullHadamardTransform(BaseTransform):
    """
    Orthonormal Hadamard transform along a chosen dimension.

    Uses the reference implementation from `transform_optimization.hadamard`:
      - matmul_hadU   : applies U
      - matmul_hadUt  : applies U^T

    Notes:
      - Operates on the specified `dim` (default: -1).
      - Supports special non-power-of-2 cases via get_hadK(n), matching matmul_hadU.
      - Self-inverse up to transpose (U is orthonormal).
    """

    def __init__(self, *args, **kwargs):
        super().__init__()

    def forward(self, x: torch.Tensor, inv_t: bool = False, dim: int = -1):
        if x.ndim == 0:
            raise ValueError("FullHadamardTransform expects at least 1D tensor")

        # Normalize dim
        if dim < 0:
            dim = x.ndim + dim
        if not (0 <= dim < x.ndim):
            raise ValueError(f"Invalid dim={dim} for x.ndim={x.ndim}")

        # Move target dim to the end (matmul_hadU/Ut operate on last dim)
        if dim != x.ndim - 1:
            x_t = x.movedim(dim, -1).contiguous()
            y_t = matmul_hadUt(x_t) if inv_t else matmul_hadU(x_t)
            return y_t.movedim(-1, dim)

        x = x.contiguous()
        return matmul_hadUt(x) if inv_t else matmul_hadU(x)

    def remove_parametrizations(self) -> None:
        pass



TRANSFORMS = {
    "identity": IdentityTransform,
    "learned": LearnedTransform,
    "learned_affine": LearnedAffineTransform,
    "hadamard": HadamardTransform,
    "full_hadamard": FullHadamardTransform
}


def build_transform(transform_class: str, **transform_kwargs) -> BaseTransform:
    transform = TRANSFORMS[transform_class]
    return transform(**filter_kwarg_dict(transform.__init__, transform_kwargs))

def get_transform_matrix(
    transform_class: str, 
    size: int, 
    device: torch.device = None, 
    dtype: torch.dtype = None,
) -> torch.Tensor:
    if transform_class == "hadamard":
        from fast_hadamard_transform import hadamard_transform
        return hadamard_transform(torch.eye(size, device=device, dtype=dtype), scale=1 / math.sqrt(size))
    elif transform_class == "identity":
        return torch.eye(size, device=device, dtype=dtype)
    else:
        raise NotImplementedError(f"get_transform_matrix is implemented only for Hadamard and Identity transforms")
