"""
Evaluation utilities for ManiSkill policies.
Consistent environment parameters matching training (ppo_opt.py, bc_ghn_student.py).
"""
import os
import torch
import numpy as np

import gymnasium as gym
from mani_skill.utils.wrappers.flatten import FlattenActionSpaceWrapper
from mani_skill.utils.wrappers.record import RecordEpisode
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
import mani_skill.envs


# Consistent environment kwargs matching training scripts
ENV_KWARGS = dict(
    obs_mode='state',
    render_mode='rgb_array',
    sim_backend='physx_cuda',
    control_mode='pd_joint_delta_pos'  # Match ppo_opt.py and bc_ghn_student.py
)


def evaluate_policy(policy, env_id, eval_steps=50, device='cuda', seed=None, num_envs=10):
    """
    Evaluate a single policy for multiple episodes in parallel.

    Args:
        policy: The policy model (nn.Sequential)
        env_id: ManiSkill environment ID
        eval_steps: Number of steps per episode
        device: Device to run on
        seed: Random seed for environment reset (for deterministic evaluation)
        num_envs: Number of parallel environments (episodes) to run

    Returns:
        metrics dict with success_once, success_at_end, mean_return (averaged over all envs)
    """
    env = gym.make(env_id, num_envs=num_envs, **ENV_KWARGS)

    if isinstance(env.action_space, gym.spaces.Dict):
        env = FlattenActionSpaceWrapper(env)

    env = ManiSkillVectorEnv(env, num_envs, ignore_terminations=True, record_metrics=True)

    obs, _ = env.reset(seed=seed)

    for _ in range(eval_steps):
        with torch.no_grad():
            action = policy(obs)
        obs, _, _, _, infos = env.step(action)

    # Extract metrics and average over all envs
    ep_data = infos["final_info"]["episode"]
    metrics = {
        'success_once': float(ep_data["success_once"].cpu().numpy().mean()),
        'success_at_end': float(ep_data["success_at_end"].cpu().numpy().mean()),
        'mean_return': float(ep_data["return"].cpu().numpy().mean()),
    }

    env.close()
    return metrics


def evaluate_policy_with_video(policy, env_id, video_path, eval_steps=50, device='cuda', early_termination=False, seed=None, num_envs=10):
    """
    Evaluate a single policy for multiple episodes in parallel and save video.

    Args:
        policy: The policy model (nn.Sequential)
        env_id: ManiSkill environment ID
        video_path: Full path for video file (without extension)
        eval_steps: Number of steps per episode
        device: Device to run on
        early_termination: If True, stop episode on success/termination
        seed: Random seed for environment reset (for deterministic evaluation)
        num_envs: Number of parallel environments (episodes) to run

    Returns:
        metrics dict with success_once, success_at_end, mean_return (averaged over all envs)
    """
    video_dir = os.path.dirname(video_path)
    video_name = os.path.basename(video_path)

    os.makedirs(video_dir, exist_ok=True)

    env = gym.make(env_id, num_envs=num_envs, **ENV_KWARGS)

    if isinstance(env.action_space, gym.spaces.Dict):
        env = FlattenActionSpaceWrapper(env)

    env = RecordEpisode(
        env,
        output_dir=video_dir,
        save_trajectory=False,
        save_video=True,
        trajectory_name=video_name,
        video_fps=30,
        max_steps_per_video=eval_steps
    )

    # Always ignore native terminations - we handle early stopping ourselves based on success
    env = ManiSkillVectorEnv(env, num_envs, ignore_terminations=True, record_metrics=True)

    obs, _ = env.reset(seed=seed)

    total_returns = np.zeros(num_envs)
    success_once_flags = np.zeros(num_envs, dtype=bool)
    success_at_end_flags = np.zeros(num_envs, dtype=bool)

    for step in range(eval_steps):
        with torch.no_grad():
            action = policy(obs)
        obs, reward, terminated, truncated, infos = env.step(action)

        total_returns += reward.cpu().numpy()

        # Check for success in info
        if "success" in infos:
            success_mask = infos["success"].cpu().numpy()
            success_once_flags |= success_mask
            success_at_end_flags = success_mask  # Update to current success state

            # Early termination on all successes
            if early_termination and success_once_flags.all():
                break

    # Try to get metrics from final_info if available
    if "final_info" in infos and infos["final_info"] is not None:
        ep_data = infos["final_info"].get("episode", {})
        if "success_once" in ep_data:
            success_once_flags = ep_data["success_once"].cpu().numpy()
        if "success_at_end" in ep_data:
            success_at_end_flags = ep_data["success_at_end"].cpu().numpy()
        if "return" in ep_data:
            total_returns = ep_data["return"].cpu().numpy()

    # Average metrics over all envs
    metrics = {
        'success_once': float(success_once_flags.mean()),
        'success_at_end': float(success_at_end_flags.mean()),
        'mean_return': float(total_returns.mean()),
    }

    env.close()
    return metrics
