#!/usr/bin/env python
"""
Visualize EE trajectories for 5 wall styles: -1 (no wall), 0, 1, 2, 3.
Outputs five individual 3D plots with thicker success trajectories,
no legend, and no title.

Usage:
  python visualize_wall_four_styles.py
  python visualize_wall_four_styles.py --save_path=/custom/path
"""
import sys
import os
import pickle

sys.path.insert(0, os.path.dirname(__file__))

import numpy as np

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection

from pyrep.errors import ConfigurationPathError, IKError
from pyrep.objects.dummy import Dummy

from rlbench import ObservationConfig
from rlbench.action_modes.action_mode import MoveArmThenGripper
from rlbench.action_modes.arm_action_modes import JointPosition
from rlbench.action_modes.gripper_action_modes import Discrete
from rlbench.backend.utils import task_file_to_task_class
from rlbench.environment import Environment

from absl import app, flags

from push_utils import (
    DEFAULT_PHASE_STEPS,
    get_drawer_handle_position,
    set_drawer_open,
    fix_cabinet_orientation,
    reset_robot_to_default,
    compute_push_waypoints,
    generate_canonical_control_point_params,
    compute_control_point_from_params,
    generate_phase_positions,
    move_robot_to_start,
    parabola3D,
    HOME_JOINTS,
    lock_other_drawers,
    generate_push_trajectory,
)

from close_drawer_config import (
    CAMERA_POSITION,
    CAMERA_ORIENTATION,
    CAMERA_IMAGE_SIZE,
    DRAWER_VARIATION as DEFAULT_DRAWER_VARIATION,
    DRAWER_OPEN_AMOUNT as DEFAULT_DRAWER_OPEN_AMOUNT,
    CONTROL_POINT_RADIUS as DEFAULT_CONTROL_POINT_RADIUS,
)

from wall_collision import (
    DEFAULT_WALL_CONFIG,
    DEFAULT_OPENING_CONFIG,
    WALL_STYLES,
    create_wall,
    check_wall_collision,
)


FLAGS = flags.FLAGS

DEFAULT_SAVE_PATH = os.path.join(
    os.path.dirname(os.path.dirname(__file__)),
    "block_setting"
)
flags.DEFINE_string("save_path", DEFAULT_SAVE_PATH, "Where to save visualizations.")
flags.DEFINE_integer("num_modes", 8, "Number of control point modes.")
flags.DEFINE_float("control_point_radius", DEFAULT_CONTROL_POINT_RADIUS, "Control point radius.")
flags.DEFINE_integer("drawer_variation", DEFAULT_DRAWER_VARIATION, "Drawer: 0=bottom, 1=middle, 2=top")
flags.DEFINE_float("drawer_open_amount", DEFAULT_DRAWER_OPEN_AMOUNT, "How far drawer is open")


def check_ee_trajectory_wall_collision(ee_positions, wall_config):
    """Check if EE trajectory crosses through the wall."""
    wall_y = wall_config["wall_y"]
    min_x = wall_config["wall_min_x"]
    max_x = wall_config["wall_max_x"]
    min_z = wall_config["wall_min_z"]
    max_z = wall_config["wall_max_z"]
    opening = wall_config.get("opening", None)

    for i in range(1, len(ee_positions)):
        prev_pos = ee_positions[i - 1]
        curr_pos = ee_positions[i]

        if prev_pos[1] >= wall_y and curr_pos[1] < wall_y:
            if abs(curr_pos[1] - prev_pos[1]) > 1e-6:
                t = (wall_y - prev_pos[1]) / (curr_pos[1] - prev_pos[1])
                cross_x = prev_pos[0] + t * (curr_pos[0] - prev_pos[0])
                cross_z = prev_pos[2] + t * (curr_pos[2] - prev_pos[2])
            else:
                cross_x = curr_pos[0]
                cross_z = curr_pos[2]

            in_x_bounds = min_x <= cross_x <= max_x
            in_z_bounds = min_z <= cross_z <= max_z

            if not (in_x_bounds and in_z_bounds):
                continue

            if opening is not None:
                if (opening["min_x"] <= cross_x <= opening["max_x"] and
                    opening["min_z"] <= cross_z <= opening["max_z"]):
                    continue

            return True, i

    return False, None


