"""
Generalization test for close_drawer trajectory encoder - Inner ring & Up-down.

Tests:
3. 12 different angles at distance=0.5, position=0.5 (inner ring generalization)
4. 12 trajectories with angle 90 and 270, each with 6 distances (up-down generalization)

Generates trajectories, encodes them, and visualizes the z embeddings.
Complements the existing generalization_test.py which tests:
  - Test 1: 12 angles at distance=1.0 (outer ring)
  - Test 2: angle 0 and 180 at 6 distances (left-to-right)
"""
import os
import sys
import torch
import numpy as np
import pickle

# Set matplotlib backend before importing pyplot
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

from tqdm import tqdm
import argparse

# Add parent directories for imports
sys.path.insert(0, os.path.dirname(__file__))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'make_dataset'))

from trajectory_encoder import TrajectoryVAE
from train_close_drawer_encoder import make_trajectory_relative
from generalization_test import (
    load_encoder,
    encode_trajectories,
    get_12_colors,
    visualize_angle_test,
    visualize_distance_test,
    visualize_combined,
)


def generate_inner_updown_trajectories(output_dir, num_angles=12, num_distances=6):
    """
    Generate test trajectories by running the RLBench simulation.

    Test 3: 12 different angles at distance=0.5, position=0.5 (inner ring)
    Test 4: angle 90 and 270, each with 6 distances at position=0.5 (up-down)
    """
    from multiprocessing import Process, Manager
    from pyrep.const import RenderMode
    from pyrep.errors import ConfigurationPathError, IKError

    from rlbench import ObservationConfig
    from rlbench.action_modes.action_mode import MoveArmThenGripper
    from rlbench.action_modes.arm_action_modes import JointPosition
    from rlbench.action_modes.gripper_action_modes import Discrete
    from rlbench.backend.utils import task_file_to_task_class
    from rlbench.environment import Environment

    from push_utils import (
        get_drawer_handle_position,
        set_drawer_open,
        fix_cabinet_orientation,
        reset_robot_to_default,
        generate_push_trajectory,
    )
    from close_drawer_config import DRAWER_VARIATION, DRAWER_OPEN_AMOUNT, CONTROL_POINT_RADIUS
    from push_utils import check_task_success
    from pyrep.objects.joint import Joint

    # Create output directories
    inner_ring_dir = os.path.join(output_dir, 'inner_ring_test')
    updown_dir = os.path.join(output_dir, 'updown_test')
    os.makedirs(inner_ring_dir, exist_ok=True)
    os.makedirs(updown_dir, exist_ok=True)

    # Setup environment
    obs_config = ObservationConfig()
    obs_config.set_all(False)
    obs_config.joint_positions = True
    obs_config.joint_velocities = True
    obs_config.gripper_open = True
    obs_config.gripper_pose = True
    obs_config.task_low_dim_state = True

    ACT_MIN = np.array([-2.8973, -1.7628, -2.8973, -3.0718,
                        -2.8973, -0.0175, -2.8973, 0.0], dtype=np.float32)
    ACT_RANGE = np.array([5.7946, 3.5256, 5.7946, 3.0020,
                          5.7946, 3.7700, 5.7946, 1.0], dtype=np.float32)

    class CustomMoveArmThenGripper(MoveArmThenGripper):
        def action_bounds(self):
            return (ACT_MIN, ACT_MIN + ACT_RANGE)

    action_mode = CustomMoveArmThenGripper(JointPosition(True), Discrete())

    rlbench_env = Environment(action_mode=action_mode, obs_config=obs_config, headless=True)
    rlbench_env.launch()

    task_class = task_file_to_task_class("close_drawer")
    task_env = rlbench_env.get_task(task_class)
    task_env.set_variation(2)  # Top drawer

    drawer_variation = DRAWER_VARIATION
    drawer_open_amount = DRAWER_OPEN_AMOUNT
    control_point_radius = CONTROL_POINT_RADIUS

    # Initialize
    descriptions, obs = task_env.reset()
    fix_cabinet_orientation(task_env)
    set_drawer_open(task_env, drawer_variation, drawer_open_amount)
    handle_pos, handle_ori = get_drawer_handle_position(task_env, drawer_variation)

    print(f"Handle position: {handle_pos}")

    # =========================================================================
    # Test 3: 12 different angles at distance=0.5, position=0.5 (inner ring)
    # =========================================================================
    print("\n" + "="*60)
    print("Test 3: Inner Ring (distance=0.5, position=0.5)")
    print("="*60)

    inner_ring_trajectories = []
    inner_ring_angles = []

    for i in range(num_angles):
        angle_deg = i * (360.0 / num_angles)
        angle_rad = np.radians(angle_deg)
        dist_frac = 0.5
        pos_frac = 0.5

        canonical_params = np.array([[angle_rad, dist_frac, pos_frac]])

        print(f"\nAngle {i+1}/{num_angles}: {angle_deg:.1f} degrees")

        try:
            reset_robot_to_default(task_env)
            task_env.reset()
            fix_cabinet_orientation(task_env)
            set_drawer_open(task_env, drawer_variation, drawer_open_amount)

            current_handle_pos, current_handle_ori = get_drawer_handle_position(
                task_env, drawer_variation
            )

            demo, traj_metadata = generate_push_trajectory(
                task_env,
                start_pos=np.zeros(3),
                handle_pos=current_handle_pos,
                handle_ori=current_handle_ori,
                cp_idx=0,
                canonical_params=canonical_params,
                control_point_radius=control_point_radius,
                waypoint_params=None,
                phase_steps=None,
                steps_per_point=5,
                target_drawer_idx=drawer_variation,
            )

            if len(demo) == 0:
                raise RuntimeError("Empty trajectory")

            # Validation checks
            trace = traj_metadata.get("trace")
            phase_indices = traj_metadata.get("phase_indices", {})

            if trace is not None and len(trace) > 0 and phase_indices:
                reach_range = phase_indices.get("reach", (0, 0))
                if reach_range[1] > 0:
                    reach_end_pos = trace[reach_range[1] - 1]
                    dist_to_handle = np.linalg.norm(reach_end_pos - current_handle_pos)

                    push_range = phase_indices.get("push", (0, 0))
                    if push_range[1] > push_range[0]:
                        push_start_pos = trace[push_range[0]]
                        push_end_pos = trace[push_range[1] - 1]
                        push_distance = np.linalg.norm(push_end_pos - push_start_pos)

                        print(f"  Validation:")
                        print(f"    Handle position: {current_handle_pos}")
                        print(f"    Reach end -> Handle distance: {dist_to_handle:.4f}m")
                        print(f"    Push distance: {push_distance:.4f}m")

            # Check if task succeeded
            task_success = check_task_success(task_env)

            # Check drawer joint position
            drawer_names = ['bottom', 'middle', 'top']
            final_drawer_joint = Joint(f'drawer_joint_{drawer_names[drawer_variation]}')
            final_drawer_pos = final_drawer_joint.get_joint_position()
            print(f"    Drawer joint after trajectory: {final_drawer_pos:.4f}m (need <0.04m for success)")
            print(f"    Task success: {task_success}")

            # Extract states and actions
            states = []
            actions = []
            for obs in demo:
                state = np.concatenate([
                    obs.joint_positions,
                    obs.joint_velocities,
                    [obs.gripper_open],
                    obs.gripper_pose
                ])
                states.append(state)

                if hasattr(obs, 'misc') and 'joint_position_action' in obs.misc:
                    action = obs.misc['joint_position_action']
                else:
                    action = np.concatenate([obs.joint_positions, [obs.gripper_open]])
                actions.append(action)

            states = np.array(states)
            actions = np.array(actions)

            inner_ring_trajectories.append({
                'states': torch.from_numpy(states).float(),
                'actions': torch.from_numpy(actions).float(),
                'params': (angle_deg, dist_frac, pos_frac),
                'success': task_success
            })
            inner_ring_angles.append(angle_deg)

            # Save trajectory
            episode_dir = os.path.join(inner_ring_dir, f'episode{i}')
            os.makedirs(episode_dir, exist_ok=True)
            np.save(os.path.join(episode_dir, 'states.npy'), states)
            np.save(os.path.join(episode_dir, 'actions.npy'), actions)
            np.save(os.path.join(episode_dir, 'ee_trajectory.npy'), trace)
            np.save(os.path.join(episode_dir, 'metadata.npy'), {
                'angle': angle_deg,
                'distance': dist_frac,
                'position': pos_frac,
                'success': task_success,
                'drawer_final_pos': final_drawer_pos
            })

            print(f"  Demo: {len(demo)} steps, success: {task_success}")

        except Exception as e:
            print(f"  Failed: {e}")

    # =========================================================================
    # Test 4: angle 90 and 270, each with 6 distances (up-down)
    # =========================================================================
    print("\n" + "="*60)
    print("Test 4: Up-Down Distance Generalization (angle=90,270, position=0.5)")
    print("="*60)

    updown_trajectories = []
    updown_params = []

    # 6 distances from 0.2 to 1.0
    distances = np.linspace(0.2, 1.0, num_distances)
    test_angles = [90, 270]

    episode_idx = 0
    for angle_deg in test_angles:
        for dist_frac in distances:
            angle_rad = np.radians(angle_deg)
            pos_frac = 0.5

            canonical_params = np.array([[angle_rad, dist_frac, pos_frac]])

            print(f"\nEpisode {episode_idx+1}: angle={angle_deg:.0f}, distance={dist_frac:.2f}")

            try:
                reset_robot_to_default(task_env)
                task_env.reset()
                fix_cabinet_orientation(task_env)
                set_drawer_open(task_env, drawer_variation, drawer_open_amount)

                current_handle_pos, current_handle_ori = get_drawer_handle_position(
                    task_env, drawer_variation
                )

                demo, traj_metadata = generate_push_trajectory(
                    task_env,
                    start_pos=np.zeros(3),
                    handle_pos=current_handle_pos,
                    handle_ori=current_handle_ori,
                    cp_idx=0,
                    canonical_params=canonical_params,
                    control_point_radius=control_point_radius,
                    waypoint_params=None,
                    phase_steps=None,
                    steps_per_point=5,
                    target_drawer_idx=drawer_variation,
                )

                if len(demo) == 0:
                    raise RuntimeError("Empty trajectory")

                # Validation checks
                trace = traj_metadata.get("trace")
                phase_indices = traj_metadata.get("phase_indices", {})

                if trace is not None and len(trace) > 0 and phase_indices:
                    reach_range = phase_indices.get("reach", (0, 0))
                    if reach_range[1] > 0:
                        reach_end_pos = trace[reach_range[1] - 1]
                        dist_to_handle = np.linalg.norm(reach_end_pos - current_handle_pos)

                        push_range = phase_indices.get("push", (0, 0))
                        if push_range[1] > push_range[0]:
                            push_start_pos = trace[push_range[0]]
                            push_end_pos = trace[push_range[1] - 1]
                            push_distance = np.linalg.norm(push_end_pos - push_start_pos)

                            print(f"  Validation:")
                            print(f"    Handle position: {current_handle_pos}")
                            print(f"    Reach end -> Handle distance: {dist_to_handle:.4f}m")
                            print(f"    Push distance: {push_distance:.4f}m")

                # Check if task succeeded
                task_success = check_task_success(task_env)

                # Check drawer joint position
                drawer_names = ['bottom', 'middle', 'top']
                final_drawer_joint = Joint(f'drawer_joint_{drawer_names[drawer_variation]}')
                final_drawer_pos = final_drawer_joint.get_joint_position()
                print(f"    Drawer joint after trajectory: {final_drawer_pos:.4f}m (need <0.04m for success)")
                print(f"    Task success: {task_success}")

                # Extract states and actions
                states = []
                actions = []
                for obs in demo:
                    state = np.concatenate([
                        obs.joint_positions,
                        obs.joint_velocities,
                        [obs.gripper_open],
                        obs.gripper_pose
                    ])
                    states.append(state)

                    if hasattr(obs, 'misc') and 'joint_position_action' in obs.misc:
                        action = obs.misc['joint_position_action']
                    else:
                        action = np.concatenate([obs.joint_positions, [obs.gripper_open]])
                    actions.append(action)

                states = np.array(states)
                actions = np.array(actions)

                updown_trajectories.append({
                    'states': torch.from_numpy(states).float(),
                    'actions': torch.from_numpy(actions).float(),
                    'params': (angle_deg, dist_frac, pos_frac),
                    'success': task_success
                })
                updown_params.append((angle_deg, dist_frac))

                # Save trajectory
                episode_dir = os.path.join(updown_dir, f'episode{episode_idx}')
                os.makedirs(episode_dir, exist_ok=True)
                np.save(os.path.join(episode_dir, 'states.npy'), states)
                np.save(os.path.join(episode_dir, 'actions.npy'), actions)
                np.save(os.path.join(episode_dir, 'ee_trajectory.npy'), trace)
                np.save(os.path.join(episode_dir, 'metadata.npy'), {
                    'angle': angle_deg,
                    'distance': dist_frac,
                    'position': pos_frac,
                    'success': task_success,
                    'drawer_final_pos': final_drawer_pos
                })

                print(f"  Demo: {len(demo)} steps, success: {task_success}")
                episode_idx += 1

            except Exception as e:
                print(f"  Failed: {e}")

    rlbench_env.shutdown()

    return (inner_ring_trajectories, np.array(inner_ring_angles),
            updown_trajectories, updown_params)


