import os
import pandas as pd
import numpy as np
from tqdm import tqdm

def load_logp(csv_path):
    df = pd.read_csv(csv_path)
    return df["logp"].tolist()

def softmax(logits):
    logits = np.array(logits)
    logits = logits - np.max(logits)
    exps = np.exp(logits)
    return exps / (np.sum(exps) + 1e-8)

def kl_divergence(p, q):
    p = np.array(p) + 1e-8
    q = np.array(q) + 1e-8
    return np.sum(p * (np.log(p) - np.log(q)))

def process_sample(index, student_dir, teacher_dir):
    try:
        # 경로 지정
        sp_c = os.path.join(student_dir, f"token_vis_{index:03d}_chosen.csv")
        sp_r = os.path.join(student_dir, f"token_vis_{index:03d}_rejected.csv")
        tp_c = os.path.join(teacher_dir, f"token_vis_{index:03d}_chosen.csv")
        tp_r = os.path.join(teacher_dir, f"token_vis_{index:03d}_rejected.csv")

        # 데이터 로드
        lp_sc = load_logp(sp_c)
        lp_sr = load_logp(sp_r)
        lp_tc = load_logp(tp_c)
        lp_tr = load_logp(tp_r)

        # margin
        #min_len = min(len(lp_sc), len(lp_sr))
        #margin = float(np.sum(np.array(lp_sc[:min_len]) - np.array(lp_sr[:min_len])))
        margin = float(np.sum(lp_sc)/len(lp_sc) - np.sum(lp_sr)/len(lp_sr))
        #margin = float(np.sum(lp_sc) - np.sum(lp_sr))

        # KL divergence
        kl_chosen = kl_divergence(softmax(lp_tc[:len(lp_sc)]), softmax(lp_sc))
        kl_rejected = kl_divergence(softmax(lp_tr[:len(lp_sr)]), softmax(lp_sr))

        return {
            "index": index,
            "margin": margin,
            "kl_chosen": kl_chosen,
            "kl_rejected": kl_rejected
        }

    except Exception as e:
        #rint(f"[{index:03d}] Error: {e}")
        return None
def analyze_student_vs_teacher(student_dir, teacher_dir, max_samples=1000, save_path=None):
    results = []
    for i in tqdm(range(max_samples)):
        result = process_sample(i, student_dir, teacher_dir)
        if result:
            results.append(result)

    df = pd.DataFrame(results)

    # Save CSV
    if save_path:
        df.to_csv(save_path, index=False)
        print(f"✅ Saved: {save_path}")

    # Print overall statistics
    if not df.empty:
        num_samples = len(df)
        margin_accuracy = (df["margin"] > 0).sum() / num_samples

        print("\n📊 Overall Statistics")
        print(f"→ # Samples           : {num_samples}")
        print(f"→ Mean Margin         : {df['margin'].mean():.4f}")
        print(f"→ Margin Accuracy     : {margin_accuracy:.4f}")
        print(f"→ Mean KL (chosen)    : {df['kl_chosen'].mean():.4f}")
        print(f"→ Mean KL (rejected)  : {df['kl_rejected'].mean():.4f}")

    return df
if __name__ == "__main__":
    base_dir = "token_visualizations"
    teacher_dir = os.path.join(base_dir, "dpo_teacher")
    #student_names = ["DPO", "TPKD", "dckd","teacher","dpkd","TPKD_entropy","dpo_with_distil"]
    #student_names=["teacher_dpo_train","TPKD_train","dckd_train","dpo_with_distil_train","dpkd_train","TPKD_entropy_train","DPO_train"]
    #student_names = ["tpkd_alpha_7e"]
    student_names = ["TPKD_margin","TPKD_max","TPKD_expectation","TPKD_sum","dpo_teacher","dpkd"]

    for student in student_names:
        student_dir = os.path.join(base_dir, student)
        #student_dir = student
        output_csv = os.path.join(base_dir, f"{student}_vs_teacher_results.csv")
        print(f"\n🔍 Processing: {student}")
        analyze_student_vs_teacher(student_dir, teacher_dir, max_samples=1000, save_path=output_csv)