def generate_ee_trajectory_only(task_env, handle_pos, handle_ori, cp_idx, canonical_params,
                                 control_point_radius, phase_steps, drawer_variation):
    """Generate EE trajectory positions without executing them."""
    robot = task_env._scene.robot
    tip = robot.arm.get_tip()

    robot.arm.set_joint_positions(HOME_JOINTS, disable_dynamics=True)
    for _ in range(10):
        task_env._scene.pyrep.step()

    home_pos = np.array(tip.get_position())
    home_ori = np.array(tip.get_orientation())

    waypoints = compute_push_waypoints(home_pos, handle_pos, handle_ori, None)

    angle, dist_frac, pos_frac = canonical_params[cp_idx]
    cp_reach = compute_control_point_from_params(
        waypoints["start"], waypoints["handle"],
        control_point_radius, angle, dist_frac, pos_frac
    )

    positions, phase_indices, phase_labels = generate_phase_positions(
        waypoints, phase_steps, cp_reach
    )

    return np.array(positions), waypoints, cp_reach, phase_indices


def get_wall_config_for_style(style):
    """Return wall config for a given style. None for style -1 (no wall)."""
    if style == -1:
        return None
    elif style == 0:
        config = DEFAULT_WALL_CONFIG.copy()
        config["opening"] = None
        return config
    elif style in WALL_STYLES:
        config = WALL_STYLES[style].copy()
        if config.get("opening") is not None:
            config["opening"] = config["opening"].copy()
        return config
    else:
        raise ValueError(f"Unknown wall style: {style}")