def load_trajectories_from_disk(output_dir):
    """Load previously generated trajectories from disk."""
    inner_ring_dir = os.path.join(output_dir, 'inner_ring_test')
    updown_dir = os.path.join(output_dir, 'updown_test')

    inner_ring_trajectories = []
    inner_ring_angles = []

    # Load inner ring test trajectories
    if os.path.exists(inner_ring_dir):
        episode_dirs = sorted([d for d in os.listdir(inner_ring_dir) if d.startswith('episode')])
        for ep_dir in episode_dirs:
            ep_path = os.path.join(inner_ring_dir, ep_dir)
            states = np.load(os.path.join(ep_path, 'states.npy'))
            actions = np.load(os.path.join(ep_path, 'actions.npy'))
            metadata = np.load(os.path.join(ep_path, 'metadata.npy'), allow_pickle=True).item()

            inner_ring_trajectories.append({
                'states': torch.from_numpy(states).float(),
                'actions': torch.from_numpy(actions).float(),
                'params': (metadata['angle'], metadata['distance'], metadata['position'])
            })
            inner_ring_angles.append(metadata['angle'])

    updown_trajectories = []
    updown_params = []

    # Load up-down test trajectories
    if os.path.exists(updown_dir):
        episode_dirs = sorted([d for d in os.listdir(updown_dir) if d.startswith('episode')])
        for ep_dir in episode_dirs:
            ep_path = os.path.join(updown_dir, ep_dir)
            states = np.load(os.path.join(ep_path, 'states.npy'))
            actions = np.load(os.path.join(ep_path, 'actions.npy'))
            metadata = np.load(os.path.join(ep_path, 'metadata.npy'), allow_pickle=True).item()

            updown_trajectories.append({
                'states': torch.from_numpy(states).float(),
                'actions': torch.from_numpy(actions).float(),
                'params': (metadata['angle'], metadata['distance'], metadata['position'])
            })
            updown_params.append((metadata['angle'], metadata['distance']))

    return (inner_ring_trajectories, np.array(inner_ring_angles),
            updown_trajectories, updown_params)


