import os
import re
import json
import math
import argparse
import random
import hashlib
import time
from dataclasses import dataclass, asdict
from datasets import load_dataset
from typing import List, Dict, Tuple
from datetime import datetime
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import glob

import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerFast, LogitsProcessor, LogitsProcessorList

from tqdm import tqdm
from typing import Dict, List, Tuple, Optional

from peft import LoraConfig, get_peft_model

def get_eos_list(tok):
    ids = set()
    if getattr(tok, "eos_token_id", None) is not None:
        ids.add(tok.eos_token_id)
    
    target_tokens = ["<|eot_id|>", "<|im_end|>", "<|endoftext|>"]
    
    for sym in target_tokens:
        try:
            tid = tok.convert_tokens_to_ids(sym)
            if isinstance(tid, int) and tid != getattr(tok, "unk_token_id", -999): 
                ids.add(tid)
        except Exception:
            pass
            
        try:
            tids = tok.encode(sym, add_special_tokens=False)
            if len(tids) == 1:
                ids.add(tids[0])
        except Exception:
            pass

    if "Qwen" in getattr(tok, "name_or_path", "") or "Qwen" in type(tok).__name__:
        ids.add(151645) 
        ids.add(151643)

    if hasattr(tok, "special_tokens_map"):
        eot = tok.special_tokens_map.get("eot_id", None)
        if isinstance(eot, int): ids.add(eot)

    final_ids = [i for i in list(ids) if i is not None]
    return final_ids if final_ids else [tok.eos_token_id]

def resolve_stop_ids(tok) -> List[int]:
    """
    - start from get_eos_list()
    - additionally parse tokenizer.chat_template to discover terminators
    - include tokenizer.eos_token_id if present
    """
    ids = set(int(x) for x in (get_eos_list(tok) or []) if x is not None)

    # Parse chat_template for tokens like <|im_end|>, <|eot_id|>, <|endoftext|>, etc.
    tmpl = getattr(tok, "chat_template", None)
    if isinstance(tmpl, str) and len(tmpl) > 0:
        for sym in set(re.findall(r"<\|[^|>]+\|>", tmpl)):
            low = sym.lower()
            if ("im_end" in low) or ("eot" in low) or ("endoftext" in low) or ("eos" in low) or (low == "<|end|>"):
                try:
                    tid = tok.convert_tokens_to_ids(sym)
                    if isinstance(tid, int) and tid != getattr(tok, "unk_token_id", None):
                        ids.add(int(tid))
                except Exception:
                    pass

    if getattr(tok, "eos_token_id", None) is not None:
        ids.add(int(tok.eos_token_id))

    # Stable ordering
    return sorted(ids)

def load_olympiadbench_split(
    lang: str = "en",             
    text_only: bool = True,        
    subject: str = "math",    
    test_size: float = 0.1,        
    seed: int = 42,                
):
    split_name = f"test_{'en' if lang.lower().startswith('en') else 'cn'}"
    ds = load_dataset("lmms-lab/OlympiadBench", split=split_name)

    if text_only:
        ds = ds.filter(lambda ex: (ex.get("images") is None) or (isinstance(ex["images"], list) and len(ex["images"]) == 0))

    def _has_answer(ex):
        ans = ex.get("final_answer")
        if ans is None:
            return False
        if isinstance(ans, str) and not ans.strip():
            return False
        return True
        
    ds = ds.filter(_has_answer)
    # ============================================

    key = "math" if subject.lower().startswith("math") else "physics"
    def _is_subject(ex):
        s = (ex.get("subject") or ex.get("source") or ex.get("category") or "")
        return key in str(s).lower()
    ds = ds.filter(_is_subject)

    stratify_col = None
    for cand in ["subject", "source", "category"]:
        if cand in ds.column_names:
            stratify_col = cand
            break

    try:
        if stratify_col is not None:
            splitted = ds.train_test_split(test_size=test_size, seed=seed, stratify_by_column=stratify_col)
        else:
            splitted = ds.train_test_split(test_size=test_size, seed=seed)
    except Exception:
        splitted = ds.train_test_split(test_size=test_size, seed=seed)

    train_dataset = splitted["train"]
    test_dataset  = splitted["test"]
    
    print(f"Dataset loaded. Train: {len(train_dataset)}, Test: {len(test_dataset)} (Filtered empty answers)")
    
    return train_dataset, test_dataset

