"""
Evaluation dataset generator for grasp task with blocked zone constraints.

Generates evaluation demos for each blocked zone style:
- Style 1: Only 0 degree approach works (tests angle adaptation)
- Style 2: Only 0.04m height works (tests height adaptation)
- Style 3: Novel 45 degree, 0.035m (tests interpolation/generalization)

Usage:
  python dataset_generator_eval.py --style=1 --num_demos=10 --save_video
"""
import sys
import os

# Add parent directory for imports
sys.path.insert(0, os.path.dirname(__file__))

from pyrep.const import RenderMode
from pyrep.objects.dummy import Dummy
from pyrep.errors import ConfigurationPathError, IKError

from rlbench import ObservationConfig
from rlbench.action_modes.action_mode import MoveArmThenGripper
from rlbench.action_modes.arm_action_modes import JointVelocity, JointPosition
from rlbench.action_modes.gripper_action_modes import Discrete
from rlbench.backend.utils import task_file_to_task_class
from rlbench.environment import Environment

import pickle
import numpy as np
import imageio

from absl import app
from absl import flags

from types import SimpleNamespace

from grasp_utils import (
    DEFAULT_PHASE_STEPS,
    get_cup_position,
    get_cup_base_z,
    reset_robot_to_default,
    generate_grasp_trajectory,
    check_task_success,
    check_grasp_success_manual,
)

from grasp_config import (
    CAMERA_POSITION,
    CAMERA_ORIENTATION,
    CAMERA_IMAGE_SIZE,
    CUP_VARIATION as DEFAULT_CUP_VARIATION,
    CONTROL_POINT_RADIUS as DEFAULT_CONTROL_POINT_RADIUS,
)

from blocked_zone import (
    BLOCKED_ZONE_STYLES,
    check_grasp_valid,
    sample_valid_config_for_style,
    get_target_config_for_style,
    print_style_info,
)


FLAGS = flags.FLAGS

# Default save path
DEFAULT_SAVE_PATH = os.path.join(
    os.environ.get("DPPO_DATA_DIR", "/tmp/dppo_data"),
    "grasp"
)

flags.DEFINE_string("save_path", DEFAULT_SAVE_PATH, "Where to save the demos.")
flags.DEFINE_integer("style", 1, "Blocked zone style (1, 2, or 3)")
flags.DEFINE_integer("num_demos", 10, "Number of evaluation demos to generate")
flags.DEFINE_bool("save_video", False, "Whether to save video recordings")
flags.DEFINE_float("control_point_radius", DEFAULT_CONTROL_POINT_RADIUS, "Control point radius")
flags.DEFINE_integer("cup_variation", DEFAULT_CUP_VARIATION, "Cup variation (color index)")


def check_and_make(dir_path):
    """Create directory if it doesn't exist."""
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)


def project_world_to_image(points_3d, camera, image_size):
    """Project 3D world points to 2D image using camera intrinsic matrix."""
    cam_pos = np.array(camera.get_position())
    cam_matrix = camera.get_matrix()[:3, :3]

    try:
        fov_deg = camera.get_perspective_angle()
        fov = fov_deg * np.pi / 180.0
    except:
        fov = 60.0 * np.pi / 180.0

    f = image_size[0] / (2.0 * np.tan(fov / 2.0))
    cx, cy = image_size[1] / 2.0, image_size[0] / 2.0

    projected = []
    for p in points_3d:
        p_rel = np.array(p) - cam_pos
        p_cam = cam_matrix.T @ p_rel
        x_cam, y_cam, z_cam = p_cam[0], p_cam[1], p_cam[2]

        if z_cam > 0.01:
            u = cx - f * x_cam / z_cam
            v = cy - f * y_cam / z_cam
            projected.append((int(u), int(v)))
        else:
            projected.append(None)

    return projected


def overlay_trajectory_on_frames(frames, ee_trace, camera, phase_labels=None):
    """Overlay EE trajectory on video frames with phase-based coloring."""
    import cv2

    if len(frames) == 0 or len(ee_trace) == 0:
        return frames

    image_size = frames[0].shape[:2]
    projected = project_world_to_image(ee_trace, camera, image_size)

    phase_colors = {
        "reach": (0, 255, 0),
        "descend": (0, 200, 255),
        "hold_grasp": (255, 255, 0),
        "close_gripper": (255, 0, 0),
        "lift": (255, 0, 255),
    }
    default_color = (255, 255, 255)

    frames_with_overlay = []

    for frame_idx, frame in enumerate(frames):
        frame_overlay = frame.copy()
        trace_idx = min(frame_idx, len(projected) - 1)

        for i in range(1, trace_idx + 1):
            p1 = projected[i - 1]
            p2 = projected[i]

            if p1 is None or p2 is None:
                continue
            if not (0 <= p1[0] < image_size[1] and 0 <= p1[1] < image_size[0]):
                continue
            if not (0 <= p2[0] < image_size[1] and 0 <= p2[1] < image_size[0]):
                continue

            if phase_labels is not None and i < len(phase_labels):
                color = phase_colors.get(phase_labels[i], default_color)
            else:
                color = default_color

            cv2.line(frame_overlay, p1, p2, color, thickness=2, lineType=cv2.LINE_AA)

        if trace_idx < len(projected) and projected[trace_idx] is not None:
            curr_pt = projected[trace_idx]
            if 0 <= curr_pt[0] < image_size[1] and 0 <= curr_pt[1] < image_size[0]:
                cv2.circle(frame_overlay, curr_pt, 6, (255, 255, 255), -1)
                cv2.circle(frame_overlay, curr_pt, 6, (0, 0, 0), 2)

        frames_with_overlay.append(frame_overlay)

    return frames_with_overlay


