#!/usr/bin/env python3
"""
Convert a LoRA+Sampler safetensors (HF-style keys) to your custom fused layout:

- Attention:   (q_proj, k_proj, v_proj)  -->  attention.wqkv  (LoRA fused)
- MLP:         (gate_proj, up_proj)      -->  feed_forward.w13 (LoRA fused)
- Attention o: (o_proj)                  -->  attention.wo     (LoRA passthrough)
- MLP down:    (down_proj)               -->  feed_forward.w2  (LoRA passthrough)
- Sampler:     sampler.layers.*          -->  passthrough (unchanged keys)

LoRA fusion math:
ΔW_qkv = block_diag(B_q, B_k, B_v) @ concat([A_q; A_k; A_v])
ΔW_13  = block_diag(B_gate, B_up) @ concat([A_gate; A_up])

IMPORTANT (permutation):
- Your base conversion permutes Q and K rows for RoPE layout.
- Apply the **same row permutation to LoRA-B** of Q and K.
- LoRA-A is column-side (input), so **do not permute LoRA-A**.

This script outputs ONLY the mapped LoRA + sampler tensors as a single .pth.
"""

from __future__ import annotations

import re

from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, Optional, Tuple

import torch
from safetensors.torch import load_file as load_safetensors_file

from Engine.models.base import ModelArgs


# ----------------------------- Parsing helpers -----------------------------

_LORA_ATT_PAT = re.compile(
    r"^model\.model\.layers\.(?P<lid>\d+)\.self_attn\.(?P<mod>q_proj|k_proj|v_proj|o_proj)\.lora_(?P<ab>A|B)\.weight$"
)
_LORA_MLP_PAT = re.compile(
    r"^model\.model\.layers\.(?P<lid>\d+)\.mlp\.(?P<mod>gate_proj|up_proj|down_proj)\.lora_(?P<ab>A|B)\.weight$"
)
_SAMPLER_PAT = re.compile(
    r"^sampler\.layers\.(?P<lid>\d+)\.(?P<what>linear|norm)\.weight$"
)

@dataclass
class LoRAAB:
    A: Optional[torch.Tensor] = None  # [r, in_features]
    B: Optional[torch.Tensor] = None  # [out_features, r]

    def ok(self) -> bool:
        return self.A is not None and self.B is not None

@dataclass
class LayerPack:
    q: LoRAAB = field(default_factory=LoRAAB)
    k: LoRAAB = field(default_factory=LoRAAB)
    v: LoRAAB = field(default_factory=LoRAAB)
    o: LoRAAB = field(default_factory=LoRAAB)
    gate: LoRAAB = field(default_factory=LoRAAB)
    up: LoRAAB = field(default_factory=LoRAAB)
    down: LoRAAB = field(default_factory=LoRAAB)


# ------------------------------- Permutations -------------------------------

