import os
import json
import argparse
import torch
import numpy as np
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification
from tqdm import tqdm
def mask_top_tokens_random(tokens, top_n, tokenizer, mask_token="[MASK]"):
    n = min(top_n, len(tokens))
    idx = np.random.choice(len(tokens), size=n, replace=False)
    masked_tokens = tokens[:]
    for i in idx:
        masked_tokens[i] = mask_token
    return tokenizer.convert_tokens_to_string(masked_tokens)
def mask_top_tokens_margin(model, tokens, input_ids, top_n, tokenizer, mask_token="[MASK]"):
    with torch.no_grad():
        logits = model(input_ids.unsqueeze(0)).logits[0]

    margins = []
    for t in range(1, len(input_ids)):
        sorted_logits = torch.sort(logits[t - 1], descending=True).values
        margin = sorted_logits[0] - sorted_logits[1]
        margins.append(margin.item())

    n = min(top_n, len(tokens))
    idx = np.argsort(margins)[::-1][:n]
    masked_tokens = tokens[:]
    for i in idx:
        masked_tokens[i] = mask_token
    return tokenizer.convert_tokens_to_string(masked_tokens)
def compute_logp_and_value(model, tokenizer, text, beta=1.0):
    inputs = tokenizer(text, return_tensors="pt").to(model.device)
    input_ids = inputs.input_ids[0]
    with torch.no_grad():
        logits = model(**inputs).logits[0]

    logps = [logits[t - 1, input_ids[t]].item() for t in range(1, len(input_ids))]
    values = [beta * torch.logsumexp(logits[t] / beta, dim=-1).item() for t in range(len(input_ids))]
    delta_vs = [0.1 *values[t] - values[t - 1] for t in range(1, len(values))]

    tokens = tokenizer.convert_ids_to_tokens(input_ids[1:])
    return tokens, logps, delta_vs

def mask_top_tokens(tokens, values, top_n, tokenizer, mask_token="[MASK]"):
    idx = np.argsort(values)[::-1][:top_n]
    masked_tokens = tokens[:]
    for i in idx:
        masked_tokens[i] = mask_token
    return tokenizer.convert_tokens_to_string(masked_tokens)

def score_with_rm(rm_model, rm_tokenizer, prompt):
    inputs = rm_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to(rm_model.device)
    with torch.no_grad():
        outputs = rm_model(**inputs)
    return outputs.logits.squeeze().item()

def format_prompt(instruction, input_text, answer):
    return f"<s>[INST] Instruction: {instruction}\nInput: {input_text}\nAnswer: {answer} [/INST]"

def process_sample(entry, causal_model, causal_tokenizer, rm_model, rm_tokenizer, beta, top_n):
    results = []

    instruction = "Evaluate the following answer."
    input_text = ""
    for kind in ["chosen"]:
        dialog = entry.get(kind, [])
        question = next((turn["content"] for turn in dialog if turn["role"] == "user"), "")
        answer = next((turn["content"] for turn in dialog if turn["role"] == "assistant"), "")
        if not question or not answer or len(answer) > 700:
            continue

        prompt = question.strip() + " A: " + answer.strip()
        tokens, logps, delta_vs = compute_logp_and_value(causal_model, causal_tokenizer, prompt, beta=beta)
        answer_start = len(causal_tokenizer(question + " A:", return_tensors="pt").input_ids[0]) - 1

        answer_tokens = tokens[answer_start:]
        answer_logps = logps[answer_start:]
        answer_dvs = delta_vs[answer_start:]

        # 원래 RM 점수
        full_prompt = format_prompt(instruction, input_text, answer)
        original_score = score_with_rm(rm_model, rm_tokenizer, full_prompt)

        # 마스킹 후 점수
        for criterion, value_list in [("logp", answer_logps), ("delta_v", answer_dvs)]:
            masked_answer = mask_top_tokens(answer_tokens, value_list, top_n, causal_tokenizer)
            masked_prompt = format_prompt(instruction, input_text, masked_answer)
            masked_score = score_with_rm(rm_model, rm_tokenizer, masked_prompt)
            drop = original_score - masked_score

            results.append({
                #"entry_id": entry_id,
                "question": question,
                "kind": kind,
                "criterion": criterion,
                "original_score": original_score,
                "masked_score": masked_score,
                "drop": drop,
                "original_answer": answer,
                "masked_answer": masked_answer
            })
            # 랜덤 마스킹
            random_masked_answer = mask_top_tokens_random(answer_tokens, top_n, causal_tokenizer)
            random_masked_prompt = format_prompt(instruction, input_text, random_masked_answer)
            random_masked_score = score_with_rm(rm_model, rm_tokenizer, random_masked_prompt)

            results.append({
                "question": question,
                "kind": kind,
                "criterion": "random",
                "original_score": original_score,
                "masked_score": random_masked_score,
                "drop": original_score - random_masked_score,
                "original_answer": answer,
                "masked_answer": random_masked_answer
            })

            # margin 마스킹
            answer_input_ids = causal_tokenizer(prompt, return_tensors="pt").input_ids[0].to(causal_model.device)
            margin_masked_answer = mask_top_tokens_margin(causal_model, answer_tokens, answer_input_ids[answer_start:], top_n, causal_tokenizer)
            margin_masked_prompt = format_prompt(instruction, input_text, margin_masked_answer)
            margin_masked_score = score_with_rm(rm_model, rm_tokenizer, margin_masked_prompt)

            results.append({
                "question": question,
                "kind": kind,
                "criterion": "margin",
                "original_score": original_score,
                "masked_score": margin_masked_score,
                "drop": original_score - margin_masked_score,
                "original_answer": answer,
                "masked_answer": margin_masked_answer
            })
    return results

