#!/usr/bin/env python3
"""
ALFWorld Rollout Analysis Script

分析ALFWorld训练过程中的rollout细节，生成详细的Excel报告。

Usage:
    python analyze_rollout.py --merged_dir /path/to/merged_models/ckpt_name
"""

import os
import sys
import argparse
import logging
from typing import List, Dict, Tuple, Optional
from collections import defaultdict

import numpy as np
import pandas as pd
import torch
import ray

# 配置日志
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# 添加项目根目录到路径
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.insert(0, PROJECT_ROOT)

from vllm import LLM, SamplingParams
from transformers import AutoTokenizer

# 使用独立的合并脚本
from agent_system.environments.env_package.alfworld import build_alfworld_envs
from agent_system.environments.env_package.alfworld.projection import alfworld_projection
from agent_system.environments.env_manager import AlfWorldEnvironmentManager
from migpo.core_migpo import match_milestones, compute_segment_rewards
from migpo.milestone_loader import get_milestone_loader
from functools import partial
from omegaconf import OmegaConf


# ============================================================
# 配置参数
# ============================================================
ALFWORLD_CONFIG_PATH = os.path.join(PROJECT_ROOT, 'agent_system/environments/env_package/alfworld/configs/config_tw.yaml')

# Split映射：用户友好名称 -> (环境参数, is_train)
SPLIT_MAPPING = {
    'train': ('train', True),
    'valid_seen': ('eval_in_distribution', False),
    'valid_unseen': ('eval_out_of_distribution', False),
}


def get_gpu_count() -> int:
    """获取CUDA_VISIBLE_DEVICES中的GPU数量"""
    cuda_devices = os.environ.get('CUDA_VISIBLE_DEVICES', '')
    if not cuda_devices:
        return 1
    return len(cuda_devices.split(','))


def parse_args():
    parser = argparse.ArgumentParser(description='ALFWorld Rollout Analysis')
    # 模型路径选项（二选一）
    parser.add_argument('--merged_dir', type=str, default=None,
                        help='合并后模型的目录（分析多个checkpoint）')
    parser.add_argument('--model_path', type=str, default=None,
                        help='直接指定模型路径（单个checkpoint分析）')
    parser.add_argument('--steps', type=str, default=None,
                        help='指定要分析的step，逗号分隔，如"100,200,300"')
    # 数据集划分
    parser.add_argument('--split', type=str, default='train',
                        choices=['train', 'valid_seen', 'valid_unseen'],
                        help='数据集划分')
    # 输出和环境参数
    parser.add_argument('--output_dir', type=str, default='./output')
    parser.add_argument('--num_envs', type=int, default=16)
    parser.add_argument('--group_size', type=int, default=8)
    parser.add_argument('--max_steps', type=int, default=30)
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--tensor_parallel_size', type=int, default=None,
                        help='张量并行大小，默认为CUDA_VISIBLE_DEVICES的GPU数量')
    parser.add_argument('--gpu_memory_utilization', type=float, default=0.7)
    # 算法选择
    parser.add_argument('--algorithm', type=str, default='migpo', choices=['grpo', 'gigpo', 'migpo'])
    # GiGPO/MiGPO参数
    parser.add_argument('--step_advantage_w', type=float, default=10.0, help='GiGPO/MiGPO: step advantage权重')
    # MiGPO专用参数
    parser.add_argument('--migpo_gamma', type=float, default=0.9, help='MiGPO: milestone奖励衰减')
    # GiGPO/MiGPO参数
    parser.add_argument('--migpo_threshold', type=float, default=0.95, help='GiGPO/MiGPO: 相似度阈值')
    # 环境参数
    parser.add_argument('--history_length', type=int, default=10, help='历史长度')
    return parser.parse_args()


def get_checkpoint_steps(merged_dir: str) -> List[int]:
    """从合并后的模型目录获取所有step编号"""
    steps = []
    for name in os.listdir(merged_dir):
        if name.startswith('step_'):
            step = int(name.replace('step_', ''))
            steps.append(step)
    return sorted(steps)


