import argparse, json, torch, torch.nn as nn, torch.nn.functional as F, math, gc, os, time, re, sys, csv
from time import perf_counter
from statistics import mean, median
from contextlib import contextmanager
from typing import Optional, Dict, Any, List, Set, Tuple
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset


def _ensure_parent_dir(path: str):
    try:
        d = os.path.dirname(path)
        if d and not os.path.exists(d):
            os.makedirs(d, exist_ok=True)
    except Exception as e:
        raise RuntimeError(f"Failed to create directory for {path}: {e}")


_FORCE_ABX_ALWAYS = False
_PROFILE_ABX = False


def set_abx_policy(force_always: bool):
    global _FORCE_ABX_ALWAYS
    _FORCE_ABX_ALWAYS = force_always


def set_profile_abx(enabled: bool):
    global _PROFILE_ABX
    _PROFILE_ABX = enabled


class _Metrics:
    def __init__(self):
        self.reset()

    def reset(self):

        self.abx_kernel_ms: float = 0.0
        self.abx_kernel_calls: int = 0

        self.abx_bx_ms: float = 0.0
        self.abx_ax_ms: float = 0.0
        self.abx_bx_calls: int = 0
        self.abx_ax_calls: int = 0

        self.bx_compute_count: int = 0
        self.bx_consume_count: int = 0
        self.bx_cache_hits: int = 0
        self.bx_cache_misses: int = 0
        self.bx_module_calls: int = 0

        self.h2d_total_ms: float = 0.0
        self.h2d_events: int = 0
        self.prefill_ms: float = 0.0
        self.decode_step_ms: List[float] = []
        self.decode_total_ms: float = 0.0


_METRICS = _Metrics()


@contextmanager
def _cuda_timer_enabled():
    if not _PROFILE_ABX:
        yield None
        return
    s, e = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
    torch.cuda.synchronize()
    s.record()
    yield (s, e)
    e.record()
    torch.cuda.synchronize()


def _elapsed_ms(events):
    if events is None:
        return 0.0
    s, e = events
    return float(s.elapsed_time(e))


@contextmanager
def abx_timer_accumulate(attr_name: str):
    cm = _cuda_timer_enabled()
    events = cm.__enter__()
    exc_info = (None, None, None)
    try:
        yield
    except Exception:
        exc_info = sys.exc_info()
        raise
    finally:
        cm.__exit__(*exc_info)
        ms = _elapsed_ms(events)
        if _PROFILE_ABX:
            if attr_name == "abx_bx_ms":
                _METRICS.abx_bx_ms += ms
                _METRICS.abx_kernel_ms += ms
            elif attr_name == "abx_ax_ms":
                _METRICS.abx_ax_ms += ms
                _METRICS.abx_kernel_ms += ms


@contextmanager
def measure_h2d_time():
    s, e = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
    torch.cuda.synchronize()
    s.record()
    yield
    e.record()
    torch.cuda.synchronize()
    ms = s.elapsed_time(e)
    _METRICS.h2d_total_ms += float(ms)
    _METRICS.h2d_events += 1


_argv_disable = any(arg == "--use_cuda_w4a16" for arg in sys.argv)
_DISABLE_TRITON = (
    _argv_disable
    or os.environ.get("DISABLE_TRITON", "").lower() in ("1", "true", "yes")
    or os.environ.get("USE_CUDA_W4A16", "").lower() in ("1", "true", "yes")
)
try:
    if not _DISABLE_TRITON:
        import triton
        import triton.language as tl

        HAS_TRITON = True
    else:
        HAS_TRITON = False
except Exception:
    HAS_TRITON = False
    print(
        "Triton is not available or disabled. If needed, install with 'pip install triton'."
    )

