from __future__ import annotations

import math
from dataclasses import dataclass
from enum import Enum
from typing import Optional

import torch

# Try to import fast hadamard transform for CUDA
try:
    from fast_hadamard_transform import (
        hadamard_transform,
        hadamard_transform_12N,
        hadamard_transform_20N,
        hadamard_transform_28N,
        hadamard_transform_40N,
    )
    HAS_FAST_HADAMARD = True
except ImportError:
    HAS_FAST_HADAMARD = False


class HadamardType(Enum):
    """Type of Hadamard transform to apply."""
    NONE = "none"
    ROW = "row"           # W -> W @ Had, Sigma -> Had^T @ Sigma @ Had
    COLUMN = "column"     # W -> Had @ W
    ROW_COLUMN = "row_column"  # Both: W -> Had @ W @ Had


def _is_power_of_2(n: int) -> bool:
    return n > 0 and (n & (n - 1)) == 0


def _factorize_for_hadamard(n: int) -> tuple:
    """
    Factorize n to find the best Hadamard decomposition.

    Returns (factor, power_of_2) where:
    - If n is power of 2: (1, n)
    - If n = k * 2^m where k in {12, 20, 28, 40}: (k, 2^m)
    - Otherwise: Uses Kronecker product decomposition

    Common LLM dimensions:
    - 4096 = 2^12 -> (1, 4096)
    - 8192 = 2^13 -> (1, 8192)
    - 5120 = 20 * 256 -> (20, 256)
    - 14336 = 7 * 2048 -> needs Kronecker: (7, 2048)
    - 28672 = 7 * 4096 -> needs Kronecker: (7, 4096)
    - 11008 = 43 * 256 -> needs Kronecker: (43, 256)
    """
    if _is_power_of_2(n):
        return (1, n)

    # Check special factors supported by fast_hadamard_transform
    for factor in [12, 20, 28, 40]:
        if n % factor == 0:
            rem = n // factor
            if _is_power_of_2(rem):
                return (factor, rem)

    # Find a factorization n = k * 2^m where k is small
    # Try to find largest power of 2 that divides n
    m = 0
    temp = n
    while temp % 2 == 0:
        temp //= 2
        m += 1

    power_of_2 = 1 << m
    factor = n // power_of_2
    return (factor, power_of_2)


def _hadamard_cpu_2d(x: torch.Tensor) -> torch.Tensor:
    """Apply Hadamard transform to last dimension using CPU (for fallback)."""
    m, n = x.shape
    out = x.clone()
    h = 1
    while h < n:
        v = out.view(m, -1, 2 * h)
        u = v[:, :, :h].clone()
        w = v[:, :, h:2 * h]
        v[:, :, :h] = u + w
        v[:, :, h:2 * h] = u - w
        h *= 2
    return out


