import json
import random
import re

def shuffle_options(question, answer, chosen_answer):
    """
    打乱问题中的选项顺序，并更新正确答案和模型回答
    """
    try:
        # 提取选项 - 使用非贪婪匹配和更健壮的正则表达式
        options_pattern = r'{"A": "(.*?)", "B": "(.*?)", "C": "(.*?)", "D": "(.*?)"}'
        options_match = re.search(options_pattern, question)
        
        if not options_match:
            # 尝试处理可能包含转义字符的JSON
            # 将JSON解析为字典后再处理
            json_pattern = r'{.*?}'
            json_matches = re.findall(json_pattern, question)
            
            if json_matches:
                for json_str in json_matches:
                    try:
                        options_dict = json.loads(json_str)
                        if all(key in options_dict for key in ['A', 'B', 'C', 'D']):
                            original_options = {
                                "A": options_dict["A"],
                                "B": options_dict["B"],
                                "C": options_dict["C"],
                                "D": options_dict["D"]
                            }
                            break
                    except json.JSONDecodeError:
                        continue
            else:
                print(f"无法提取选项，保持原样: {question[:100]}...")
                return question, answer, chosen_answer
        else:
            # 原始选项
            original_options = {
                "A": options_match.group(1),
                "B": options_match.group(2),
                "C": options_match.group(3),
                "D": options_match.group(4)
            }
        
        # 记录原始正确答案的内容
        correct_content = original_options[answer]
        
        # 选项内容列表
        option_contents = list(original_options.values())
        
        # 打乱选项内容
        random.shuffle(option_contents)
        
        # 创建新的选项映射
        new_options = {
            "A": option_contents[0],
            "B": option_contents[1],
            "C": option_contents[2],
            "D": option_contents[3]
        }
        
        # 找到正确答案的新选项标签
        new_answer = None
        for key, value in new_options.items():
            if value == correct_content:
                new_answer = key
                break
        
        # 安全地更新问题中的选项 - 使用JSON序列化确保正确转义
        options_json = json.dumps({
            "A": new_options["A"],
            "B": new_options["B"],
            "C": new_options["C"],
            "D": new_options["D"]
        })
        
        # 使用更安全的替换方法
        if options_match:
            new_question = re.sub(
                options_pattern,
                options_json.replace('"', '\\"'),  # 确保引号正确转义
                question
            )
        else:
            # 如果使用了JSON解析方法
            for json_str in json_matches:
                try:
                    options_dict = json.loads(json_str)
                    if all(key in options_dict for key in ['A', 'B', 'C', 'D']):
                        new_question = question.replace(
                            json_str, 
                            options_json
                        )
                        break
                except json.JSONDecodeError:
                    continue
        
        # 替换模型回答中的旧答案标签为新答案标签
        pattern = r'\[\[([A-D])\]\]'
        new_chosen_answer = re.sub(pattern, f'[[{new_answer}]]', chosen_answer)
        
        return new_question, new_answer, new_chosen_answer
    
    except Exception as e:
        print(f"处理选项时出错: {str(e)}")
        return question, answer, chosen_answer

def balance_sft_dataset(input_file, output_file):
    """平衡SFT数据集中的选项分布"""
    with open(input_file, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    balanced_data = []
    answer_distribution = {"A": 0, "B": 0, "C": 0, "D": 0}
    
    for i, item in enumerate(data):
        try:
            if len(item["conversations"]) < 2:
                balanced_data.append(item)
                continue
                
            human_msg = item["conversations"][0]["value"]
            gpt_msg = item["conversations"][1]["value"]
            
            # 提取当前正确答案
            answer_pattern = r'\[\[([A-D])\]\]'
            answer_match = re.search(answer_pattern, gpt_msg)
            if not answer_match:
                balanced_data.append(item)
                continue
                
            current_answer = answer_match.group(1)
            
            # 随机化选项
            new_human_msg, new_answer, new_gpt_msg = shuffle_options(human_msg, current_answer, gpt_msg)
            
            # 更新数据项
            new_item = item.copy()
            new_item["conversations"] = [
                {"from": "human", "value": new_human_msg},
                {"from": "gpt", "value": new_gpt_msg}
            ]
            
            balanced_data.append(new_item)
            answer_distribution[new_answer] += 1
            
            if (i+1) % 20 == 0:
                print(f"已处理 {i+1}/{len(data)} 条数据")
        
        except Exception as e:
            print(f"处理第 {i+1} 条数据时出错: {str(e)}")
            balanced_data.append(item)  # 发生错误时保留原始数据
    
    print(f"均衡后答案分布: {answer_distribution}")
    
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(balanced_data, f, ensure_ascii=False, indent=2)
    
    print(f"平衡后的数据已保存至 {output_file}")

def add_bias_warning(input_file, output_file):
    """在问题提示中添加反偏差警告"""
    with open(input_file, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    warning = (
        "\n\n注意：选择答案时请基于图像内容而非选项顺序。"
        "不要偏向于特定选项(如A、B、C、D)，"
        "而应根据图像内容选择正确答案。"
    )
    
    # 添加警告
    for item in data:
        if len(item["conversations"]) >= 1:
            human_msg = item["conversations"][0]["value"]
            item["conversations"][0]["value"] = human_msg + warning
    
    # 保存结果
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(data, f, ensure_ascii=False, indent=2)
    
    print(f"已添加偏差警告，数据保存至 {output_file}")

# 使用示例
if __name__ == "__main__":
    base_dir = "/Users/wad3/Downloads/Research/AutoBench-V-private/document/basic_understanding/sft_data"
    
    # 平衡选项分布
    input_file = f"{base_dir}/all_sft_data.json"
    balanced_file = f"{base_dir}/balanced_all_sft_data.json"
    balance_sft_dataset(input_file, balanced_file)
    
    # 添加反偏差警告
    final_file = f"{base_dir}/balanced_with_warning_sft_data.json"
    add_bias_warning(balanced_file, final_file)
    
    # 生成训练集和验证集
    with open(final_file, 'r', encoding='utf-8') as f:
        final_data = json.load(f)
    
    # 随机打乱数据
    random.seed(42)  # 固定种子以确保可复现性
    random.shuffle(final_data)
    
    # 分割数据
    val_ratio = 0.1
    val_size = int(len(final_data) * val_ratio)
    train_data = final_data[val_size:]
    val_data = final_data[:val_size]
    
    # 保存训练集和验证集
    with open(f"{base_dir}/train_sft_data.json", 'w', encoding='utf-8') as f:
        json.dump(train_data, f, ensure_ascii=False, indent=2)
    
    with open(f"{base_dir}/val_sft_data.json", 'w', encoding='utf-8') as f:
        json.dump(val_data, f, ensure_ascii=False, indent=2)
    
    print(f"数据集已分割：训练集 {len(train_data)} 条，验证集 {len(val_data)} 条")