def get_models_to_analyze(args) -> List[Tuple[int, str]]:
    """根据参数确定要分析的模型列表

    Returns:
        List of (step_num, model_path) tuples
    """
    if args.model_path:
        # 从model_path中提取step数，如 step_100 -> 100
        model_name = os.path.basename(args.model_path)
        if model_name.startswith('step_'):
            step_num = int(model_name.replace('step_', ''))
        else:
            step_num = 0
        return [(step_num, args.model_path)]

    if not args.merged_dir:
        raise ValueError("必须指定 --model_path 或 --merged_dir")

    # 获取所有step
    all_steps = get_checkpoint_steps(args.merged_dir)

    if args.steps:
        # 只分析指定的step
        selected = [int(s.strip()) for s in args.steps.split(',')]
        steps = [s for s in all_steps if s in selected]
    else:
        steps = all_steps

    return [(step, os.path.join(args.merged_dir, f'step_{step}')) for step in steps]


# ============================================================
# 模型推理
# ============================================================
class VLLMInference:
    """vLLM推理封装"""

    def __init__(self, model_path: str, tp_size: int = 1, gpu_util: float = 0.7):
        self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
        self.llm = LLM(
            model=model_path,
            tensor_parallel_size=tp_size,
            gpu_memory_utilization=gpu_util,
            trust_remote_code=True,
        )
        self.sampling_params = SamplingParams(
            temperature=0.4,
            max_tokens=512,
        )

    def generate(self, prompts: List[str]) -> List[str]:
        """批量生成响应"""
        outputs = self.llm.generate(prompts, self.sampling_params)
        return [output.outputs[0].text for output in outputs]

    def cleanup(self):
        """清理资源"""
        del self.llm
        torch.cuda.empty_cache()


# ============================================================
# Trial ID提取
# ============================================================
def extract_trial_id(gamefile_path: str) -> Optional[str]:
    """从gamefile路径提取trial_id，用于milestone匹配"""
    if not gamefile_path:
        return None
    # 路径格式: .../task_type-xxx/trial_Txxx/game.tw-pddl
    parts = gamefile_path.split('/')
    for i, part in enumerate(parts):
        if part.startswith('trial_T'):
            task_dir = parts[i-1]  # 任务类型目录
            trial_dir = part       # trial目录
            return f"{task_dir}_{trial_dir}"
    return None


# ============================================================
# Rollout执行
# ============================================================
def extract_action(response: str) -> str:
    """从模型响应中提取action"""
    if '<action>' in response and '</action>' in response:
        start = response.find('<action>') + len('<action>')
        end = response.find('</action>')
        return response[start:end].strip()
    return response.strip()


def build_prompt(task_desc: str, obs: str, admissible: List[str],
                 history: List[Tuple[str, str]], step: int) -> str:
    """构建prompt"""
    history_str = ""
    if history:
        for i, (act, ob) in enumerate(history[-10:]):  # 最近10步
            history_str += f"\nStep {i+1}: Action: {act}\nObservation: {ob}"

    admissible_str = ", ".join(admissible) if admissible else "None"

    return ALFWORLD_TEMPLATE.format(
        task_description=task_desc,
        step_count=step,
        history_length=min(len(history), 10),
        action_history=history_str if history_str else "None",
        current_step=step + 1,
        current_observation=obs,
        admissible_actions=admissible_str,
    )


