import pandas as pd
import os
import sys
import numpy as np
import argparse
sys.path.append(os.getcwd())
from utils.parse_jsonString import parse_probabilities

def calculate_rms_calibration_error(predictions, ground_truths):
    """
    计算RMS校准误差
    
    Args:
        predictions: 预测概率列表，每个元素是一个字典 {选项: 概率}
        ground_truths: 真实标签列表
        
    Returns:
        float: RMS校准误差
    """
    if not predictions or len(predictions) != len(ground_truths):
        return None
    
    # 提取每个样本的预测概率和是否正确
    squared_errors = []
    
    for pred, gt in zip(predictions, ground_truths):
        if gt in pred:
            # 预测的概率
            predicted_prob = pred[gt]
            # 计算平方误差 (预测概率 - 1)^2
            squared_error = (predicted_prob - 1) ** 2
            squared_errors.append(squared_error)
        else:
            # 如果预测中没有真实标签，则认为预测概率为0
            squared_errors.append(1)  # (0 - 1)^2 = 1
    
    # 计算均方根误差
    if squared_errors:
        rms = np.sqrt(np.mean(squared_errors))
        return rms
    else:
        return None

def compare_results(csv_path1, csv_path2, output_path=None):
    """
    比较两个CSV文件中的结果，并将匹配到的记录保存到新的CSV文件中
    
    Args:
        csv_path1: 第一个CSV文件路径
        csv_path2: 第二个CSV文件路径
        output_path: 输出CSV文件路径，如果不指定则根据输入文件名自动生成
    """
    # 读取CSV文件
    df1 = pd.read_csv(csv_path1)
    df2 = pd.read_csv(csv_path2)
    
    # 创建索引以便快速查找
    df1_indexed = df1.set_index(['journal', 'id'])
    
    # 如果没有指定输出路径，则自动生成
    if output_path is None:
        output_path = os.path.splitext(csv_path1)[0] + '_matched.csv'
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    # 存储匹配到的记录的索引
    matched_indices = []
    
    total = len(df2)
    total1 = 0
    correct = 0
    top2_correct = 0
    all_predictions = []
    all_ground_truths = []
    
    # 遍历第二个文件的每一行
    for _, row2 in df2.iterrows():
        try:
            # 获取journal和question_id
            journal = row2['journal']
            question_id = row2['id']
            
            # 在第一个文件中查找对应的行
            try:
                # 获取匹配行的索引
                idx = df1[(df1['journal'] == journal) & (df1['id'] == question_id)].index
                if len(idx) > 0:
                    total1 += 1
                    matched_indices.append(idx[0])
                    row1 = df1_indexed.loc[(journal, question_id)]
                else:
                    print(f"警告: 在第一个文件中找不到 journal={journal}, question_id={question_id}")
                    continue
            except KeyError:
                print(f"警告: 在第一个文件中找不到 journal={journal}, question_id={question_id}")
                continue
            
            # 解析response
            probabilities = parse_probabilities(row1['response'])
            if not probabilities:
                print(f"警告: 无法解析 journal={journal}, question_id={question_id} 的response")
                continue
                
            # 获取ground truth
            ground_truth = row2['ground_truth']
            all_predictions.append(probabilities)
            all_ground_truths.append(ground_truth)
            
            # 获取top1和top2预测
            sorted_probs = sorted(probabilities.items(), key=lambda x: x[1], reverse=True)
            top1 = sorted_probs[0][0]
            top2 = [sorted_probs[0][0], sorted_probs[1][0]] if len(sorted_probs) > 1 else [sorted_probs[0][0]]
            
            # 计算正确率
            if top1 == ground_truth:
                correct += 1
            if ground_truth in top2:
                top2_correct += 1
                
        except Exception as e:
            print(f"处理行时出错: {e}")
            continue
    
    # 将匹配到的记录保存为CSV文件
    if matched_indices:
        matched_df = df1.iloc[matched_indices]
        matched_df.to_csv(output_path, index=False)
        print(f"\n匹配到的记录已保存到: {output_path}")
        print(f"保存了 {len(matched_indices)} 条记录")
    
    # 计算准确率
    # accuracy = correct / total1 if total1 > 0 else 0
    # top2_accuracy = top2_correct / total1 if total1 > 0 else 0
    # rms_error = calculate_rms_calibration_error(all_predictions, all_ground_truths)
    
    # print(f"\n{'='*20} 比较结果 {'='*20}")
    # print(f"总样本数: {total}")
    # print(f"总样本数1: {total1}")
    # print(f"Top1正确数: {correct}")
    # print(f"Top2正确数: {top2_correct}")
    # print(f"Top1准确率: {accuracy:.8f}")
    # print(f"Top2准确率: {top2_accuracy:.8f}")
    # print(f"RMS校准误差: {rms_error:.8f}")
    # print('='*50)
    
    # return accuracy, top2_accuracy, rms_error
    return total1, correct, top2_correct

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--csv1', type=str, required=True)
    parser.add_argument('--csv2', type=str, required=True)
    parser.add_argument('--output_path', type=str, required=True)
    args = parser.parse_args()
    # csv1 = '/home/ubuntu/scratch/mhjiang/CNS_cover/experiment/results/understanding/qwen2.5-vl-7b-instruct_4options_20250315_233609/results.csv'
    # csv2 = '/home/ubuntu/scratch/mhjiang/CNS_cover/experiment/results/understanding/CoVR_4options_20250316_154532/results.csv'
    compare_results(args.csv1, args.csv2, output_path=args.output_path)

if __name__ == "__main__":
    main()