import os
import math
import argparse
from pathlib import Path
from datetime import datetime
from collections import OrderedDict
from typing import Optional, List, Tuple, Dict
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from torch.utils.data import DataLoader

import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer

import faiss
import hashlib

from data_utils import get_weight_quant_data
from eval_utils import evaluate

from quant import Quantizer
from gptq import GPTQ


def _hash_ids(x: torch.Tensor) -> str:
    return hashlib.md5(x.detach().to("cpu", dtype=torch.int32).numpy().tobytes()).hexdigest()


class RectifiedSigmoid(nn.Module):
    def __init__(self, gamma=-0.1, zeta=1.1):
        super().__init__()
        self.gamma = gamma
        self.zeta = zeta

    def forward(self, x):
        return torch.clamp(torch.sigmoid(x) * (self.zeta - self.gamma) + self.gamma, 0, 1)

    def inverse(self, y):
        return -torch.log((self.zeta - self.gamma) / (y - self.gamma) - 1)


def per_row_affine_params(w: torch.Tensor, n_bits: int) -> Tuple[torch.Tensor, torch.Tensor]:
    out, in_dim = w.shape
    wmin = w.amin(dim=1, keepdim=True)
    wmax = w.amax(dim=1, keepdim=True)
    qmax = (1 << n_bits) - 1
    scale = (wmax - wmin) / max(1, qmax)
    scale = torch.where(scale == 0, torch.ones_like(scale), scale)
    zero = torch.round(-wmin / scale)
    return scale.view(-1, 1), zero.view(-1, 1)


def find_linears_in_layer(layer: nn.Module) -> Dict[str, nn.Module]:
    out = {}
    for n, m in layer.named_modules():
        if isinstance(m, (nn.Linear, transformers.Conv1D, nn.Conv2d)):
            out[n] = m
    return out


