import os, json, math, random, argparse, torch
from dataclasses import dataclass
from typing import List, Dict
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from datasets import Dataset
import torch.nn.functional as F

ABCD_INSTRUCTION = "Answer the following question with A, B, C, or D.\n\n"
LETTER_TO_IDX = {"A": 0, "B": 1, "C": 2, "D": 3}
@dataclass
class QAItem:
    question: str
    options: List[str]
    answer: str

def load_qa_arrow(path: str):
    ds = Dataset.from_file(path)  # 直接加载 .arrow 文件
    items = []
    for j in ds:
        # 将j[answer]从数字转为字母
        if isinstance(j["answer"], int):
            answer_letter = chr(65 + j["answer"])
            # print(f"Converted numeric answer {j['answer']} to letter {answer_letter}")
        items.append(QAItem(j["question"], j["choices"], answer_letter))
    return items

def format_mc_prompt(tok, item: QAItem):
    opt_text = "\n".join([f"{chr(65+i)}. {o}" for i, o in enumerate(item.options)])
    text = ABCD_INSTRUCTION + item.question.strip() + "\n" + opt_text + "\n\nAnswer:"
    enc = tok(text, add_special_tokens=False)
    return enc["input_ids"]

def pick_label_id(tok, letter: str):
    # 取 "A"/"B"/"C"/"D" 对应 token（一般 BPE 会有多粒度；用单字符）
    return tok(letter, add_special_tokens=False)["input_ids"][0]

def collate_ids(batch_ids: List[List[int]], pad_id: int):
    ids = [torch.tensor(x, dtype=torch.long) for x in batch_ids]
    attn = [torch.ones(len(x), dtype=torch.long) for x in batch_ids]
    ids = pad_sequence(ids, batch_first=True, padding_value=pad_id)
    attn = pad_sequence(attn, batch_first=True, padding_value=0)
    return ids, attn

class SoftPrefix(torch.nn.Module):
    def __init__(self, hidden_size: int, length: int):
        super().__init__()
        self.prefix = torch.nn.Parameter(torch.randn(length, hidden_size) * 0.02)

    def forward(self):
        return self.prefix  # [P, H]

def inject_soft_prefix(model, tok, soft_prefix, input_ids, attention_mask):
    # 将 soft-prefix 作为 inputs_embeds 的前缀拼接
    with torch.no_grad():
        embeds = model.get_input_embeddings()(input_ids)  # [B, L, H]
    print(embeds.device, soft_prefix().device)
    prefix = soft_prefix()[None, :, :].expand(embeds.size(0), -1, -1)  # [B, P, H]
    inputs_embeds = torch.cat([prefix, embeds], dim=1)
    p = prefix.size(1)
    attn = torch.cat([torch.ones((attention_mask.size(0), p), dtype=attention_mask.dtype, device=attention_mask.device),
                      attention_mask], dim=1)
    return inputs_embeds, attn

@torch.no_grad()
def generate_teacher_continuation(teacher, tok, inputs_embeds, attention_mask, max_new_tokens=16):
    with torch.no_grad():
        teacher = teacher.to(inputs_embeds.device)
        cfg = GenerationConfig(max_new_tokens=max_new_tokens, do_sample=False)
        out = teacher.generate(inputs_embeds=inputs_embeds, attention_mask=attention_mask, generation_config=cfg)
    teacher = teacher.to("cpu")
    return out[:, attention_mask.shape[1]:]  # 仅新生成部分

def hidden_states_for_positions(model, inputs_embeds, attention_mask, take_last_k=16, layer_ids=None):
    out = model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, output_hidden_states=True)
    hs = out.hidden_states  # Tuple[layer](B, T, H)
    if layer_ids is None:  # 默认取所有 Transformer block（除 embeddings）
        layer_ids = list(range(1, len(hs)))
    # 取末尾 take_last_k 个位置
    T = inputs_embeds.size(1)
    sl = slice(max(0, T - take_last_k), T)
    # [B, L_sel, take_last_k, H]
    sel = torch.stack([hs[i][:, sl, :] for i in layer_ids], dim=1)
    return sel  # (B, n_layers, K, H)