def save_demo(demo, episode_path, save_video=False, ee_trace=None, camera=None, phase_labels=None):
    """Save low-dimensional state (and optional RGB video with trajectory overlay)."""
    with open(os.path.join(episode_path, "low_dim_obs.pkl"), "wb") as f:
        pickle.dump(demo, f)

    if save_video:
        frames = []
        for obs in demo:
            if hasattr(obs, 'front_rgb') and obs.front_rgb is not None:
                frames.append(obs.front_rgb)

        if len(frames) > 0:
            if ee_trace is not None and camera is not None and len(ee_trace) > 0:
                frames = overlay_trajectory_on_frames(frames, ee_trace, camera, phase_labels)

            video_path = os.path.join(episode_path, "video.mp4")
            imageio.mimwrite(video_path, frames, fps=30, codec='libx264', quality=8)
            print(f"\t  Video saved: {video_path} ({len(frames)} frames)")


def run_eval_split(task_env, cfg, style_path, style_config, cup_pos, cup_ori):
    """
    Generate evaluation demos for a specific blocked zone style.

    Each demo uses a configuration sampled from the valid region for the style.
    """
    num_demos = cfg.num_demos
    rng = np.random.default_rng(seed=42)  # Reproducible

    print(f"\n{'='*70}")
    print(f"Generating EVAL split for style: {style_config['name']}")
    print(f"  Description: {style_config['description']}")
    print(f"  Number of demos: {num_demos}")
    print(f"  Cup position: {cup_pos}")
    print(f"{'='*70}\n")

    episodes_path = os.path.join(style_path, "episodes")
    check_and_make(episodes_path)

    all_metadata = []
    episode_idx = 0

    custom_cam = task_env._scene._cam_front if cfg.save_video else None

    for demo_idx in range(num_demos):
        print(f"\nDemo {demo_idx}/{num_demos}")

        # Sample a valid configuration for this style
        approach_angle, grasp_height = sample_valid_config_for_style(style_config, rng)

        # Verify it's valid
        is_valid, reason = check_grasp_valid(approach_angle, grasp_height, style_config)
        print(f"  Config: angle={np.degrees(approach_angle):.1f}deg, height={grasp_height:.4f}m")
        print(f"  Valid: {is_valid} ({reason})")

        max_attempts = 5
        success = False

        for attempt in range(max_attempts):
            try:
                # Reset environment
                reset_robot_to_default(task_env)
                task_env.reset()

                # Get current cup position
                current_cup_pos, current_cup_ori = get_cup_position(task_env)

                # Generate trajectory
                demo, traj_metadata = generate_grasp_trajectory(
                    task_env,
                    start_pos=np.zeros(3),
                    cup_pos=current_cup_pos,
                    cup_ori=current_cup_ori,
                    approach_angle=approach_angle,
                    grasp_height=grasp_height,
                    control_point_radius=cfg.control_point_radius,
                    waypoint_params=None,
                    phase_steps=None,
                    steps_per_point=5,
                )

                if len(demo) == 0:
                    raise RuntimeError("Empty trajectory")

                # Check success
                task_success = check_task_success(task_env)
                initial_cup_z = traj_metadata.get("initial_cup_z", current_cup_pos[2])
                grasp_success = check_grasp_success_manual(task_env, initial_cup_z)

                # Save demo
                episode_path = os.path.join(episodes_path, f"episode{episode_idx}")
                os.makedirs(episode_path, exist_ok=True)
                save_demo(
                    demo, episode_path, save_video=cfg.save_video,
                    ee_trace=traj_metadata.get("trace"),
                    camera=custom_cam,
                    phase_labels=traj_metadata.get("phase_labels")
                )

                # Save metadata
                metadata = {
                    'eval_demo_idx': demo_idx,
                    'approach_angle': approach_angle,
                    'grasp_height': grasp_height,
                    'cup_pos': current_cup_pos.tolist(),
                    'cup_ori': current_cup_ori.tolist(),
                    'style_name': style_config['name'],
                    'is_valid_config': is_valid,
                    'task_success': task_success,
                    'grasp_success': grasp_success,
                    **{k: v for k, v in traj_metadata.items()
                       if k not in ['trace', 'phase_labels', 'gripper_states', 'waypoints']},
                }
                np.save(os.path.join(episode_path, "metadata.npy"), metadata)

                # Save EE trajectory
                ee_trace = traj_metadata.get("trace")
                if ee_trace is not None:
                    np.save(os.path.join(episode_path, "ee_trajectory.npy"), ee_trace)

                all_metadata.append(metadata)

                print(f"  Demo length: {len(demo)} steps, task_success: {task_success}, grasp_success: {grasp_success}")
                episode_idx += 1
                success = True
                break

            except Exception as e:
                print(f"  Attempt {attempt+1} failed: {e}")
                if attempt == max_attempts - 1:
                    print(f"  Skipping demo {demo_idx} after {max_attempts} failures")

    # Save all metadata
    np.save(os.path.join(style_path, "eval_metadata.npy"), all_metadata)

    print(f"\n  Generated {episode_idx} evaluation episodes")
    print(f"  Saved metadata to eval_metadata.npy")

    return all_metadata


