import os
import csv
import argparse
from tqdm import tqdm
import pandas as pd

def find_missing_embeddings(base_path, output_csv):
    """
    找出在Story目录中存在但在Story_Embedding_ViTH目录中不存在的issue文件
    
    Args:
        base_path: 基础路径，包含Story和Story_Embedding_ViTH目录
        output_csv: 输出CSV文件路径
    """
    # 构建Story和Embedding目录路径
    story_dir = os.path.join(base_path, "Story")
    embedding_dir = os.path.join(base_path, "Story_Embeddings_ViTH14")
    
    # 检查目录是否存在
    if not os.path.exists(story_dir):
        raise ValueError(f"Story目录不存在: {story_dir}")
    if not os.path.exists(embedding_dir):
        raise ValueError(f"Story_Embedding_ViTH目录不存在: {embedding_dir}")
    
    # 存储缺失的文件路径
    missing_files = []
    
    # 遍历Story目录
    print("正在扫描Story目录...")
    for journal in tqdm(os.listdir(story_dir)):
        journal_path = os.path.join(story_dir, journal)
        
        # 跳过非目录
        if not os.path.isdir(journal_path):
            continue
        
        # 检查对应的embedding目录是否存在
        embedding_journal_path = os.path.join(embedding_dir, journal)
        if not os.path.exists(embedding_journal_path):
            # 如果整个期刊目录不存在，添加所有story文件
            for story_file in os.listdir(journal_path):
                if story_file.endswith('.txt'):
                    story_file_path = os.path.join(journal_path, story_file)
                    missing_files.append({
                        'journal': journal,
                        'issue': story_file,
                        'story_path': story_file_path,
                        'reason': 'missing_journal_dir'
                    })
            continue
        
        # 遍历Story中的文件
        for story_file in os.listdir(journal_path):
            if not story_file.endswith('.txt'):
                continue
                
            story_file_path = os.path.join(journal_path, story_file)
            
            # 构建对应的embedding文件路径 (.txt -> .pt)
            embedding_file = story_file.replace('.txt', '.pt')
            embedding_file_path = os.path.join(embedding_journal_path, embedding_file)
            
            # 检查embedding文件是否存在
            if not os.path.exists(embedding_file_path):
                missing_files.append({
                    'journal': journal,
                    'issue': story_file,
                    'story_path': story_file_path,
                    'reason': 'missing_embedding_file'
                })
    
    # 保存结果到CSV
    print(f"找到 {len(missing_files)} 个缺失的embedding文件")
    
    if missing_files:
        # 确保输出目录存在
        output_dir = os.path.dirname(output_csv)
        if output_dir and not os.path.exists(output_dir):
            os.makedirs(output_dir)
        
        # 写入CSV
        with open(output_csv, 'w', newline='', encoding='utf-8') as csvfile:
            fieldnames = ['journal', 'issue', 'story_path', 'reason']
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            
            writer.writeheader()
            for file_info in missing_files:
                writer.writerow(file_info)
        
        print(f"结果已保存到: {output_csv}")
    else:
        print("没有找到缺失的embedding文件")