# -----------------------------
# Qwen Prompt 
# -----------------------------
class ChatFormatter:
    def __init__(self, tokenizer_name: str):
        self.tokenizer_name = tokenizer_name
        if "Qwen" in tokenizer_name:
            try:
                self.tokenizer = PreTrainedTokenizerFast.from_pretrained(
                    tokenizer_name, 
                    trust_remote_code=True, 
                    use_fast=True
                )
            except Exception:
                self.tokenizer = AutoTokenizer.from_pretrained(
                    tokenizer_name, 
                    trust_remote_code=True, 
                    use_fast=True
                )
        else:
            self.tokenizer = AutoTokenizer.from_pretrained(
                tokenizer_name, trust_remote_code=True, use_fast=True
            )
        
    def build_prompt(self, question: str, data_name: str) -> str:
        if "mmlu_pro" in data_name:
            system_content = "You are a helpful assistant. You are solving high-level academic problems."
            user_content = (
                f"{question}\n"
                "Please reason step by step to find the correct answer. "
                "At the end of your response, output the final answer in the format: 'The answer is ({answer})'."
            )
        elif "humaneval" in data_name:
            system_content = "You write Python code. Output only Python code. No markdown, no explanation."
            user_content = question
        else:
            system_content = "You are a helpful assistant. Please answer the math problem step by step, and put the final answer in the format \\boxed{answer}."
            user_content = question
        if "gemma" in self.tokenizer_name:
            messages = [
                {"role": "user", "content": system_content + user_content}
            ]
            return self.tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )

        messages = [
            {"role": "system", "content": system_content},
            {"role": "user", "content": user_content}
        ]
        return self.tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )

    def build_inputs(self, question: str, data_name: str):
        if "mmlu_pro" in data_name:
            system_content = "You are a helpful assistant. You are solving high-level academic problems."
            user_content = (
                f"{question}\n"
                "Please reason step by step to find the correct answer. "
                "At the end of your response, output the final answer in the format: 'The answer is ({answer})'."
            )
        elif "humaneval" in data_name:
            system_content = "You write Python code. Output only Python code. No markdown, no explanation."
            user_content = question
        else:
            system_content = "You are a helpful assistant. Please answer the math problem step by step, and put the final answer in the format \\boxed{answer}."
            user_content = question
        
        if "gemma" in self.tokenizer_name:
            messages = [
                {"role": "user", "content": system_content + user_content}
            ]
            prompt = self.tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
            return self.tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
        messages = [
            {"role": "system", "content": system_content},
            {"role": "user", "content": user_content}
        ]
        prompt = self.tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        return self.tokenizer(prompt, return_tensors="pt", add_special_tokens=False)

