"""
Visualize EE trajectories and embeddings for each mode/phase separately.
Creates 16 plots: 8 modes × 2 phases (REACH/CARRY)
Each plot has 2 subplots:
  - Left: 10 EE trajectories (3D projected to 2D)
  - Right: 10 embedding points
"""
import os
import sys
import torch
import numpy as np
import argparse

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

sys.path.insert(0, os.path.dirname(__file__))

from trajectory_encoder_normalized_ee import TrajectoryVAENormalizedEE
from full_trajectory_dataset import FullTrajectoryDataset


def build_local_frame_batched(start_pos, end_pos):
    """Build trajectory-relative orthonormal frame (Gram-Schmidt)."""
    line_vec = end_pos - start_pos
    dist = torch.norm(line_vec, dim=1, keepdim=True)
    dist = torch.clamp(dist, min=1e-6)
    line_dir = line_vec / dist

    world_up = torch.tensor([0.0, 0.0, 1.0], device=start_pos.device, dtype=start_pos.dtype)
    world_up = world_up.unsqueeze(0).expand(start_pos.shape[0], -1)

    dot = torch.sum(world_up * line_dir, dim=1, keepdim=True)
    perp1 = world_up - dot * line_dir
    perp1_len = torch.norm(perp1, dim=1, keepdim=True)

    nearly_vertical = perp1_len.squeeze(1) < 1e-6
    if nearly_vertical.any():
        world_forward = torch.tensor([0.0, 1.0, 0.0], device=start_pos.device, dtype=start_pos.dtype)
        world_forward = world_forward.unsqueeze(0).expand(start_pos.shape[0], -1)
        dot_fwd = torch.sum(world_forward * line_dir, dim=1, keepdim=True)
        perp1_alt = world_forward - dot_fwd * line_dir
        perp1_alt_len = torch.norm(perp1_alt, dim=1, keepdim=True)
        perp1_alt = perp1_alt / torch.clamp(perp1_alt_len, min=1e-6)
        perp1 = torch.where(nearly_vertical.unsqueeze(1), perp1_alt, perp1 / torch.clamp(perp1_len, min=1e-6))
    else:
        perp1 = perp1 / perp1_len

    perp2 = torch.cross(line_dir, perp1, dim=1)
    perp2 = perp2 / torch.norm(perp2, dim=1, keepdim=True).clamp(min=1e-6)

    return line_dir, perp1, perp2, dist.squeeze(1)


def extract_and_normalize_ee(states):
    """Extract and normalize EE trajectory."""
    ee = states[:, :, 15:18]
    start_pos = ee[:, 0, :]
    end_pos = ee[:, -1, :]

    line_dir, perp1, perp2, dist = build_local_frame_batched(start_pos, end_pos)

    start_pos_exp = start_pos.unsqueeze(1)
    line_dir_exp = line_dir.unsqueeze(1)
    perp1_exp = perp1.unsqueeze(1)
    perp2_exp = perp2.unsqueeze(1)
    dist_exp = dist.unsqueeze(1).unsqueeze(2)

    vec_from_start = ee - start_pos_exp
    progress = torch.sum(vec_from_start * line_dir_exp, dim=2, keepdim=True) / dist_exp

    line_vec = (end_pos - start_pos).unsqueeze(1)
    point_on_line = start_pos_exp + progress * line_vec
    offset = ee - point_on_line

    perp1_offset = torch.sum(offset * perp1_exp, dim=2, keepdim=True) / dist_exp
    perp2_offset = torch.sum(offset * perp2_exp, dim=2, keepdim=True) / dist_exp

    ee_normalized = torch.cat([progress, perp1_offset, perp2_offset], dim=2)

    return ee_normalized, dist.unsqueeze(1)