@torch.no_grad()
def gptq_sequential_collect_grid_opt(
    model: AutoModelForCausalLM,
    calib_loader: DataLoader,
    device: torch.device,
    wbits: int = 4,
    groupsize: int = 128,
    sym: bool = True,
    actorder: bool = False,
    percdamp: float = 0.01,
    static_groups: bool = False,
    blocksize: int = 128,
    write_quantized_weights: bool = False,
) -> Dict[str, dict]:
    dev = device
    model.eval()

    use_cache = model.config.use_cache
    model.config.use_cache = False

    if not hasattr(model, 'model'):  # BLOOM
        model_type = 'bloom'
        layers = model.transformer.h
        model.transformer.word_embeddings = model.transformer.word_embeddings.to(dev)
        model.transformer.word_embeddings_layernorm = model.transformer.word_embeddings_layernorm.to(dev)
    else:
        if hasattr(model.model, 'decoder'):  # OPT
            model_type = "opt"
            layers = model.model.decoder.layers
            model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(dev)
            model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(dev)
            if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out:
                model.model.decoder.project_out = model.model.decoder.project_out.to(dev)
            if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in:
                model.model.decoder.project_in = model.model.decoder.project_in.to(dev)
        else:
            model_type = "llama"
            layers = model.model.layers
            model.model.embed_tokens = model.model.embed_tokens.to(dev)
            if getattr(model.model, "norm", None) is not None:
                model.model.norm = model.model.norm.to(dev)
            if hasattr(model.model, "rotary_emb"):
                model.model.rotary_emb = model.model.rotary_emb.to(dev)

    layers[0] = layers[0].to(dev)

    dtype = next(iter(model.parameters())).dtype

    try:
        nsamples = len(calib_loader)
        batched = False
    except TypeError:
        calib_batches = list(calib_loader)
        nsamples = len(calib_batches)
        batched = True

    inps = torch.zeros(
        (nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
    )
    cache = {"i": 0, "attention_mask": None}

    class Catcher(nn.Module):
        def __init__(self, module):
            super().__init__()
            self.module = module
        def forward(self, inp, **kwargs):
            inps[cache["i"]] = inp
            cache["i"] += 1
            cache["attention_mask"] = kwargs["attention_mask"]
            if model_type == 'llama':
                cache['position_ids'] = kwargs['position_ids']
            raise ValueError

    layers[0] = Catcher(layers[0])
    if not batched:
        for batch in calib_loader:
            ids = batch[0] if isinstance(batch, (list, tuple)) else batch
            try:
                model(ids.to(dev))
            except ValueError:
                pass
    else:
        for batch in calib_batches:
            ids = batch[0] if isinstance(batch, (list, tuple)) else batch
            try:
                model(ids.to(dev))
            except ValueError:
                pass
    layers[0] = layers[0].module

    layers[0] = layers[0].cpu()
    torch.cuda.empty_cache()

    outs = torch.zeros_like(inps)
    attention_mask = cache["attention_mask"]
    if model_type == 'llama':
        position_ids = cache['position_ids']

    grids: Dict[str, dict] = {}

    for li in range(len(layers)):
        layer = layers[li].to(dev)
        subset = find_linears_in_layer(layer)
        
        gptqs: Dict[str, GPTQ] = {}
        origW: Dict[str, torch.Tensor] = {}

        for name, mod in subset.items():
            g = GPTQ(mod)
            g.quantizer = Quantizer()
            g.quantizer.configure(bits=wbits, perchannel=True, sym=sym, mse=False)
            gptqs[name] = g
            origW[name] = mod.weight.data.detach().clone()

        handles = []
        def add_batch(name):
            def _hook(_mod, inp, out):
                gptqs[name].add_batch(inp[0].data, out.data)
            return _hook

        for name, mod in subset.items():
            handles.append(mod.register_forward_hook(add_batch(name)))

        with torch.inference_mode():
            for j in range(nsamples):
                if model_type == 'llama':
                    outs[j] = layer(
                        inps[j].unsqueeze(0),
                        attention_mask=attention_mask,
                        position_ids=position_ids
                    )[0]
                else:
                    outs[j] = layer(
                        inps[j].unsqueeze(0),
                        attention_mask=attention_mask
                    )[0]

        for h in handles: h.remove()

        for name, g in gptqs.items():
            print(name)
            ret = g.fasterquant_return_with_error(
                blocksize=blocksize, percdamp=percdamp,
                groupsize=groupsize, actorder=actorder, static_groups=static_groups
            )
            if hasattr(model.model, 'decoder'):  # opt
                key = f"model.decoder.layers.{li}.{name}"
            else:  # llama
                key = f"model.layers.{li}.{name}"

            grids[key] = {
                "scale": ret["scale"].cpu(),
                "zero":  ret["zero"].cpu(),
                "maxq":  ret["maxq"],
                "groupsize": groupsize,
                "sym": sym,
                "base_int": ret["base_int"].cpu(),
                "rest": ret["rest"].cpu(),
            }


            if not write_quantized_weights:
                subset[name].weight.data.copy_(origW[name].to(subset[name].weight.dtype))
            g.free()

        del gptqs, origW

        with torch.inference_mode(): 
            for j in range(nsamples):
                if model_type == 'llama':
                    outs[j] = layer(
                        inps[j].unsqueeze(0),
                        attention_mask=attention_mask,
                        position_ids=position_ids
                    )[0]
                else:
                    outs[j] = layer(
                        inps[j].unsqueeze(0),
                        attention_mask=attention_mask
                    )[0]

        layers[li] = layer.cpu()
        del layer
        torch.cuda.empty_cache()
        inps, outs = outs, inps

    if not hasattr(model, 'model'):  # BLOOM
        model.transformer.word_embeddings = model.transformer.word_embeddings.cpu()
        model.transformer.word_embeddings_layernorm = model.transformer.word_embeddings_layernorm.cpu()
    else:
        if hasattr(model.model, 'decoder'):  # OPT
            model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu()
            model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu()
            if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out:
                model.model.decoder.project_out = model.model.decoder.project_out.cpu()
            if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in:
                model.model.decoder.project_in = model.model.decoder.project_in.cpu()
        else:  # LLaMA
            model.model.embed_tokens = model.model.embed_tokens.cpu()
            if getattr(model.model, "norm", None) is not None:
                model.model.norm = model.model.norm.cpu()
            if hasattr(model.model, "rotary_emb"):
                model.model.rotary_emb = model.model.rotary_emb.cpu()
    model.config.use_cache = use_cache
    return grids   

class VQRoundLinear(nn.Module):
    def __init__(
        self,
        base_linear: nn.Linear,
        n_bits: int = 4,
        D: int = 8,
        K: int = 2**12,
        kmeans_iters: int = 30,
        gamma: float = -0.1,
        zeta: float = 1.1,
        external_scale: torch.Tensor = None,
        external_zero: torch.Tensor = None,
        external_wq: Optional[torch.Tensor] = None,
        external_base_int: torch.Tensor = None,
        external_rest: torch.Tensor = None,
    ):
        super().__init__()
        assert isinstance(base_linear, nn.Linear)


        self.in_features = base_linear.in_features
        self.out_features = base_linear.out_features
        if base_linear.bias is not None:
            self.register_buffer("bias", base_linear.bias.detach().clone())
        else:
            self.bias = None

        device = base_linear.weight.device
        w_dtype = base_linear.weight.dtype
        self.w_dtype = w_dtype

        self.n_bits = int(n_bits)

        if external_scale is not None and external_zero is not None and external_base_int is not None:
            self.register_buffer("scale", external_scale.to(device=device, dtype=w_dtype))
            self.register_buffer("zero",  external_zero.to(device=device, dtype=w_dtype))

            self.register_buffer("base_int", external_base_int.to(device=device, dtype=torch.int16))
            rest = external_rest.to(device=device, dtype=w_dtype)
            rest = torch.clamp(rest, 0, 1)
        else:
            W = base_linear.weight.data.detach().clone()
            s, z = per_row_affine_params(W.float(), n_bits)
            scale = s.to(device=device, dtype=W.dtype)
            zero  = z.to(device=device, dtype=W.dtype)
            self.register_buffer("scale", scale.to(device=device, dtype=w_dtype))
            self.register_buffer("zero",  zero.to(device=device, dtype=w_dtype))
            self.register_buffer("base_int", torch.floor(W.float() / scale).to(device=device, dtype=torch.int16))
            rest = (W.float() / scale) - torch.floor(W.float() / scale)
            rest = torch.clamp(rest, 0, 1)
            
        rsig = RectifiedSigmoid(gamma, zeta).to(device)
        self.rsig = rsig

        alpha_full = rsig.inverse(rest)
        base_shape = alpha_full.shape
        n = alpha_full.numel()
        L = (n + D - 1) // D
        pad = L * D - n
        if pad:
            alpha_vec = torch.cat([alpha_full.reshape(-1), torch.zeros(pad, device=device, dtype=alpha_full.dtype)], dim=0)
        else:
            alpha_vec = alpha_full.reshape(-1)
        alpha_blocks = alpha_vec.view(L, D)

        with torch.no_grad():
            codebook_init, indices = self._kmeans(alpha_blocks, K=K, iters=kmeans_iters)

        self.register_buffer("indices", indices.to(torch.int16))
        self.D = int(D)
        self.base_shape = base_shape
        self.n = n
        self.codebook = nn.Parameter(codebook_init.to(device=device, dtype=torch.float16).clone())
        self.quant_mode = "soft"

    def set_quant_mode(self, mode: str):
        assert mode in ("soft", "hard")
        self.quant_mode = mode

    def _reconstruct_alpha(self) -> torch.Tensor:
        alpha_blocks = self.codebook[self.indices.int()]
        alpha_flat = alpha_blocks.view(-1)[: self.n]
        return alpha_flat.view(self.base_shape)

    def _r(self, alpha: torch.Tensor) -> torch.Tensor:
        return (alpha >= 0).float() if self.quant_mode == "hard" else self.rsig(alpha)

    @torch.no_grad()
    def _kmeans(self, X: torch.Tensor, K: int, iters: int = 30):
        kmeans = faiss.Kmeans(X.shape[1], K, niter=iters, verbose=False, gpu=True)
        kmeans.train(X.cpu())
        centroids = torch.from_numpy(kmeans.centroids).to(X.device)
        idx = torch.from_numpy(kmeans.index.search(X.cpu(), 1)[1][:, 0]).to(X.device)
        return centroids, idx

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        alpha = self._reconstruct_alpha()
        r = self._r(alpha)  # ∈ [0,1]
        base_int_f = self.base_int.to(dtype=self.scale.dtype)
        q_int = torch.clamp(base_int_f + r + self.zero, 0, 2**self.n_bits - 1)
        Wq = self.scale * (q_int - self.zero)
        Wq = Wq.to(self.w_dtype)
        return F.linear(x, Wq, self.bias)



def replace_linear_with_vqround_using_grids(
    model: nn.Module,
    grids: dict,
    n_bits=4, D=8, K=4096, kmeans_iters=30,
    skip_keywords: Optional[List[str]] = None
) -> int:
    if skip_keywords is None:
        skip_keywords = ["lm_head", "embed_tokens", "embed_positions"]

    to_replace = []
    
    if hasattr(model, "model") and hasattr(model.model, "decoder"):  # OPT
        layer_list = model.model.decoder.layers
        layer_prefix = "model.decoder.layers"
    else:   # LLaMA
        layer_list = model.model.layers
        layer_prefix = "model.layers"

    for li, layer in enumerate(layer_list):
        for subname, submod in layer.named_modules():
            if not isinstance(submod, nn.Linear):
                continue
            full_name = f"{layer_prefix}.{li}.{subname}"
            if any(kw in full_name for kw in (skip_keywords or [])):
                continue
            to_replace.append((full_name, submod))

    def get_parent(root: nn.Module, path: str):
        parts = path.split("."); cur = root
        for p in parts[:-1]:
            if p.isdigit():
                cur = cur[int(p)]
            elif isinstance(cur, (nn.ModuleDict, dict)) and p in cur:
                cur = cur[p]
            else:
                cur = getattr(cur, p)
        return cur, parts[-1]

    def set_child(parent: nn.Module, child: str, new_mod: nn.Module):
        if child.isdigit():
            parent[int(child)] = new_mod
        elif isinstance(parent, (nn.ModuleDict, dict)) and child in parent:
            parent[child] = new_mod
        else:
            setattr(parent, child, new_mod)

    cnt = 0
    for name, lin in to_replace:
        print(name, flush=True)
        grid = grids.get(name, None)
        parent, child = get_parent(model, name)
        if grid is not None:
            new_m = VQRoundLinear(
                base_linear=lin, n_bits=n_bits, D=D, K=K, kmeans_iters=kmeans_iters,
                external_scale=grid["scale"], external_zero=grid["zero"],
                external_base_int=grid["base_int"], external_rest=grid["rest"],
            )
        else:
            new_m = VQRoundLinear(base_linear=lin, n_bits=n_bits, D=D, K=K, kmeans_iters=kmeans_iters,)
        set_child(parent, child, new_m)
        cnt += 1
    return cnt



def _round_reg_checkpointed(m, beta_now, step, rnd_loss, amp_dtype):
    def _fn(codebook: torch.Tensor):
        with torch.amp.autocast("cuda", dtype=amp_dtype):
            alpha_blocks = codebook[m.indices.int()]
            alpha_flat = alpha_blocks.view(-1)[: m.n]
            r_soft = m.rsig(alpha_flat)
            return rnd_loss(step, r_soft, override_b=beta_now)
    return checkpoint(_fn, m.codebook, use_reentrant=False)



class LinearTempDecay:
    def __init__(self, t_max: int, rel_start_decay: float, start_b: float, end_b: float):
        self.t_max = int(t_max)
        self.start_decay = rel_start_decay * self.t_max
        self.start_b = float(start_b)
        self.end_b = float(end_b)
    def __call__(self, t: int) -> float:
        if t < self.start_decay: return self.start_b
        rel_t = (t - self.start_decay) / max(1.0, (self.t_max - self.start_decay))
        return self.end_b + (self.start_b - self.end_b) * max(0.0, 1.0 - rel_t)


class RoundLoss(nn.Module):
    def __init__(self, max_count: int, b_range: Tuple[float, float], decay_start: float, warmup: float, p_norm: float):
        super().__init__()
        self.loss_start = max_count * warmup
        self.temp_decay = LinearTempDecay(max_count, rel_start_decay=warmup + (1 - warmup) * decay_start,
                                          start_b=b_range[0], end_b=b_range[1])
        self.p_norm = p_norm
        self.b = 0.0
    def forward(self, iter_count: int, sb: torch.Tensor, override_b: Optional[float] = None) -> torch.Tensor:
        if iter_count < self.loss_start: return sb.new_zeros(())
        self.b = float(override_b) if override_b is not None else self.temp_decay(iter_count)
        return (1.0 - (2.0 * sb - 1.0).abs().pow(self.b)).sum()


def _kd_loss_with_temperature(student_logits: torch.Tensor, teacher_logits: torch.Tensor, T: float = 1.0) -> torch.Tensor:
    if T == 1.0:
        log_p_s = F.log_softmax(student_logits, dim=-1)
        log_p_t = F.log_softmax(teacher_logits, dim=-1)
        return F.kl_div(log_p_s, log_p_t, log_target=True, reduction="batchmean")
    log_p_s = F.log_softmax(student_logits / T, dim=-1)
    log_p_t = F.log_softmax(teacher_logits / T, dim=-1)
    return F.kl_div(log_p_s, log_p_t, log_target=True, reduction="batchmean") * (T * T)


@torch.no_grad()
def build_teacher_cache(calib_data, teacher, device, T: float, save_dir: str):
    teacher.eval()
    items = []
    for idx, pair in enumerate(calib_data):
        inp = pair[0] if isinstance(pair, (list, tuple)) else pair
        input_ids = inp.to(device)
        logits = teacher(input_ids).logits
        if T != 1.0: logits = logits / T
        logp = F.log_softmax(logits, dim=-1).to(torch.float16).cpu()
        h = _hash_ids(inp)
        items.append({"h": h, "logp": logp})
        if (idx + 1) % 10 == 0:
            print(f"[cache-onefile] {idx+1}/{len(calib_data)} collected")
    index = {it["h"]: i for i, it in enumerate(items)}
    blob = {"type": "full_logp_single", "T": float(T), "items": items, "index": index}
    torch.save(blob, save_dir)
    print(f"[cache-onefile] saved {len(items)} items to single file: {save_dir}")


def collect_trainable_codebooks(model: nn.Module) -> Tuple[List[nn.Parameter], List[VQRoundLinear]]:
    params, mods = [], []
    for m in model.modules():
        if isinstance(m, VQRoundLinear):
            m.codebook.data = m.codebook.data.float()
            m.codebook.requires_grad_(True)
            params.append(m.codebook)
            mods.append(m)
    return params, mods


def freeze_stochastic_but_keep_train(model):
    cfg = getattr(model, "config", None)
    if cfg is not None:
        for k in [
            "dropout", "attention_dropout", "activation_dropout",
            "hidden_dropout", "hidden_dropout_prob",
            "attention_dropout_prob", "attention_probs_dropout_prob",
            "embd_pdrop", "resid_pdrop", "attn_pdrop",
            "classifier_dropout", "summary_first_dropout",
            "ffn_dropout", "ff_dropout", "dropout_rate",
            "layerdrop", "encoder_layerdrop", "decoder_layerdrop",
            "drop_path_rate",
        ]:
            if hasattr(cfg, k): setattr(cfg, k, 0.0)
    for m in model.modules():
        if isinstance(m, (nn.Dropout, nn.Dropout1d, nn.Dropout2d, nn.Dropout3d, nn.AlphaDropout)):
            m.p = 0.0
        if m.__class__.__name__ in {"StableDropout", "T5Dropout"}:
            if hasattr(m, "dropout_prob"): m.dropout_prob = 0.0
            if hasattr(m, "p"): m.p = 0.0
        if m.__class__.__name__ in {"StochasticDepth", "DropPath"}:
            if hasattr(m, "p"): m.p = 0.0
            if hasattr(m, "drop_prob"): m.drop_prob = 0.0
        if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)):
            m.eval()
        for attr in ["attention_dropout", "attn_dropout", "resid_dropout",
                     "hidden_dropout", "dropout_p", "dropout"]:
            if hasattr(m, attr) and isinstance(getattr(m, attr), float):
                setattr(m, attr, 0.0)
    model.train(True)
    if hasattr(model, "config"):
        model.config.use_cache = False