def get_label_token_ids(tok, labels=(" A"," B"," C"," D")):
    """尽量取到“单 token”的 A/B/C/D（带前导空格）。
    若不是单 token，就回退到取最后一个 subtoken 的 id。"""
    ids = []
    for s in labels:
        enc = tok.encode(s, add_special_tokens=False)
        if len(enc) == 0:
            raise ValueError(f"Tokenizer returns empty ids for {repr(s)}")
        # 优先使用单 token；否则回退到最后一个 subtoken
        ids.append(enc[0] if len(enc) == 1 else enc[-1])
    return ids  # [4]

def cross_entropy_on_answer_logits(model, tok, inputs_embeds, attention_mask, answer_ids):
    out = model(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
    logits = out.logits  # [B, T, V]
    B, T, V = logits.shape
    device = logits.device

    # 1) 计算每个样本最后一个有效位置的 index（预测下一个 token 的分布就在这个位置）
    last_pos = attention_mask.to(dtype=torch.long, device=device).sum(dim=1) - 1  # [B]
    # 2) 用“批索引 + gather”稳健取出每个样本对应的 [V]
    batch_idx = torch.arange(B, device=device)
    next_logits = logits[batch_idx, last_pos, :]  # [B, V]

    # 3) 取出 A/B/C/D 的 token id（优先 " A" 带空格）
    abcd_ids = get_label_token_ids(tok, labels=(" A"," B"," C"," D"))
    abcd_idx = torch.tensor(abcd_ids, dtype=torch.long, device=device)  # [4]

    # 4) 只保留四个选项的 logits（用 index_select 或 gather 避免高级索引陷阱）
    sel = next_logits.index_select(dim=1, index=abcd_idx)  # [B, 4]

    # 5) 构造标签（把 "A"/"B"/"C"/"D" → 0/1/2/3）
    target = torch.tensor([LETTER_TO_IDX[a] for a in answer_ids], dtype=torch.long, device=device)  # [B]

    # 6) 交叉熵
    return F.cross_entropy(sel, target)

def perplexity_regularizer(model, tok, inputs_embeds, attention_mask, prefix_len):
    # 语言流畅度正则：计算除软前缀外部分的困惑度
    # 对软前缀后的部分重新编码获得token ids，然后计算语言模型损失
    
    # 获取除软前缀外的embedding部分
    content_embeds = inputs_embeds[:, prefix_len:, :]  # [B, L-P, H]
    content_mask = attention_mask[:, prefix_len:]      # [B, L-P]
    
    # 将embedding转回最近邻token ids（用于计算labels）
    with torch.no_grad():
        vocab_embeds = model.get_input_embeddings().weight  # [V, H]
        # 对每个位置找最近的词汇表token
        B, L, H = content_embeds.shape
        flat_embeds = content_embeds.reshape(-1, H)  # [B*L, H]
        # 计算距离并找最近邻
        distances = torch.cdist(flat_embeds, vocab_embeds)  # [B*L, V]
        nearest_ids = distances.argmin(dim=-1).reshape(B, L)  # [B, L]
    
    # 构造labels：向右shift一位
    if L > 1:
        input_ids = nearest_ids[:, :-1]  # [B, L-1]
        label_ids = nearest_ids[:, 1:]   # [B, L-1]
        input_mask = content_mask[:, :-1]  # [B, L-1]
        
        # 获取对应的embeddings
        input_embeds = model.get_input_embeddings()(input_ids)  # [B, L-1, H]
        
        # 前向传播计算logits
        outputs = model(inputs_embeds=input_embeds, attention_mask=input_mask)
        logits = outputs.logits  # [B, L-1, V]
        
        # 计算交叉熵损失（只在有效位置）
        shift_logits = logits.view(-1, logits.size(-1))  # [B*(L-1), V]
        shift_labels = label_ids.view(-1)  # [B*(L-1)]
        
        # 创建损失掩码，忽略padding位置
        loss_mask = input_mask.view(-1)  # [B*(L-1)]
        
        # 计算损失
        loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
        losses = loss_fct(shift_logits, shift_labels)  # [B*(L-1)]
        
        # 只计算有效位置的平均损失
        masked_losses = losses * loss_mask.float()
        if loss_mask.sum() > 0:
            ppl_loss = masked_losses.sum() / loss_mask.sum()
        else:
            ppl_loss = torch.tensor(0.0, device=inputs_embeds.device)
    else:
        ppl_loss = torch.tensor(0.0, device=inputs_embeds.device)
    
    return ppl_loss

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--model", required=True)  # 修复：改为model
    # ap.add_argument("--model", default="")
    ap.add_argument("--teacher", required=True)
    ap.add_argument("--qa_path", required=True)  # 仅无害数据
    ap.add_argument("--bsz", type=int, required=True)
    ap.add_argument("--prefix_len", type=int, required=True)
    ap.add_argument("--steps", type=int, required=True)
    ap.add_argument("--lr", type=float, required=True)
    ap.add_argument("--layers", type=str, required=True)  # 例如 "4,8,12,20"；空则自动选择
    ap.add_argument("--match_tokens", type=int, required=True)
    ap.add_argument("--alpha_rep", type=float, required=True)
    ap.add_argument("--beta_ppl", type=float, required=True)
    ap.add_argument("--gamma_ce", type=float, required=True)
    ap.add_argument("--save_dir", required=True)
    args = ap.parse_args()

    device = "cuda" if torch.cuda.is_available() else "cpu"
    tok = AutoTokenizer.from_pretrained(args.model)
    tok.padding_side = "left"
    tok.truncation_side = "left"
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token

    # 修复：使用正确的参数名
    model = AutoModelForCausalLM.from_pretrained(args.model)
    teacher = AutoModelForCausalLM.from_pretrained(args.teacher)
    for p in model.parameters(): p.requires_grad = False
    for p in teacher.parameters(): p.requires_grad = False
    model.eval(); teacher.eval()

    hidden_size = model.config.hidden_size
    soft_prefix = SoftPrefix(hidden_size, args.prefix_len)
    opt = torch.optim.AdamW(soft_prefix.parameters(), lr=args.lr)

    # 层选择
    layer_ids = [int(x) for x in args.layers.split(",") if x.strip().isdigit()] if args.layers.strip() else None

    # 载入 QA，只取前10个用于训练
    data = load_qa_arrow(args.qa_path)
    random.shuffle(data)
    # data = data[:10]  # 限制为10个样本

    def batch_iter(lst, bsz):
        for i in range(0, len(lst), bsz):
            yield lst[i:i+bsz]

    for step, batch in enumerate(batch_iter(data, args.bsz)):
        if step >= args.steps: break

        # 构造 batch 的 prompt + label
        prompts, answers = [], []
        for it in batch:
            prompts.append(format_mc_prompt(tok, it))
            answers.append(it.answer)

        input_ids, attn = collate_ids(prompts, tok.pad_token_id)
        # input_ids = input_ids.to(device); attn = attn.to(device)

        # 拼接软前缀
        inputs_embeds, attn2 = inject_soft_prefix(model, tok, soft_prefix, input_ids, attn)

        # 由"教师模型"生成 tmatch，用于表示对齐（移动目标）
        with torch.no_grad():
            tmatch = generate_teacher_continuation(teacher, tok, inputs_embeds, attn2, max_new_tokens=args.match_tokens)

        # 将tmatch的token转换为embeddings并拼接到原输入后面
        with torch.no_grad():
            tmatch_embeds = teacher.get_input_embeddings()(tmatch)  # [B, match_tokens, H]
            
        # 创建扩展的inputs_embeds和attention_mask，包含tmatch部分
        extended_inputs_embeds = torch.cat([inputs_embeds, tmatch_embeds], dim=1)  # [B, T+match_tokens, H]
        extended_attn = torch.cat([attn2, torch.ones(tmatch.size(), dtype=attn2.dtype, device=attn2.device)], dim=1)
        
        # 在扩展后的序列上计算两模型的hidden states进行对齐
        Ha = hidden_states_for_positions(model, extended_inputs_embeds, extended_attn, take_last_k=args.match_tokens, layer_ids=layer_ids)
        Hm = hidden_states_for_positions(teacher, extended_inputs_embeds, extended_attn, take_last_k=args.match_tokens, layer_ids=layer_ids)

        # 表示对齐损失（L2）
        L_rep = torch.nn.functional.mse_loss(Ha, Hm)

        # 语言流畅度正则（计算除软前缀外内容的困惑度）
        L_ppl = perplexity_regularizer(model, tok, inputs_embeds, attn2, args.prefix_len)

        # CE：鼓励在 "Answer:" 下一 token 处给出正确选项
        L_ce = cross_entropy_on_answer_logits(model, tok, inputs_embeds, attn2, answers)

        loss = args.alpha_rep * L_rep + args.beta_ppl * L_ppl + args.gamma_ce * L_ce

        opt.zero_grad(set_to_none=True)
        loss.backward()
        opt.step()

        if (step+1) % 10 == 0:
            print(f"[{step+1}] loss={loss.item():.4f}  rep={L_rep.item():.4f}  ce={L_ce.item():.4f}")

    os.makedirs(args.save_dir, exist_ok=True)
    torch.save(soft_prefix.state_dict(), os.path.join(args.save_dir, "soft_prefix.pt"))
    with torch.no_grad():
        # 将软前缀近似映射为离散 tokens（最近邻到词表嵌入），仅用于可读性导出
        emb = model.get_input_embeddings().weight  # [V, H]
        P = soft_prefix.prefix.data  # [L, H]
        nn_ids = torch.cdist(P.float().cpu(), emb.float().cpu()).argmin(dim=1).tolist()
        text = tok.decode(nn_ids)
    with open(os.path.join(args.save_dir, "prefix.txt"), "w", encoding="utf-8") as f:
        f.write(text)
    print(f"✅ Done. Saved to {args.save_dir}\nNearest-neighbor discrete prefix preview:\n{text}")
    # ================== Final Prediction and Save ==================
    print("🔍 Running final prediction on training examples...")
    pred_save_path = os.path.join(args.save_dir, "final_predictions.txt")

    model = model.to(device)
    soft_prefix = soft_prefix.to(device)
    model.eval()

    with open(pred_save_path, "w", encoding="utf-8") as f:
        for i, item in enumerate(data):
            prompt_ids = format_mc_prompt(tok, item)
            input_ids, attn = collate_ids([prompt_ids], tok.pad_token_id)
            input_ids = input_ids.to(device)
            attn = attn.to(device)

            # 拼接 soft prefix
            inputs_embeds, attn2 = inject_soft_prefix(model, tok, soft_prefix, input_ids, attn)

            # 使用模型生成答案（取 Answer: 后一 token）
            with torch.no_grad():
                cfg = GenerationConfig(max_new_tokens=1, do_sample=False)
                output = model.generate(inputs_embeds=inputs_embeds, attention_mask=attn2, generation_config=cfg)
                pred_token = output[:, attn2.shape[1]:]
                pred_text = tok.decode(pred_token[0], skip_special_tokens=True).strip()

            # 保存结果
            f.write(f"Example {i+1}\n")
            f.write("Normal:\n")
            f.write(f"- Question: {item.question.strip()}\n")
            f.write(f"- Options: {item.options}\n")
            f.write(f"- Ground Truth: {item.answer}\n")
            f.write(f"- Predicted: {pred_text}\n\n")

    print(f"📝 Final predictions saved to {pred_save_path}")

if __name__ == "__main__":
    main()