import os
from tqdm import tqdm
import logging
from collections import defaultdict
import json
import argparse
import numpy as np

def calculate_averages(datas,file_name):
    """
    计算多个JSON文件中每个txt id的平均排名和平均相似度
    
    Args:
        json_files: JSON文件路径列表
    
    Returns:
        dict: 包含每个txt id的平均排名和平均相似度的字典
    """
    # 用于存储每个txt id的所有排名和相似度
    rankings = defaultdict(list)
    similarities = defaultdict(list)
    # 处理每个条目
    for data in datas:
        for entry in data:
            try:
                txt_id = entry[0]  # 第一个元素是txt id
                rank = entry[1]    # 第二个元素是排名
                similarity = entry[2]  # 第三个元素是相似度
                
                rankings[txt_id].append(rank)
                similarities[txt_id].append(similarity)
                    
            except Exception as e:
                print(f"处理文件 {file_name} 时出错: {e}")
                continue
    
    # 计算平均值
    results = []
    for txt_id in rankings.keys():
        avg_rank = np.mean(rankings[txt_id])
        avg_similarity = np.mean(similarities[txt_id])
        results.append([txt_id, avg_rank, avg_similarity])
    results.sort(key=lambda x: x[1], reverse=False)
    return results

def construct_ave_similarity(base_path, type = "image2text_given"):
    folder1 = type+'_allv2'
    folder2 = type+'_multiBv1'
    folder3 = type+'_sbert'
    output = type+'_ave'
    total_files = 0
    output_base_path = os.path.join(base_path, output)
    os.makedirs(output_base_path, exist_ok=True)
    story_path = os.path.join(base_path, folder1)
    if not os.path.exists(story_path):
        print(f"错误: Story文件夹不存在于 {base_path}")
        return {}
    # 遍历base_path下的所有目录
    for journal_name in tqdm(os.listdir(story_path)):
        journal_path = os.path.join(story_path, journal_name)
        if not os.path.isdir(journal_path):
            continue
        # print(f"处理期刊: {journal_name}")
        # 为每个期刊创建对应的输出目录
        journal_output_dir = os.path.join(output_base_path, journal_name)
        os.makedirs(journal_output_dir, exist_ok=True)
        for root, dirs, files in os.walk(journal_path):
            # 只处理txt文件
            txt_files = [f for f in files if f.endswith('.json')]
        
            if not txt_files:
                continue
            
            # 创建对应的输出目录
            rel_path = os.path.relpath(root, journal_path)
            if rel_path != '.':
                output_dir = os.path.join(journal_output_dir, rel_path)
                os.makedirs(output_dir, exist_ok=True)
            else:
                output_dir = journal_output_dir
            
            print(f"处理目录: {rel_path} (找到 {len(txt_files)} 个txt文件)")
            
            # 处理当前目录下的所有txt文件
            for txt_file in txt_files:
                total_files += 1
                
                try:
                    # 构建文件路径
                    file_path = os.path.join(root, txt_file)
                    file_path2 = os.path.join(base_path, folder2, journal_name, txt_file)
                    file_path3 = os.path.join(base_path, folder3, journal_name, txt_file)
                    datas = []
                    with open(file_path, 'r', encoding='utf-8') as f:
                        datas.append(json.load(f))
                    with open(file_path2, 'r', encoding='utf-8') as f:
                        datas.append(json.load(f))
                    with open(file_path3, 'r', encoding='utf-8') as f:
                        datas.append(json.load(f))
                    results = calculate_averages(datas, txt_file)
                    with open(os.path.join(output_dir, txt_file), 'w', encoding='utf-8') as f:
                        json.dump(results, f, ensure_ascii=False, indent=4)
                except Exception as e:
                    print(f"处理文件 {txt_file} 时出错: {e}")
                    continue
    print(f"处理完成，共处理 {total_files} 个文件")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="construct multi-choice dataset")
    parser.add_argument('--base_path', type=str, default="./CNS_cover/Science", required=True, help="output directory")
    parser.add_argument('--type', type=str, default="text2image_given", required=True, help="type")
    args = parser.parse_args()
    construct_ave_similarity(args.base_path, args.type)