def main(argv):
    style_id = FLAGS.style
    if style_id not in BLOCKED_ZONE_STYLES:
        print(f"Error: Unknown style {style_id}. Available: {list(BLOCKED_ZONE_STYLES.keys())}")
        return

    style_config = BLOCKED_ZONE_STYLES[style_id]

    print(f"{'='*70}")
    print("GRASP EVALUATION DATASET GENERATOR")
    print(f"{'='*70}")
    print_style_info(style_id)
    print(f"Save path: {FLAGS.save_path}")
    print(f"Number of demos: {FLAGS.num_demos}")
    print(f"Save video: {FLAGS.save_video}")
    print(f"{'='*70}\n")

    # Setup environment
    obs_config = ObservationConfig()
    obs_config.set_all(False)
    obs_config.joint_positions = True
    obs_config.joint_velocities = True
    obs_config.gripper_open = True
    obs_config.gripper_pose = True
    obs_config.task_low_dim_state = True

    if FLAGS.save_video:
        obs_config.front_camera.rgb = True
        obs_config.front_camera.image_size = CAMERA_IMAGE_SIZE

    ACT_MIN = np.array([-2.8973, -1.7628, -2.8973, -3.0718,
                        -2.8973, -0.0175, -2.8973, 0.0], dtype=np.float32)
    ACT_RANGE = np.array([5.7946, 3.5256, 5.7946, 3.0020,
                          5.7946, 3.7700, 5.7946, 1.0], dtype=np.float32)

    class CustomMoveArmThenGripper(MoveArmThenGripper):
        def action_bounds(self):
            return (ACT_MIN, ACT_MIN + ACT_RANGE)

    action_mode = CustomMoveArmThenGripper(JointPosition(True), Discrete())

    rlbench_env = Environment(action_mode=action_mode, obs_config=obs_config, headless=True)
    rlbench_env.launch()

    # Get task
    task_class = task_file_to_task_class("pick_up_cup")
    task_env = rlbench_env.get_task(task_class)
    task_env.set_variation(FLAGS.cup_variation)

    # Initialize and get cup position
    descriptions, obs = task_env.reset()
    cup_pos, cup_ori = get_cup_position(task_env)
    print(f"\nCup position: {cup_pos}")
    print(f"Cup orientation: {cup_ori}")

    # Create output directory
    style_path = os.path.join(
        FLAGS.save_path,
        f"variation{FLAGS.cup_variation}",
        f"eval_style{style_id}"
    )
    check_and_make(style_path)

    # Set camera
    front_cam = task_env._scene._cam_front
    front_cam.set_position(CAMERA_POSITION)
    front_cam.set_orientation(CAMERA_ORIENTATION)

    # Save style config
    config = {
        'style_id': style_id,
        'style_name': style_config['name'],
        'style_description': style_config['description'],
        'valid_angles_deg': style_config['valid_angles_deg'],
        'valid_heights': style_config['valid_heights'],
        'cup_pos': cup_pos.tolist(),
        'cup_ori': cup_ori.tolist(),
        'num_demos': FLAGS.num_demos,
    }
    np.save(os.path.join(style_path, "eval_config.npy"), config)

    # Create config namespace
    cfg = SimpleNamespace(
        save_video=FLAGS.save_video,
        control_point_radius=FLAGS.control_point_radius,
        num_demos=FLAGS.num_demos,
    )

    # Generate evaluation split
    eval_metadata = run_eval_split(
        task_env, cfg, style_path, style_config, cup_pos, cup_ori
    )

    # Summary
    n_task_success = sum(1 for m in eval_metadata if m.get("task_success", False))
    n_grasp_success = sum(1 for m in eval_metadata if m.get("grasp_success", False))
    n_valid = len(eval_metadata)

    print(f"\n{'='*70}")
    print("EVAL DATA GENERATION COMPLETE")
    print(f"  Style: {style_config['name']}")
    print(f"  Total episodes: {n_valid}")
    print(f"  Task success: {n_task_success}/{n_valid}")
    print(f"  Grasp success: {n_grasp_success}/{n_valid}")
    print(f"  Data saved to: {style_path}")
    print(f"{'='*70}\n")

    rlbench_env.shutdown()


if __name__ == "__main__":
    import multiprocessing as mp
    mp.set_start_method("spawn", force=True)
    app.run(main)
