import torch
import os
import marlin
from ..core.quantize import Quantizer

class MarlinLinear(torch.nn.Module):
    def __init__(
        self, W: torch.Tensor, scales: torch.Tensor, u=None, bias=None, groupsize=-1, U=None, V=None
    ):
        super().__init__()

        m, n = W.shape
        device = W.device
        _linear = torch.nn.Linear(m, n)
        _linear.weight.data = W.half().t()

        effective_groupsize = m if (groupsize == -1) else groupsize

        _layer = marlin.Layer(m, n, groupsize=groupsize)
        _layer.k = m
        _layer.n = n
        _layer.groupsize = effective_groupsize
        _layer.B = torch.empty((m // 16, n * 16 // 8), dtype=torch.int, device=device)
        _layer.s = torch.empty(
            (m // effective_groupsize, n), dtype=torch.half, device=device
        )
        _layer.pack(_linear, scales.t())

        self.bias = bias.half() if (bias is not None) else None
        self.Wq_packed = _layer.B.clone()
        self.scales = _layer.s.clone()
        self.workspace_fp = torch.zeros(n // 128 * 16, device=device)
        self.in_features = m
        self.out_features = n
        self.group_size = effective_groupsize
        self.axis = 1
        self.device = device
        self.compute_dtype = torch.float16
        self.u = torch.nn.Parameter(u, requires_grad=False) if (u is not None) else None
        self.U = torch.nn.Parameter(U, requires_grad=False) if (U is not None) else None
        self.V = torch.nn.Parameter(V, requires_grad=False) if (V is not None) else None
        self.name = "MarlinLinear"

        del _linear, _layer
        torch.cuda.empty_cache()

    @torch.no_grad()
    def matmul(self, x):
        out = torch.empty(
            x.shape[:-1] + (self.scales.shape[1],), dtype=x.dtype, device=x.device
        )
        marlin.mul(
            x.to(self.device).view((-1, x.shape[-1])),
            self.Wq_packed,
            out.view((-1, out.shape[-1])),
            self.scales,
            self.workspace_fp,
        )
        return out

    @torch.jit.ignore
    def forward(self, x):
        print("MarlinLinear forward -> matmul")
        out = self.matmul(x)

        if self.u is not None and os.getenv("BTMOE_DISABLE_MARLIN_U", "0") != "1":
            out += torch.matmul(x.sum(axis=-1, keepdim=True), self.u)

        if self.U is not None and self.V is not None and os.getenv("BTMOE_DISABLE_UV", "0") != "1":
            print("[DEBUG] Adding UV compensation for", self.name)
            try:
                # x: [batch, seq, in_features], V: [in_features, rank], U: [rank, out_features]
                V_matrix = self.V
                U_matrix = self.U

                if V_matrix.shape[0] != x.shape[-1]:
                    V_matrix = V_matrix.t()
                    
                if U_matrix.shape[-1] != out.shape[-1]:
                    U_matrix = U_matrix.t()
                
                uv_compensation = (x @ V_matrix) @ U_matrix
                
                if uv_compensation.shape == out.shape:
                    out = out + uv_compensation
                else:
                    if os.getenv("BTMOE_DEBUG_SHAPES", "0") == "1":
                        print(f"[Marlin UV] shape mismatch after correction: out={tuple(out.shape)}, uv_comp={tuple(uv_compensation.shape)}; skipping")
                        
            except Exception as e:
                if os.getenv("BTMOE_DEBUG_SHAPES", "0") == "1":
                    print(f"[Marlin UV] compensation failed: {e}; skipping UV compensation")

        if self.bias is not None:
            out += self.bias

        return out


# Works with AXIS=1, group_size in {-1/None, 64, 128}
def patch_btmoe_to_marlin(layer, patch_params):
    if marlin is None:
        return layer

    z_shift = 8.0
    hqq_layer = layer.linear_layer if hasattr(layer, "linear_layer") else layer

    # Check config support: 4-bit, axis=1, and groupsize either per-tensor (-1/None) or 64/128
    axis = hqq_layer.meta.get("axis", 1)
    gs = hqq_layer.meta.get("group_size", None)
    nbits = hqq_layer.meta.get("nbits", None)
    if (nbits != 4) or (axis == 0) or (gs not in (None, -1, 64, 128)):
        print("Skipping marlin conversion for", hqq_layer.name)
        return layer
    # Minimal alignment/divisibility checks
    shape = hqq_layer.meta.get("shape", None)
    if shape is not None and len(shape) >= 2:
        n_out = int(shape[0])
        m_in = int(shape[1])
        # Marlin kernels expect k (m_in) multiple of 16; n multiple of 8
        if (m_in % 16) != 0 or (n_out % 8) != 0:
            print(f"Skipping marlin conversion for {hqq_layer.name}: m%16 or n%8 not satisfied (m={m_in}, n={n_out})")
            return layer
        if gs in (64, 128) and (m_in % int(gs)) != 0:
            print(f"Skipping marlin conversion for {hqq_layer.name}: m({m_in}) % gs({gs}) != 0")
            return layer

    W_r = Quantizer.unpack[hqq_layer.meta["packing"]](
        hqq_layer.W_q, dtype=hqq_layer.compute_dtype
    ).t()
    z = hqq_layer.meta["zero"]
    s = hqq_layer.meta["scale"].t()
    W_r = (W_r - z_shift) * s

    if type(z) in [torch.Tensor, torch.nn.Parameter]:
        z = z.t()
        u = (s * (-z + z_shift)).view([1, -1])
    else:
        u = None

    U = getattr(hqq_layer, 'U', None)
    V = getattr(hqq_layer, 'V', None)
    
    # Pass effective groupsize to Marlin so group-wise (64/128) is handled by the kernel
    groupsize = -1 if gs in (None, -1) else int(gs)
    
    try:
        marlin_layer = MarlinLinear(W_r, s, u=u, bias=hqq_layer.bias, groupsize=groupsize, U=U, V=V)
        # print(f"[DEBUG] Successfully created MarlinLinear for {hqq_layer.name} (UV={'Yes' if U is not None and V is not None else 'No'})")
    except Exception as e:
        print(f"[ERROR] MarlinLinear creation failed for {hqq_layer.name}: {e}")
        print(f"  W_r.shape={W_r.shape}, s.shape={s.shape}, groupsize={groupsize}")
        return layer

    if hasattr(layer, "linear_layer"):
        del layer.linear_layer.W_q
        del layer.linear_layer.meta
        del layer.linear_layer
        layer.linear_layer = marlin_layer
    else:
        del hqq_layer.W_q
        del hqq_layer.meta
        del hqq_layer
        layer = marlin_layer

    torch.cuda.empty_cache()

    return layer