def main():
    parser = argparse.ArgumentParser(
        description='Generalization test (inner ring + up-down) for close_drawer encoder'
    )
    parser.add_argument('--encoder_path', type=str, required=True,
                       help='Path to encoder checkpoint')
    parser.add_argument('--output_dir', type=str, required=True,
                       help='Directory to save test data and visualizations')
    parser.add_argument('--device', type=str, default='cuda:0',
                       help='Device to use for encoding')
    parser.add_argument('--skip_generation', action='store_true',
                       help='Skip trajectory generation and load from disk')
    parser.add_argument('--num_angles', type=int, default=12,
                       help='Number of angles to test (inner ring)')
    parser.add_argument('--num_distances', type=int, default=6,
                       help='Number of distances to test per angle (up-down)')

    args = parser.parse_args()

    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)
    plots_dir = os.path.join(args.output_dir, 'plots')
    os.makedirs(plots_dir, exist_ok=True)

    # Generate or load trajectories
    if args.skip_generation:
        print("Loading trajectories from disk...")
        (inner_ring_trajs, inner_ring_angles,
         updown_trajs, updown_params) = load_trajectories_from_disk(args.output_dir)
    else:
        print("Generating test trajectories...")
        (inner_ring_trajs, inner_ring_angles,
         updown_trajs, updown_params) = generate_inner_updown_trajectories(
            args.output_dir, num_angles=args.num_angles, num_distances=args.num_distances
        )

    print(f"\nLoaded {len(inner_ring_trajs)} inner ring test trajectories")
    print(f"Loaded {len(updown_trajs)} up-down test trajectories")

    if len(inner_ring_trajs) == 0 and len(updown_trajs) == 0:
        print("No trajectories to encode!")
        return

    # Load encoder
    device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
    model, config = load_encoder(args.encoder_path)
    model = model.to(device)

    # Encode inner ring test trajectories
    if len(inner_ring_trajs) > 0:
        print("\nEncoding inner ring test trajectories...")
        z_inner_ring = encode_trajectories(model, inner_ring_trajs, device)

        # Save embeddings
        np.save(os.path.join(args.output_dir, 'z_inner_ring_test.npy'), z_inner_ring)
        # angles.npy is shared with the outer ring test (same 12 angles)

        # Visualize (reuse angle test visualization with updated title)
        visualize_angle_test(
            z_inner_ring, inner_ring_angles,
            os.path.join(plots_dir, 'inner_ring_test.png')
        )

        print(f"\nInner ring test embeddings:")
        for i, (angle, z) in enumerate(zip(inner_ring_angles, z_inner_ring)):
            print(f"  angle={angle:6.1f}: z=[{z[0]:7.4f}, {z[1]:7.4f}]")

    # Encode up-down test trajectories
    if len(updown_trajs) > 0:
        print("\nEncoding up-down test trajectories...")
        z_updown = encode_trajectories(model, updown_trajs, device)

        # Save embeddings
        np.save(os.path.join(args.output_dir, 'z_updown_test.npy'), z_updown)
        np.save(os.path.join(args.output_dir, 'updown_params.npy'), np.array(updown_params))

        # Visualize (reuse distance test visualization)
        visualize_distance_test(
            z_updown, updown_params,
            os.path.join(plots_dir, 'updown_test.png')
        )

        print(f"\nUp-down test embeddings:")
        for i, ((angle, dist), z) in enumerate(zip(updown_params, z_updown)):
            print(f"  angle={angle:3.0f}, dist={dist:.2f}: z=[{z[0]:7.4f}, {z[1]:7.4f}]")

    # Combined visualization
    if len(inner_ring_trajs) > 0 and len(updown_trajs) > 0:
        visualize_combined(
            z_inner_ring, inner_ring_angles, z_updown, updown_params,
            os.path.join(plots_dir, 'combined_inner_updown.png')
        )

    print(f"\nResults saved to {args.output_dir}")
    print(f"Plots saved to {plots_dir}")


if __name__ == '__main__':
    import multiprocessing as mp
    mp.set_start_method("spawn", force=True)
    main()