def _permute_rows_for_rope(B: torch.Tensor, n_heads: int, head_dim: int) -> torch.Tensor:
    """
    Apply the same Q/K row permutation used for base weights to a LoRA-B matrix.

    Base permute for LLaMA-style Q/K rearranges the output dimension (rows) to interleave
    even/odd (or real/imag) sub-dimensions per head due to RoPE packing differences between
    implementations. We replicate the **row** permutation:

    view(out, r) -> [n_heads, 2, head_dim//2, r] --transpose(1,2)--> [n_heads, head_dim//2, 2, r]
                  -> reshape back to [out, r].

    References: weight conversion discussions around LLaMA + RoPE:
      - HF forum: why permute wq/wk for LLaMA (RoPE layout differences). 
      - Conversion scripts: permute on q/k before concat to WQKV.
    """
    out, r = B.shape
    assert out == n_heads * head_dim, f"Unexpected LoRA-B rows: {out} vs {n_heads}*{head_dim}"
    return (
        B.view(n_heads, 2, head_dim // 2, r)
         .transpose(1, 2)
         .reshape(out, r)
    )


# ------------------------------- Fusion utils --------------------------------

def _fuse_qkv_lora(
    q: LoRAAB, k: LoRAAB, v: LoRAAB,
    *,
    n_head_q: int, n_head_k: int, head_dim: int,
    permute_qk_rows: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Build fused (A_qkv, B_qkv) given per-proj LoRA (ΔW = B @ A).

    - A_qkv := stack([A_q; A_k; A_v])                 -> shape [r_q + r_k + r_v, in]
    - B_qkv := block_diag(B_q*, B_k*, B_v)            -> shape [out_q+out_k+out_v, r_q+r_k+r_v]
      where B_q*, B_k* are row-permuted to match base fused layout if permute_qk_rows=True.
    """
    assert q.ok() and k.ok() and v.ok(), "Q/K/V LoRA must all be present for fused WQKV."

    # A_fused: vertical stack
    A_qkv = torch.cat([q.A, k.A, v.A], dim=0)  # type: ignore

    # Optional row permutation on B for Q & K (to mirror base RoPE layout)
    Bq = _permute_rows_for_rope(q.B, n_head_q, head_dim) if permute_qk_rows else q.B  # type: ignore
    Bk = _permute_rows_for_rope(k.B, n_head_k, head_dim) if permute_qk_rows else k.B  # type: ignore
    Bv = v.B  # type: ignore

    out_q, rq = Bq.shape
    out_k, rk = Bk.shape
    out_v, rv = Bv.shape
    r_tot = rq + rk + rv

    # Block-diagonal B_fused
    B_qkv = torch.zeros((out_q + out_k + out_v, r_tot), dtype=Bq.dtype)
    ro = 0
    co = 0
    B_qkv[ro:ro + out_q, co:co + rq] = Bq; ro += out_q; co += rq
    B_qkv[ro:ro + out_k, co:co + rk] = Bk; ro += out_k; co += rk
    B_qkv[ro:ro + out_v, co:co + rv] = Bv

    return A_qkv, B_qkv


def _fuse_w13_lora(gate: LoRAAB, up: LoRAAB) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Build fused (A_13, B_13) for gate+up → W13.

    - A_13 := stack([A_gate; A_up])                   -> [r_g + r_u, in]
    - B_13 := block_diag(B_gate, B_up)                -> [out_gate+out_up, r_g+r_u]
    """
    assert gate.ok() and up.ok(), "gate/up LoRA must both be present for fused W13."

    A_13 = torch.cat([gate.A, up.A], dim=0)  # type: ignore
    Bg, Bu = gate.B, up.B  # type: ignore

    out_g, rg = Bg.shape
    out_u, ru = Bu.shape
    r_tot = rg + ru

    B_13 = torch.zeros((out_g + out_u, r_tot), dtype=Bg.dtype)
    B_13[0:out_g, 0:rg] = Bg
    B_13[out_g:out_g + out_u, rg:rg + ru] = Bu
    return A_13, B_13

def cleanup_original_files(checkpoint_dir: Path) -> None:
    """Remove original .bin and .safetensors files after conversion."""
    print("Cleaning up original model files...")

    # Remove .bin files
    bin_files = list(checkpoint_dir.glob("*.bin"))
    for bin_file in bin_files:
        print(f"Removing {bin_file}")
        bin_file.unlink()

    # Remove .safetensors files
    safetensors_files = list(checkpoint_dir.glob("*.safetensors"))
    for safetensors_file in safetensors_files:
        print(f"Removing {safetensors_file}")
        safetensors_file.unlink()

    # Remove index files
    index_files = list(checkpoint_dir.glob("*.index.json"))
    for index_file in index_files:
        if "model" in index_file.name:
            print(f"Removing {index_file}")
            index_file.unlink()

    print("Cleanup completed!")

# ------------------------------- Main convert -------------------------------

@torch.inference_mode()
def convert_lora_and_sampler_fused(
    *,
    lora_path: Path,                 # path to LoRA+sampler model.safetensors
    out_path: Path,                  # output .pth containing ONLY mapped LoRA+sampler
    model_name: Optional[str] = None,
    n_head: Optional[int] = None,    # if None, inferred from B_q rows / head_dim
    n_kv_head: Optional[int] = None, # if None, inferred from B_k rows / head_dim
    head_dim: Optional[int] = None,  # if None, inferred from B_q rows / n_head
    permute_qk_rows: bool = True,    # apply RoPE row-permutation to LoRA-B for Q/K
    cleanup: bool = True,
) -> None:
    """
    Map HF LoRA + sampler keys into fused custom layout and save as .pth.

    Output keys:
      - layers.{L}.attention.wqkv.lora_{A,B}.weight
      - layers.{L}.attention.wo.lora_{A,B}.weight
      - layers.{L}.feed_forward.w13.lora_{A,B}.weight
      - layers.{L}.feed_forward.w2.lora_{A,B}.weight
      - sampler.layers.{L}.{linear|norm}.weight  (passthrough)

    Notes:
      * If base model uses fused WQKV and W13, the LoRA ranks of fused modules become
        r_qkv = r_q + r_k + r_v, r_13 = r_gate + r_up.
      * Q/K LoRA-B rows are permuted the same way as base Q/K weights (RoPE layout).
        LoRA-A is NOT permuted.
    """
    src: Dict[str, torch.Tensor] = load_safetensors_file(str(lora_path), device="cpu")

    # Bucket by layer id
    per_layer: Dict[str, LayerPack] = {}
    sampler_out: Dict[str, torch.Tensor] = {}

    for k, t in src.items():
        m = _LORA_ATT_PAT.match(k)
        if m:
            lid, mod, ab = m.group("lid"), m.group("mod"), m.group("ab")
            pack = per_layer.setdefault(lid, LayerPack())
            dst = getattr(pack, {"q_proj": "q", "k_proj": "k", "v_proj": "v", "o_proj": "o"}[mod])
            if ab == "A":
                dst.A = t
            else:
                dst.B = t
            continue

        m = _LORA_MLP_PAT.match(k)
        if m:
            lid, mod, ab = m.group("lid"), m.group("mod"), m.group("ab")
            pack = per_layer.setdefault(lid, LayerPack())
            mapmlp = {"gate_proj": "gate", "up_proj": "up", "down_proj": "down"}
            dst = getattr(pack, mapmlp[mod])
            if ab == "A":
                dst.A = t
            else:
                dst.B = t
            continue

        m = _SAMPLER_PAT.match(k)
        if m:
            # passthrough for sampler.* keys (already good for runtime)
            sampler_out[k] = t
            continue

    # Build output dict
    out: Dict[str, torch.Tensor] = {}

    # Optional: shape hints from ModelArgs (if available)
    if model_name is not None:
        cfg = ModelArgs.from_name(model_name)
        print(f"[INFO] Inferring shapes from ModelArgs for {model_name}")
        n_head = n_head or cfg.n_head
        n_kv_head = n_kv_head or cfg.n_local_heads
        head_dim = head_dim or cfg.head_dim

    # Process every layer
    for lid, P in per_layer.items():
        # --- Attention: fused WQKV ---
        if P.q.ok() and P.k.ok() and P.v.ok():
            # infer shapes if not given
            if head_dim is None or n_head is None or n_kv_head is None:
                # out sizes from B matrices
                out_q = P.q.B.shape[0]  # type: ignore
                out_k = P.k.B.shape[0]  # type: ignore
                # simplest guess: use gcd to get head_dim; but we know k/v share kv heads
                # Prefer head_dim from K (often smaller when GQA): head_dim = out_k / n_kv_head
                # If n_kv_head unknown, assume it's divisor of out_k and equals out_q / head_dim ratio.
                # Here, we try to infer head_dim from common divisors.
                # For robustness, derive head_dim from q.A/B alignment is overkill; keep simple:
                raise ValueError(
                    "Please provide n_head, n_kv_head, head_dim or initialize ModelArgs to infer them."
                )

            A_qkv, B_qkv = _fuse_qkv_lora(
                P.q, P.k, P.v,
                n_head_q=n_head, n_head_k=n_kv_head, head_dim=head_dim,
                permute_qk_rows=permute_qk_rows,
            )
            out[f"layers.{lid}.attention.wqkv.lora_A.weight"] = A_qkv
            out[f"layers.{lid}.attention.wqkv.lora_B.weight"] = B_qkv

        else:
            # If one of q/k/v is missing, you may choose to error instead.
            missing = [name for name, ab in (("q", P.q), ("k", P.k), ("v", P.v)) if not ab.ok()]
            raise KeyError(f"[layer {lid}] missing LoRA for: {missing}; cannot build fused WQKV.")

        # --- Attention: wo passthrough ---
        if P.o.ok():
            out[f"layers.{lid}.attention.wo.lora_A.weight"] = P.o.A  # type: ignore
            out[f"layers.{lid}.attention.wo.lora_B.weight"] = P.o.B  # type: ignore

        # --- MLP: fused W13 (gate + up) ---
        if P.gate.ok() and P.up.ok():
            A_13, B_13 = _fuse_w13_lora(P.gate, P.up)
            out[f"layers.{lid}.feed_forward.w13.lora_A.weight"] = A_13
            out[f"layers.{lid}.feed_forward.w13.lora_B.weight"] = B_13
        else:
            missing = [name for name, ab in (("gate", P.gate), ("up", P.up)) if not ab.ok()]
            raise KeyError(f"[layer {lid}] missing LoRA for: {missing}; cannot build fused W13.")

        # --- MLP: down passthrough ---
        if P.down.ok():
            out[f"layers.{lid}.feed_forward.w2.lora_A.weight"] = P.down.A  # type: ignore
            out[f"layers.{lid}.feed_forward.w2.lora_B.weight"] = P.down.B  # type: ignore

    # Sampler passthrough
    out.update(sampler_out)

    # Save
    out_path.parent.mkdir(parents=True, exist_ok=True)
    torch.save(out, out_path)
    print(f"[SAVE] fused LoRA+sampler -> {out_path}")

    if cleanup:
        cleanup_original_files(lora_path)

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--lora_ckpt_path", type=Path, required=True)
    parser.add_argument("--out_dir", type=Path, required=True)
    parser.add_argument("--model_name", type=str, required=True)
    parser.add_argument("--permute_qk_rows", type=bool, default=True)
    parser.add_argument("--no_cleanup", action="store_true")
    args = parser.parse_args()

    convert_lora_and_sampler_fused(
        lora_path=args.lora_ckpt_path,
        out_path=args.out_dir / "model.pth",
        model_name=args.model_name,
        permute_qk_rows=args.permute_qk_rows,
        cleanup=not args.no_cleanup,
    )