import argparse
import os
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm

torch.set_grad_enabled(True)

def parse_blocks(s):
    full = ["q","k","v","o","up","down","gate"]
    if s.lower() == "all":
        return full
    return [x.strip() for x in s.split(',') if x.strip() in full]

def tag2mod(layer):
    d = {}
    if hasattr(layer, "self_attn"):
        attn = layer.self_attn
        if hasattr(attn, "q_proj"): d["q"] = attn.q_proj
        if hasattr(attn, "k_proj"): d["k"] = attn.k_proj
        if hasattr(attn, "v_proj"): d["v"] = attn.v_proj
        if hasattr(attn, "o_proj"): d["o"] = attn.o_proj
    if hasattr(layer, "mlp"):
        mlp = layer.mlp
        if hasattr(mlp, "up_proj"): d["up"] = mlp.up_proj
        if hasattr(mlp, "down_proj"): d["down"] = mlp.down_proj
        if hasattr(mlp, "gate_proj"): d["gate"] = mlp.gate_proj
    return d

def build_tokens(tokenizer, dataset, nsamples, seqlen, device):
    texts = dataset["train"]["text"]
    text = "\n\n".join(texts)
    toks = tokenizer(text, return_tensors="pt")["input_ids"]
    toks = toks[:, : (nsamples * seqlen)]
    if toks.shape[1] < nsamples * seqlen:
        pad = nsamples * seqlen - toks.shape[1]
        toks = torch.nn.functional.pad(toks, (0, pad), value=tokenizer.eos_token_id or 0)
    toks = toks.view(1, nsamples, seqlen)[0].to(device)
    return toks

def lm_ce_loss(model, input_ids, attn_mask):
    out = model(input_ids=input_ids, attention_mask=attn_mask, use_cache=False)
    shift_logits = out.logits[:, :-1].contiguous()
    shift_labels = input_ids[:, 1:].contiguous()
    loss = torch.nn.functional.cross_entropy(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1))
    return loss

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--model_dir", required=True)
    ap.add_argument("--out_hess_dir", required=True)
    ap.add_argument("--dataset", type=str, default="wikitext2")
    ap.add_argument("--nsamples", type=int, default=64)
    ap.add_argument("--seqlen", type=int, default=512)
    ap.add_argument("--device", type=str, default="cuda")
    ap.add_argument("--blocks", type=str, default="q,k,v,o,up,down,gate")
    ap.add_argument("--batch_size", type=int, default=1)
    ap.add_argument("--emp_fisher", action="store_true")
    args = ap.parse_args()
    os.makedirs(args.out_hess_dir, exist_ok=True)
    blocks = parse_blocks(args.blocks)
    tok = AutoTokenizer.from_pretrained(args.model_dir, use_fast=True)
    model = AutoModelForCausalLM.from_pretrained(args.model_dir, torch_dtype=torch.bfloat16, device_map=args.device)
    model.train()
    model.config.use_cache = False
    if hasattr(model, "gradient_checkpointing_enable"):
        model.gradient_checkpointing_enable()
    for layer in model.model.layers:
        for tag, mod in tag2mod(layer).items():
            mod.weight.requires_grad_(tag in blocks)
    for p in model.parameters():
        if not p.requires_grad:
            p.grad = None
    diags = {}
    for i, layer in enumerate(model.model.layers):
        for tag, mod in tag2mod(layer).items():
            if tag in blocks:
                diags[(i, tag)] = torch.zeros(mod.weight.shape[0], dtype=torch.float32, device="cpu")
    if args.dataset.lower() in ("wikitext2", "wikitext-2"):
        ds = load_dataset("wikitext", "wikitext-2-raw-v1")
    else:
        ds = load_dataset(args.dataset)
    tokens = build_tokens(tok, ds, args.nsamples, args.seqlen, args.device)
    bs = max(1, int(args.batch_size))
    total = 0
    for s in tqdm(range(0, args.nsamples, bs), desc="hess_diag"):
        input_ids = tokens[s:s+bs]
        attn_mask = torch.ones_like(input_ids)
        for p in model.parameters():
            if p.grad is not None:
                p.grad = None
        loss = lm_ce_loss(model, input_ids, attn_mask)
        loss.backward()
        for i, layer in enumerate(model.model.layers):
            for tag, mod in tag2mod(layer).items():
                if tag in blocks and mod.weight.grad is not None:
                    g = mod.weight.grad.detach().to("cpu")
                    if args.emp_fisher:
                        diags[(i, tag)] += (g*g).sum(dim=1)
                    else:
                        diags[(i, tag)] += g.abs().sum(dim=1)
        total += input_ids.shape[0]
        del input_ids, attn_mask, loss
        torch.cuda.empty_cache()
    for (i, tag), d in diags.items():
        d /= float(total)
        torch.save(d.contiguous(), os.path.join(args.out_hess_dir, f"{i}_{tag}.pt"))

if __name__ == "__main__":
    try:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
    except Exception:
        pass
    main()