def visualize_3d_single(all_results, wall_config, save_path, style):
    """
    Create a single 3D plot for one wall style.
    No legend, no title. Success trajectories are thicker.
    """
    fig = plt.figure(figsize=(8, 7))
    ax = fig.add_subplot(111, projection='3d')

    all_x, all_y, all_z = [], [], []

    for result in all_results:
        positions = result["positions"]
        collision = result["collision"]
        collision_idx = result["collision_idx"]

        if collision and collision_idx is not None:
            draw_end = min(collision_idx + 1, len(positions))
        else:
            draw_end = len(positions)

        pos_to_draw = positions[:draw_end]
        x, y, z = pos_to_draw[:, 0], pos_to_draw[:, 1], pos_to_draw[:, 2]

        all_x.extend(x)
        all_y.extend(y)
        all_z.extend(z)

        is_extra = result.get("is_extra", False)
        linestyle = ':' if is_extra else '-'

        if collision:
            color = 'red'
            linewidth = 1.5
            alpha = 0.6
        else:
            color = 'blue'
            linewidth = 3.5
            alpha = 0.9

        ax.plot(x, y, z, linestyle, color=color, linewidth=linewidth, alpha=alpha)

        if not collision:
            ax.scatter(x[-1], y[-1], z[-1], c='green', marker='*', s=150)

    # Start point (average)
    start_pos = np.mean([r["positions"][0] for r in all_results], axis=0)
    ax.scatter(*start_pos, color='blue', s=200, marker='o',
              edgecolors='black', linewidths=2)

    # Trajectory bounds with padding
    x_min, x_max = min(all_x), max(all_x)
    y_min, y_max = min(all_y), max(all_y)
    z_min, z_max = min(all_z), max(all_z)

    x_pad = (x_max - x_min) * 0.3
    y_pad = (y_max - y_min) * 0.3
    z_pad = (z_max - z_min) * 0.3

    # Draw wall (only if wall_config is not None, i.e. style != -1)
    if wall_config is not None:
        wall_y = wall_config["wall_y"]
        wall_min_x = max(wall_config["wall_min_x"], x_min - x_pad)
        wall_max_x = min(wall_config["wall_max_x"], x_max + x_pad)
        wall_min_z = max(wall_config["wall_min_z"], z_min - z_pad)
        wall_max_z = min(wall_config["wall_max_z"], z_max + z_pad)

        opening = wall_config.get("opening", None)

        if opening is not None:
            op_min_x = max(opening["min_x"], wall_min_x)
            op_max_x = min(opening["max_x"], wall_max_x)
            op_min_z = max(opening["min_z"], wall_min_z)
            op_max_z = min(opening["max_z"], wall_max_z)

            wall_pieces = []
            if wall_min_z < op_min_z:
                wall_pieces.append([
                    [wall_min_x, wall_y, wall_min_z],
                    [wall_max_x, wall_y, wall_min_z],
                    [wall_max_x, wall_y, op_min_z],
                    [wall_min_x, wall_y, op_min_z],
                ])
            if op_max_z < wall_max_z:
                wall_pieces.append([
                    [wall_min_x, wall_y, op_max_z],
                    [wall_max_x, wall_y, op_max_z],
                    [wall_max_x, wall_y, wall_max_z],
                    [wall_min_x, wall_y, wall_max_z],
                ])
            if wall_min_x < op_min_x:
                wall_pieces.append([
                    [wall_min_x, wall_y, op_min_z],
                    [op_min_x, wall_y, op_min_z],
                    [op_min_x, wall_y, op_max_z],
                    [wall_min_x, wall_y, op_max_z],
                ])
            if op_max_x < wall_max_x:
                wall_pieces.append([
                    [op_max_x, wall_y, op_min_z],
                    [wall_max_x, wall_y, op_min_z],
                    [wall_max_x, wall_y, op_max_z],
                    [op_max_x, wall_y, op_max_z],
                ])

            for piece in wall_pieces:
                piece_poly = Poly3DCollection([piece], alpha=0.3, facecolor='red',
                                              edgecolor='darkred', linewidth=1)
                ax.add_collection3d(piece_poly)
        else:
            wall_vertices = [
                [wall_min_x, wall_y, wall_min_z],
                [wall_max_x, wall_y, wall_min_z],
                [wall_max_x, wall_y, wall_max_z],
                [wall_min_x, wall_y, wall_max_z],
            ]
            wall_poly = Poly3DCollection([wall_vertices], alpha=0.3, facecolor='red',
                                          edgecolor='darkred', linewidth=2)
            ax.add_collection3d(wall_poly)

    # Set axis limits
    ax.set_xlim(x_min - x_pad, x_max + x_pad)
    ax.set_ylim(y_min - y_pad, y_max + y_pad)
    ax.set_zlim(z_min - z_pad, z_max + z_pad)

    # Axis labels only
    ax.set_xlabel('X (m)', fontsize=12)
    ax.set_ylabel('Y (m)', fontsize=12)
    ax.set_zlabel('Z (m)', fontsize=12)

    # View angle to match camera frame
    ax.view_init(elev=30, azim=-135)

    plt.tight_layout()
    plot_path = os.path.join(save_path, f'trajectories_3d_style{style}.png')
    plt.savefig(plot_path, dpi=150, bbox_inches='tight')
    print(f"Saved 3D plot to {plot_path}")
    plt.close()


