import json
import random
import logging
import os
from tqdm import tqdm

class DPODataAgent:
    """
    根据奖励评分结果生成DPO训练数据对的Agent
    将高奖励步骤作为正样本，低奖励步骤作为负样本，生成偏好对
    """
    
    @staticmethod
    def read_json(file_path):
        """读取JSON文件"""
        with open(file_path, 'r', encoding='utf-8') as f:
            return json.load(f)
            
    @staticmethod
    def save_jsonl(data, file_path):
        """保存数据为JSONL格式"""
        os.makedirs(os.path.dirname(file_path), exist_ok=True)
        with open(file_path, 'w', encoding='utf-8') as f:
            for item in data:
                f.write(json.dumps(item, ensure_ascii=False) + '\n')
    
    @staticmethod
    def save_json(data, file_path):
        """保存数据为JSON格式"""
        os.makedirs(os.path.dirname(file_path), exist_ok=True)
        with open(file_path, 'w', encoding='utf-8') as f:
            json.dump(data, f, ensure_ascii=False, indent=2)
                
    @staticmethod
    def extract_content_text(step, plan_type):
        """根据计划类型提取内容文本"""
        if plan_type == 'Week' or plan_type.startswith('Menu'):
            # 周记/菜单类型
            if 'events' in step:
                if isinstance(step['events'], list):
                    if all(isinstance(e, dict) for e in step['events']):
                        # 事件列表是字典列表
                        return "\n".join([f"{e.get('event', '')}: {e.get('details', '')}" for e in step['events']])
                    else:
                        # 事件列表是字符串列表
                        return "\n".join(step['events'])
                else:
                    # 事件是单个字符串
                    return str(step['events'])
            elif 'dishes' in step:
                if isinstance(step['dishes'], list):
                    return "\n".join(str(dish) for dish in step['dishes'])
                else:
                    return str(step['dishes'])
            else:
                return str(step)
        
        elif plan_type == 'Floor':
            # 楼层类型
            if 'purpose' in step:
                return f"{step.get('floor_id', '')}: {step['purpose']}"
            else:
                return str(step)
        
        elif plan_type == 'Block':
            # 区块类型
            if 'use' in step:
                return f"{step.get('block_id', '')}: {step['use']}"
            else:
                return str(step)
        
        # 默认情况
        return str(step)
    
    @staticmethod
    def extract_dpo_pairs(reward_results, min_reward_diff=0.2, random_seed=42):
        """根据奖励分数提取正反样本对"""
        random.seed(random_seed)
        dpo_pairs = []
        valid_count = 0
        
        for item in tqdm(reward_results, desc="Extracting DPO pairs"):
            # 跳过无效计划或没有奖励分数的计划
            if not item.get('is_valid') and not item.get('filter_response', '').upper() == 'YES':
                continue
            if 'reward_stats' not in item:
                continue
                
            valid_count += 1
            
            # 获取计划类型
            plan_type = item.get('type', '')
            
            # 确定计划键
            plan_key = None
            for key in ["weekly_plan", "floor_plan", "block_plan", 
                       "filtered_weekly_plan", "filtered_floor_plan", "filtered_block_plan"]:
                if key in item and isinstance(item[key], list) and item[key]:
                    plan_key = key
                    break
                    
            if not plan_key:
                continue
                
            # 获取最大和最小奖励样本
            max_steps = item.get('max_reward_steps', [])
            min_steps = item.get('min_reward_steps', [])
            
            # 确保有不同奖励的步骤
            if not max_steps or not min_steps or max_steps[0].get('reward', 0) == min_steps[0].get('reward', 0):
                continue
                
            # 检查奖励差距是否足够大
            reward_diff = max_steps[0].get('reward', 0) - min_steps[0].get('reward', 0)
            if reward_diff < min_reward_diff:
                logging.debug(f"Skipping pair with small reward diff: {reward_diff}")
                continue
            
            # 在每个步骤中随机选择一个作为样本
            chosen_step = random.choice(max_steps)
            rejected_step = random.choice(min_steps)
            
            # 创建DPO样本对
            dpo_sample = {
                'id': item.get('id', f"{plan_type}_{valid_count}"),
                'prompt': item.get('prompt', ''),
                'type': plan_type,
                'chosen': chosen_step,
                'rejected': rejected_step,
                'chosen_reward': float(chosen_step.get('reward', 0)),
                'rejected_reward': float(rejected_step.get('reward', 0)),
                'chosen_advantage': float(chosen_step.get('advantage', 0)),
                'rejected_advantage': float(rejected_step.get('advantage', 0))
            }
            
            # 附加特定任务的检查条件
            for check_key in ['checks_once', 'checks_range', 'checks_periodic']:
                if check_key in item:
                    dpo_sample[check_key] = item[check_key]
                    
            dpo_pairs.append(dpo_sample)
        
        logging.info(f"Generated {len(dpo_pairs)} valid DPO pairs from {valid_count} valid plans")
        return dpo_pairs
    
    @staticmethod
    def format_for_training(dpo_pairs):
        """将DPO对格式化为训练所需格式"""
        training_data = []
        
        for pair in dpo_pairs:
            # 根据计划类型提取内容文本
            chosen_text = DPODataAgent.extract_content_text(pair['chosen'], pair['type'])
            rejected_text = DPODataAgent.extract_content_text(pair['rejected'], pair['type'])
            
            # 创建训练样本
            sample = {
                'id': pair['id'],
                'prompt': pair['prompt'],
                'chosen': chosen_text,
                'rejected': rejected_text,
                'type': pair['type']
            }
            training_data.append(sample)
            
        return training_data
    
    @staticmethod
    def generate_dpo_data(reward_file, output_file, format_type='training', min_reward_diff=0.2):
        """
        生成DPO训练数据的主入口函数
        
        参数:
        - reward_file: 奖励结果文件路径
        - output_file: 输出文件路径
        - format_type: 输出格式，'pairs'为详细信息，'training'为训练格式
        - min_reward_diff: 最小奖励差距，低于此值的对会被过滤
        
        返回:
        - 生成的DPO数据列表
        """
        # 读取奖励结果
        try:
            reward_results = DPODataAgent.read_json(reward_file)
            logging.info(f"Loaded {len(reward_results)} examples from {reward_file}")
        except Exception as e:
            logging.error(f"Error reading reward file: {e}")
            return []
        
        # 提取DPO对
        dpo_pairs = DPODataAgent.extract_dpo_pairs(reward_results, min_reward_diff)
        
        # 根据需要的格式处理和保存数据
        if format_type == 'training':
            # 转换为训练格式
            training_data = DPODataAgent.format_for_training(dpo_pairs)
            # 使用JSON格式保存而非JSONL
            output_file = output_file.replace('.jsonl', '.json')
            DPODataAgent.save_json(training_data, output_file)
            logging.info(f"Saved {len(training_data)} training examples to {output_file}")
            return training_data
        else:
            # 输出详细的DPO对
            # 使用JSON格式保存而非JSONL
            output_file = output_file.replace('.jsonl', '.json')
            DPODataAgent.save_json(dpo_pairs, output_file)
            logging.info(f"Saved {len(dpo_pairs)} detailed DPO pairs to {output_file}")
            return dpo_pairs
    
    @staticmethod
    async def async_generate_dpo_data(reward_file, output_file, format_type='training', min_reward_diff=0.2):
        """异步版本的DPO数据生成函数，便于与异步工作流集成"""
        return DPODataAgent.generate_dpo_data(reward_file, output_file, format_type, min_reward_diff)