def set_quant_mode(model: nn.Module, mode: str):
    for m in model.modules():
        if isinstance(m, VQRoundLinear):
            m.set_quant_mode(mode)


def train_e2e_kd(
    student: nn.Module,
    teacher: Optional[nn.Module],
    train_loader: DataLoader,
    steps: int,
    lr: float,
    device: torch.device,
    kd_T: float = 2.0,
    kd_alpha: float = 1.0,
    use_round_reg: bool = True,
    beta_hi: float = 10.0,
    beta_lo: float = 2.0,
    beta_hold_ratio: float = 0.1,
    round_weight: float = 0.01,
    log_interval: int = 50,
    use_kd: bool = True,
    teacher_cache: Optional[str] = None,
    verify_cache: bool = True
):
    def _sync_time():
        torch.cuda.synchronize(device) if torch.cuda.is_available() else None
        return time.time()

    student.to(device)
    for p in student.parameters():
        p.requires_grad_(False)
    train_params, mods = collect_trainable_codebooks(student)
    assert len(train_params) > 0, "No VQ codebooks to train!"
    
    for name, param in student.named_parameters():
        print(f"  [param] {name} | shape = {param.shape} | requires_grad = {param.requires_grad}")
    optimizer = torch.optim.Adam(train_params, lr=lr, betas=(0.9, 0.99))
    scaler = torch.cuda.amp.GradScaler(enabled=True)

    rnd_loss = RoundLoss(steps, b_range=(beta_hi, beta_lo), decay_start=0.0, warmup=0.1, p_norm=2.0)

    total_cb = sum(m.codebook.numel() for m in mods)
    print(f"[train] VQRoundLinear modules: {len(mods)} | total codebook params: {total_cb}")

    cache_items, cache_index, T_cache = None, None, 2.0
    if teacher_cache is not None:
        cache_blob = torch.load(teacher_cache, map_location="cpu")
        assert cache_blob.get("type") in ("full_logp_single", "full_logp_memmap"), "wrong cache type"
        T_cache = float(cache_blob.get("T", 2.0))
        if cache_blob["type"] == "full_logp_single":
            cache_items = cache_blob["items"]
            cache_index = cache_blob.get("index", None)
        print(f"[cache] loaded T={T_cache}; items={len(cache_items) if cache_items is not None else 'memmap'}; index={'ok' if cache_index else 'missing'}")

    data_iter = iter(train_loader)
    step, best = 0, float("inf")

    while step < steps:
        t0 = _sync_time()
        try:
            batch = next(data_iter)
        except StopIteration:
            data_iter = iter(train_loader)
            batch = next(data_iter)
        t_data = _sync_time()

        input_ids = batch[0].to(device, non_blocking=True) if isinstance(batch, (list, tuple)) else batch.to(device, non_blocking=True)

        hold_steps = int(steps * beta_hold_ratio)

        if step <= steps * beta_hold_ratio:
            beta_now = beta_hi
        else:
            beta_now = beta_hi - (beta_hi - beta_lo) * (step / steps)
        amp_dtype = (torch.bfloat16 if (torch.cuda.is_available() and torch.cuda.is_bf16_supported()) else torch.float16)

        with torch.amp.autocast("cuda", dtype=amp_dtype):
            if cache_items is not None:
                s_logits = student(input_ids).logits
                t_fwd = _sync_time()

                cur_ids = batch[0] if isinstance(batch, (list, tuple)) else batch
                h = _hash_ids(cur_ids)
                if cache_index is not None and h in cache_index:
                    item = cache_items[cache_index[h]]
                    t_logp_cpu = item["logp"]
                else:
                    if step == 0:
                        print("[cache] no index or hash miss; falling back to sequential pointer (results may differ).")
                    item = cache_items[step % len(cache_items)]
                    t_logp_cpu = item["logp"] if isinstance(item, dict) else item
                if verify_cache and isinstance(item, dict) and "h" in item:
                    assert item["h"] == h, f"[cache] misaligned at step={step}"

                T_s, T_t = s_logits.size(1), t_logp_cpu.size(1)
                Tm = min(T_s, T_t)
                s_logp_T = F.log_softmax(s_logits[:, -Tm:, :] / T_cache, dim=-1)
                t_logp = t_logp_cpu[:, -Tm:, :].to(device, dtype=s_logp_T.dtype, non_blocking=True)
                task_loss = F.kl_div(s_logp_T, t_logp, log_target=True, reduction="batchmean") * (T_cache * T_cache)
                task_loss = task_loss * kd_alpha
                t_kd = _sync_time()
            else:
                if teacher is not None and use_kd:
                    with torch.no_grad():
                        t_logits = teacher(input_ids).logits
                    s_logits = student(input_ids).logits
                    t_fwd = _sync_time()
                    task_loss = _kd_loss_with_temperature(s_logits, t_logits, T=kd_T) * kd_alpha
                    t_kd = _sync_time()
                else:
                    outputs = student(input_ids, labels=input_ids)
                    t_fwd = _sync_time()
                    task_loss = outputs.loss
                    t_kd = t_fwd

        if use_round_reg:
            rr = 0.0
            for m in mods:
                rr = rr + _round_reg_checkpointed(
                    m=m, beta_now=beta_now, step=step, rnd_loss=rnd_loss, amp_dtype=amp_dtype
                )
            total_loss = task_loss + round_weight * rr
        else:
            rr = torch.tensor(0.0, device=device)
            total_loss = task_loss
        t_round = _sync_time()

        optimizer.zero_grad(set_to_none=True)
        scaler.scale(total_loss).backward()
        t_bwd = _sync_time()
        scaler.step(optimizer)
        scaler.update()
        t_opt = _sync_time()

        step += 1
        if step % log_interval == 0:
            print(f"[e2e] step {step}/{steps} | beta={beta_now:.2f} | task={task_loss.item():.4f} | round={rr.item():.4f}", flush=True)
            best = min(best, float(task_loss.item()))
    return best