def main(args):
    # 모델 로딩
    causal_tokenizer = AutoTokenizer.from_pretrained(args.model_path)
    causal_model = AutoModelForCausalLM.from_pretrained(
        args.model_path,
        device_map="auto",
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    ).eval()

    rm_tokenizer = AutoTokenizer.from_pretrained(args.rm_path, use_fast=False)
    rm_model = AutoModelForSequenceClassification.from_pretrained(
        args.rm_path,
        device_map="auto",
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    ).eval()

    dataset = load_dataset("argilla/dpo-mix-7k", split="test")
    os.makedirs(args.output_dir, exist_ok=True)

    all_results = []
    for i in tqdm(range(min(args.max_samples, len(dataset)))):
        entry = dataset[i]
        results = process_sample(entry, causal_model, causal_tokenizer, rm_model, rm_tokenizer, args.beta, args.top_n)
        all_results.extend(results)

    output_file = os.path.join(args.output_dir, "masked_rm_results.jsonl")
    with open(output_file, "w", encoding="utf-8") as f:
        for r in all_results:
            f.write(json.dumps(r, ensure_ascii=False) + "\n")
    print(f"Saved results to {output_file}")

    # ✅ 평균 drop 출력
    logp_drops = [r["drop"] for r in all_results if r["criterion"] == "logp"]
    dv_drops = [r["drop"] for r in all_results if r["criterion"] == "delta_v"]

    def summarize(name, drops):
        if drops:
            drops_np = np.array(drops)
            print(f"\n📉 {name} masking results (top {args.top_n} tokens):")
            print(f"- Mean drop: {drops_np.mean():.4f}")
            print(f"- Std  drop: {drops_np.std():.4f}")
            print(f"- Max  drop: {drops_np.max():.4f}")
            print(f"- Min  drop: {drops_np.min():.4f}")
        else:
            print(f"⚠️ No drops recorded for {name}.")

    summarize("LogP", logp_drops)
    summarize("ΔV", dv_drops)
    random_drops = [r["drop"] for r in all_results if r["criterion"] == "random"]
    margin_drops = [r["drop"] for r in all_results if r["criterion"] == "margin"]

    summarize("Random", random_drops)
    summarize("Margin", margin_drops)
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, required=True)
    parser.add_argument("--rm_path", type=str, default="sfairXC/FsfairX-LLaMA3-RM-v0.1")
    parser.add_argument("--output_dir", type=str, default="masked_results")
    parser.add_argument("--beta", type=float, default=1)
    parser.add_argument("--top_n", type=int, default=5)
    parser.add_argument("--max_samples", type=int, default=100)
    args = parser.parse_args()
    print("beta:", args.beta)
    print("top_n:", args.top_n)
    print("max_samples:", args.max_samples)
    main(args)