#!/usr/bin/env python
"""
Visualize EE trajectories for 4 grasp constraint settings: -1 (no constraint), 1, 2, 3.
Outputs 8 files: 4 x 3D trajectory plots + 4 x 2D grasp point plots.
One noise-free demo per mode (4 base + 1 extra for style 3).
No legend, no title.

Usage:
  python visualize_grasp_four_styles.py
  python visualize_grasp_four_styles.py --save_path=/custom/path
"""
import sys
import os

sys.path.insert(0, os.path.dirname(__file__))

import numpy as np

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

from pyrep.errors import ConfigurationPathError, IKError

from rlbench import ObservationConfig
from rlbench.action_modes.action_mode import MoveArmThenGripper
from rlbench.action_modes.arm_action_modes import JointPosition
from rlbench.action_modes.gripper_action_modes import Discrete
from rlbench.backend.utils import task_file_to_task_class
from rlbench.environment import Environment

from absl import app, flags

from grasp_utils import (
    compute_grasp_waypoints,
    generate_reach_control_point,
    generate_phase_positions,
    get_cup_position,
    get_cup_base_z,
    set_cup_position,
    HOME_JOINTS,
)

from grasp_config import (
    CAMERA_POSITION,
    CAMERA_ORIENTATION,
    CAMERA_IMAGE_SIZE,
    CUP_VARIATION,
    CONTROL_POINT_RADIUS,
    FIXED_CUP_POSITION,
    APPROACH_ANGLES_DEG,
    GRASP_HEIGHTS,
    PHASE_STEPS,
    generate_canonical_mode_params,
)

from blocked_zone import (
    BLOCKED_ZONE_STYLES,
    check_grasp_valid,
)


FLAGS = flags.FLAGS

DEFAULT_SAVE_PATH = os.path.join(
    os.path.dirname(os.path.dirname(__file__)),
    "block_setting"
)
flags.DEFINE_string("save_path", DEFAULT_SAVE_PATH, "Where to save visualizations.")
flags.DEFINE_float("control_point_radius", CONTROL_POINT_RADIUS, "Control point radius.")
flags.DEFINE_integer("cup_variation", CUP_VARIATION, "Cup variation (color index)")


def generate_ee_trajectory_only(task_env, cup_pos, cup_base_z,
                                 approach_angle, grasp_height,
                                 control_point_radius, phase_steps):
    """Generate EE trajectory positions without executing them."""
    robot = task_env._scene.robot
    tip = robot.arm.get_tip()

    robot.arm.set_joint_positions(HOME_JOINTS, disable_dynamics=True)
    for _ in range(10):
        task_env._scene.pyrep.step()

    home_pos = np.array(tip.get_position())

    waypoints = compute_grasp_waypoints(
        home_pos, cup_pos, cup_base_z, approach_angle, grasp_height
    )

    cp_reach = generate_reach_control_point(
        waypoints["start"], waypoints["pregrasp"],
        approach_angle, control_point_radius
    )

    positions, gripper_states, phase_indices, phase_labels = generate_phase_positions(
        waypoints, phase_steps, cp_reach
    )

    return np.array(positions), waypoints, cp_reach, phase_indices


