import os

os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
print("HF_ENDPOINT:", os.environ.get("HF_ENDPOINT"))
import json
from pathlib import Path
from collections import defaultdict
from pprint import pprint
from tqdm import tqdm
from bert_eval import semantic_evaluation
from split_en_zh import split_chinese_english

debug = False


log_dir = Path(str(os.environ.get("LOG_DIR"))) 
if not log_dir:
    raise ValueError("LOG_DIR 环境变量未设置")

parts = set(Path(log_dir).parts)
if "en" in parts:
    lang = "en"
elif "zh" in parts:
    lang = "zh"
else:
    raise ValueError(f"log_dir: {log_dir} 不合法")


result_all = []
for file in log_dir.glob("*results*.jsonl"):
    with open(file, "r") as f:
        result_all.extend(json.loads(line) for line in f)


seen_ids = set()
result_all_clean = []
for result in result_all:
    current_id = result["data"]["id"][0]
    if current_id in seen_ids:
        continue
    result_all_clean.append(result)
    seen_ids.add(current_id)

result_all = result_all_clean
print("去重后数据样本数量：", len(result_all))
print("数据样本示例：")


print("\n" + "=" * 80)
print("分析结果")
print("=" * 80)


def analyze_results(results):
    score_all = count_all = 0
    score_by_ds = defaultdict(int)
    count_by_ds = defaultdict(int)

    sample_score_all = sample_count_all = 0
    sample_score_by_ds = defaultdict(int)
    sample_count_by_ds = defaultdict(int)

    pos_correct = pos_total = 0
    neg_correct = neg_total = 0

    both_correct = both_wrong = only_pos_wrong = only_neg_wrong = 0

    for result in tqdm(results, desc="分析中", unit="样本"):
        dataset = result["dataset"]
        anss = result["results"]
        assert len(anss) == 2, "每个样本应该有2个结果"

        pos_ok = neg_ok = False

        for ans in anss:
            label, pred, ptype, score = ans["label"], ans["predicted_label"], ans["pos_neg_type"], ans["score"]
            expected_score = 1 if label == pred else 0
            if score != expected_score:
                print(f"[警告] score字段不正确: id={result['data']['id']}, 期望{expected_score}, 实际{score}")

            
            score_all += score
            count_all += 1
            score_by_ds[dataset] += score
            count_by_ds[dataset] += 1

            
            if ptype == "positive":
                pos_total += 1
                if label == pred:
                    pos_correct += 1
                    pos_ok = True
            elif ptype == "negative":
                neg_total += 1
                if label == pred:
                    neg_correct += 1
                    neg_ok = True

        
        sample_ok = pos_ok and neg_ok
        sample_score_all += int(sample_ok)
        sample_count_all += 1
        sample_score_by_ds[dataset] += int(sample_ok)
        sample_count_by_ds[dataset] += 1

        
        if pos_ok and neg_ok:
            both_correct += 1
        elif not pos_ok and not neg_ok:
            both_wrong += 1
        elif not pos_ok and neg_ok:
            only_pos_wrong += 1
        else:  
            only_neg_wrong += 1

    return (
        (score_all, count_all, score_by_ds, count_by_ds),
        (sample_score_all, sample_count_all, sample_score_by_ds, sample_count_by_ds),
        (pos_correct, pos_total, neg_correct, neg_total),
        (both_correct, only_pos_wrong, only_neg_wrong, both_wrong, len(results)),
    )



(
    (score_all, count_all, score_by_ds, count_by_ds),
    (sample_score_all, sample_count_all, sample_score_by_ds, sample_count_by_ds),
    (pos_correct, pos_total, neg_correct, neg_total),
    (both_correct, only_pos_wrong, only_neg_wrong, both_wrong, total_samples),
) = analyze_results(result_all)


print("\n1. 逐结果分析：")
print(f"总体准确率: {score_all / count_all:.4f} ({score_all}/{count_all})")
for ds in sorted(score_by_ds):
    acc = score_by_ds[ds] / count_by_ds[ds]
    print(f"  {ds}: {acc:.4f} ({score_by_ds[ds]}/{count_by_ds[ds]})")


print("\n2. 样本级别分析：")
print(f"样本级别总体准确率: {sample_score_all / sample_count_all:.4f} ({sample_score_all}/{sample_count_all})")
for ds in sorted(sample_score_by_ds):
    acc = sample_score_by_ds[ds] / sample_count_by_ds[ds]
    print(f"  {ds}: {acc:.4f} ({sample_score_by_ds[ds]}/{sample_count_by_ds[ds]})")


print("\n3. Positive 和 Negative 分析：")
print(f"Positive准确率: {pos_correct / pos_total:.4f} ({pos_correct}/{pos_total})")
print(f"Negative准确率: {neg_correct / neg_total:.4f} ({neg_correct}/{neg_total})")


print("\n4. 样本级别错误分析：")
print(f"两个都正确: {both_correct} ({both_correct/total_samples:.4f})")
print(f"只有positive错误: {only_pos_wrong} ({only_pos_wrong/total_samples:.4f})")
print(f"只有negative错误: {only_neg_wrong} ({only_neg_wrong/total_samples:.4f})")
print(f"两个都错误: {both_wrong} ({both_wrong/total_samples:.4f})")
print(f"总样本数: {total_samples}")