def run_rollout(model: VLLMInference, env_manager, num_trajs: int, max_steps: int):
    """执行rollout并收集数据（使用EnvironmentManager确保与训练一致）"""
    # 重置环境
    obs, infos = env_manager.reset(kwargs=None)
    text_obs = obs['text']  # EnvironmentManager返回的是dict

    # 提取trial_ids（从gamefile路径转换为milestone JSON的id格式）
    trial_ids = [extract_trial_id(info.get('extra.gamefile')) for info in infos]

    # 初始化
    trajectories = [[] for _ in range(num_trajs)]
    is_done = np.zeros(num_trajs, dtype=bool)
    episode_rewards = np.zeros(num_trajs, dtype=np.float32)

    for step_idx in range(max_steps):
        active_mask = ~is_done
        if not active_mask.any():
            break

        # EnvironmentManager已经构建好了prompt，直接使用text_obs
        prompts = text_obs

        # 生成响应
        responses = model.generate(prompts)

        # 执行动作（EnvironmentManager内部会调用projection提取action）
        next_obs, rewards, dones, next_infos = env_manager.step(responses)

        # 记录数据
        next_text_obs = next_obs['text']
        for i in range(num_trajs):
            if active_mask[i]:
                # 从info中获取提取后的action
                action = next_infos[i].get('action', responses[i])
                step_data = {
                    'step_idx': step_idx,
                    'action': action,
                    'observation': text_obs[i],
                    'response': responses[i],
                    'reward': rewards[i],
                    'done': dones[i],
                    'info': next_infos[i],  # 使用当前步的info，包含正确的won字段
                }
                trajectories[i].append(step_data)
                episode_rewards[i] += rewards[i]

        # 更新状态
        is_done = np.logical_or(is_done, dones)
        text_obs = next_text_obs
        infos = next_infos

    # 计算成功率
    success = np.array([traj[-1]['info'].get('won', False) if traj else False for traj in trajectories])
    episode_lengths = np.array([len(traj) for traj in trajectories])

    return trajectories, episode_rewards, episode_lengths, success, trial_ids


# ============================================================
# Advantage计算
# ============================================================
def compute_step_rewards(trajectories, trial_ids, gamma, threshold):
    """计算milestone-based step rewards"""
    loader = get_milestone_loader()
    all_step_rewards = []

    for traj_idx, traj in enumerate(trajectories):
        trial_id = trial_ids[traj_idx]
        milestones = loader.get_milestones(trial_id) if trial_id else None

        if milestones is None or len(traj) == 0:
            all_step_rewards.append([0.0] * len(traj))
            continue

        actions = [step['action'] for step in traj]
        match_indices = match_milestones(actions, milestones, threshold)
        rewards, _ = compute_segment_rewards(len(actions), match_indices, gamma)
        all_step_rewards.append((rewards * 10.0).tolist())

    return all_step_rewards


def compute_grpo_advantages(trajectories, episode_rewards, num_envs, group_size):
    """GRPO: 仅episode级别的advantage"""
    num_trajs = len(trajectories)

    # Episode advantage: 按环境分组归一化
    episode_advs = np.zeros(num_trajs)
    for env_idx in range(num_envs):
        start = env_idx * group_size
        end = start + group_size
        group_rewards = episode_rewards[start:end]
        mean_r = np.mean(group_rewards)
        episode_advs[start:end] = group_rewards - mean_r

    # GRPO没有step级别的advantage
    step_rewards = [[0.0] * len(traj) for traj in trajectories]
    step_advs = [[0.0] * len(traj) for traj in trajectories]

    # Total advantage = episode advantage
    total_advs = []
    for traj_idx in range(num_trajs):
        total_advs.append([episode_advs[traj_idx]] * len(trajectories[traj_idx]))

    return {
        'step_rewards': step_rewards,
        'episode_advs': episode_advs,
        'step_advs': step_advs,
        'total_advs': total_advs,
    }