def harden_and_export(student: nn.Module, save_dir: str):
    wrappers = []
    for name, m in student.named_modules():
        if isinstance(m, VQRoundLinear):
            wrappers.append((name, m))

    def get_parent(root: nn.Module, path: str):
        parts = path.split("."); cur = root
        for p in parts[:-1]:
            if p.isdigit():
                cur = cur[int(p)]
            elif isinstance(cur, (nn.ModuleDict, dict)) and p in cur:
                cur = cur[p]
            else:
                cur = getattr(cur, p)
        return cur, parts[-1]

    def set_child(parent: nn.Module, child: str, new_mod: nn.Module):
        if child.isdigit():
            parent[int(child)] = new_mod
        elif isinstance(parent, (nn.ModuleDict, dict)) and child in parent:
            parent[child] = new_mod
        else:
            setattr(parent, child, new_mod)

    with torch.no_grad():
        for name, m in wrappers:
            alpha = m._reconstruct_alpha()
            r_hard = (alpha >= 0).float()
            base_int = m.base_int.to(dtype=m.scale.dtype)
            q_int = torch.clamp(base_int + r_hard + m.zero, 0, 2 ** m.n_bits - 1)
            Wq = m.scale * (q_int - m.zero)
            Wq = Wq.to(m.w_dtype)

            new_lin = nn.Linear(m.in_features, m.out_features, bias=(m.bias is not None),
                                device=Wq.device, dtype=Wq.dtype)
            new_lin.weight.copy_(Wq)
            if m.bias is not None:
                new_lin.bias.copy_(m.bias.data)

            parent, child = get_parent(student, name)
            set_child(parent, child, new_lin)

    n_left = sum(1 for _ in student.modules() if isinstance(_, VQRoundLinear))
    print(f"[export] VQRoundLinear remaining: {n_left}")
    assert n_left == 0, "Some VQRoundLinear remain after harden!"

    os.makedirs(save_dir, exist_ok=True)
    student.save_pretrained(save_dir)
    print(f"[export] Saved hardened model to {save_dir}")