def visualize_3d_single(all_results, save_path, style):
    """
    Create a single 3D plot for one grasp constraint setting.
    No legend, no title. Valid trajectories are thicker.
    Invalid trajectories truncated at grasp point.
    """
    fig = plt.figure(figsize=(8, 7))
    ax = fig.add_subplot(111, projection='3d')

    all_x, all_y, all_z = [], [], []

    for result in all_results:
        positions = result["positions"]
        is_valid = result["is_valid"]
        is_extra = result.get("is_extra", False)
        grasp_point_idx = result["grasp_point_idx"]

        # Truncate at grasp point if invalid
        if is_valid:
            draw_end = len(positions)
        else:
            draw_end = min(grasp_point_idx + 1, len(positions))

        pos_to_draw = positions[:draw_end]
        x, y, z = pos_to_draw[:, 0], pos_to_draw[:, 1], pos_to_draw[:, 2]

        all_x.extend(x)
        all_y.extend(y)
        all_z.extend(z)

        linestyle = ':' if is_extra else '-'

        if is_valid:
            color = 'blue'
            linewidth = 3.5
            alpha = 0.9
        else:
            color = 'red'
            linewidth = 1.5
            alpha = 0.6

        ax.plot(x, y, z, linestyle, color=color, linewidth=linewidth, alpha=alpha)

        # End marker only for valid trajectories
        if is_valid:
            ax.scatter(x[-1], y[-1], z[-1], c='green', marker='*', s=150)

    # Start point (average)
    start_pos = np.mean([r["positions"][0] for r in all_results], axis=0)
    ax.scatter(*start_pos, color='blue', s=200, marker='o',
              edgecolors='black', linewidths=2)

    # Axis limits with padding
    x_min, x_max = min(all_x), max(all_x)
    y_min, y_max = min(all_y), max(all_y)
    z_min, z_max = min(all_z), max(all_z)

    x_pad = (x_max - x_min) * 0.3
    y_pad = (y_max - y_min) * 0.3
    z_pad = (z_max - z_min) * 0.3

    ax.set_xlim(x_min - x_pad, x_max + x_pad)
    ax.set_ylim(y_min - y_pad, y_max + y_pad)
    ax.set_zlim(z_min - z_pad, z_max + z_pad)

    ax.set_xlabel('X (m)', fontsize=12)
    ax.set_ylabel('Y (m)', fontsize=12)
    ax.set_zlabel('Z (m)', fontsize=12)

    ax.view_init(elev=30, azim=-135)

    plt.tight_layout()
    plot_path = os.path.join(save_path, f'trajectories_3d_style{style}.png')
    plt.savefig(plot_path, dpi=150, bbox_inches='tight')
    print(f"Saved 3D plot to {plot_path}")
    plt.close()


def visualize_2d_grasp_points(all_results, cup_pos, save_path, style):
    """
    Create a 2D top-down grasp point plot for one setting.
    Valid grasp: blue dot. Invalid grasp: red 'x'.
    No legend, no title.
    """
    fig, ax = plt.subplots(figsize=(7, 7))

    for result in all_results:
        positions = result["positions"]
        grasp_point_idx = result["grasp_point_idx"]
        is_valid = result["is_valid"]

        grasp_pos = positions[grasp_point_idx]

        if is_valid:
            ax.scatter(grasp_pos[0], grasp_pos[1],
                       c='blue', s=200, marker='o', alpha=0.8,
                       edgecolors='black', linewidths=0.5, zorder=5)
        else:
            ax.scatter(grasp_pos[0], grasp_pos[1],
                       c='red', s=200, marker='x', linewidths=2,
                       alpha=0.8, zorder=5)

    # Cup center
    ax.scatter(cup_pos[0], cup_pos[1],
               c='purple', s=200, marker='*', zorder=10,
               edgecolors='black')

    # Reference circle (gripper offset from cup center)
    circle_radius = 0.04
    theta = np.linspace(0, 2 * np.pi, 100)
    circle_x = cup_pos[0] + circle_radius * np.cos(theta)
    circle_y = cup_pos[1] + circle_radius * np.sin(theta)
    ax.plot(circle_x, circle_y, 'k--', alpha=0.3, linewidth=1)

    ax.set_xlabel('X (m)', fontsize=12)
    ax.set_ylabel('Y (m)', fontsize=12)
    ax.set_aspect('equal')
    ax.grid(True, alpha=0.3)

    plt.tight_layout()
    plot_path = os.path.join(save_path, f'grasp_points_2d_style{style}.png')
    plt.savefig(plot_path, dpi=150, bbox_inches='tight')
    print(f"Saved 2D plot to {plot_path}")
    plt.close()