def compute_gigpo_advantages(trajectories, episode_rewards, num_envs, group_size,
                              step_advantage_w, threshold):
    """GiGPO: episode + 基于observation相似度的step级别advantage"""
    from difflib import SequenceMatcher
    num_trajs = len(trajectories)

    # Episode advantage: 按环境分组归一化
    episode_advs = np.zeros(num_trajs)
    for env_idx in range(num_envs):
        start = env_idx * group_size
        end = start + group_size
        group_rewards = episode_rewards[start:end]
        mean_r = np.mean(group_rewards)
        episode_advs[start:end] = group_rewards - mean_r

    # Step级别: 基于observation相似度分组
    # 收集所有(env_idx, step_idx, observation, reward)
    all_steps = []
    for traj_idx, traj in enumerate(trajectories):
        env_idx = traj_idx // group_size
        for step_idx, step in enumerate(traj):
            obs = step['observation']
            reward = step['reward']
            all_steps.append((env_idx, traj_idx, step_idx, obs, reward))

    # 按env_idx分组，然后在组内按observation相似度聚类
    step_rewards = [[0.0] * len(traj) for traj in trajectories]
    step_advs = [[0.0] * len(traj) for traj in trajectories]

    for env_idx in range(num_envs):
        env_steps = [s for s in all_steps if s[0] == env_idx]
        if not env_steps:
            continue

        # 简单的相似度分组：相同observation归为一组
        obs_groups = defaultdict(list)
        for _, traj_idx, step_idx, obs, reward in env_steps:
            # 使用observation的hash作为key（精确匹配）
            obs_key = obs[:200]  # 截断以提高效率
            obs_groups[obs_key].append((traj_idx, step_idx, reward))

        # 计算每组的step advantage
        for obs_key, group in obs_groups.items():
            if len(group) <= 1:
                continue
            rewards = [r for _, _, r in group]
            mean_r = np.mean(rewards)
            for traj_idx, step_idx, reward in group:
                step_rewards[traj_idx][step_idx] = reward
                step_advs[traj_idx][step_idx] = reward - mean_r

    # Total advantage
    total_advs = []
    for traj_idx in range(num_trajs):
        traj_total = []
        for step_idx in range(len(trajectories[traj_idx])):
            total = episode_advs[traj_idx] + step_advantage_w * step_advs[traj_idx][step_idx]
            traj_total.append(total)
        total_advs.append(traj_total)

    return {
        'step_rewards': step_rewards,
        'episode_advs': episode_advs,
        'step_advs': step_advs,
        'total_advs': total_advs,
    }


def compute_migpo_advantages(trajectories, episode_rewards, trial_ids, num_envs, group_size,
                              step_advantage_w, gamma, threshold):
    """MiGPO: episode + 基于milestone的segment级别advantage"""
    num_trajs = len(trajectories)
    step_rewards = compute_step_rewards(trajectories, trial_ids, gamma, threshold)

    # Episode advantage: 按环境分组归一化
    episode_advs = np.zeros(num_trajs)
    for env_idx in range(num_envs):
        start = env_idx * group_size
        end = start + group_size
        group_rewards = episode_rewards[start:end]
        mean_r = np.mean(group_rewards)
        episode_advs[start:end] = group_rewards - mean_r

    # Step advantage: 按milestone segment分组归一化
    step_advs = []
    for traj_idx, traj in enumerate(trajectories):
        if len(traj) == 0:
            step_advs.append([])
            continue
        rewards = np.array(step_rewards[traj_idx])
        mean_r = np.mean(rewards) if len(rewards) > 0 else 0
        step_advs.append((rewards - mean_r).tolist())

    # Total advantage
    total_advs = []
    for traj_idx in range(num_trajs):
        traj_total = []
        for step_idx in range(len(trajectories[traj_idx])):
            total = episode_advs[traj_idx] + step_advantage_w * step_advs[traj_idx][step_idx]
            traj_total.append(total)
        total_advs.append(traj_total)

    return {
        'step_rewards': step_rewards,
        'episode_advs': episode_advs,
        'step_advs': step_advs,
        'total_advs': total_advs,
    }


def compute_advantages(algorithm, trajectories, episode_rewards, trial_ids, num_envs, group_size,
                       step_advantage_w, gamma, threshold):
    """根据算法类型计算advantage"""
    if algorithm == 'grpo':
        return compute_grpo_advantages(trajectories, episode_rewards, num_envs, group_size)
    elif algorithm == 'gigpo':
        return compute_gigpo_advantages(trajectories, episode_rewards, num_envs, group_size,
                                        step_advantage_w, threshold)
    elif algorithm == 'migpo':
        return compute_migpo_advantages(trajectories, episode_rewards, trial_ids, num_envs, group_size,
                                        step_advantage_w, gamma, threshold)
    else:
        raise ValueError(f"Unknown algorithm: {algorithm}")


