import json
import argparse
import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from core.template_dict import CRITERIA, TEMPLATES_TEXT, TEMPLATES_IMAGE_250716_TWO_STAGE as TEMPLATES_IMAGE


# 默认输入文件路径配置（支持每个风险类型多个文件）
DEFAULT_INPUT_JSONL_DICT = {
    "disease": [
        "data/eval/disease/qwen2.5-7b-instruct/template_disease_text_200_seed-30_qwen2.5-7b-instruct_20250710_1219.jsonl",
        "data/eval/disease/qwen2.5-7b-instruct/template_disease_text_200_seed-30_qwen2.5-7b-instruct_20250711_004048.jsonl",
        "data/eval/disease/qwen2.5-7b-instruct/template_disease_text_200_seed-30_qwen2.5-7b-instruct_20250713_1939.jsonl"
    ],
    "suoyin": [
        "data/eval/suoyin/qwen2.5-7b-instruct/template_suoyin_text_200_seed-30_qwen2.5-7b-instruct_20250715_1529.jsonl",
        "data/eval/suoyin/qwen2.5-7b-instruct/template_suoyin_text_200_seed-30_qwen2.5-7b-instruct_20250714_2144.jsonl"
    ], 
    "jianfei": [
        "data/eval/jianfei/qwen2.5-7b-instruct/template_jianfei_text_200_seed-30_qwen2.5-7b-instruct_20250715_0021.jsonl"
    ],
    "fengxiong": [
        "data/eval/fengxiong/qwen2.5-7b-instruct/template_fengxiong_text_200_seed-30_qwen2.5-7b-instruct_20250714_2145.jsonl"
    ],
    "tall": [
        "data/eval/tall/qwen2.5-7b-instruct/template_tall_text_200_seed-30_qwen2.5-7b-instruct_20250710_215652.jsonl"
    ],
    "zhuangyang": [
        "data/eval/zhuangyang/qwen2.5-7b-instruct/template_zhuangyang_text_200_seed-30_qwen2.5-7b-instruct_20250710_232621.jsonl"
    ]
}

def parse_args():
    parser = argparse.ArgumentParser(description='Build DPO data from JSONL eval results.')
    parser.add_argument('--input_jsonl_dict', type=str, default=None, help='JSON file containing paths to input JSONL files for each risk type (optional, will use default paths if not provided)')
    parser.add_argument('--output_jsonl', type=str, default='data/dataset/DPO/250716_dpo_data_all_risks.jsonl', help='Output JSONL file for DPO pairs')
    parser.add_argument('--training_set', type=str, default='data/dataset/evade_dataset_train_text.json', help='Training set file path to filter IDs and extract background')
    parser.add_argument('--content_mode', type=str, choices=['text', 'image'], default='text', help='Content mode to select templates from')
    parser.add_argument('--max_samples_per_risk', type=int, default=2000, help='Maximum number of DPO samples to output per risk type')
    return parser.parse_args()


def load_jsonl(path):
    """加载JSONL文件"""
    data = []
    with open(path, encoding='utf-8') as f:
        for line in f:
            if line.strip():
                data.append(json.loads(line))
    return data

def load_input_jsonl_dict(input_jsonl_dict_path=None):
    """加载输入文件字典配置"""
    if input_jsonl_dict_path and os.path.exists(input_jsonl_dict_path):
        with open(input_jsonl_dict_path, 'r', encoding='utf-8') as f:
            return json.load(f)
    else:
        print("[信息] 使用默认输入文件路径配置")
        return DEFAULT_INPUT_JSONL_DICT


def extract_background_from_training_set(training_set_path, risk_type):
    """从训练集中提取指定风险类型的背景信息"""
    with open(training_set_path, 'r', encoding='utf-8') as f:
        training_data = json.load(f)
    
    # 寻找匹配的风险类型数据
    for item in training_data:
        if item.get('content_type') == risk_type and 'question' in item:
            question = item['question']
            # 截取到 "# 输出格式" 前
            if "# 输出格式" in question:
                background = question.split("# 输出格式")[0].strip()
                return background
    
    # 如果没找到，返回默认背景
    print(f"[警告] 未在训练集中找到风险类型 '{risk_type}' 的question，使用默认背景")
    return f"针对{risk_type}类型内容的风险评估分析。"


