import os
import pandas as pd
import numpy as np
from tqdm import tqdm

def load_logp(csv_path):
    df = pd.read_csv(csv_path)
    return df["logp"].tolist()

def process_margin_only(index, student_dir):
    try:
        sp_c = os.path.join(student_dir, f"token_vis_{index:03d}_chosen.csv")
        sp_r = os.path.join(student_dir, f"token_vis_{index:03d}_rejected.csv")

        lp_sc = load_logp(sp_c)
        lp_sr = load_logp(sp_r)

        # 평균 logp 차이 (margin)
        margin = float(np.mean(lp_sc) - np.mean(lp_sr))
        return {"index": index, "margin": margin}
    except Exception as e:
        return None

def analyze_margin_only(student_dir, max_samples=1000, save_path=None):
    results = []
    for i in tqdm(range(max_samples)):
        result = process_margin_only(i, student_dir)
        if result:
            results.append(result)

    df = pd.DataFrame(results)

    if save_path:
        df.to_csv(save_path, index=False)
        print(f"✅ Saved: {save_path}")

    if not df.empty:
        num_samples = len(df)
        margin_accuracy = (df["margin"] > 0).sum() / num_samples

        print("\n📊 Margin-Only Statistics")
        print(f"→ # Samples       : {num_samples}")
        print(f"→ Mean Margin     : {df['margin'].mean():.4f}")
        print(f"→ Margin Accuracy : {margin_accuracy:.4f}")

    return df

if __name__ == "__main__":
    base_dir = "token_visualizations"
    #student_names = ["TPKD_sum","TPKD_margin", "TPKD_max", "TPKD_expectation", "TPKD_sum", "dpo_teacher", "dpkd"]
    student_names = ["dckd_exp"]

    for student in student_names:
        student_dir = os.path.join(base_dir, student)
        output_csv = os.path.join(base_dir, f"{student}_margin_only.csv")
        print(f"\n🔍 Evaluating Margin Only: {student}")
        analyze_margin_only(student_dir, max_samples=1000, save_path=output_csv)