# ============================================================
# 数据收集
# ============================================================
TASK_TYPES = [
    "pick_and_place",
    "pick_two_obj_and_place",
    "look_at_obj_in_light",
    "pick_heat_then_place_in_recep",
    "pick_cool_then_place_in_recep",
    "pick_clean_then_place_in_recep",
]


def extract_task_type(trial_id: str) -> str:
    """从trial_id提取任务类型"""
    if not trial_id:
        return 'unknown'
    for task in TASK_TYPES:
        if task in trial_id:
            return task
    return 'unknown'


def count_milestones(traj, trial_id, threshold):
    """计算完成的milestone数量"""
    loader = get_milestone_loader()
    milestones = loader.get_milestones(trial_id) if trial_id else None
    if milestones is None or len(traj) == 0:
        return 0, 0
    actions = [step['action'] for step in traj]
    match_indices = match_milestones(actions, milestones, threshold)
    completed = sum(1 for idx in match_indices if idx != -1)
    return completed, len(milestones)


class DataCollector:
    """数据收集器"""

    def __init__(self):
        self.traj_data = []
        self.step_data = []

    def add_data(self, step_num, trajectories, episode_rewards, episode_lengths,
                 success, trial_ids, advantages, num_envs, group_size, threshold):
        """添加一个checkpoint的数据"""
        num_trajs = len(trajectories)

        for traj_idx, traj in enumerate(trajectories):
            env_id = traj_idx // group_size
            task_type = extract_task_type(trial_ids[traj_idx])
            completed, total = count_milestones(traj, trial_ids[traj_idx], threshold)

            # 轨迹级数据
            self.traj_data.append({
                'training_step': step_num,
                'traj_id': f"step{step_num}_traj{traj_idx}",
                'env_id': env_id,
                'task_type': task_type,
                'is_success': bool(success[traj_idx]),
                'episode_reward': episode_rewards[traj_idx],
                'episode_length': episode_lengths[traj_idx],
                'milestones_completed': completed,
                'total_milestones': total,
            })

            # Step级数据
            ms_so_far = 0
            for s_idx, step in enumerate(traj):
                # 检查是否达成milestone
                ms_achieved = (s_idx < len(advantages['step_rewards'][traj_idx]) and
                               advantages['step_rewards'][traj_idx][s_idx] >= 9.0)
                if ms_achieved:
                    ms_so_far += 1

                self.step_data.append({
                    'training_step': step_num,
                    'traj_id': f"step{step_num}_traj{traj_idx}",
                    'step_idx': s_idx,
                    'env_id': env_id,
                    'action': step['action'],
                    'observation': step['observation'],  # 完整prompt
                    'step_reward': step['reward'],
                    'episode_reward': episode_rewards[traj_idx],
                    'step_advantage': advantages['step_advs'][traj_idx][s_idx] if s_idx < len(advantages['step_advs'][traj_idx]) else 0,
                    'episode_advantage': advantages['episode_advs'][traj_idx],
                    'total_advantage': advantages['total_advs'][traj_idx][s_idx] if s_idx < len(advantages['total_advs'][traj_idx]) else 0,
                    'milestone_achieved': ms_achieved,
                    'milestones_so_far': ms_so_far,
                })

    def save_to_excel(self, output_dir: str):
        """保存到CSV"""
        os.makedirs(output_dir, exist_ok=True)

        traj_df = pd.DataFrame(self.traj_data)
        step_df = pd.DataFrame(self.step_data)

        traj_path = os.path.join(output_dir, 'trajectory_data.csv')
        step_path = os.path.join(output_dir, 'step_data.csv')

        traj_df.to_csv(traj_path, index=False)
        step_df.to_csv(step_path, index=False)

        print(f"Saved {len(traj_df)} trajectories to {traj_path}")
        print(f"Saved {len(step_df)} steps to {step_path}")