def main(argv):
    styles = [-1, 0, 1, 2, 3]

    print(f"{'='*70}")
    print("FIVE-STYLE TRAJECTORY VISUALIZATION")
    print(f"Styles: {styles}")
    print(f"{'='*70}")

    os.makedirs(FLAGS.save_path, exist_ok=True)

    # Setup environment
    obs_config = ObservationConfig()
    obs_config.set_all(False)
    obs_config.joint_positions = True
    obs_config.joint_velocities = True
    obs_config.gripper_open = True
    obs_config.gripper_pose = True
    obs_config.front_camera.rgb = True
    obs_config.front_camera.image_size = CAMERA_IMAGE_SIZE

    ACT_MIN = np.array([-2.8973, -1.7628, -2.8973, -3.0718,
                        -2.8973, -0.0175, -2.8973, 0.0], dtype=np.float32)
    ACT_RANGE = np.array([5.7946, 3.5256, 5.7946, 3.0020,
                          5.7946, 3.7700, 5.7946, 1.0], dtype=np.float32)

    class CustomMoveArmThenGripper(MoveArmThenGripper):
        def action_bounds(self):
            return (ACT_MIN, ACT_MIN + ACT_RANGE)

    action_mode = CustomMoveArmThenGripper(JointPosition(True), Discrete())
    rlbench_env = Environment(action_mode=action_mode, obs_config=obs_config, headless=True)
    rlbench_env.launch()

    task_class = task_file_to_task_class("close_drawer")
    task_env = rlbench_env.get_task(task_class)
    task_env.set_variation(FLAGS.drawer_variation)

    descriptions, obs = task_env.reset()
    fix_cabinet_orientation(task_env)
    set_drawer_open(task_env, FLAGS.drawer_variation, FLAGS.drawer_open_amount)

    handle_pos, handle_ori = get_drawer_handle_position(task_env, FLAGS.drawer_variation)
    print(f"Handle position: {handle_pos}")

    # Generate canonical control points (shared across all styles)
    canonical_params = list(generate_canonical_control_point_params(FLAGS.num_modes))
    phase_steps = DEFAULT_PHASE_STEPS.copy()
    num_base_modes = len(canonical_params)

    # Style 3 extra mode: angle=315°, dist=1.0, pos_frac=0.5
    extra_mode = (np.radians(315), 1.0, 0.5)
    canonical_params_with_extra = canonical_params + [extra_mode]

    # Generate all trajectories (base + extra) once
    print(f"\nGenerating {len(canonical_params_with_extra)} trajectory modes ({num_base_modes} base + 1 extra)...")
    all_trajectories = []
    for mode_idx in range(len(canonical_params_with_extra)):
        try:
            positions, waypoints, cp_reach, phase_indices = generate_ee_trajectory_only(
                task_env, handle_pos, handle_ori, mode_idx, canonical_params_with_extra,
                FLAGS.control_point_radius, phase_steps, FLAGS.drawer_variation
            )
            angle_deg = np.degrees(canonical_params_with_extra[mode_idx][0])
            dist_frac = canonical_params_with_extra[mode_idx][1]
            extra_tag = " (extra)" if mode_idx >= num_base_modes else ""
            print(f"  Mode {mode_idx}: angle={angle_deg:.0f} deg, dist={dist_frac:.1f}, {len(positions)} steps{extra_tag}")
            all_trajectories.append({
                "mode_idx": mode_idx,
                "positions": positions,
                "cp_params": canonical_params_with_extra[mode_idx],
                "is_extra": mode_idx >= num_base_modes,
            })
        except Exception as e:
            print(f"  Mode {mode_idx} failed: {e}")

    print(f"\nGenerated {len(all_trajectories)} trajectories.")
    print(f"{'='*70}")

    # For each style, check collision and create individual plot
    for style in styles:
        print(f"\n--- Style {style} ---")
        wall_config = get_wall_config_for_style(style)

        if wall_config is not None:
            print(f"  Wall Y={wall_config['wall_y']}, "
                  f"X=[{wall_config['wall_min_x']}, {wall_config['wall_max_x']}], "
                  f"Z=[{wall_config['wall_min_z']}, {wall_config['wall_max_z']}]")
            if wall_config.get("opening"):
                o = wall_config["opening"]
                print(f"  Opening: X=[{o['min_x']}, {o['max_x']}], Z=[{o['min_z']}, {o['max_z']}]")
        else:
            print("  No wall")

        # Style 3 includes the extra mode; others use base modes only
        trajs_for_style = all_trajectories if style == 3 else [t for t in all_trajectories if not t["is_extra"]]

        style_results = []
        for traj in trajs_for_style:
            if wall_config is None:
                collision, collision_idx = False, None
            else:
                collision, collision_idx = check_ee_trajectory_wall_collision(
                    traj["positions"], wall_config
                )

            status = "COLLISION" if collision else "SUCCESS"
            angle_deg = np.degrees(traj["cp_params"][0])
            print(f"  Mode {traj['mode_idx']} (angle={angle_deg:.0f}): {status}")

            style_results.append({
                "mode_idx": traj["mode_idx"],
                "positions": traj["positions"],
                "collision": collision,
                "collision_idx": collision_idx,
                "is_extra": traj["is_extra"],
            })

        n_success = sum(1 for r in style_results if not r["collision"])
        n_fail = sum(1 for r in style_results if r["collision"])
        print(f"  Results: {n_success} success, {n_fail} collision")

        visualize_3d_single(style_results, wall_config, FLAGS.save_path, style)

    rlbench_env.shutdown()
    print("\nDone!")


if __name__ == "__main__":
    import multiprocessing as mp
    mp.set_start_method("spawn", force=True)
    app.run(main)