def get_template_content(template_name, content_type, content_mode):
    """获取模板内容
    
    Args:
        template_name: 模板名称
        content_type: 内容类型 (如 'disease', 'tall' 等)
        content_mode: 内容模式 ('text' 或 'image')
    """
    # 使用code中导入的TEMPLATES
    if content_mode == 'image':
        if content_type in TEMPLATES_IMAGE and template_name in TEMPLATES_IMAGE[content_type]:
            return TEMPLATES_IMAGE[content_type][template_name]
    else:
        if content_type in TEMPLATES_TEXT and template_name in TEMPLATES_TEXT[content_type]:
            return TEMPLATES_TEXT[content_type][template_name]
    print(f"[警告] 模板名 '{template_name}' 在{content_mode}模式的{content_type}类型中查不到！")
    return ''


def load_training_ids(training_set_path):
    """从训练集文件中加载ID集合"""
    if not training_set_path or not os.path.exists(training_set_path):
        return None
    
    with open(training_set_path, 'r', encoding='utf-8') as f:
        training_data = json.load(f)
    
    if isinstance(training_data, list):
        # 如果是list，提取每个item的id
        return set(item['id'] for item in training_data if 'id' in item)
    elif isinstance(training_data, dict) and 'ids' in training_data:
        # 如果是dict且有ids字段
        return set(training_data['ids'])
    else:
        print(f"[警告] 训练集文件格式不支持: {training_set_path}")
        return None

def get_allowed_templates(content_type, content_mode):
    """获取允许的模板列表"""
    if content_mode == 'image':
        templates = set(TEMPLATES_IMAGE.get(content_type, {}).keys())
    else:
        templates = set(TEMPLATES_TEXT.get(content_type, {}).keys())
    
    # 排除baseline模板
    templates.discard('baseline')
    return templates