def hadamard_transform_adaptive(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
    """
    Apply Hadamard transform to the last dimension with automatic CPU/GPU handling.

    Uses fast_hadamard_transform CUDA kernels when available, falls back to CPU.
    Handles non-power-of-2 dimensions using Kronecker product decomposition.

    Args:
        x: Input tensor (..., dim)
        scale: Scale factor to multiply output

    Returns:
        Transformed tensor (..., dim)
    """
    dim = x.shape[-1]
    factor, power_of_2 = _factorize_for_hadamard(dim)

    if x.is_cuda and HAS_FAST_HADAMARD:
        # Use fast CUDA kernels
        if factor == 1:
            return hadamard_transform(x, scale=scale)
        elif factor == 12:
            return hadamard_transform_12N(x, scale=scale)
        elif factor == 20:
            return hadamard_transform_20N(x, scale=scale)
        elif factor == 28:
            return hadamard_transform_28N(x, scale=scale)
        elif factor == 40:
            return hadamard_transform_40N(x, scale=scale)
        else:
            # Kronecker product: reshape to (..., factor, power_of_2)
            # Apply Hadamard to both dimensions
            orig_shape = x.shape
            x_reshaped = x.view(-1, factor, power_of_2)

            # First apply to power_of_2 dimension
            out = hadamard_transform(x_reshaped, scale=1.0)

            # Then apply to factor dimension using CPU Hadamard
            out = out.transpose(-1, -2).contiguous()  # (..., power_of_2, factor)
            out_flat = out.view(-1, factor)
            out_flat = _hadamard_cpu_2d(out_flat)
            out = out_flat.view(-1, power_of_2, factor)
            out = out.transpose(-1, -2).contiguous()  # (..., factor, power_of_2)

            # Apply combined scale
            combined_scale = scale / math.sqrt(factor)
            out = out * combined_scale
            return out.view(orig_shape)
    else:
        # CPU fallback
        orig_shape = x.shape
        x_flat = x.view(-1, dim)

        if factor == 1:
            # _hadamard_cpu_2d applies unnormalized Hadamard (±1 entries)
            # scale already contains 1/sqrt(dim) for normalization
            out = _hadamard_cpu_2d(x_flat)
            out = out * scale
        else:
            # Kronecker product on CPU
            x_reshaped = x_flat.view(-1, factor, power_of_2)

            # Apply to power_of_2 dimension
            out = _hadamard_cpu_2d(x_reshaped.view(-1, power_of_2))
            out = out.view(-1, factor, power_of_2)

            # Apply to factor dimension
            out = out.transpose(-1, -2).contiguous()
            out = _hadamard_cpu_2d(out.view(-1, factor))
            out = out.view(-1, power_of_2, factor)
            out = out.transpose(-1, -2).contiguous()

            # Scale: scale already contains the desired normalization
            out = out * scale
            out = out.view(-1, dim)

        return out.view(orig_shape)


def _pow2_blocks(n: int):
    """Decompose n into sum of powers of 2 for block-diagonal Hadamard."""
    blocks = []
    r = int(n)
    while r > 0:
        b = 1 << (r.bit_length() - 1)
        blocks.append(b)
        r -= b
    return blocks


class BlockRandomHadamard:
    """
    Block-diagonal random Hadamard transform.

    For non-power-of-2 dimensions, splits into blocks and applies
    Hadamard to each block independently with random sign flips.

    Methods:
        right(x): Apply W @ Had (row Hadamard)
        left(x): Apply Had @ W (column Hadamard)
        bilinear(h): Apply Had^T @ H @ Had (for Hessian/covariance)
    """

    def __init__(self, n: int, seed: int = 0, device=None, dtype=None):
        self.n = int(n)
        self.blocks = _pow2_blocks(self.n)

        # Generate random signs for each position
        g = torch.Generator(device="cpu")
        g.manual_seed(int(seed))
        s = torch.randint(0, 2, (self.n,), generator=g, dtype=torch.int8)
        s = s.mul(2).sub(1)  # Convert 0,1 to -1,+1

        if device is None:
            device = "cpu"
        if dtype is None:
            dtype = torch.float32
        self.signs = s.to(device=device, dtype=dtype)

    def _apply_block_hadamard(self, x: torch.Tensor, transpose: bool = False) -> torch.Tensor:
        """Apply block-diagonal Hadamard to last dimension."""
        device = x.device
        dtype = x.dtype
        s = self.signs.to(device=device, dtype=dtype)

        orig_shape = x.shape
        y = x.view(-1, self.n).clone()

        off = 0
        for b in self.blocks:
            blk = y[:, off:off + b]
            inv = 1.0 / math.sqrt(float(b))

            if not transpose:
                # Forward: multiply by signs, then Hadamard, then scale
                blk_signed = blk * s[off:off + b]
                blk_transformed = hadamard_transform_adaptive(blk_signed, scale=inv)
                y[:, off:off + b] = blk_transformed
            else:
                # Transpose: Hadamard, then scale, then multiply by signs
                blk_transformed = hadamard_transform_adaptive(blk, scale=inv)
                y[:, off:off + b] = blk_transformed * s[off:off + b]

            off += b

        return y.view(orig_shape)

    def right_(self, x: torch.Tensor, transpose: bool = False):
        """In-place apply Hadamard to last dimension: x @ H (or x @ H^T)."""
        if x.shape[-1] != self.n:
            raise ValueError(f"right_ expects last dim {self.n}, got {x.shape[-1]}")

        result = self._apply_block_hadamard(x, transpose=transpose)
        x.copy_(result)

    def right(self, x: torch.Tensor, transpose: bool = False) -> torch.Tensor:
        """Apply Hadamard to last dimension: x @ H (or x @ H^T)."""
        if x.shape[-1] != self.n:
            raise ValueError(f"right expects last dim {self.n}, got {x.shape[-1]}")

        return self._apply_block_hadamard(x, transpose=transpose)

    def left(self, x: torch.Tensor, transpose: bool = False) -> torch.Tensor:
        """
        Apply Hadamard to first dimension: H @ x (or H^T @ x).

        For matrix x of shape (n, m), applies Hadamard transform to rows.
        """
        if x.shape[0] != self.n:
            raise ValueError(f"left expects first dim {self.n}, got {x.shape[0]}")

        # H @ x = (x^T @ H^T)^T
        # For forward (transpose=False): H @ x = (x^T @ H^T)^T
        # For transpose (transpose=True): H^T @ x = (x^T @ H)^T
        xt = x.T.contiguous()
        if not transpose:
            result = self._apply_block_hadamard(xt, transpose=True)
        else:
            result = self._apply_block_hadamard(xt, transpose=False)
        return result.T.contiguous()

    def bilinear(self, h: torch.Tensor) -> torch.Tensor:
        """
        Apply bilinear transform: H^T @ h @ H

        For Hessian/covariance matrix h of shape (n, n).
        """
        if h.shape[-1] != self.n or h.shape[-2] != self.n:
            raise ValueError(f"bilinear expects ({self.n},{self.n}), got {tuple(h.shape[-2:])}")

        # H^T @ h @ H
        # First apply h @ H (right multiply)
        out = self.right(h, transpose=False)
        # Then apply H^T @ (h @ H) (left multiply with transpose)
        out = self.left(out, transpose=True)
        return out


@dataclass
class HadamardConfig:
    """Configuration for Hadamard transforms in quantization."""
    enabled: bool = False
    hadamard_type: HadamardType = HadamardType.ROW
    seed: int = 0

    @staticmethod
    def from_args(hadamard: bool, hadamard_type: str = "row", hadamard_seed: int = 0) -> "HadamardConfig":
        """Create config from CLI arguments."""
        if not hadamard:
            return HadamardConfig(enabled=False)

        type_map = {
            "none": HadamardType.NONE,
            "row": HadamardType.ROW,
            "column": HadamardType.COLUMN,
            "row_column": HadamardType.ROW_COLUMN,
        }
        htype = type_map.get(hadamard_type.lower(), HadamardType.ROW)
        return HadamardConfig(enabled=True, hadamard_type=htype, seed=hadamard_seed)


def apply_hadamard_to_weight(
    W: torch.Tensor,
    hadamard_type: HadamardType,
    had_row: Optional[BlockRandomHadamard] = None,
    had_col: Optional[BlockRandomHadamard] = None,
) -> torch.Tensor:
    """
    Apply Hadamard transform to weight matrix.

    Args:
        W: Weight matrix (out_features, in_features)
        hadamard_type: Type of Hadamard transform
        had_row: Hadamard transform for row dimension (out_features)
        had_col: Hadamard transform for column dimension (in_features)

    Returns:
        Transformed weight matrix

    Transform types:
        - ROW: W @ H_col (multiply columns by Hadamard)
        - COLUMN: H_row @ W (multiply rows by Hadamard)
        - ROW_COLUMN: H_row @ W @ H_col (both)
    """
    if hadamard_type == HadamardType.NONE:
        return W

    out_dim, in_dim = W.shape

    if hadamard_type in (HadamardType.ROW, HadamardType.ROW_COLUMN):
        if had_col is None:
            had_col = BlockRandomHadamard(in_dim, seed=0, device=W.device, dtype=W.dtype)
        W = had_col.right(W)  # W @ H

    if hadamard_type in (HadamardType.COLUMN, HadamardType.ROW_COLUMN):
        if had_row is None:
            had_row = BlockRandomHadamard(out_dim, seed=0, device=W.device, dtype=W.dtype)
        W = had_row.left(W)  # H @ W

    return W


def apply_hadamard_to_hessian(
    H: torch.Tensor,
    hadamard_type: HadamardType,
    had: Optional[BlockRandomHadamard] = None,
) -> torch.Tensor:
    """
    Apply Hadamard transform to Hessian/covariance matrix.

    For ROW Hadamard on weights (W @ H), the Hessian transforms as:
        Sigma_new = H^T @ Sigma @ H

    Args:
        H: Hessian/covariance matrix (n, n)
        hadamard_type: Type of Hadamard transform
        had: Hadamard transform object

    Returns:
        Transformed Hessian
    """
    if hadamard_type in (HadamardType.NONE, HadamardType.COLUMN):
        # Column Hadamard doesn't affect the input covariance
        return H

    n = H.shape[0]
    if had is None:
        had = BlockRandomHadamard(n, seed=0, device=H.device, dtype=H.dtype)

    # H^T @ Sigma @ H
    return had.bilinear(H)


def inverse_hadamard_weight(
    W_hat: torch.Tensor,
    hadamard_type: HadamardType,
    had_row: Optional[BlockRandomHadamard] = None,
    had_col: Optional[BlockRandomHadamard] = None,
) -> torch.Tensor:
    """
    Apply inverse Hadamard transform to reconstruct original weight space.

    Args:
        W_hat: Transformed weight matrix
        hadamard_type: Type of Hadamard transform that was applied
        had_row: Hadamard transform for row dimension
        had_col: Hadamard transform for column dimension

    Returns:
        Weight matrix in original space

    Inverse transforms:
        - ROW: W_hat @ H^T (H is orthogonal so H^T = H^-1)
        - COLUMN: H^T @ W_hat
        - ROW_COLUMN: H_row^T @ W_hat @ H_col^T
    """
    if hadamard_type == HadamardType.NONE:
        return W_hat

    out_dim, in_dim = W_hat.shape

    if hadamard_type in (HadamardType.COLUMN, HadamardType.ROW_COLUMN):
        if had_row is None:
            had_row = BlockRandomHadamard(out_dim, seed=0, device=W_hat.device, dtype=W_hat.dtype)
        W_hat = had_row.left(W_hat, transpose=True)  # H^T @ W

    if hadamard_type in (HadamardType.ROW, HadamardType.ROW_COLUMN):
        if had_col is None:
            had_col = BlockRandomHadamard(in_dim, seed=0, device=W_hat.device, dtype=W_hat.dtype)
        W_hat = had_col.right(W_hat, transpose=True)  # W @ H^T

    return W_hat