if HAS_TRITON:

    @triton.jit
    def quant_linear_kernel(
        x_ptr,
        qweight_ptr,
        qzeros_ptr,
        scales_ptr,
        bias_ptr,
        output_ptr,
        M,
        N,
        K,
        stride_xm,
        stride_xk,
        stride_qwm,
        stride_qwk,
        stride_qzm,
        stride_qzk,
        stride_sm,
        stride_sk,
        stride_om,
        stride_on,
        group_size: tl.constexpr,
        BLOCK_SIZE_M: tl.constexpr,
        BLOCK_SIZE_N: tl.constexpr,
        BLOCK_SIZE_K: tl.constexpr,
        HAS_ZEROS: tl.constexpr,
        HAS_BIAS: tl.constexpr,
    ):
        pid = tl.program_id(axis=0)
        num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
        num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
        pid_m = pid // num_pid_n
        pid_n = pid % num_pid_n
        offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
        offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
        offs_k = tl.arange(0, BLOCK_SIZE_K)
        x_ptrs = x_ptr + (offs_am[:, None] * stride_xm + offs_k[None, :] * stride_xk)
        qweight_ptrs = qweight_ptr + (
            offs_bn[None, :] * stride_qwm + (offs_k[:, None] // 2) * stride_qwk
        )
        accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
        for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
            k_start = k * BLOCK_SIZE_K
            k_offs = k_start + offs_k
            x = tl.load(
                x_ptrs, mask=(offs_am[:, None] < M) & (k_offs[None, :] < K), other=0.0
            )
            packed_weights = tl.load(
                qweight_ptrs,
                mask=(k_offs[:, None] < K) & (offs_bn[None, :] < N),
                other=0,
            )
            is_low_nibble = k_offs % 2 == 0
            nibbles = tl.where(
                is_low_nibble[:, None], packed_weights & 0x0F, packed_weights >> 4
            )
            group_id = k_offs // group_size
            scales_ptrs = scales_ptr + (
                offs_bn[None, :] * stride_sm + group_id[:, None] * stride_sk
            )
            scales = tl.load(
                scales_ptrs,
                mask=(k_offs[:, None] < K) & (offs_bn[None, :] < N),
                other=0.0,
            )
            if HAS_ZEROS:
                zeros_group_id = group_id // 2
                qzeros_ptrs = qzeros_ptr + (
                    offs_bn[None, :] * stride_qzm + zeros_group_id[:, None] * stride_qzk
                )
                packed_zeros = tl.load(
                    qzeros_ptrs,
                    mask=(k_offs[:, None] < K) & (offs_bn[None, :] < N),
                    other=0,
                )
                is_low_zero_nibble = group_id % 2 == 0
                zeros = tl.where(
                    is_low_zero_nibble[:, None], packed_zeros & 0x0F, packed_zeros >> 4
                )
            else:
                zeros = 8
            dequant_weights = (
                nibbles.to(tl.float32) - zeros.to(tl.float32)
            ) * scales.to(tl.float32)
            accumulator += tl.dot(x.to(tl.float32), dequant_weights)
            x_ptrs += BLOCK_SIZE_K * stride_xk
            qweight_ptrs += (BLOCK_SIZE_K // 2) * stride_qwk
        if HAS_BIAS:
            bias = tl.load(bias_ptr + offs_bn, mask=offs_bn < N, other=0.0)
            accumulator = accumulator + bias[None, :]
        c = accumulator.to(output_ptr.dtype.element_ty)
        offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
        offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
        output_ptrs = (
            output_ptr + stride_om * offs_cm[:, None] + stride_on * offs_cn[None, :]
        )
        tl.store(output_ptrs, c, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N))

    def quant_linear(x, qweight, qzeros, scales, bias, group_size):
        original_shape = x.shape
        if x.dim() == 3:
            M, K = original_shape[0] * original_shape[1], original_shape[2]
        else:
            M, K = x.shape
        N = scales.shape[0]
        x = x.reshape(M, K)
        output = torch.empty((M, N), device=x.device, dtype=x.dtype)
        grid = lambda META: (
            triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
        )
        stride_qz0, stride_qz1 = (
            (qzeros.stride(0), qzeros.stride(1)) if qzeros is not None else (0, 0)
        )
        quant_linear_kernel[grid](
            x,
            qweight,
            qzeros,
            scales,
            bias,
            output,
            M,
            N,
            K,
            x.stride(0),
            x.stride(1),
            qweight.stride(0),
            qweight.stride(1),
            stride_qz0,
            stride_qz1,
            scales.stride(0),
            scales.stride(1),
            output.stride(0),
            output.stride(1),
            group_size=group_size,
            BLOCK_SIZE_M=64,
            BLOCK_SIZE_N=64,
            BLOCK_SIZE_K=32,
            HAS_ZEROS=(qzeros is not None),
            HAS_BIAS=(bias is not None),
            num_warps=4,
            num_stages=3,
        )
        return output.reshape(*original_shape[:-1], N)

    class TritonTrue4BitLinear(nn.Module):
        def __init__(self, in_features, out_features, group_size=128, bias=False):
            super().__init__()
            self.in_features = in_features
            self.out_features = out_features
            self.group_size = group_size
            self.register_buffer(
                "qweight",
                torch.empty((out_features, in_features // 2), dtype=torch.uint8),
            )
            self.register_buffer(
                "qzeros",
                torch.empty(
                    (out_features, math.ceil(in_features / group_size) // 2),
                    dtype=torch.uint8,
                ),
            )
            self.register_buffer(
                "scales",
                torch.empty(
                    (out_features, math.ceil(in_features / group_size)),
                    dtype=torch.float16,
                ),
            )
            self.bias = (
                nn.Parameter(torch.empty(out_features, dtype=torch.float16))
                if bias
                else None
            )

        def forward(self, x):
            return quant_linear(
                x, self.qweight, self.qzeros, self.scales, self.bias, self.group_size
            )

        @classmethod
        def from_float(cls, linear_layer, group_size):
            qlayer = cls(
                linear_layer.in_features,
                linear_layer.out_features,
                group_size,
                linear_layer.bias is not None,
            ).to(linear_layer.weight.device, dtype=linear_layer.weight.dtype)
            W = linear_layer.weight.data.clone()
            O, I = W.shape
            if I % group_size != 0:
                W = F.pad(W, (0, group_size - (I % group_size)))
            I_p = W.shape[1]
            Wg = W.reshape(O, I_p // group_size, group_size)
            minv = Wg.min(dim=-1).values
            maxv = Wg.max(dim=-1).values
            scales = ((maxv - minv) / 15.0).clamp(min=1e-8)
            zeros_f = (-minv / scales).round()
            qv = (
                torch.round(Wg / scales.unsqueeze(-1) + zeros_f.unsqueeze(-1))
                .clamp(0, 15)
                .to(torch.uint8)
            )
            low_w, high_w = qv[:, :, 0::2], qv[:, :, 1::2]
            qlayer.qweight.data.copy(((high_w << 4) | low_w).reshape(O, I_p // 2))
            qlayer.scales.data.copy(scales.to(torch.float16))
            z8 = zeros_f.to(torch.uint8)
            low_z, high_z = z8[:, 0::2], z8[:, 1::2]
            qlayer.qzeros.data.copy((high_z << 4) | low_z)
            if linear_layer.bias is not None:
                qlayer.bias.data.copy_(linear_layer.bias.data)
            return qlayer

    def convert_to_triton_4bit(model, group_size=128):
        TARGET_LAYERS = [
            "q_proj",
            "k_proj",
            "v_proj",
            "o_proj",
            "out_proj",
            "gate_proj",
            "up_proj",
            "down_proj",
            "fc1",
            "fc2",
        ]
        for name, module in model.named_modules():
            if isinstance(module, nn.Linear) and any(
                name.endswith(t) for t in TARGET_LAYERS
            ):
                parent, attr_name = get_parent_module(model, name)
                setattr(
                    parent,
                    attr_name,
                    TritonTrue4BitLinear.from_float(module, group_size),
                )
        gc.collect()
        torch.cuda.empty_cache()
        return model


def _cuda_sync(device):
    if isinstance(device, torch.device) and device.type == "cuda":
        torch.cuda.synchronize(device)


@contextmanager
def temp_generation_overrides(model, **overrides):
    gen_cfg = getattr(model, "generation_config", None)
    if gen_cfg is None:
        yield
        return
    old = {k: getattr(gen_cfg, k, None) for k in overrides}
    for k, v in overrides.items():
        try:
            setattr(gen_cfg, k, v)
        except:
            pass
    try:
        yield
    finally:
        for k, v in old.items():
            try:
                setattr(gen_cfg, k, v)
            except:
                pass


def _get_sequences_from_generate(output):
    return output.sequences if hasattr(output, "sequences") else output


def get_parent_module(model, name):
    parts = name.split(".")
    parent = model
    for part in parts[:-1]:
        parent = getattr(parent, part)
    return parent, parts[-1]


def clear_group_cache():
    pass


class MiniGroupCache:
    __slots__ = ("r", "valid", "uses_left")

    def __init__(self):
        self.r: Optional[torch.Tensor] = None
        self.valid: bool = False
        self.uses_left: int = 0

    def set(self, r: torch.Tensor, uses: int):
        self.r = r
        self.valid = True
        self.uses_left = uses

    def consume(self):
        if self.valid and self.uses_left > 0:
            self.uses_left -= 1
            out = self.r
            if self.uses_left == 0:
                self.valid = False
                self.r = None
            _METRICS.bx_consume_count += 1
            _METRICS.bx_cache_hits += 1
            return out, True
        _METRICS.bx_consume_count += 1
        _METRICS.bx_cache_misses += 1
        return None, False

    def clear(self):
        self.r = None
        self.valid = False
        self.uses_left = 0


class AddSVDCorrection(nn.Module):
    def __init__(
        self,
        inner: nn.Module,
        A_q: torch.Tensor,
        B_q: torch.Tensor,
        role: str,
        is_group: bool,
        group_cache: Optional[MiniGroupCache],
        alpha_svd: float = 1.0,
    ):
        super().__init__()
        self.inner = inner
        self.register_buffer("A_q", A_q.to(torch.float16), persistent=False)
        self.register_buffer("B_q", B_q.to(torch.float16), persistent=False)
        self.role = role
        self.is_group = is_group
        self.group_cache = group_cache
        self.alpha_svd = alpha_svd

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        z = self.inner(x)
        if self.alpha_svd == 0.0:
            return z
        A_q_dev, B_q_dev = self.A_q, self.B_q
        try:
            if self.is_group and self.role in ("q", "k", "v", "gate", "up"):
                _METRICS.bx_module_calls += 1

            if _FORCE_ABX_ALWAYS:
                with abx_timer_accumulate("abx_bx_ms"):
                    r = F.linear(x, B_q_dev)
                _METRICS.abx_bx_calls += 1
                _METRICS.bx_compute_count += 1
                with abx_timer_accumulate("abx_ax_ms"):
                    svd_raw = F.linear(r, A_q_dev)
                _METRICS.abx_ax_calls += 1
                return z.add_(svd_raw, alpha=self.alpha_svd)

            if self.is_group:
                if self.role in ("q", "gate"):
                    with abx_timer_accumulate("abx_bx_ms"):
                        r = F.linear(x, B_q_dev)
                    _METRICS.abx_bx_calls += 1
                    _METRICS.bx_compute_count += 1
                    uses = 2 if self.role == "q" else 1
                    if self.group_cache is not None:
                        self.group_cache.set(r, uses)
                    intermediate_r = r
                elif self.role in ("k", "v", "up"):
                    r, ok = (
                        self.group_cache.consume()
                        if self.group_cache is not None
                        else (None, False)
                    )
                    if not ok or r is None:
                        with abx_timer_accumulate("abx_bx_ms"):
                            r = F.linear(x, B_q_dev)
                        _METRICS.abx_bx_calls += 1
                        _METRICS.bx_compute_count += 1
                    intermediate_r = r
                else:
                    with abx_timer_accumulate("abx_bx_ms"):
                        intermediate_r = F.linear(x, B_q_dev)
                    _METRICS.abx_bx_calls += 1
            else:
                with abx_timer_accumulate("abx_bx_ms"):
                    intermediate_r = F.linear(x, B_q_dev)
                _METRICS.abx_bx_calls += 1
            with abx_timer_accumulate("abx_ax_ms"):
                svd_raw = F.linear(intermediate_r, A_q_dev)
            _METRICS.abx_ax_calls += 1
            if z.shape != svd_raw.shape:
                svd_raw = svd_raw.reshape(z.shape)
            return z.add_(svd_raw, alpha=self.alpha_svd)
        except RuntimeError as e:
            if "size of tensor" in str(e) or "invalid for input of size" in str(e):
                return z
            else:
                raise e


def _role_from_suffix(sfx: str) -> str:
    if sfx.endswith("q_proj"):
        return "q"
    elif sfx.endswith("k_proj"):
        return "k"
    elif sfx.endswith("v_proj"):
        return "v"
    elif sfx.endswith("gate_proj"):
        return "gate"
    elif sfx.endswith("up_proj"):
        return "up"
    else:
        return "solo"


def _collect_unit_order(bmap: Dict[str, str]) -> List[str]:
    seen = set()
    order = []
    for _, bkey in bmap.items():
        unit = None
        if bkey.endswith(".B_shared"):
            unit = bkey[: -len(".B_shared")]
        elif bkey.endswith(".B"):
            unit = bkey[: -len(".B")]
        if unit is not None and unit not in seen:
            seen.add(unit)
            order.append(unit)
    return order


def _patch_core(
    model, shared, bmap, alpha_svd: float, allowed_units: Optional[Set[str]] = None
) -> Tuple[int, int]:
    patched_count, skipped_count = 0, 0
    gkey2cache: Dict[str, MiniGroupCache] = {}
    for weight_name, bkey in tqdm(bmap.items(), desc="Patching SVD Correction"):
        module_name = weight_name.replace(".weight", "")
        is_group = bkey.endswith(".B_shared")

        if is_group:
            unit_key = bkey[: -len(".B_shared")]
        elif bkey.endswith(".B"):
            unit_key = bkey[: -len(".B")]
        else:
            skipped_count += 1
            continue

        if (allowed_units is not None) and (unit_key not in allowed_units):
            continue

        if is_group:
            module_suffix = module_name.split(".")[-1]
            a_key = f"{unit_key}.{module_suffix}.A"
            B_q = shared.get(f"{unit_key}.B_shared")
            role = _role_from_suffix(module_suffix)
            cache = gkey2cache.setdefault(unit_key, MiniGroupCache())
        else:
            a_key = f"{unit_key}.A"
            B_q = shared.get(f"{unit_key}.B")
            role = "solo"
            cache = None

        A_q = shared.get(a_key)
        if A_q is None or B_q is None:
            skipped_count += 1
            continue

        try:
            parent, attr_name = get_parent_module(model, module_name)
            inner = getattr(parent, attr_name)
            types_list = []
            try:
                from cuda_w4a16.linear import CudaW4A16Linear

                types_list.append(CudaW4A16Linear)
            except Exception:
                pass
            try:
                types_list.append(TritonTrue4BitLinear)
            except Exception:
                pass
            valid_types = tuple(types_list) if types_list else (nn.Module,)
            if not isinstance(inner, valid_types):
                skipped_count += 1
                continue
            setattr(
                parent,
                attr_name,
                AddSVDCorrection(
                    inner, A_q, B_q, role, is_group, cache, alpha_svd=alpha_svd
                ),
            )
            patched_count += 1
        except AttributeError as e:
            print(f"AttributeError for {module_name}: {e}")
            skipped_count += 1
            continue

    print(
        f"SVD Correction Patching Summary: {patched_count} patched, {skipped_count} skipped"
    )
    return patched_count, skipped_count


def patch_svd_correction_wrappers(model, shared, bmap, alpha_svd=1.0):
    _patch_core(model, shared, bmap, alpha_svd=alpha_svd, allowed_units=None)
    return model


def patch_svd_correction_wrappers_partial(
    model, shared, bmap, restore_ratio: float = 0.5, alpha_svd: float = 1.0
):
    unit_order = _collect_unit_order(bmap)
    n_total = len(unit_order)
    k = int(math.floor(n_total * restore_ratio))
    allowed = set(unit_order[:k]) if k > 0 else set()
    print(
        f"[Partial Restore] units total={n_total}, restore_ratio={restore_ratio:.2f}, restored={len(allowed)}"
    )
    _patch_core(model, shared, bmap, alpha_svd=alpha_svd, allowed_units=allowed)
    return model


@torch.no_grad()
def measure_generation_metrics_precise(
    model, tokenizer, device, prompts: List[str], max_new_tokens=50
) -> Dict[str, Any]:
    model.eval()
    results_all = []
    for prompt in prompts:
        _METRICS.decode_step_ms.clear()
        encoded = tokenizer(
            prompt, return_tensors="pt", truncation=True, max_length=512
        )
        input_ids = encoded["input_ids"]
        attn = encoded.get("attention_mask", None)
        with measure_h2d_time():
            input_ids = input_ids.to(device)
            attn = attn.to(device) if attn is not None else None

        torch.cuda.synchronize()
        s, e = torch.cuda.Event(True), torch.cuda.Event(True)
        s.record()
        out = model(input_ids=input_ids, attention_mask=attn, use_cache=True)
        e.record()
        torch.cuda.synchronize()
        prefill_ms = float(s.elapsed_time(e))
        _METRICS.prefill_ms += prefill_ms
        past = (
            out.past_key_values
            if hasattr(out, "past_key_values")
            else out.get("past_key_values", None)
        )

        last_token = input_ids[:, -1:]
        total_decode_s = perf_counter()
        step_times_ms = []
        for _ in range(max_new_tokens):
            torch.cuda.synchronize()
            s2, e2 = torch.cuda.Event(True), torch.cuda.Event(True)
            s2.record()
            out = model(input_ids=last_token, use_cache=True, past_key_values=past)
            e2.record()
            torch.cuda.synchronize()
            step_ms = float(s2.elapsed_time(e2))
            step_times_ms.append(step_ms)
            _METRICS.decode_step_ms.append(step_ms)
            past = (
                out.past_key_values
                if hasattr(out, "past_key_values")
                else out.get("past_key_values", None)
            )
            logits = out.logits[:, -1, :]
            next_token = torch.argmax(logits, dim=-1, keepdim=True)
            last_token = next_token
        _METRICS.decode_total_ms += (perf_counter() - total_decode_s) * 1000.0
        t = torch.tensor(step_times_ms)
        p50 = float(torch.quantile(t, 0.5).item())
        p95 = float(torch.quantile(t, 0.95).item())
        results_all.append(
            {
                "prefill_ms": prefill_ms,
                "decode_ms_token_p50": p50,
                "decode_ms_token_p95": p95,
            }
        )
    if not results_all:
        return {}
    prefill_ms = mean([x["prefill_ms"] for x in results_all])
    p50 = mean([x["decode_ms_token_p50"] for x in results_all])
    p95 = mean([x["decode_ms_token_p95"] for x in results_all])
    jitter_pct = ((p95 - p50) / max(1e-6, p50)) * 100.0
    bx_module_calls = max(1, _METRICS.bx_module_calls)
    reuse_rate_pct = (
        _METRICS.bx_cache_hits / max(1, _METRICS.bx_consume_count)
    ) * 100.0
    flops_saving_pct = (
        (bx_module_calls - _METRICS.bx_compute_count) / bx_module_calls
    ) * 100.0
    abx_total_ms = _METRICS.abx_bx_ms + _METRICS.abx_ax_ms
    abx_share_pct = (abx_total_ms / max(1e-6, _METRICS.decode_total_ms)) * 100.0
    abx_bx_share_pct = (
        _METRICS.abx_bx_ms / max(1e-6, _METRICS.decode_total_ms)
    ) * 100.0
    abx_ax_share_pct = (
        _METRICS.abx_ax_ms / max(1e-6, _METRICS.decode_total_ms)
    ) * 100.0
    steps = max(1, len(_METRICS.decode_step_ms))
    abx_ms_per_token = abx_total_ms / steps

    return {
        "prefill_ms": prefill_ms,
        "decode_ms_token_p50": p50,
        "decode_ms_token_p95": p95,
        "jitter_pct": jitter_pct,
        "bx_reuse_rate_pct": reuse_rate_pct,
        "abx_flops_saving_pct": flops_saving_pct,
        "abx_kernel_time_share_pct": abx_share_pct,
        "abx_bx_time_share_pct": abx_bx_share_pct,
        "abx_ax_time_share_pct": abx_ax_share_pct,
        "abx_ms_per_token": abx_ms_per_token,
        "abx_bx_ms_total": _METRICS.abx_bx_ms,
        "abx_ax_ms_total": _METRICS.abx_ax_ms,
        "h2d_ms_token": (
            (_METRICS.h2d_total_ms / max(1, _METRICS.h2d_events))
            if _METRICS.h2d_events
            else 0.0
        ),
    }


@torch.no_grad()
def measure_generation_metrics(
    model,
    tokenizer,
    device,
    prompts,
    max_new_tokens=50,
    do_sample=False,
    num_beams=1,
    temperature=1.0,
    top_p=1.0,
    repeats=1,
):
    model.eval()
    pad_id = tokenizer.eos_token_id
    override_kwargs = {"temperature": 1.0, "top_p": 1.0} if not do_sample else {}
    with temp_generation_overrides(model, **override_kwargs):
        try:
            warm = tokenizer(
                prompts[0], return_tensors="pt", truncation=True, max_length=512
            ).to(device)
            if warm["input_ids"].dim() == 1:
                warm["input_ids"] = warm["input_ids"].unsqueeze(0)
            if "attention_mask" in warm and warm["attention_mask"].dim() == 1:
                warm["attention_mask"] = warm["attention_mask"].unsqueeze(0)
            _ = model.generate(**warm, max_new_tokens=1, use_cache=True)
            _cuda_sync(device)
        except Exception as e:
            print(f"Warning: Warmup failed: {e}")
        ttfb_list_ms, tokps_list, total_times, total_new_tokens = [], [], [], []
        gen_kwargs = dict(
            do_sample=do_sample,
            num_beams=num_beams,
            pad_token_id=pad_id,
            use_cache=True,
            return_dict_in_generate=True,
        )
        if do_sample:
            gen_kwargs.update(dict(temperature=temperature, top_p=top_p))
        for _ in range(repeats):
            for prompt in prompts:
                clear_group_cache()
                inputs = tokenizer(
                    prompt, return_tensors="pt", truncation=True, max_length=512
                )
                with measure_h2d_time():
                    inputs = inputs.to(device)
                try:
                    if inputs["input_ids"].dim() == 1:
                        inputs["input_ids"] = inputs["input_ids"].unsqueeze(0)
                    if (
                        "attention_mask" in inputs
                        and inputs["attention_mask"].dim() == 1
                    ):
                        inputs["attention_mask"] = inputs["attention_mask"].unsqueeze(0)
                    _cuda_sync(device)
                    t0 = perf_counter()
                    model.generate(**inputs, max_new_tokens=1, **gen_kwargs)
                    _cuda_sync(device)
                    ttfb_list_ms.append((perf_counter() - t0) * 1000.0)
                    clear_group_cache()
                    _cuda_sync(device)
                    t1 = perf_counter()
                    outN = model.generate(
                        **inputs, max_new_tokens=max_new_tokens, **gen_kwargs
                    )
                    _cuda_sync(device)
                    t_total = perf_counter() - t1
                    gen_tokens = (
                        _get_sequences_from_generate(outN).shape[1]
                        - inputs["input_ids"].shape[1]
                    )
                    tokps_list.append((gen_tokens / t_total) if t_total > 0 else 0.0)
                    total_times.append(t_total)
                    total_new_tokens.append(gen_tokens)
                except Exception as e:
                    print(
                        f"Warning: Generation measurement failed for prompt '{prompt[:30]}...': {e}"
                    )
                    continue
    return {
        "ttfb_ms_mean": mean(ttfb_list_ms) if ttfb_list_ms else 0,
        "ttfb_ms_median": median(ttfb_list_ms) if ttfb_list_ms else 0,
        "tok_s_mean": mean(tokps_list) if tokps_list else 0,
        "tok_s_median": median(tokps_list) if tokps_list else 0,
        "avg_total_time_s": mean(total_times) if total_times else 0,
        "avg_new_tokens": mean(total_new_tokens) if total_new_tokens else 0,
    }


@torch.no_grad()
def evaluate(model, tokenizer, device, model_name):
    print(f"\n--- Evaluating {model_name} ---")
    model.eval()
    ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
    text = "\n\n".join(ds["text"])
    enc = tokenizer(text, return_tensors="pt")
    input_ids = enc.input_ids
    seq_len = input_ids.size(1)
    total_loss, total_tokens = 0.0, 0
    start_time = time.time()
    pbar = tqdm(range(0, seq_len, 2048), desc=f"PPL for {model_name}")
    for i in pbar:
        clear_group_cache()
        begin, end = i, min(i + 2048, seq_len)
        if end - begin <= 1:
            continue
        input_batch = input_ids[:, begin:end].to(device)
        labels = input_batch
        outputs = model(input_batch)
        logits = outputs.logits
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        loss = nn.CrossEntropyLoss()(
            shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
        )
        ntok = shift_labels.numel()
        total_loss += loss.item() * ntok
        total_tokens += ntok
        pbar.set_description(f"PPL for {model_name} (Loss: {loss.item():.4f})")
    ppl = math.exp(total_loss / total_tokens)
    elapsed = time.time() - start_time
    print(f"✅ Result for {model_name}: PPL = {ppl:.4f}, Time = {elapsed:.2f}s")
    return ppl, elapsed


@torch.no_grad()
def run_self_test(device_str: str = "cuda:0", iters: int = 16):
    print("🔧 Running ABX self-test...")

    set_profile_abx(True)
    device = torch.device(device_str)

    in_features = 16
    rank = 4
    out_features = 16

    inner = nn.Linear(in_features, out_features, bias=False).to(
        device, dtype=torch.float16
    )

    torch.manual_seed(0)
    B_q = torch.randn(rank, in_features, device=device, dtype=torch.float16)
    A_q = torch.randn(out_features, rank, device=device, dtype=torch.float16)

    mod = AddSVDCorrection(
        inner=inner,
        A_q=A_q,
        B_q=B_q,
        role="solo",
        is_group=False,
        group_cache=None,
        alpha_svd=1.0,
    ).to(device)

    x = torch.randn(32, in_features, device=device, dtype=torch.float16)

    _METRICS.reset()
    for _ in range(iters):
        y = mod(x)

        _ = y.norm().item()

    total = _METRICS.abx_bx_ms + _METRICS.abx_ax_ms
    print(
        f"ABX(B@x)  total ms : {_METRICS.abx_bx_ms:.3f}  (calls={_METRICS.abx_bx_calls})"
    )
    print(
        f"ABX(A@r)  total ms : {_METRICS.abx_ax_ms:.3f}  (calls={_METRICS.abx_ax_calls})"
    )
    print(f"ABX total (ms)     : {total:.3f}")
    print("✅ Self-test finished.")


def _build_model_from_fp16(args):
    device = torch.device(args.device)
    print(f"📥 Loading FP16 model: {args.model_name}")
    model_fp16 = AutoModelForCausalLM.from_pretrained(
        args.model_name,
        torch_dtype=torch.float16,
        device_map="cpu",
        trust_remote_code=args.trust_remote_code,
    )
    print(f"📥 Loading original weights: {args.original_weights_path}")
    original_weights = torch.load(
        args.original_weights_path, map_location="cpu", weights_only=True
    )
    model_fp16.load_state_dict(original_weights)
    if args.use_cuda_w4a16:
        print("🔄 Converting to CUDA W4A16...")
        try:
            from cuda_w4a16.linear import convert_to_cuda_w4a16
        except Exception as e:
            raise RuntimeError(f"Failed to import CUDA W4A16 module: {e}")
        model = convert_to_cuda_w4a16(model_fp16, group_size=args.group_size).to(device)
        method_label = "CUDA W4A16"
    else:
        print(f"🔄 Converting to Triton 4-bit...")
        model = convert_to_triton_4bit(model_fp16, group_size=args.group_size).to(
            device
        )
        method_label = "Triton 4-bit"
    del model_fp16, original_weights
    gc.collect()
    torch.cuda.empty_cache()
    return model, method_label


def main():
    p = argparse.ArgumentParser(
        description="Compare ABX always-recompute vs BX caching (full vs partial restore)."
    )
    p.add_argument("--model_name", required=True)
    p.add_argument("--shared_path", required=True)
    p.add_argument("--bmap_path", required=True)
    p.add_argument(
        "--original_weights_path",
        required=True,
        help="Path to original weights (from step1)",
    )
    p.add_argument("--device", default="cuda:0")
    p.add_argument("--trust_remote_code", action="store_true")
    p.add_argument("--group_size", type=int, default=128)
    p.add_argument(
        "--cache_mode",
        choices=["both", "cache", "no_cache"],
        default="both",
        help="Select which caching strategy to evaluate: both, cache-only, or no-cache.",
    )
    p.add_argument("--skip_gen", action="store_true")
    p.add_argument("--gen_max_new_tokens", type=int, default=50)
    p.add_argument("--gen_repeats", type=int, default=1)
    p.add_argument("--gen_do_sample", action="store_true")
    p.add_argument("--gen_num_beams", type=int, default=1)
    p.add_argument("--gen_temperature", type=float, default=1.0)
    p.add_argument("--gen_top_p", type=float, default=1.0)
    p.add_argument("--use_cuda_w4a16", action="store_true")
    p.add_argument(
        "--restore_ratio",
        type=float,
        default=0.5,
        help="Ratio of ALL units (groups + solo) to restore for partial scenario",
    )
    p.add_argument(
        "--metrics_csv",
        type=str,
        default="",
        help="CSV path to append results (3 rows)",
    )
    p.add_argument(
        "--profile_abx",
        action="store_true",
        help="Enable CUDA-event timing for A@B@X breakdown (Bx and Ax).",
    )
    p.add_argument(
        "--self_test",
        action="store_true",
        help="Run a lightweight ABX timing self-test and exit (no model load).",
    )
    args = p.parse_args()

    if args.self_test:
        run_self_test(device_str=args.device)
        return

    set_profile_abx(args.profile_abx)

    device = torch.device(args.device)
    tokenizer = AutoTokenizer.from_pretrained(
        args.model_name, use_fast=True, trust_remote_code=args.trust_remote_code
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    default_prompts = [
        "Hello, my name is",
        "The quick brown fox",
        "In a shocking finding, scientists discovered that",
    ]

    print(f"🧩 Loading correction artifacts ...")
    shared_cpu = torch.load(args.shared_path, map_location="cpu", weights_only=True)
    with open(args.bmap_path, "r") as f:
        bmap = json.load(f)

    results: Dict[str, Any] = {}

    evaluate_no_cache = args.cache_mode in ("both", "no_cache")
    evaluate_cache = args.cache_mode in ("both", "cache")
    evaluate_partial = evaluate_cache and args.restore_ratio > 0

    if not (evaluate_no_cache or evaluate_cache):
        raise ValueError("cache_mode must enable at least one evaluation")

    method_label = ""
    model_primary, method_label = _build_model_from_fp16(args)
    shared = {
        k: (v.to(device) if torch.is_tensor(v) else v) for k, v in shared_cpu.items()
    }
    model_primary = patch_svd_correction_wrappers(
        model_primary, shared, bmap, alpha_svd=1.0
    )

    if evaluate_no_cache:
        print("\n=== EVALUATION: ABX-ALWAYS  ===")
        set_abx_policy(True)
        ppl_a, time_a = evaluate(
            model_primary, tokenizer, device, f"{method_label} + SVD (ABX-ALWAYS)"
        )
        gen_legacy_a = None
        precise_a = None
        if not args.skip_gen:
            print("Measuring generation metrics (legacy + precise) for ABX-ALWAYS...")
            _METRICS.reset()
            try:
                gen_legacy_a = measure_generation_metrics(
                    model_primary,
                    tokenizer,
                    device,
                    prompts=default_prompts,
                    max_new_tokens=args.gen_max_new_tokens,
                    do_sample=args.gen_do_sample,
                    num_beams=args.gen_num_beams,
                    temperature=args.gen_temperature,
                    top_p=args.gen_top_p,
                    repeats=args.gen_repeats,
                )
            except Exception as e:
                print(f"Generation (legacy) failed for ABX-ALWAYS: {e}")
            try:
                precise_a = measure_generation_metrics_precise(
                    model_primary,
                    tokenizer,
                    device,
                    prompts=default_prompts,
                    max_new_tokens=args.gen_max_new_tokens,
                )
            except Exception as e:
                print(f"Precise generation failed for ABX-ALWAYS: {e}")
        results["abx_always"] = {
            "ppl": ppl_a,
            "time": time_a,
            "generation_metrics": gen_legacy_a,
            "precise": precise_a,
        }
    else:
        print(
            "\n[Step3] Skipping ABX-ALWAYS evaluation (--cache_mode excludes no_cache)"
        )

    if evaluate_cache:
        print("\n=== EVALUATION: BX-CACHE ===")
        set_abx_policy(False)
        ppl_c, time_c = evaluate(
            model_primary, tokenizer, device, f"{method_label} + SVD (BX-CACHE)"
        )
        gen_legacy_c = None
        precise_c = None
        if not args.skip_gen:
            print("Measuring generation metrics (legacy + precise) for BX-CACHE...")
            _METRICS.reset()
            try:
                gen_legacy_c = measure_generation_metrics(
                    model_primary,
                    tokenizer,
                    device,
                    prompts=default_prompts,
                    max_new_tokens=args.gen_max_new_tokens,
                    do_sample=args.gen_do_sample,
                    num_beams=args.gen_num_beams,
                    temperature=args.gen_temperature,
                    top_p=args.gen_top_p,
                    repeats=args.gen_repeats,
                )
            except Exception as e:
                print(f"Generation (legacy) failed for BX-CACHE: {e}")
            try:
                precise_c = measure_generation_metrics_precise(
                    model_primary,
                    tokenizer,
                    device,
                    prompts=default_prompts,
                    max_new_tokens=args.gen_max_new_tokens,
                )
            except Exception as e:
                print(f"Precise generation failed for BX-CACHE: {e}")
        results["bx_cache"] = {
            "ppl": ppl_c,
            "time": time_c,
            "generation_metrics": gen_legacy_c,
            "precise": precise_c,
        }
    else:
        print("\n[Step3] Skipping BX-CACHE evaluation (--cache_mode excludes cache)")

    del shared
    if model_primary is not None:
        del model_primary
    gc.collect()
    torch.cuda.empty_cache()

    if evaluate_partial:
        print(
            "\n=== EVALUATION: BX-CACHE (RESTORE {:.0f}% of ALL units, SVD α=1.0) ===".format(
                args.restore_ratio * 100
            )
        )
        model_partial, _ = _build_model_from_fp16(args)
        shared_dev = {
            k: (v.to(device) if torch.is_tensor(v) else v)
            for k, v in shared_cpu.items()
        }
        model_partial = patch_svd_correction_wrappers_partial(
            model_partial,
            shared_dev,
            bmap,
            restore_ratio=args.restore_ratio,
            alpha_svd=1.0,
        )
        set_abx_policy(False)
        ppl_p, time_p = evaluate(
            model_partial,
            tokenizer,
            device,
            f"{method_label} + SVD (BX-CACHE, restore {int(args.restore_ratio*100)}%)",
        )
        gen_legacy_p = None
        precise_p = None
        if not args.skip_gen:
            print(
                "Measuring generation metrics (legacy + precise) for BX-CACHE (partial restore)..."
            )
            _METRICS.reset()
            try:
                gen_legacy_p = measure_generation_metrics(
                    model_partial,
                    tokenizer,
                    device,
                    prompts=default_prompts,
                    max_new_tokens=args.gen_max_new_tokens,
                    do_sample=args.gen_do_sample,
                    num_beams=args.gen_num_beams,
                    temperature=args.gen_temperature,
                    top_p=args.gen_top_p,
                    repeats=args.gen_repeats,
                )
            except Exception as e:
                print(f"Generation (legacy) failed for BX-CACHE (partial): {e}")
            try:
                precise_p = measure_generation_metrics_precise(
                    model_partial,
                    tokenizer,
                    device,
                    prompts=default_prompts,
                    max_new_tokens=args.gen_max_new_tokens,
                )
            except Exception as e:
                print(f"Precise generation failed for BX-CACHE (partial): {e}")
        results["bx_cache_restore"] = {
            "ppl": ppl_p,
            "time": time_p,
            "generation_metrics": gen_legacy_p,
            "precise": precise_p,
        }
        del shared_dev
        del model_partial
        gc.collect()
        torch.cuda.empty_cache()
    elif evaluate_cache:
        print("\n[Step3] Skipping partial BX-CACHE restore (--restore_ratio <= 0).")

    print(f"\n{'='*15} FINAL SUMMARY ({args.model_name} + SVD α=1.0) {'='*15}")
    print(f"Model: {args.model_name}")
    print("-" * 160)
    print(
        f"{'Setting':<32} | {'PPL':<8} | {'Time(s)':<8} | {'TTFB(ms)':<10} | {'tok/s':<10} | {'Prefill(ms)':<12} | {'Dec p50/p95(ms)':<20} | {'Jitter%':<8} | {'BxReuse%':<9} | {'ABxSave%':<9} | {'ABxShare%':<10} | {'Bx%':<6} | {'Ax%':<6} | {'ABx ms/tok':<11} | {'H2D ms/tok':<10}"
    )
    print("-" * 160)

    def pr(tag: str, d: Dict[str, Any]):
        gm = d.get("generation_metrics") or {}
        pm = d.get("precise") or {}
        ttfb = gm.get("ttfb_ms_median", float("nan"))
        toks = gm.get("tok_s_median", float("nan"))
        pre = pm.get("prefill_ms", float("nan"))
        p50 = pm.get("decode_ms_token_p50", float("nan"))
        p95 = pm.get("decode_ms_token_p95", float("nan"))
        jitter = pm.get("jitter_pct", float("nan"))
        reuse = pm.get("bx_reuse_rate_pct", float("nan"))
        save = pm.get("abx_flops_saving_pct", float("nan"))
        share = pm.get("abx_kernel_time_share_pct", float("nan"))
        share_bx = pm.get("abx_bx_time_share_pct", float("nan"))
        share_ax = pm.get("abx_ax_time_share_pct", float("nan"))
        abx_tok = pm.get("abx_ms_per_token", float("nan"))
        h2d = pm.get("h2d_ms_token", float("nan"))
        print(
            f"{tag:<32} | {d['ppl']:<8.4f} | {d['time']:<8.2f} | {ttfb:<10.1f} | {toks:<10.2f} | {pre:<12.1f} | {p50:<6.1f}/{p95:<6.1f}       | {jitter:<8.1f} | {reuse:<9.1f} | {save:<9.1f} | {share:<10.1f} | {share_bx:<6.1f} | {share_ax:<6.1f} | {abx_tok:<11.2f} | {h2d:<10.2f}"
        )

    summary_entries = []
    if "abx_always" in results:
        summary_entries.append(("ABX-ALWAYS (no reuse)", results["abx_always"]))
    if "bx_cache" in results:
        summary_entries.append(("BX-CACHE   (reuse)", results["bx_cache"]))
    if "bx_cache_restore" in results:
        summary_entries.append(
            (
                f"BX-CACHE (restore{int(args.restore_ratio*100)}%)",
                results["bx_cache_restore"],
            )
        )

    for tag, data in summary_entries:
        pr(tag, data)

    if args.metrics_csv:
        fieldnames = [
            "tag",
            "model",
            "method",
            "ppl",
            "time_s",
            "ttfb_ms_median",
            "tok_s_median",
            "prefill_ms",
            "decode_ms_token_p50",
            "decode_ms_token_p95",
            "jitter_pct",
            "bx_reuse_rate_pct",
            "abx_flops_saving_pct",
            "abx_kernel_time_share_pct",
            "abx_bx_time_share_pct",
            "abx_ax_time_share_pct",
            "abx_ms_per_token",
            "h2d_ms_token",
        ]

        def row_of(tag_label, dat):
            gm = dat.get("generation_metrics") or {}
            pm = dat.get("precise") or {}
            return {
                "tag": tag_label,
                "model": args.model_name,
                "method": method_label or "Unknown",
                "ppl": dat["ppl"],
                "time_s": dat["time"],
                "ttfb_ms_median": gm.get("ttfb_ms_median", float("nan")),
                "tok_s_median": gm.get("tok_s_median", float("nan")),
                "prefill_ms": pm.get("prefill_ms", float("nan")),
                "decode_ms_token_p50": pm.get("decode_ms_token_p50", float("nan")),
                "decode_ms_token_p95": pm.get("decode_ms_token_p95", float("nan")),
                "jitter_pct": pm.get("jitter_pct", float("nan")),
                "bx_reuse_rate_pct": pm.get("bx_reuse_rate_pct", float("nan")),
                "abx_flops_saving_pct": pm.get("abx_flops_saving_pct", float("nan")),
                "abx_kernel_time_share_pct": pm.get(
                    "abx_kernel_time_share_pct", float("nan")
                ),
                "abx_bx_time_share_pct": pm.get("abx_bx_time_share_pct", float("nan")),
                "abx_ax_time_share_pct": pm.get("abx_ax_time_share_pct", float("nan")),
                "abx_ms_per_token": pm.get("abx_ms_per_token", float("nan")),
                "h2d_ms_token": pm.get("h2d_ms_token", float("nan")),
            }

        rows = []
        if "abx_always" in results:
            rows.append(row_of("ABX-ALWAYS (no reuse)", results["abx_always"]))
        if "bx_cache" in results:
            rows.append(row_of("BX-CACHE (reuse)", results["bx_cache"]))
        if "bx_cache_restore" in results:
            rows.append(
                row_of(
                    f"BX-CACHE (restore{int(args.restore_ratio*100)}%)",
                    results["bx_cache_restore"],
                )
            )

        if rows:
            _ensure_parent_dir(args.metrics_csv)
            exists = os.path.exists(args.metrics_csv)
            try:
                with open(args.metrics_csv, "a", newline="") as f:
                    writer = csv.DictWriter(f, fieldnames=fieldnames)
                    if not exists:
                        writer.writeheader()
                    writer.writerows(rows)
                print(f"📝 Metrics appended to CSV: {args.metrics_csv}")
            except FileNotFoundError as e:
                raise RuntimeError(
                    f"CSV path not found even after mkdir: {args.metrics_csv} (orig: {e})"
                )
            except PermissionError as e:
                raise RuntimeError(
                    f"No permission to write CSV at: {args.metrics_csv} (orig: {e})"
                )

    if __name__ == "__main__":
        main()