# ============================================================
# 主函数
# ============================================================
def main():
    args = parse_args()

    # 处理tensor_parallel_size默认值
    if args.tensor_parallel_size is None:
        args.tensor_parallel_size = get_gpu_count()

    logger.info("=" * 60)
    logger.info("ALFWorld Rollout Analysis")
    logger.info("=" * 60)
    if args.model_path:
        logger.info(f"Model path: {args.model_path}")
    else:
        logger.info(f"Merged dir: {args.merged_dir}")
    logger.info(f"Split: {args.split}")
    logger.info(f"Output dir: {args.output_dir}")
    logger.info(f"Num envs: {args.num_envs}, Group size: {args.group_size}")
    logger.info(f"Tensor parallel size: {args.tensor_parallel_size}")

    # 初始化Ray
    logger.info("初始化Ray...")
    if not ray.is_initialized():
        ray.init()

    # 获取要分析的模型列表
    models = get_models_to_analyze(args)
    logger.info(f"找到 {len(models)} 个模型待分析")

    # 获取split对应的环境参数
    eval_dataset, is_train = SPLIT_MAPPING[args.split]

    # 数据收集器
    collector = DataCollector()
    num_trajs = args.num_envs * args.group_size

    for idx, (step_num, model_path) in enumerate(models):
        logger.info("")
        logger.info("=" * 60)
        logger.info(f"[{idx+1}/{len(models)}] 处理 {model_path}")
        logger.info("=" * 60)

        # 计算输出目录: output/{exp_name}/{split}/step_{step}/
        exp_name = os.path.basename(args.merged_dir) if args.merged_dir else os.path.basename(os.path.dirname(model_path))
        output_subdir = os.path.join(args.output_dir, exp_name, args.split, f"step_{step_num}")
        os.makedirs(output_subdir, exist_ok=True)
        logger.info(f"输出目录: {output_subdir}")

        # 1. 加载模型
        logger.info("[Step 1/5] 加载vLLM模型...")
        model = VLLMInference(model_path, args.tensor_parallel_size, args.gpu_memory_utilization)

        # 2. 创建环境（使用EnvironmentManager确保与训练一致）
        logger.info(f"[Step 2/5] 创建ALFWorld环境 (split={args.split})...")
        _envs = build_alfworld_envs(
            ALFWORLD_CONFIG_PATH, args.seed, args.num_envs, args.group_size,
            {'num_cpus': 0.1}, is_train=is_train,
            env_kwargs={'eval_dataset': eval_dataset}
        )
        # 创建config用于EnvironmentManager
        env_config = OmegaConf.create({
            'env': {'history_length': args.history_length}
        })
        projection_f = partial(alfworld_projection)
        env_manager = AlfWorldEnvironmentManager(_envs, projection_f, env_config)

        # 3. 执行rollout
        logger.info(f"[Step 3/5] 执行rollout ({num_trajs} 条轨迹)...")
        trajs, ep_rewards, ep_lengths, success, trial_ids = run_rollout(
            model, env_manager, num_trajs, args.max_steps
        )

        # 4. 计算advantage
        logger.info(f"[Step 4/5] 计算{args.algorithm.upper()} advantages...")
        advs = compute_advantages(args.algorithm, trajs, ep_rewards, trial_ids, args.num_envs, args.group_size,
                                  args.step_advantage_w, args.migpo_gamma, args.migpo_threshold)

        # 5. 收集数据
        logger.info("[Step 5/5] 收集数据...")
        collector.add_data(step_num, trajs, ep_rewards, ep_lengths, success,
                          trial_ids, advs, args.num_envs, args.group_size, args.migpo_threshold)

        # 打印统计
        logger.info(f"统计: 成功率={np.mean(success):.2%}, 平均长度={np.mean(ep_lengths):.1f}, 平均奖励={np.mean(ep_rewards):.2f}")

        # 及时保存数据
        logger.info("保存数据...")
        collector.save_to_excel(output_subdir)

        # 清理
        logger.info("清理资源...")
        env_manager.envs.close()
        model.cleanup()

    logger.info("")
    logger.info("=" * 60)
    logger.info("完成!")


if __name__ == '__main__':
    main()
