from utils import load_single_dataset, save_dataset
import datasets
import torch
from tqdm import tqdm
from statistics import mean
import argparse
import numpy as np
from sklearn.metrics import roc_curve


from typing import List, Tuple
from sklearn.metrics import roc_curve, accuracy_score, f1_score
from scipy.stats import pearsonr
import numpy as np

def evaluate_scores(positive_scores: List[float], negative_scores: List[float]) -> Tuple[float, float, float]:
    # 1. 构造数据和标签
    scores = np.array(positive_scores + negative_scores)
    labels = np.array([1] * len(positive_scores) + [0] * len(negative_scores))
    
    # 2. 计算相关系数（Pearson）
    corr, _ = pearsonr(scores, labels)
    
    # 3. 计算ROC曲线和找到最佳阈值（你可以用 Youden's J 统计量）
    fpr, tpr, thresholds = roc_curve(labels, scores)
    youdens_j = tpr - fpr
    best_idx = np.argmax(youdens_j)
    best_threshold = thresholds[best_idx]
    
    # 4. 基于相关系数正负决定分类方式
    if corr >= 0:
        predictions = (scores >= best_threshold).astype(int)
    else:
        predictions = (scores < best_threshold).astype(int)
    
    # 5. 评估准确率和 F1
    acc = accuracy_score(labels, predictions)
    f1 = f1_score(labels, predictions)
    
    return corr, acc, f1


def plot_distributions(pos_scores, neg_scores, filename="dist_log.png"):
    import matplotlib.pyplot as plt
    import seaborn as sns

    plt.figure(figsize=(8, 5))

    # KDE 曲线
    sns.kdeplot(pos_scores, label="Positive", shade=True, color="green", linewidth=2)
    sns.kdeplot(neg_scores, label="Negative", shade=True, color="red", linewidth=2)

    # 设置对数坐标系
    plt.yscale("log", base=1.1)

    # 图形修饰
    plt.title("Score Distributions (Log Scale)")
    plt.xlabel("Score")
    plt.ylabel("Density (log)")
    plt.legend()
    plt.grid(True, which='both', linestyle='--', linewidth=0.5)

    # 保存图像
    plt.tight_layout()
    plt.savefig(filename)
    plt.close()

def sigmoid(x):
    return 1 / (1 + np.exp(-x))



parser = argparse.ArgumentParser(description='Merge JSONL or JSON files.')
parser.add_argument('--dpo_ds', required=True)
parser.add_argument('--ref_ds', required=True)
parser.add_argument('--dpo_logp_name', required=True)
parser.add_argument('--ref_logp_name', required=True)
parser.add_argument('--seq_reward_cal', required=False, default="mean")
args = parser.parse_args()


dpo_ds: datasets.Dataset = load_single_dataset(args.dpo_ds)
ref_ds: datasets.Dataset = load_single_dataset(args.ref_ds)
dpo_ds = dpo_ds.add_column(args.ref_logp_name, ref_ds[args.ref_logp_name])

positive_scores_tok = []
negative_scores_tok = []
positive_scores_seq = []
negative_scores_seq = []
scores = {
    "tok_logpratio": ([], []),
    "tok_cumsum": ([], []),
    "tok_cumsummean": ([], []),
    "tok_sigmoidcumsummean": ([], []),
    "seq_mean": ([], []),
    "seq_sum": ([], []),
}


for row in tqdm(dpo_ds):
    tok_scores = np.array([(dpo_logp - ref_logp) for (dpo_logp, ref_logp) in zip (row[args.dpo_logp_name], row[args.ref_logp_name])])
    cumulative_sum = np.cumsum(tok_scores)
    cumulative_summean = cumulative_sum / np.arange(1, len(cumulative_sum) + 1)

    if row["score"] > 0:
        scores["tok_logpratio"][0].extend(tok_scores.tolist())
        scores["tok_cumsum"][0].extend(cumulative_sum.tolist())
        scores["tok_cumsummean"][0].extend(cumulative_summean.tolist())
        scores["tok_sigmoidcumsummean"][0].extend(sigmoid(cumulative_summean).tolist())
        scores["seq_mean"][0].append(float(np.mean(tok_scores)))
        scores["seq_sum"][0].append(float(np.sum(tok_scores)))
    else:
        scores["tok_logpratio"][1].extend(tok_scores.tolist())
        scores["tok_cumsum"][1].extend(cumulative_sum.tolist())
        scores["tok_cumsummean"][1].extend(cumulative_summean.tolist())
        scores["tok_sigmoidcumsummean"][1].extend(sigmoid(cumulative_summean).tolist())
        scores["seq_mean"][1].append(float(np.mean(tok_scores)))
        scores["seq_sum"][1].append(float(np.sum(tok_scores)))

for k, v in scores.items():
    corr, acc, f1 = evaluate_scores(v[0], v[1])
    print(k, f"corr: {corr}, acc: {acc}, f1: {f1}")