def build_dpo_pairs(samples, background, content_type, content_mode, training_ids=None):
    """构建DPO数据对
    
    Args:
        samples: 样本数据
        background: 背景信息
        content_type: 内容类型
        content_mode: 内容模式 ('text' 或 'image')
        training_ids: 允许的训练ID集合
    """
    # 获取允许的模板集合
    allowed_templates = get_allowed_templates(content_type, content_mode)
    print(f"[信息] 允许的模板: {sorted(allowed_templates)}")
    
    input_id_groups = {}
    for s in samples:
        # 如果指定了训练集ID，则过滤
        if training_ids is not None and s['id'] not in training_ids:
            continue
            
        key = (s['input'], s['id'])
        input_id_groups.setdefault(key, []).append(s)

    dpo_data = []
    for (input_text, _), group in input_id_groups.items():
        # 排除baseline、CoT模板，并只保留允许的模板
        by_template = {}
        for s in group:
            template_name = s['template_name']
            if (template_name.lower() not in ['baseline', 'cot'] and 
                template_name in allowed_templates):
                by_template[template_name] = s
        
        if len(by_template) < 2:
            continue  # 需要至少2个模板才能对比
            
        templates = list(by_template.keys())
        for i in range(len(templates)):
            for j in range(len(templates)):
                if i == j:
                    continue
                a = by_template[templates[i]]
                b = by_template[templates[j]]
                a_correct = set(a['pred']) == set(a['gt'])
                b_correct = set(b['pred']) == set(b['gt'])
                
                # 只有当两个模板结果不同时才生成DPO pair
                if a_correct != b_correct:
                    # 确定胜利者和失败者
                    if a_correct and not b_correct:
                        winner = a
                        loser = b
                    elif b_correct and not a_correct:
                        winner = b
                        loser = a
                    else:
                        continue
                    
                    # 获取胜利者的模板用于构建prompt
                    winner_template = get_template_content(winner['template_name'], content_type, content_mode)
                    
                    # 构建新的prompt，将Problem和胜利者模板拼接
                    question_truncated = background
                    text = input_text
                    template = winner_template
                    
                    prompt = f'''你是一个严谨的分析引擎。你的唯一任务是严格遵循下面提供的`推理模板`中的每一个步骤，对`给定信息`进行分析，并得出最终结论。模板中的每一步都是你必须执行的指令。
{question_truncated}

# 给定信息
{text.strip()}

# 推理模板
{template.strip() if template else ""}

# 推理流程

## 第一部分：整体分析
首先忽略细节，通读给定信息，形成一个初步的、整体的判断。

## 第二部分：搜集证据
在这一部分，请不要直接给出结论。而是针对模板中的 [占位符]（代表问题的关键、复杂、易错的方面）进行分析，来尝试证明你的整体分析中的结论。
1. 快速扫描：阅读用户提供的文本。
2. 提取核心实体：识别出文本中提到的具体疾病/状况、核心症状、以及提及的产品/疗法等。
3. 找到模板中的 [占位符]。
4. 对[占位符]进行初步分析和判断。

## 第三部分：综合研判与最终结论
综合以上`模板化推理过程`中每一步的分析结果，进行最终的判断，决定是否符合整体判断，还是根据第二部分的证据发现了新的结论。在你的结论中，必须：
1.  明确给出一个或多个最终的管控类型（A-Z）。
2.  详细阐述你的推理过程，说明你是如何整合各步骤的发现，最终得出这个结论的。

# 输出要求
你的唯一任务是判断文本是否触发了管控规则。请像一个精准的规则引擎一样工作，聚焦于寻找并匹配违规信号。
1. **第一部分：模板化推理过程**，严格按照模板步骤，逐条展示分析过程和发现。
2. **第二部分：综合研判与最终结论**，基于第一部分的分析进行汇总和阐述。
3. 最后用`\\box{{}}`输出最终答案，格式为`\\box{{A~Z中的一个或多个选项}}`，box内只能包含答案选项，不允许有其他任何文字。

## 示例
分析：...
\\box{{A}}'''
                    
                    # chosen是胜利者的raw_output，rejected是失败者的raw_output
                    dpo_data.append({
                        "prompt": prompt,
                        "chosen": winner.get('raw_output', ''),
                        "rejected": loser.get('raw_output', '')
                    })
    return dpo_data



def load_samples_from_files(file_paths):
    """从多个文件中加载并合并样本数据"""
    all_samples = []
    total_files = 0
    existing_files = 0
    
    for file_path in file_paths:
        if os.path.exists(file_path):
            samples = load_jsonl(file_path)
            all_samples.extend(samples)
            existing_files += 1
            print(f"    - 加载文件: {file_path} ({len(samples)} 个样本)")
        else:
            print(f"    - [警告] 文件不存在: {file_path}")
        total_files += 1
    
    print(f"[信息] 合并样本: {existing_files}/{total_files} 个文件，总计 {len(all_samples)} 个样本")
    return all_samples