def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--cache_dir", default="./cache", type=str)
    parser.add_argument("--model_path", type=str, required=True)
    parser.add_argument("--teacher_model", type=str, default=None)
    parser.add_argument("--calib_data", type=str, default="c4", choices=["c4", "wikitext2"])
    parser.add_argument("--seqlen", type=int, default=2048)
    parser.add_argument("--nsamples", type=int, default=128)
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--steps", type=int, default=5000)
    parser.add_argument("--lr", type=float, default=1e-2)
    parser.add_argument("--seed", type=int, default=0)

    parser.add_argument("--gptq_only", action="store_true", help="GPTQ baseline")
    parser.add_argument("--pre_gptq", action="store_true", help="")
    parser.add_argument("--gptq_blocksize", type=int, default=128)
    parser.add_argument("--gptq_percdamp", type=float, default=0.01)
    parser.add_argument("--gptq_groupsize", type=int, default=-1)
    parser.add_argument("--gptq_sym", action="store_true", default=False)
    parser.add_argument("--gptq_actorder", action="store_true", default=False)
    parser.add_argument("--gptq_static_groups", action="store_true", default=False)

    parser.add_argument("--w_bits", type=int, default=4)
    parser.add_argument("--D", type=int, default=8)
    parser.add_argument("--K", type=int, default=4096)
    parser.add_argument("--kmeans_iters", type=int, default=100)
    parser.add_argument("--skip_keywords", nargs="*", default=["lm_head"])

    parser.add_argument("--kd_temperature", type=float, default=2.0)
    parser.add_argument("--kd_alpha", type=float, default=1.0)
    parser.add_argument("--use_round_reg", action="store_true", default=True)
    parser.add_argument("--round_weight", type=float, default=0.01)
    parser.add_argument("--beta_hi", type=float, default=20.0)
    parser.add_argument("--beta_lo", type=float, default=2.0)
    parser.add_argument("--beta_hold_ratio", type=float, default=0.1)
    parser.add_argument("--loss", type=str, default="kd", choices=["kd", "lm"])

    parser.add_argument("--build_teacher_cache", action="store_true", default=True)
    parser.add_argument("--teacher_cache", type=str, default=None)

    parser.add_argument("--export_dir", type=str, default=None)

    args = parser.parse_args()
    Path(args.cache_dir).mkdir(parents=True, exist_ok=True)
    args.model_name = args.model_path.split('/')[-1]
    args.model_type = args.model_name.split('-')[0]
    print(args)
    torch.manual_seed(args.seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    tok_name = args.teacher_model if args.teacher_model else args.model_path
    tokenizer = AutoTokenizer.from_pretrained(tok_name, use_fast=True)
    if tokenizer.eos_token_id is None:
        tokenizer.add_special_tokens({"eos_token": "</s>"})
    tokenizer.pad_token = tokenizer.eos_token

    calib_data = get_weight_quant_data(args)

    teacher = None
    teacher_cache_path = args.teacher_cache
    use_kd = (args.loss == "kd") and (args.teacher_model is not None)
    if use_kd:
        teacher = AutoModelForCausalLM.from_pretrained(
            args.teacher_model,
            torch_dtype=torch.float16 if torch.cuda.is_available() else None,
            attn_implementation="flash_attention_2",
        ).to(device).eval()
        for p in teacher.parameters(): p.requires_grad = False
        if args.build_teacher_cache:
            if teacher_cache_path is None:
                teacher_cache_path = os.path.join(
                    args.cache_dir,
                    f"teacher_full_logp_{args.model_type}_{args.calib_data}_ns{args.nsamples}_L{args.seqlen}_T{args.kd_temperature}.pt"
                )
            if os.path.exists(teacher_cache_path):
                print(f"[cache-onefile] teacher_cache {teacher_cache_path} already exists, skip building.")
            else:
                print(f"[cache-onefile] building teacher_cache to {teacher_cache_path} ...")
                build_teacher_cache(calib_data, teacher, device=device, T=args.kd_temperature, save_dir=teacher_cache_path)
            del teacher
            import gc
            gc.collect()
            torch.cuda.empty_cache()
            teacher = None

    student = AutoModelForCausalLM.from_pretrained(
        args.model_path,
        torch_dtype=torch.float16 if torch.cuda.is_available() else None,
        attn_implementation="flash_attention_2",
    )
    student.seqlen = args.seqlen
    if args.gptq_only:
        grids = gptq_sequential_collect_grid_opt(
            student, calib_loader=calib_data, device=device,
            wbits=args.w_bits,
            groupsize=args.gptq_groupsize if args.gptq_groupsize > 0 else -1,
            sym=args.gptq_sym,
            actorder=args.gptq_actorder,
            percdamp=args.gptq_percdamp,
            static_groups=args.gptq_static_groups,
            blocksize=args.gptq_blocksize,
            write_quantized_weights=True,
        )
        evaluate(student, args)

        ts = datetime.now().strftime("%Y%m%d-%H%M%S")
        exp_dir = args.export_dir
        if not exp_dir or str(exp_dir).strip() == "":
            exp_dir = os.path.join(
                "exports",
                f"{args.model_name}"
                f"_gptq_only"
                f"_gptq_w{args.w_bits}"
                f"_g{args.gptq_groupsize if args.gptq_groupsize>0 else 'full'}"
                f"_b{args.gptq_blocksize}_d{args.gptq_percdamp}"
                f"_sym{'Y' if args.gptq_sym else 'N'}_act{'Y' if args.gptq_actorder else 'N'}"
                f"_ns{args.nsamples}_L{args.seqlen}"
                f"_{ts}"
            )
        os.makedirs(exp_dir, exist_ok=True)
        student.save_pretrained(exp_dir)
        print(f"[export] GPTQ-only model saved to: {exp_dir}")
        exit()
        return

    save_path = os.path.join("cache", f"gptq_grids_{args.model_path.split('/')[1]}_wbits_{args.w_bits}.pt")
    grids = {}
    if args.pre_gptq:
        if os.path.exists(save_path):
            print(f"[gptq] loading pre-collected GPTQ grids from {save_path}")
            grids = torch.load(save_path, map_location="cpu")
        else:
            print(f"[gptq] collecting GPTQ grids ...")
            grids = gptq_sequential_collect_grid_opt(
                student, calib_loader=calib_data, device=device,
                wbits=args.w_bits,
                groupsize=args.gptq_groupsize if args.gptq_groupsize > 0 else -1,
                sym=args.gptq_sym,
                actorder=args.gptq_actorder,
                percdamp=args.gptq_percdamp,
                static_groups=args.gptq_static_groups,
                blocksize=args.gptq_blocksize,
                write_quantized_weights=False,
            )
            torch.save(grids, save_path)

    n_replaced = replace_linear_with_vqround_using_grids(
        student, grids, n_bits=args.w_bits, D=args.D, K=args.K,
        kmeans_iters=args.kmeans_iters, skip_keywords=args.skip_keywords
    )
    print(f"[wrap] replaced {n_replaced} Linear -> VQRoundLinear (GPTQ grid)")
    
    del grids
    import gc
    gc.collect()
    torch.cuda.empty_cache()
    try:
        evaluate(student, args)
    except Exception as e:
        print(f"[eval] failed: {e}")
    
    freeze_stochastic_but_keep_train(student)
    student.config.use_cache = False
    student.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})


    import time
    t0 = time.time()
    train_e2e_kd(
        student=student,
        teacher=None,
        train_loader=calib_data,
        steps=args.steps,
        lr=args.lr,
        device=device,
        kd_T=args.kd_temperature,
        kd_alpha=args.kd_alpha,
        use_round_reg=args.use_round_reg,
        beta_hi=args.beta_hi,
        beta_lo=args.beta_lo,
        beta_hold_ratio=args.beta_hold_ratio,
        round_weight=args.round_weight,
        log_interval=50,
        use_kd=(args.loss == "kd"),
        teacher_cache=teacher_cache_path,
        verify_cache=True
    )
    t1 = time.time()
    print(f"[train] time cost: {t1-t0:.1f} seconds")

    student.gradient_checkpointing_disable()
    set_quant_mode(student, "soft")
    try:
        evaluate(student, args)
    except Exception as e:
        print(f"[eval] failed: {e}")
    
    set_quant_mode(student, "hard")
    try:
        evaluate(student, args)
    except Exception as e:
        print(f"[eval] failed: {e}")

    exp_dir = args.export_dir
    if not exp_dir or str(exp_dir).strip() == "":
        ts = datetime.now().strftime("%Y%m%d-%H%M%S")
        exp_dir = os.path.join(
            "exports",
            f"{args.model_name}"
            f"_w{args.w_bits}_D{args.D}_K{args.K}"
            f"_ns{args.nsamples}_L{args.seqlen}"
            f"_lr{args.lr}_bs{args.batch_size}"
            f"_T{args.kd_temperature}_steps{args.steps}"
            f"_gptqGrid_sym{'Y' if args.gptq_sym else 'N'}_g{args.gptq_groupsize if args.gptq_groupsize>0 else 'full'}"
            f"_act{'Y' if args.gptq_actorder else 'N'}"
            f"_{ts}"
        )
    print(f"[export] export_dir = {exp_dir}")
    harden_and_export(student, exp_dir)


if __name__ == "__main__":
    main()