def main(argv):
    styles = [-1, 1, 2, 3]

    print(f"{'='*70}")
    print("FOUR-STYLE GRASP TRAJECTORY VISUALIZATION")
    print(f"Styles: {styles}")
    print(f"{'='*70}")

    os.makedirs(FLAGS.save_path, exist_ok=True)

    # Setup environment
    obs_config = ObservationConfig()
    obs_config.set_all(False)
    obs_config.joint_positions = True
    obs_config.joint_velocities = True
    obs_config.gripper_open = True
    obs_config.gripper_pose = True
    obs_config.front_camera.rgb = True
    obs_config.front_camera.image_size = CAMERA_IMAGE_SIZE

    ACT_MIN = np.array([-2.8973, -1.7628, -2.8973, -3.0718,
                        -2.8973, -0.0175, -2.8973, 0.0], dtype=np.float32)
    ACT_RANGE = np.array([5.7946, 3.5256, 5.7946, 3.0020,
                          5.7946, 3.7700, 5.7946, 1.0], dtype=np.float32)

    class CustomMoveArmThenGripper(MoveArmThenGripper):
        def action_bounds(self):
            return (ACT_MIN, ACT_MIN + ACT_RANGE)

    action_mode = CustomMoveArmThenGripper(JointPosition(True), Discrete())
    rlbench_env = Environment(action_mode=action_mode, obs_config=obs_config, headless=True)
    rlbench_env.launch()

    task_class = task_file_to_task_class("pick_up_cup")
    task_env = rlbench_env.get_task(task_class)
    task_env.set_variation(FLAGS.cup_variation)

    descriptions, obs = task_env.reset()

    # Set cup to fixed position
    if FIXED_CUP_POSITION is not None:
        set_cup_position(task_env, FIXED_CUP_POSITION)

    cup_pos, cup_ori = get_cup_position(task_env)
    cup_base_z = get_cup_base_z(task_env)
    print(f"Cup position: {cup_pos}")
    print(f"Cup base Z: {cup_base_z}")

    # 4 base modes + 1 extra (45 degrees for style 3)
    canonical_params = list(generate_canonical_mode_params())
    phase_steps = PHASE_STEPS.copy()
    num_base_modes = len(canonical_params)

    extra_mode_angle = np.radians(45)
    extra_mode_height = GRASP_HEIGHTS[0]
    canonical_params_with_extra = canonical_params + [np.array([extra_mode_angle, extra_mode_height])]

    # Generate all trajectories once
    print(f"\nGenerating {len(canonical_params_with_extra)} trajectory modes "
          f"({num_base_modes} base + 1 extra)...")
    all_trajectories = []
    for mode_idx, (angle, height) in enumerate(canonical_params_with_extra):
        try:
            positions, waypoints, cp_reach, phase_indices = generate_ee_trajectory_only(
                task_env, cup_pos, cup_base_z, angle, height,
                FLAGS.control_point_radius, phase_steps
            )

            # Grasp point index: end of descend phase
            grasp_point_idx = phase_indices["descend"][1] - 1

            angle_deg = np.degrees(angle)
            extra_tag = " (extra)" if mode_idx >= num_base_modes else ""
            print(f"  Mode {mode_idx}: angle={angle_deg:.0f} deg, height={height:.3f}m, "
                  f"{len(positions)} steps, grasp_idx={grasp_point_idx}{extra_tag}")

            all_trajectories.append({
                "mode_idx": mode_idx,
                "positions": positions,
                "approach_angle": angle,
                "grasp_height": height,
                "grasp_point_idx": grasp_point_idx,
                "is_extra": mode_idx >= num_base_modes,
            })
        except Exception as e:
            print(f"  Mode {mode_idx} failed: {e}")

    print(f"\nGenerated {len(all_trajectories)} trajectories.")
    print(f"{'='*70}")

    # For each style, check validity and create both plots
    for style in styles:
        print(f"\n--- Style {style} ---")

        # Style 3 includes the extra mode; others use base modes only
        trajs_for_style = (all_trajectories if style == 3
                           else [t for t in all_trajectories if not t["is_extra"]])

        style_results = []
        for traj in trajs_for_style:
            if style == -1:
                is_valid = True
                reason = "No constraint"
            else:
                style_config = BLOCKED_ZONE_STYLES[style]
                is_valid, reason = check_grasp_valid(
                    traj["approach_angle"], traj["grasp_height"], style_config
                )

            status = "VALID" if is_valid else "INVALID"
            angle_deg = np.degrees(traj["approach_angle"])
            print(f"  Mode {traj['mode_idx']} (angle={angle_deg:.0f}): {status}")

            style_results.append({
                "mode_idx": traj["mode_idx"],
                "positions": traj["positions"],
                "is_valid": is_valid,
                "is_extra": traj["is_extra"],
                "grasp_point_idx": traj["grasp_point_idx"],
                "approach_angle": traj["approach_angle"],
            })

        n_valid = sum(1 for r in style_results if r["is_valid"])
        n_invalid = sum(1 for r in style_results if not r["is_valid"])
        print(f"  Results: {n_valid} valid, {n_invalid} invalid")

        visualize_3d_single(style_results, FLAGS.save_path, style)
        visualize_2d_grasp_points(style_results, cup_pos, FLAGS.save_path, style)

    rlbench_env.shutdown()
    print("\nDone!")


if __name__ == "__main__":
    import multiprocessing as mp
    mp.set_start_method("spawn", force=True)
    app.run(main)