def process_single_risk_type(risk_type, input_jsonl_dict, training_set_path, content_mode, max_samples_per_risk):
    """处理单个风险类型的数据"""
    print(f"\n{'='*60}")
    print(f"处理风险类型: {risk_type}")
    print(f"{'='*60}")
    
    # 加载训练集并提取背景信息
    training_ids = load_training_ids(training_set_path)
    background = extract_background_from_training_set(training_set_path, risk_type)
    print(f"[信息] 提取背景信息: {background[:100]}...")
    
    if training_ids:
        print(f"[信息] 加载训练集ID: {len(training_ids)} 个")
    else:
        print(f"[警告] 无法加载训练集ID，将使用所有样本")
    
    # 检查输入文件配置
    if risk_type not in input_jsonl_dict:
        print(f"[警告] 风险类型 '{risk_type}' 不在输入文件配置中，跳过")
        return []
    
    input_file_config = input_jsonl_dict[risk_type]
    
    # 兼容处理：如果是字符串，转换为列表
    if isinstance(input_file_config, str):
        input_file_paths = [input_file_config]
    elif isinstance(input_file_config, list):
        input_file_paths = input_file_config
    else:
        print(f"[警告] 风险类型 '{risk_type}' 的文件配置格式不正确，跳过")
        return []
    
    # 从多个文件加载并合并数据
    print(f"[信息] 加载风险类型 '{risk_type}' 的数据文件:")
    samples = load_samples_from_files(input_file_paths)
    
    if not samples:
        print(f"[警告] 风险类型 '{risk_type}' 没有加载到任何样本，跳过")
        return []
    
    # 过滤样本（如果指定了训练集）
    if training_ids:
        original_count = len(samples)
        samples = [s for s in samples if s['id'] in training_ids]
        print(f"[信息] 过滤后样本数: {len(samples)} (原始: {original_count})")
    
    if not samples:
        print(f"[警告] 风险类型 '{risk_type}' 过滤后没有有效样本，跳过")
        return []
    
    # 构建DPO数据
    dpo_data = build_dpo_pairs(
        samples, 
        background, 
        risk_type, 
        content_mode, 
        training_ids
    )
    
    # 限制数据量
    dpo_data = dpo_data[:max_samples_per_risk]
    
    # 为每条数据添加风险类型标识
    for item in dpo_data:
        item['risk_type'] = risk_type
    
    print(f"[信息] 风险类型 '{risk_type}' 生成 {len(dpo_data)} 条DPO数据")
    return dpo_data

def main():
    # 配置参数（用变量代替命令行参数）
    input_jsonl_dict_path = None  # JSON file containing paths to input JSONL files for each risk type (optional, will use default paths if not provided)
    output_jsonl = 'data/dataset/DPO/250719_dpo_data_all_risks.jsonl'  # Output JSONL file for DPO pairs
    training_set = 'data/dataset/EVADE/evade_dataset_train_text.json'  # Training set file path to filter IDs and extract background
    content_mode = 'text'  # Content mode to select templates from ('text' or 'image')
    max_samples_per_risk = 2000  # Maximum number of DPO samples to output per risk type
    
    # 加载输入文件字典配置
    input_jsonl_dict = load_input_jsonl_dict(input_jsonl_dict_path)
    print(f"[信息] 使用输入文件配置: {input_jsonl_dict_path if input_jsonl_dict_path else '默认配置'}")
    
    # 所有风险类型
    all_risk_types = ['disease', 'tall', 'suoyin', 'jianfei', 'fengxiong', 'zhuangyang']
    
    # 收集所有DPO数据
    all_dpo_data = []
    
    # 逐个处理每种风险类型
    for risk_type in all_risk_types:
        risk_dpo_data = process_single_risk_type(
            risk_type=risk_type,
            input_jsonl_dict=input_jsonl_dict,
            training_set_path=training_set,
            content_mode=content_mode,
            max_samples_per_risk=max_samples_per_risk
        )
        all_dpo_data.extend(risk_dpo_data)
    
    # 确保输出目录存在
    os.makedirs(os.path.dirname(output_jsonl), exist_ok=True)
    
    # 写入所有数据到一个文件
    with open(output_jsonl, 'w', encoding='utf-8') as f:
        for item in all_dpo_data:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')
    
    # 统计信息
    print(f"\n{'='*60}")
    print(f"处理完成！")
    print(f"{'='*60}")
    print(f"总计输出 {len(all_dpo_data)} 条DPO数据到 {output_jsonl}")
    
    # 按风险类型统计
    risk_counts = {}
    for item in all_dpo_data:
        risk_type = item.get('risk_type', 'unknown')
        risk_counts[risk_type] = risk_counts.get(risk_type, 0) + 1
    
    print("\n各风险类型数据量:")
    for risk_type, count in risk_counts.items():
        print(f"  {risk_type}: {count} 条")

if __name__ == '__main__':
    main() 