class ACG_Router(nn.Module):
    def __init__(self, hidden_dim: int, embedding_dim: int, inner_dim: int = 256, num_heads: int = 2):
        super().__init__()
        
        self.criticality_net = nn.Sequential(
            nn.Linear(hidden_dim, inner_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(inner_dim, 1),
        )
        
    def predict_criticality(self, hidden_states: torch.Tensor) -> torch.Tensor:
        return self.criticality_net(hidden_states).squeeze(-1)

class TriggerSampleDataset(Dataset):
    def __init__(self, data_dir: str):
        pattern = os.path.join(data_dir, "trigger_samples_*.jsonl")
        paths = sorted(glob.glob(pattern))

        all_samples = []
        for p in paths:
            try:
                with open(p, "r", encoding="utf-8") as f:
                    for line in f:
                        line = line.strip()
                        if not line:
                            continue
                        try:
                            rec = json.loads(line)
                        except Exception:
                            continue
                        if ("hidden" not in rec) or ("y_trig" not in rec):
                            continue
                        all_samples.append(rec)
            except FileNotFoundError:
                continue

        pos_samples = [s for s in all_samples if int(s.get("y_trig", 0)) == 1]
        neg_samples = [s for s in all_samples if int(s.get("y_trig", 0)) == 0]

        max_neg_pos_ratio = 4  
        if len(pos_samples) > 0 and len(neg_samples) > 0:
            max_neg = min(len(neg_samples), max_neg_pos_ratio * len(pos_samples))
            if max_neg < len(neg_samples):
                neg_samples = random.sample(neg_samples, max_neg)
            self.samples = pos_samples + neg_samples
        else:
            self.samples = all_samples

        random.shuffle(self.samples)
        print(
            f"[Header-Data] Loaded {len(all_samples)} raw trigger samples from {len(paths)} files, "
            f"pos={len(pos_samples)}, neg={len(neg_samples)}, "
            f"used={len(self.samples)} after downsampling."
        )

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        rec = self.samples[idx]
        h = torch.tensor(rec["hidden"], dtype=torch.float32)
        y = torch.tensor(rec["y_trig"], dtype=torch.float32)
        return {"hidden": h, "y_trig": y}

# -----------------------------
# ZIP-Distill 
# -----------------------------
def get_zip_token_ids(tokenizer, max_len, num_bins: int=10, avoid_ids=None, avoid_special: bool = True):
    """
    Select `num_bins` reserved token ids for ZIP, avoiding EOS/EOT and other special tokens.

    IMPORTANT:
    - Old version used the last `num_bins` ids in the vocab, which can *overlap* with Qwen's
      <|im_end|>/<|eot_id|> and break EOS stopping (especially if these ids are masked).
    """
    vocab_size = getattr(tokenizer, "vocab_size", None)
    if vocab_size is None:
        raise ValueError("Tokenizer must have vocab_size attribute.")
    if vocab_size < num_bins:
        raise ValueError(f"vocab_size {vocab_size} < num_bins {num_bins}")

    avoid = set(int(x) for x in (avoid_ids or []) if x is not None)
    if avoid_special:
        try:
            avoid |= set(int(x) for x in getattr(tokenizer, "all_special_ids", []) if x is not None)
        except Exception:
            pass

    picked = []
    for tid in range(max_len - 1, int(vocab_size), -1):
        if tid in avoid:
            continue
        picked.append(int(tid))
        if len(picked) >= int(num_bins):
            break

    if len(picked) < int(num_bins):
        # raise ValueError(f"Failed to pick {num_bins} ZIP ids under avoid set size={len(avoid)}")
        return []

    picked.sort()
    return picked

class ZIPMaskLogitsProcessor(LogitsProcessor):
    def __init__(self, zip_token_ids: List[int]):
        self.zip_token_ids = zip_token_ids

    def __call__(self, input_ids, scores):
        scores[:, self.zip_token_ids] = -float('inf')
        return scores

class ZIPGroupDataset(Dataset):
    def __init__(self, data_glob, num_bins, max_len, expected_k: int = 16):
        """
        Supports TWO input formats under the same `data_glob`:

        (A) group-wise (preferred, already listwise-ready)
            {
              "context_ids": [...],
              "cand_ids": [...],
              "teacher_prob_abs": [...],
              ... optional meta ...
            }

        (B) point-wise (legacy / reusable logs; one line per candidate)
            {
              "context_ids": [...],
              "action_id": <int>,          # or "cand_id"
              "teacher_prob_abs": <float>, # or teacher_prob_raw / teacher_prob / teacher_prob_cond
              "problem_id": <int>, "step": <int>   # or "group_id"
            }

        For (B), we will losslessly re-construct groups by (problem_id, step) / group_id.
        """
        self.files = sorted(glob.glob(data_glob))
        self.num_bins = int(num_bins)
        self.max_len = int(max_len)
        self.expected_k = expected_k  # optional: trim groups to this K to match inference slm_multi_k
        self.groups = []

        def _group_key(r: dict) -> str:
            if "group_id" in r and r["group_id"] is not None:
                return str(r["group_id"])
            pid = r.get("problem_id", r.get("pid", None))
            step = r.get("step", r.get("t", r.get("step_id", None)))
            if pid is not None and step is not None:
                return f"{pid}:{step}"
            # fallback: hash last tokens of context (should be stable within a step)
            ctx = r.get("context_ids", r.get("prefix_ids", r.get("ctx_ids", []))) or []
            tail = ctx[-256:] if isinstance(ctx, list) else []
            h = hashlib.md5((",".join(map(str, tail))).encode("utf-8")).hexdigest()
            return f"ctxhash:{h}"

        def _get_ctx(r: dict) -> list:
            if "context_ids" in r:
                return r["context_ids"]
            if "prefix_ids" in r:
                return r["prefix_ids"]
            # last resort: if logged as input_ids = [ctx, action]
            if "input_ids" in r and isinstance(r["input_ids"], list) and len(r["input_ids"]) >= 2:
                return r["input_ids"][:-1]
            return []

        def _get_action(r: dict):
            for k in ("action_id", "cand_id", "token_id", "action"):
                if k in r and r[k] is not None:
                    return int(r[k])
            return None

        def _get_teacher_abs(r: dict):
            for k in ("teacher_prob_abs", "teacher_abs", "teacher_prob_raw", "teacher_prob", "teacher_prob_cond"):
                if k in r and r[k] is not None:
                    try:
                        return float(r[k]), k
                    except Exception:
                        pass
            if "is_ref" in r:
                return (1.0 if int(r["is_ref"]) == 1 else 0.0), "is_ref"
            return None, None

        def _trim_group_inplace(g: dict):
            # keep arrays aligned
            if "cand_ids" in g and "teacher_prob_abs" in g:
                n = min(len(g["cand_ids"]), len(g["teacher_prob_abs"]))
                g["cand_ids"] = list(g["cand_ids"][:n])
                g["teacher_prob_abs"] = list(g["teacher_prob_abs"][:n])
                if self.expected_k is not None and n > int(self.expected_k):
                    K = int(self.expected_k)
                    g["cand_ids"] = g["cand_ids"][:K]
                    g["teacher_prob_abs"] = g["teacher_prob_abs"][:K]

        for fp in self.files:
            # per-file buffer (keeps memory bounded, and preserves file-local ordering)
            buf = {}  # key -> {"context_ids":..., "cand_ids":[...], "teacher_prob_abs":[...], ...}
            with open(fp, "r", encoding="utf-8") as f:
                for line in f:
                    try:
                        r = json.loads(line)
                    except Exception:
                        continue

                    # (A) already grouped
                    if (
                        "cand_ids" in r and isinstance(r.get("cand_ids"), list)
                        and "teacher_prob_abs" in r and isinstance(r.get("teacher_prob_abs"), list)
                    ):
                        if not r.get("cand_ids"):
                            continue
                        _trim_group_inplace(r)
                        if r.get("cand_ids"):
                            self.groups.append(r)
                        continue

                    # (B) point-wise: reconstruct
                    act = _get_action(r)
                    t_abs, _src = _get_teacher_abs(r)
                    if act is None or t_abs is None:
                        continue

                    key = _group_key(r)
                    ctx = _get_ctx(r)
                    if not ctx:
                        continue

                    g = buf.get(key)
                    if g is None:
                        g = {
                            "group_id": key,
                            "problem_id": r.get("problem_id", r.get("pid", None)),
                            "step": r.get("step", r.get("t", r.get("step_id", None))),
                            "context_ids": ctx,
                            "cand_ids": [],
                            "teacher_prob_abs": [],
                        }
                        buf[key] = g
                    else:
                        # sanity: keep a stable context (prefer the longer one if mismatch)
                        if (
                            isinstance(ctx, list) and isinstance(g.get("context_ids"), list)
                            and len(ctx) > len(g["context_ids"])
                        ):
                            g["context_ids"] = ctx

                    g["cand_ids"].append(int(act))
                    g["teacher_prob_abs"].append(float(t_abs))

            # flush file buffer into groups
            for _, g in buf.items():
                if not g["cand_ids"]:
                    continue
                _trim_group_inplace(g)
                if g.get("cand_ids"):
                    self.groups.append(g)

    def __len__(self):
        return len(self.groups)

    def __getitem__(self, idx):
        r = self.groups[idx]
        ctx = r["context_ids"]
        max_prefix = self.max_len - 1
        ctx = ctx[-max_prefix:] if len(ctx) > max_prefix else ctx
        cand_ids = r["cand_ids"]
        teacher_abs = r["teacher_prob_abs"]
        return {
            "context_ids": ctx,
            "cand_ids": cand_ids,
            "teacher_prob_abs": teacher_abs,
        }

def zip_group_collate_fn(batch, pad_token_id, num_bins):
    # assume K fixed (slm_multi_k), but still handle variable
    B = len(batch)
    K = max(len(x["cand_ids"]) for x in batch)

    flat_input_ids = []
    flat_targets_bin = []
    teacher_abs = torch.zeros(B, K, dtype=torch.float32)
    cand_mask = torch.zeros(B, K, dtype=torch.float32)

    for b, ex in enumerate(batch):
        ctx = ex["context_ids"]
        cands = ex["cand_ids"]
        tabs = ex["teacher_prob_abs"]
        for j in range(K):
            if j < len(cands):
                ids = ctx + [int(cands[j])]
                flat_input_ids.append(torch.tensor(ids, dtype=torch.long))
                teacher_abs[b, j] = float(tabs[j])
                cand_mask[b, j] = 1.0
            else:
                ids = ctx + [pad_token_id]
                flat_input_ids.append(torch.tensor(ids, dtype=torch.long))
        # bin targets computed later from teacher_abs if needed

    # pad flat sequences
    Lmax = max(x.numel() for x in flat_input_ids)
    input_ids = torch.full((B*K, Lmax), pad_token_id, dtype=torch.long)
    attention_mask = torch.zeros((B*K, Lmax), dtype=torch.long)
    for i, x in enumerate(flat_input_ids):
        input_ids[i, :x.numel()] = x
        attention_mask[i, :x.numel()] = 1

    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "teacher_abs": teacher_abs,
        "cand_mask": cand_mask,
        "B": B,
        "K": K,
    }

def add_lora_to_slm(model, r=64, alpha=128, dropout=0.1):
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj"]
    config = LoraConfig(
        r=r,
        lora_alpha=alpha,
        lora_dropout=dropout,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=target_modules,
        modules_to_save=["lm_head"]
    )
    model = get_peft_model(model, config)
    model.print_trainable_parameters()
    return model

