import os
import json
from datetime import datetime

def remove_substring_from_json(input_path, output_path, substring):
    """从JSON文件中删除指定子字符串，并保存到新文件"""
    try:
        # 读取原始文件
        with open(input_path, 'r', encoding='utf-8') as file:
            content = file.read()
        
        # 如果文件为空，直接创建空的输出文件
        if not content.strip():
            with open(output_path, 'w', encoding='utf-8') as file:
                pass
            print(f"已创建空文件: {output_path}")
            return True
        
        # 删除所有出现的子字符串
        new_content = content.replace(substring, '')
        
        # 验证处理后的内容是否仍然是有效的JSON
        try:
            json_obj = json.loads(new_content)
        except json.JSONDecodeError as e:
            print(f"错误: 文件 {input_path} 处理后不再是有效的JSON: {e}")
            return False
        
        # 写入新文件（确保格式化输出）
        with open(output_path, 'w', encoding='utf-8') as file:
            json.dump(json_obj, file, ensure_ascii=False, indent=2)
        
        print(f"已保存到: {output_path}")
        return True
            
    except Exception as e:
        print(f"处理文件 {input_path} 时出错: {e}")
        return False

def process_directory(input_dir, output_dir, substring, recursive):
    """递归处理目录中的所有JSON文件"""
    modified_count = 0
    total_count = 0
    
    # 创建输出目录（如果不存在）
    os.makedirs(output_dir, exist_ok=True)
    
    if recursive:
        # 递归处理模式
        for root, _, files in os.walk(input_dir):
            for file in files:
                if file.endswith('.json'):
                    # 构建输入和输出文件路径
                    rel_path = os.path.relpath(os.path.join(root, file), input_dir)
                    input_file_path = os.path.join(input_dir, rel_path)
                    output_file_path = os.path.join(output_dir, rel_path)
                    
                    # 确保输出目录存在
                    output_file_dir = os.path.dirname(output_file_path)
                    os.makedirs(output_file_dir, exist_ok=True)
                    
                    # 处理文件
                    total_count += 1
                    if remove_substring_from_json(input_file_path, output_file_path, substring):
                        modified_count += 1
    else:
        # 非递归模式（只处理当前目录）
        for file in os.listdir(input_dir):
            if file.endswith('.json') and os.path.isfile(os.path.join(input_dir, file)):
                input_file_path = os.path.join(input_dir, file)
                output_file_path = os.path.join(output_dir, file)
                
                total_count += 1
                if remove_substring_from_json(input_file_path, output_file_path, substring):
                    modified_count += 1
    
    print(f"\n处理完成:")
    print(f"总JSON文件数: {total_count}")
    print(f"成功处理的文件数: {modified_count}")
    
    return modified_count, total_count

if __name__ == "__main__":
    # 直接在此处修改路径和参数
    INPUT_PATH = ""
    OUTPUT_PATH = ""
    SUBSTRING = "<s>[INST]"
    RECURSIVE = False  # 是否递归处理子目录（仅在INPUT_PATH是目录时有效）
    
    # 检查输入路径是否存在
    if not os.path.exists(INPUT_PATH):
        print(f"错误: 指定的输入路径不存在 - {INPUT_PATH}")
        exit(1)
    
    if os.path.isfile(INPUT_PATH) and INPUT_PATH.endswith('.json'):
        # 处理单个JSON文件
        print(f"处理单个文件: {INPUT_PATH}")
        
        # 确保输出目录存在
        output_dir = os.path.dirname(OUTPUT_PATH)
        if output_dir and not os.path.exists(output_dir):
            os.makedirs(output_dir, exist_ok=True)
        
        success = remove_substring_from_json(INPUT_PATH, OUTPUT_PATH, SUBSTRING)
        print(f"处理结果: {'成功' if success else '失败'}")
    elif os.path.isdir(INPUT_PATH):
        # 处理目录
        print(f"处理目录: {INPUT_PATH}")
        modified, total = process_directory(INPUT_PATH, OUTPUT_PATH, SUBSTRING, RECURSIVE)
    else:
        print(f"错误: 指定的输入路径不是有效的JSON文件或目录 - {INPUT_PATH}")