def main():
    parser = argparse.ArgumentParser(description='Visualize per-mode embeddings')
    parser.add_argument('--checkpoint', type=str,
                        default='/scratch4/workspace/placeholder-hdp1/dppo/data/stack_blocks/variation0/encoder/z2_normalized_ee/checkpoints/checkpoint_epoch_5000.pt')
    parser.add_argument('--dataset_path', type=str,
                        default='/scratch4/workspace/placeholder-hdp1/dppo/data/stack_blocks/variation0/processed/train_raw.npz')
    parser.add_argument('--metadata_path', type=str,
                        default='/scratch4/workspace/placeholder-hdp1/dppo/data/stack_blocks/variation0/train')
    parser.add_argument('--output_dir', type=str,
                        default='/scratch4/workspace/placeholder-hdp1/dppo/data/stack_blocks/variation0/encoder/z2_normalized_ee/visualizations/per_mode')
    parser.add_argument('--device', type=str, default='cuda:0')

    args = parser.parse_args()

    device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Load model
    print(f"Loading checkpoint: {args.checkpoint}")
    checkpoint = torch.load(args.checkpoint, map_location=device, weights_only=True)
    model = TrajectoryVAENormalizedEE(
        state_dim=22, action_dim=8, hidden_dim=128,
        num_layers=4, num_heads=4, latent_dim=2, horizon=64
    ).to(device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    # Load dataset
    print(f"Loading dataset: {args.dataset_path}")
    dataset = FullTrajectoryDataset(
        dataset_path=args.dataset_path,
        horizon_steps=64,
        device='cpu'
    )

    # Load metadata
    meta_file = os.path.join(args.metadata_path, 'train_metadata.npy')
    metadata = np.load(meta_file, allow_pickle=True)
    print(f"Loaded {len(metadata)} episode metadata entries")

    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)

    # Collect all data
    all_ee = []  # Raw EE trajectories
    all_z = []   # Embeddings
    all_modes = []
    all_phases = []
    all_angles = []
    all_dists = []  # Distance fractions

    print("Collecting embeddings and EE trajectories...")
    with torch.no_grad():
        for i in range(len(dataset)):
            batch = dataset[i]
            states = batch.conditions['state'].unsqueeze(0).to(device)

            # Get raw EE trajectory
            ee = states[0, :, 15:18].cpu().numpy()
            all_ee.append(ee)

            # Get embedding
            ee_normalized, _ = extract_and_normalize_ee(states)
            z = model.encode(ee_normalized).cpu().numpy()[0]
            all_z.append(z)

            # Get metadata
            ep_idx = i // 2
            phase = 'reach' if i % 2 == 0 else 'carry'
            mode = metadata[ep_idx]['mode']
            angle = np.degrees(metadata[ep_idx]['current_cp_params'][0])
            dist = metadata[ep_idx]['current_cp_params'][1]

            all_modes.append(mode)
            all_phases.append(phase)
            all_angles.append(angle)
            all_dists.append(dist)

    all_ee = np.array(all_ee)
    all_z = np.array(all_z)
    all_modes = np.array(all_modes)
    all_phases = np.array(all_phases)
    all_angles = np.array(all_angles)
    all_dists = np.array(all_dists)

    # Get global z range for consistent axes
    z_min = all_z.min(axis=0)
    z_max = all_z.max(axis=0)
    z_margin = (z_max - z_min) * 0.1
    z_xlim = (z_min[0] - z_margin[0], z_max[0] + z_margin[0])
    z_ylim = (z_min[1] - z_margin[1], z_max[1] + z_margin[1])

    # Colors for 10 demos within each mode
    demo_colors = plt.cm.viridis(np.linspace(0, 1, 10))

    # Generate plots for each mode and phase
    for mode in range(8):
        for phase in ['reach', 'carry']:
            # Find trajectories for this mode/phase
            mask = (all_modes == mode) & (all_phases == phase)
            indices = np.where(mask)[0]

            if len(indices) == 0:
                print(f"No data for Mode {mode} {phase.upper()}")
                continue

            # Get data for this mode/phase
            ee_subset = all_ee[indices]
            z_subset = all_z[indices]
            angles_subset = all_angles[indices]
            dists_subset = all_dists[indices]

            # Create figure with 3D subplot on left
            fig = plt.figure(figsize=(14, 5))

            # Get base angle for this mode
            base_angle = mode * 45 if mode < 4 else (mode - 4) * 45
            base_dist = 0.5 if mode < 4 else 1.0

            # Left: EE trajectories (3D)
            ax = fig.add_subplot(1, 2, 1, projection='3d')
            for j, (ee, angle, dist) in enumerate(zip(ee_subset, angles_subset, dists_subset)):
                color = demo_colors[j % 10]
                ax.plot(ee[:, 0], ee[:, 1], ee[:, 2], '-', color=color, alpha=0.8, linewidth=1.5,
                       label=f'Demo {j}: {angle:.1f}°, d={dist:.2f}')
                ax.scatter(ee[0, 0], ee[0, 1], ee[0, 2], c=[color], s=50, marker='o', edgecolors='black', zorder=5)
                ax.scatter(ee[-1, 0], ee[-1, 1], ee[-1, 2], c=[color], s=50, marker='s', edgecolors='black', zorder=5)

            ax.set_xlabel('X', fontsize=10)
            ax.set_ylabel('Y', fontsize=10)
            ax.set_zlabel('Z', fontsize=10)
            ax.set_title(f'Mode {mode} {phase.upper()} - EE Trajectories (3D)\nBase: {base_angle}°, dist={base_dist}', fontsize=12)
            ax.legend(fontsize=6, loc='upper left')

            # Right: Embeddings
            ax = fig.add_subplot(1, 2, 2)
            for j, (z, angle, dist) in enumerate(zip(z_subset, angles_subset, dists_subset)):
                color = demo_colors[j % 10]
                ax.scatter(z[0], z[1], c=[color], s=100, alpha=0.8,
                          edgecolors='black', linewidths=1,
                          label=f'Demo {j}: {angle:.1f}°, d={dist:.2f}')

            ax.set_xlabel('z[0]', fontsize=11)
            ax.set_ylabel('z[1]', fontsize=11)
            ax.set_title(f'Mode {mode} {phase.upper()} - Embeddings\n10 demos, base ~{base_angle}°, d={base_dist}', fontsize=12)
            ax.legend(fontsize=6, loc='upper right')
            ax.grid(True, alpha=0.3)
            ax.set_xlim(z_xlim)
            ax.set_ylim(z_ylim)

            plt.tight_layout()
            save_path = os.path.join(args.output_dir, f'mode{mode}_{phase}.png')
            plt.savefig(save_path, dpi=150, bbox_inches='tight')
            print(f"Saved: {save_path}")
            plt.close(fig)

    print("\nDone!")


if __name__ == '__main__':
    main()