print("\n" + "=" * 80)
print("分析完成")
print("=" * 80)





gt_path = "gpt4o_responses/geo_explain_test_0825.jsonl"
with open(gt_path, "r") as f:
    gt_all_data = [json.loads(line) for line in f]

gt_all_ids = {}
for gt_data in gt_all_data:
    gt_id = gt_data["id"]
    pos_gt = gt_data["positive_response"]
    neg_gt = gt_data["negative_response"]
    chinese_pos, english_pos = split_chinese_english(pos_gt)
    chinese_neg, english_neg = split_chinese_english(neg_gt)
    chinese_pos, english_pos = "[[1]]\n\n" + chinese_pos.strip(), "[[1]]\n\n" + english_pos.strip()
    chinese_neg, english_neg = "[[0]]\n\n" + chinese_neg.strip(), "[[0]]\n\n" + english_neg.strip()
    assert chinese_pos and english_pos, f"pos_gt: {pos_gt}"
    assert chinese_neg and english_neg, f"neg_gt: {neg_gt}"
    if lang == "zh":
        gt_all_ids[gt_id] = {"chinese_pos": chinese_pos, "chinese_neg": chinese_neg}
    else:
        gt_all_ids[gt_id] = {"english_pos": english_pos, "english_neg": english_neg}

print("\n" + "=" * 80)
print("语义相似度分析")
print("=" * 80)

pos_scores, neg_scores, all_scores = [], [], []

for result in tqdm(result_all, desc="计算语义相似度", unit="样本"):
    rid = result["data"]["id"][0]
    if rid not in gt_all_ids:
        print(f"[警告] id={rid} 不在gt_all_ids中，跳过")
        continue

    gt_data = gt_all_ids[rid]

    
    pred_pos = pred_neg = None
    for ans in result["results"]:
        if ans["label"] == 1:
            pred_pos = ans["output_text"]
        elif ans["label"] == 0:
            pred_neg = ans["output_text"]

    if lang == "zh":
        gt_pos, gt_neg = gt_data["chinese_pos"], gt_data["chinese_neg"]
    else:
        gt_pos, gt_neg = gt_data["english_pos"], gt_data["english_neg"]
    
    if debug:
        print(f"gt_pos: {gt_pos}")
        print(f"pred_pos: {pred_pos}")
        print(f"gt_neg: {gt_neg}")
        print(f"pred_neg: {pred_neg}")
        input("Press Enter to continue...")

    pos_score = semantic_evaluation(gt_pos, pred_pos, method="sbert", lang=lang)
    neg_score = semantic_evaluation(gt_neg, pred_neg, method="sbert", lang=lang)

    if pos_score:
        pos_scores.append(pos_score["Sentence-BERT Similarity"])
        all_scores.append(pos_score["Sentence-BERT Similarity"])
    if neg_score:
        neg_scores.append(neg_score["Sentence-BERT Similarity"])
        all_scores.append(neg_score["Sentence-BERT Similarity"])


def calculate_stats(scores, name):
    if not scores:
        return
    scores = [s for s in scores if s is not None]
    if not scores:
        return

    mean_score = sum(scores) / len(scores)
    min_score = min(scores)
    max_score = max(scores)
    sorted_scores = sorted(scores)
    n = len(sorted_scores)
    if n % 2 == 1:
        median_score = sorted_scores[n // 2]
    else:
        median_score = (sorted_scores[n // 2 - 1] + sorted_scores[n // 2]) / 2

    print(f"\n{name} 语义相似度统计:")
    print(f"  样本数量: {len(scores)}")
    print(f"  平均分数: {mean_score:.4f}")
    print(f"  中位数: {median_score:.4f}")
    print(f"  最低分数: {min_score:.4f}")
    print(f"  最高分数: {max_score:.4f}")

    
    score_ranges = [(0.0, 0.2), (0.2, 0.4), (0.4, 0.6), (0.6, 0.8), (0.8, 1.0)]
    for low, high in score_ranges:
        if high == 1.0:
            count = sum(1 for s in scores if low <= s <= high)
        else:
            count = sum(1 for s in scores if low <= s < high)
        percentage = count / len(scores) * 100
        print(f"  {low:.1f}-{high:.1f}: {count} 样本 ({percentage:.1f}%)")



calculate_stats(pos_scores, "Positive")
calculate_stats(neg_scores, "Negative")
calculate_stats(all_scores, "整体")


if pos_scores and neg_scores:
    pos_mean = sum(pos_scores) / len(pos_scores)
    neg_mean = sum(neg_scores) / len(neg_scores)
    diff = pos_mean - neg_mean

    print(f"\nPositive vs Negative 比较:")
    print(f"  Positive平均分数: {pos_mean:.4f}")
    print(f"  Negative平均分数: {neg_mean:.4f}")
    print(f"  差异 (Positive - Negative): {diff:.4f}")
    if diff > 0:
        print("  Positive语义相似度平均高于Negative")
    elif diff < 0:
        print("  Negative语义相似度平均高于Positive")
    else:
        print("  Positive和Negative语义相似度相当")

print("\n" + "=" * 80)
print("语义相似度分析完成")
print("=" * 80)
