import os
import json
import glob
from typing import List, Dict, Tuple
import numpy as np
from tqdm import tqdm
from scipy import stats
import pandas as pd

def get_self_rank_ratio(json_path: str) -> Tuple[float, bool]:
    """
    计算单个JSON文件的自排名比例
    
    Args:
        json_path: JSON文件路径
    
    Returns:
        Tuple[float, bool]: (自排名比例, 是否排第一)
    """
    try:
        # 获取当前文件名（不含扩展名）并加上.txt
        self_name = os.path.basename(json_path).replace('.json', '.txt')
        
        # 读取JSON文件
        with open(json_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        
        # 获取总条目数
        total_entries = len(data)
        if total_entries == 0:
            return 1.0, False
        
        # 查找自己的排名
        for rank, entry in enumerate(data, 1):
            if entry[0] == self_name:
                return rank / total_entries, rank == 1
        
        # 如果没找到自己，返回1.0（最差情况）
        return 1.0, False
    except Exception as e:
        print(f"处理文件 {json_path} 时出错: {e}")
        return 1.0, False

def calculate_average_self_rank_ratio(base_paths: List[str]) -> Dict[str, float]:
    """
    计算多个路径下所有JSON文件的统计指标
    
    Args:
        base_paths: JSON文件所在的基础路径列表
    
    Returns:
        Dict: 包含各种统计指标的字典
    """
    all_ratios = []
    first_rank_count = 0
    total_files = 0
    
    for base_path in base_paths:
        story_path = os.path.join(base_path, "image2text_given_ave")
        if not os.path.exists(story_path):
            print(f"错误: Cover文件夹不存在于 {base_path}")
            continue
            
        # 遍历base_path下的所有目录
        for journal_name in tqdm(os.listdir(story_path), desc=f"处理 {base_path}"):
            journal_path = os.path.join(story_path, journal_name)
            if not os.path.isdir(journal_path):
                continue
                
            print(f"处理期刊: {journal_name}")

            for root, dirs, files in os.walk(journal_path):
                # 只处理json文件
                json_files = [f for f in files if f.endswith('.json')]
                
                if not json_files:
                    continue
                
                # 处理当前目录下的所有json文件
                for json_file in json_files:
                    file_path = os.path.join(root, json_file)
                    ratio, is_first = get_self_rank_ratio(file_path)
                    all_ratios.append(ratio)
                    if is_first:
                        first_rank_count += 1
                    total_files += 1
    
    if not all_ratios:
        return {
            'error': '没有找到任何有效的JSON文件'
        }
    
    # 计算统计指标
    results = {
        'average_ratio': np.mean(all_ratios),
        'median_ratio': np.median(all_ratios),
        'mode_ratio': stats.mode(all_ratios, keepdims=False).mode,  # 众数
        'first_rank_ratio': first_rank_count / total_files if total_files > 0 else 0,
        'total_files': total_files
    }
    
    return results

def get_rank_among_options(json_path: str, options: List[str],self_name:str) -> bool:
    """
    计算在给定的四个选项中是否排名第一
    
    Args:
        json_path: JSON文件路径
        options: 四个选项的文件名列表
    
    Returns:
        bool: 是否在四个选项中排名第一
    """
    try:
        # 获取当前文件名（不含扩展名）并加上.txt
        self_name = self_name.replace('.json', '.txt')
        # 读取JSON文件
        with open(json_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        
        # 获取四个选项的排名
        option_ranks = {}
        for rank, entry in enumerate(data):
            if entry[0] in options:
                option_ranks[entry[0]] = rank
        # 检查自己是否在四个选项中排名第一
        self_rank = option_ranks.get(self_name, float('inf'))

        result = all(self_rank < option_ranks.get(opt, float('inf')) for opt in options if opt != self_name)
        return result
        
    except Exception as e:
        print(f"处理文件 {json_path} 时出错: {e}")
        return False

def calculate_accuracy(base_paths: List[str]) -> Dict[str, float]:
    """
    计算在四个选项中排名第一的准确率
    
    Args:
        base_paths: JSON文件所在的基础路径列表
    
    Returns:
        Dict: 包含准确率统计的字典
    """
    correct_count = 0
    total_files = 0
    # questions = pd.read_csv('/home/ubuntu/scratch/mhjiang/CNS_cover/Data/Understanding/MAC_2025/image2text_given/full_dataset.csv')
    questions = pd.read_csv('/home/ubuntu/scratch/mhjiang/CNS_cover/Data/Understanding/MAC_2025/text2image_given/full_dataset.csv')
    # 找到journal为journal_name,id为question_id的行
    
    for base_path in base_paths:
        # story_path = os.path.join(base_path, "image2text_given_qwen")
        story_path = os.path.join(base_path, "text2image_given_qwen")
        if not os.path.exists(story_path):
            print(f"错误: Cover文件夹不存在于 {base_path}")
            continue
            
        # 遍历base_path下的所有目录
        for journal_name in tqdm(os.listdir(story_path), desc=f"处理 {base_path}"):
            journal_path = os.path.join(story_path, journal_name)
            if not os.path.isdir(journal_path):
                continue
                
            print(f"处理期刊: {journal_name}")

            for root, dirs, files in os.walk(journal_path):
                # 只处理json文件
                json_files = [f for f in files if f.endswith('.json')]
                
                if not json_files:
                    continue
                
                # 处理当前目录下的所有json文件
                for json_file in json_files:
                    question_id = json_file.replace('.json', '')
                    row = questions[(questions['journal'] == journal_name) & (questions['id'] == question_id)]
                    if not row.empty:
                        options = [
                            f"{row['option_A_embedding_id'].iloc[0]}",
                            f"{row['option_B_embedding_id'].iloc[0]}",
                            f"{row['option_C_embedding_id'].iloc[0]}",
                            f"{row['option_D_embedding_id'].iloc[0]}"
                        ]
                        file_path = os.path.join(root, json_file)
                        is_correct = get_rank_among_options(file_path, options,json_file)
                        if is_correct:
                            correct_count += 1
                        total_files += 1
            # # 读取问题配置文件，获取每个问题的四个选项
            # question_config_path = os.path.join(journal_path, "questions.json")
            # if os.path.exists(question_config_path):
            #     with open(question_config_path, 'r') as f:
            #         questions = json.load(f)
                
            #     for question in questions:
            #         json_path = os.path.join(journal_path, f"{question['id']}.json")
            #         if os.path.exists(json_path):
            #             options = [f"{opt}.txt" for opt in question['options']]
            #             is_correct = get_rank_among_options(json_path, options)
            #             if is_correct:
            #                 correct_count += 1
            #             total_files += 1
    
    if total_files == 0:
        return {
            'error': '没有找到任何有效的问题文件'
        }
    
    # 计算准确率
    accuracy = correct_count / total_files if total_files > 0 else 0
    
    results = {
        'accuracy': accuracy,
        'correct_count': correct_count,
        'total_files': total_files
    }
    
    return results

def main():
    # 四个需要处理的路径
    base_paths = [
        'Cell/',
        'Nature/',
        'Science/',
        'ACS/'
    ]
    
    # 计算结果
    results = calculate_accuracy(base_paths)
    
    # 打印结果
    print("\n准确率统计结果：")
    print("-" * 50)
    if 'error' in results:
        print(f"错误: {results['error']}")
    else:
        print(f"处理问题总数: {results['total_files']}")
        print(f"正确回答数: {results['correct_count']}")
        print(f"准确率: {results['accuracy']:.10f} ({results['correct_count']} / {results['total_files']})")
        print(f"准确率百分比: {results['accuracy']*100:.2f}%")
    print("-" * 50)

if __name__ == "__main__":
    main()