import pandas as pd
import os
import argparse

def extract_volume_number(question_id):
    """
    从question_id中提取卷号（下划线前的数字）
    例如: "227_2" -> 227
    """
    try:
        return int(question_id.split('_')[0])
    except:
        return 0

from PIL import Image
import numpy as np

def check_white_patches(image_path: str, white_threshold: int = 240, patch_ratio: float = 0.9) -> bool:
    """
    检查图片是否有大片白色区域
    
    Args:
        image_path: 图片路径
        white_threshold: 判定为白色的像素阈值（0-255），默认240
        patch_ratio: 白色区域占比的阈值，默认0.75（75%）
        
    Returns:
        bool: True如果有大片白色区域，False如果图片正常
    """
    try:
        # 打开图片
        img = Image.open(image_path)
        
        # 转换为RGB模式（处理可能的RGBA图片）
        if img.mode == 'RGBA':
            img = img.convert('RGB')
        
        # 转换为numpy数组
        img_array = np.array(img)
        
        # 计算每个像素点RGB值的平均值
        pixel_means = np.mean(img_array, axis=2)
        
        # 计算白色像素的数量（所有通道都接近白色）
        white_pixels = np.sum(pixel_means > white_threshold)
        total_pixels = pixel_means.size
        
        # 计算白色像素的比例
        white_ratio = white_pixels / total_pixels
        
        # 如果白色像素比例超过阈值，认为有大片白色
        if white_ratio > patch_ratio:
            # print(f"发现大片白色区域在图片 {image_path}，白色比例: {white_ratio:.2%}")
            return True
        
        return False
        
    except Exception as e:
        print(f"处理图片 {image_path} 时出错: {e}")
        return True  # 如果出错，认为图片有问题

def check_all_images_in_row(row) -> bool:
    """
    检查一行数据中的所有图片是否都有问题
    
    Args:
        row: DataFrame的一行数据
        
    Returns:
        bool: True如果所有图片都有问题，False如果至少有一张图片正常
    """
    # 检查cover图片
    # cover_image = row['cover_image']
    # if 'png' in cover_image:
    #     if not check_white_patches(cover_image):
    #         return False
        
    # # 检查所有选项图片
    # for opt in ['A', 'B', 'C', 'D']:
    #     option_path = row[f'option_{opt}_path']
    #     if not '.png' in option_path:
    #         continue
    #     if not check_white_patches(option_path):
    #         return False
    # 获取ground truth选项
    ground_truth = row['answer']
    
    # 检查ground truth对应的图片
    option_path = row[f'option_{ground_truth}_path']
    if '.png' in option_path:
        if check_white_patches(option_path):
            # print(f"删除记录: journal={row['journal']}, id={row['id']}, ground truth图片有大片白色区域")
            return True
    
    # 如果所有图片都有问题，打印信息
    # print(f"删除记录: journal={row['journal']}, id={row['id']}, 所有图片都有大片白色区域")
    return False

def filter_latest_data(input_path, output_path, min_questions=10, random_seed=42):
    """
    对每个期刊按volume排序后筛选最新一期数据，如果最新一期题目数量少于min_questions则选择最新两期
    同时检查并删除所有图片都有大片白色区域的记录
    """
    # 读取数据集
    print(f"正在读取数据集: {input_path}")
    df = pd.read_csv(input_path)
    
    # 添加volume列
    df['volume'] = df['id'].apply(extract_volume_number)
    
    # 存储每个期刊的选择信息
    journals_info = []
    
    # 按期刊分组处理
    all_filtered = []
    for journal in df['journal'].unique():
        print(f"\n处理期刊: {journal}")
        journal_df = df[df['journal'] == journal].copy()
        
        # 按volume排序（升序，选最老的）
        journal_df = journal_df.sort_values('volume', ascending=True)
        
        # 检查每一期的图片质量
        valid_volumes = []
        for volume in journal_df['volume'].unique():
            volume_df = journal_df[journal_df['volume'] == volume]
            
            # 检查并删除所有图片都有问题的记录
            valid_rows = []
            for _, row in volume_df.iterrows():
                if not check_all_images_in_row(row):
                    valid_rows.append(row)
            
            if len(valid_rows) >= min_questions:
                valid_volumes.append(volume)
                selected_df = pd.DataFrame(valid_rows)
                break  # 找到第一个符合条件的期就停止
        
        if not valid_volumes:
            print(f"警告: {journal} 没有找到符合条件的期")
            continue
            
        # 存储期刊信息
        journals_info.append({
            "journal": journal,
            "total_questions": len(journal_df),
            "selected_questions": len(selected_df),
            "volume_min": int(selected_df['volume'].min()),
            "volume_max": int(selected_df['volume'].max()),
            "selected_volumes": ','.join(map(str, valid_volumes)),
            "selected_volumes_count": len(valid_volumes)
        })
        
        print(f"  总数据量: {len(journal_df)}")
        print(f"  选择数据量: {len(selected_df)}")
        print(f"  选择的volume: {valid_volumes}")
        
        all_filtered.append(selected_df)
    
    # 合并所有筛选后的数据
    filtered_df = pd.concat(all_filtered, ignore_index=True)
    
    # 删除辅助列
    filtered_df = filtered_df.drop('volume', axis=1)
    
    # 随机打乱顺序
    filtered_df = filtered_df.sample(frac=1, random_state=random_seed).reset_index(drop=True)
    
    # 保存结果
    filtered_df.to_csv(output_path, index=False)
    info_path = os.path.splitext(output_path)[0] + '_info.csv'
    pd.DataFrame(journals_info).to_csv(info_path, index=False)
    
    print(f"\n筛选完成！")
    print(f"原始数据集大小: {len(df)}")
    print(f"筛选后数据集大小: {len(filtered_df)}")
    print(f"结果已保存到: {output_path}")
    print(f"期刊信息已保存到: {info_path}")
    
    return filtered_df, pd.DataFrame(journals_info)

def main():
    parser = argparse.ArgumentParser(description="按volume排序筛选每个期刊最新一期数据（如果题目数量少于阈值则选择最新两期）并随机打乱")
    parser.add_argument('--input', type=str, required=True, help="输入CSV文件路径")
    parser.add_argument('--output', type=str, help="输出CSV文件路径")
    parser.add_argument('--min_questions', type=int, default=12, help="最少题目数量，如果最新一期少于这个数量则选择两期，默认为4")
    parser.add_argument('--seed', type=int, default=42, help="随机种子，默认为42")
    
    args = parser.parse_args()
    
    # 如果没有指定输出路径，在输入文件名后添加_filtered
    if args.output is None:
        input_name = os.path.splitext(args.input)[0]
        args.output = f"{input_name}_filtered.csv"
    
    # 确保输出目录存在
    os.makedirs(os.path.dirname(args.output), exist_ok=True)
    
    # 处理数据
    filter_latest_data(args.input, args.output, args.min_questions, args.seed)
if __name__ == "__main__":
    main()