import pandas as pd
import argparse
def analyze_embeddings(dataset_path: str, results_path: str):
    """
    统计预测结果和各embedding模型的错误数及选项总数
    
    Args:
        dataset_path: 原始数据集路径
        results_path: 实验结果路径
    """
    try:
        # 加载数据集
        dataset = pd.read_csv(dataset_path)
        results = pd.read_csv(results_path)
        
        # 确保id列的类型一致
        dataset['id'] = dataset['id'].astype(str)
        results['question_id'] = results['question_id'].astype(str)
        
        # 统计数据集中的embedding分布
        total_embeddings = {
            'multibv1': 0,
            'sbert': 0,
            'allv2': 0,
            'ave': 0,
            'other': 0
        }
        
        # 计算每个模型的选项总数
        for option in ['A', 'B', 'C', 'D']:

            col_name = f'option_{option}_embedding_name'
            for embedding_name in dataset[col_name]:
                if pd.isna(embedding_name):
                    continue
                    
                embedding_name = embedding_name.lower()
                if 'multibv1' in embedding_name:
                    total_embeddings['multibv1'] += 1
                elif 'sbert' in embedding_name:
                    total_embeddings['sbert'] += 1
                elif 'allv2' in embedding_name:
                    total_embeddings['allv2'] += 1
                elif 'ave' in embedding_name:
                    total_embeddings['ave'] += 1
                else:
                    total_embeddings['other'] += 1
        
        # 统计错误预测
        wrong_predictions = results[results['judge'] == 0]
        print(f"总预测数: {len(results)}")
        print(f"错误预测数: {len(wrong_predictions)}")
        
        # 统计每个模型的错误数
        wrong_embedding_stats = {
            'multibv1': 0,
            'sbert': 0,
            'allv2': 0,
            'ave': 0,
            'other': 0
        }
        
        for _, wrong_pred in wrong_predictions.iterrows():
            wrong_answer = wrong_pred['answer']
            original_data = dataset[
                (dataset['journal'] == wrong_pred['journal']) & 
                (dataset['id'] == wrong_pred['question_id'])
            ]
            
            if not original_data.empty:
                if not wrong_answer in ['A', 'B', 'C', 'D']:
                    continue
                embedding_name = original_data[f'option_{wrong_answer}_embedding_name'].iloc[0]
                if pd.isna(embedding_name):
                    continue
                    
                embedding_name = embedding_name.lower()
                if 'multibv1' in embedding_name:
                    wrong_embedding_stats['multibv1'] += 1
                elif 'sbert' in embedding_name:
                    wrong_embedding_stats['sbert'] += 1
                elif 'allv2' in embedding_name:
                    wrong_embedding_stats['allv2'] += 1
                elif 'ave' in embedding_name:
                    wrong_embedding_stats['ave'] += 1
                else:
                    wrong_embedding_stats['other'] += 1
        
        # 打印每个模型的统计信息
        print("\n各模型统计:")
        for model in total_embeddings.keys():
            print(f"{model.upper()}: 错误数 {wrong_embedding_stats[model]}, 选项总数 {total_embeddings[model]}")
        
        return total_embeddings, wrong_embedding_stats
            
    except Exception as e:
        print(f"分析过程中出错: {e}")
        raise

def analyze_experiments(dataset_path: str, experiments_csv: str, task: str, output_path: str):
    """
    分析指定task下所有实验中各个模型的错误数和选项总数，并保存结果到CSV文件
    
    Args:
        dataset_path: 原始数据集路径
        experiments_csv: 包含experiment、path和task的CSV文件路径
        task: 要分析的任务名称
        output_path: 结果保存的CSV文件路径
    """
    try:
        # 读取实验列表
        experiments_df = pd.read_csv(experiments_csv)
        
        # 筛选指定task的实验
        task_experiments = experiments_df[experiments_df['task'] == task]
        
        if task_experiments.empty:
            print(f"没有找到task '{task}'的实验")
            return
            
        # 存储所有实验的结果
        all_results = []
            
        # 处理每个实验
        for _, exp in task_experiments.iterrows():
            experiment_name = exp['model']
            results_path = exp['path']
            
            print(f"\n=== 分析实验: {experiment_name} ===")
            try:
                # 使用原有的analyze_embeddings函数分析每个实验
                total_embeddings, wrong_embedding_stats = analyze_embeddings(dataset_path, results_path)
                
                # 为每个模型创建一条记录
                for model in total_embeddings.keys():
                    total = total_embeddings[model]
                    wrong = wrong_embedding_stats[model]
                    if total > 0:
                        accuracy = (total - wrong) / total * 100
                        result = {
                            'experiment': experiment_name,
                            'model_type': model.upper(),
                            'total_samples': total,
                            'wrong_predictions': wrong,
                            'correct_predictions': total - wrong,
                            'accuracy': accuracy
                        }
                        all_results.append(result)
                        print(f"{model.upper()}: 正确率 {accuracy:.2f}% ({total-wrong}/{total})")
                
            except Exception as e:
                print(f"处理实验 {experiment_name} 时出错: {e}")
                continue
        
        if all_results:
            # 创建结果DataFrame
            results_df = pd.DataFrame(all_results)
            
            # 保存到CSV文件
            results_df.to_csv(output_path, index=False)
            print(f"\n结果已保存到: {output_path}")
            
            # 打印汇总信息
            print("\n=== 汇总统计 ===")
            print("\n每个模型类型的平均正确率:")
            model_stats = results_df.groupby('model_type')['accuracy'].agg(['mean', 'min', 'max'])
            print(model_stats)
            
            return results_df
                
    except Exception as e:
        print(f"分析过程中出错: {e}")
        raise

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, required=True, help="数据集路径")
    parser.add_argument("--experiments", type=str, required=True, help="实验列表CSV文件路径")
    parser.add_argument("--task", type=str, required=True, help="要分析的任务名称")
    parser.add_argument("--output", type=str, required=True, help="结果保存的CSV文件路径")
    
    args = parser.parse_args()
    analyze_experiments(args.dataset, args.experiments, args.task, args.output)