def find_missing_cover(base_path, output_csv):
    """
    找出在Story目录中存在但在Story_Cover目录中不存在的issue文件
    """
    # 构建Story和Embedding目录路径
    story_dir = os.path.join(base_path, "Story")
    embedding_dir = os.path.join(base_path, "Cover")
    
    # 检查目录是否存在
    if not os.path.exists(story_dir):
        raise ValueError(f"Story目录不存在: {story_dir}")
    if not os.path.exists(embedding_dir):
        raise ValueError(f"Story_Embedding_ViTH目录不存在: {embedding_dir}")
    
    # 存储缺失的文件路径
    missing_files = []
    
    # 遍历Story目录
    print("正在扫描Story目录...")
    for journal in tqdm(os.listdir(story_dir)):
        journal_path = os.path.join(story_dir, journal)
        
        # 跳过非目录
        if not os.path.isdir(journal_path):
            continue
        
        # 检查对应的embedding目录是否存在
        embedding_journal_path = os.path.join(embedding_dir, journal)
        if not os.path.exists(embedding_journal_path):
            # 如果整个期刊目录不存在，添加所有story文件
            for story_file in os.listdir(journal_path):
                if story_file.endswith('.txt'):
                    story_file_path = os.path.join(journal_path, story_file)
                    missing_files.append({
                        'journal': journal,
                        'issue': story_file,
                        'story_path': story_file_path,
                        'reason': 'missing_journal_dir'
                    })
            continue
        
        # 遍历Story中的文件
        for story_file in os.listdir(journal_path):
            if not story_file.endswith('.txt'):
                continue
                
            story_file_path = os.path.join(journal_path, story_file)
            
            # 构建对应的embedding文件路径 (.txt -> .pt)
            embedding_file = story_file.replace('.txt', '.png')
            embedding_file_path = os.path.join(embedding_journal_path, embedding_file)
            
            # 检查embedding文件是否存在
            if not os.path.exists(embedding_file_path):
                missing_files.append({
                    'journal': journal,
                    'issue': story_file,
                    'story_path': story_file_path,
                    'reason': 'missing_embedding_file'
                })
    
    # 保存结果到CSV
    print(f"找到 {len(missing_files)} 个缺失的embedding文件")
    
    if missing_files:
        # 确保输出目录存在
        output_dir = os.path.dirname(output_csv)
        if output_dir and not os.path.exists(output_dir):
            os.makedirs(output_dir)
        
        # 写入CSV
        with open(output_csv, 'w', newline='', encoding='utf-8') as csvfile:
            fieldnames = ['journal', 'issue', 'story_path', 'reason']
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            
            writer.writeheader()
            for file_info in missing_files:
                writer.writerow(file_info)
        
        print(f"结果已保存到: {output_csv}")
    else:
        print("没有找到缺失的embedding文件")

def delete_files_from_csv(csv_path, dry_run=True):
    """
    从CSV文件中读取story_path列，并删除这些文件
    
    Args:
        csv_path: CSV文件路径
        dry_run: 如果为True，只打印要删除的文件而不实际删除
        
    Returns:
        tuple: (成功删除的文件数, 失败的文件数)
    """
    # 检查CSV文件是否存在
    if not os.path.exists(csv_path):
        print(f"错误: CSV文件不存在: {csv_path}")
        return 0, 0
    
    # 读取CSV文件
    try:
        # 使用pandas读取CSV
        df = pd.read_csv(csv_path)
        
        # 检查是否包含story_path列
        if 'story_path' not in df.columns:
            print("错误: CSV文件必须包含'story_path'列")
            return 0, 0
        
        # 获取所有存在的文件路径
        file_paths = [path for path in df['story_path'] if path and os.path.exists(path)]
        
    except Exception as e:
        print(f"读取CSV文件时出错: {str(e)}")
        return 0, 0
    
    print(f"从CSV文件中读取了 {len(file_paths)} 个有效文件路径")
    
    # 如果是dry run模式，只打印文件路径
    if dry_run:
        print("=== 干运行模式 - 以下文件将被删除 ===")
        for path in file_paths[:10]:  # 只显示前10个
            print(f"将删除: {path}")
        if len(file_paths) > 10:
            print(f"... 以及其他 {len(file_paths) - 10} 个文件")
        print(f"总计 {len(file_paths)} 个文件")
        print("要实际删除文件，请将dry_run参数设置为False")
        return 0, 0
    
    # 删除文件
    success_count = 0
    failure_count = 0
    
    print("正在删除文件...")
    for file_path in tqdm(file_paths):
        try:
            os.remove(file_path)
            success_count += 1
        except Exception as e:
            print(f"删除文件时出错 {file_path}: {str(e)}")
            failure_count += 1
    
    print(f"删除完成! 成功: {success_count}, 失败: {failure_count}")
    return success_count, failure_count

def main():
    parser = argparse.ArgumentParser(description="找出在Story目录中存在但在Story_Embedding_ViTH目录中不存在的issue文件")
    parser.add_argument("--base_path", help="基础路径，包含Story和Story_Embedding_ViTH目录")
    parser.add_argument("--output_path", default="missing_embeddings.csv", help="输出CSV文件路径")
    parser.add_argument("--execute", action="store_true", help="是否为执行模式")
    
    args = parser.parse_args()
    
    try:
        if not args.execute:
            # find_missing_embeddings(args.base_path, args.output_path)
            find_missing_cover(args.base_path, args.output_path)
        else:
            delete_files_from_csv(args.base_path, dry_run=False)
    except Exception as e:
        print(f"执行出错: {str(e)}")

if __name__ == "__main__":
    main()
    