"""
Visualize REACH and CARRY embeddings separately from a trained encoder checkpoint.

Usage:
    python visualize_reach_carry.py --checkpoint=/path/to/checkpoint.pt
"""
import os
import sys
import torch
import numpy as np
import argparse

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.patches import Patch

sys.path.insert(0, os.path.dirname(__file__))

from trajectory_encoder_normalized_ee import TrajectoryVAENormalizedEE
from full_trajectory_dataset import FullTrajectoryDataset


def build_local_frame_batched(start_pos, end_pos):
    """
    Build trajectory-relative orthonormal frame for a batch of trajectories.
    MUST match the frame used in training (train_normalized_ee_encoder.py).

    Uses Gram-Schmidt projection (same as CP definition in utils.py):
    - perp1 = world_up - (world_up · line_dir) * line_dir
    - perp2 = line_dir × perp1
    """
    line_vec = end_pos - start_pos
    dist = torch.norm(line_vec, dim=1, keepdim=True)
    dist = torch.clamp(dist, min=1e-6)
    line_dir = line_vec / dist

    world_up = torch.tensor([0.0, 0.0, 1.0], device=start_pos.device, dtype=start_pos.dtype)
    world_up = world_up.unsqueeze(0).expand(start_pos.shape[0], -1)

    dot = torch.sum(world_up * line_dir, dim=1, keepdim=True)
    perp1 = world_up - dot * line_dir
    perp1_len = torch.norm(perp1, dim=1, keepdim=True)

    nearly_vertical = perp1_len.squeeze(1) < 1e-6
    if nearly_vertical.any():
        world_forward = torch.tensor([0.0, 1.0, 0.0], device=start_pos.device, dtype=start_pos.dtype)
        world_forward = world_forward.unsqueeze(0).expand(start_pos.shape[0], -1)
        dot_fwd = torch.sum(world_forward * line_dir, dim=1, keepdim=True)
        perp1_alt = world_forward - dot_fwd * line_dir
        perp1_alt_len = torch.norm(perp1_alt, dim=1, keepdim=True)
        perp1_alt = perp1_alt / torch.clamp(perp1_alt_len, min=1e-6)
        perp1 = torch.where(nearly_vertical.unsqueeze(1), perp1_alt, perp1 / torch.clamp(perp1_len, min=1e-6))
    else:
        perp1 = perp1 / perp1_len

    perp2 = torch.cross(line_dir, perp1, dim=1)
    perp2 = perp2 / torch.norm(perp2, dim=1, keepdim=True).clamp(min=1e-6)

    return line_dir, perp1, perp2, dist.squeeze(1)


def extract_and_normalize_ee(states):
    """
    Extract EE positions and normalize in trajectory frame.
    MUST use the same frame construction as training (Gram-Schmidt projection).
    """
    ee = states[:, :, 15:18]  # (B, T, 3)
    start_pos = ee[:, 0, :]   # (B, 3)
    end_pos = ee[:, -1, :]    # (B, 3)

    # Build local frame using Gram-Schmidt (same as training)
    line_dir, perp1, perp2, dist = build_local_frame_batched(start_pos, end_pos)

    # Expand for broadcasting
    start_pos_exp = start_pos.unsqueeze(1)  # (B, 1, 3)
    line_dir_exp = line_dir.unsqueeze(1)    # (B, 1, 3)
    perp1_exp = perp1.unsqueeze(1)          # (B, 1, 3)
    perp2_exp = perp2.unsqueeze(1)          # (B, 1, 3)
    dist_exp = dist.unsqueeze(1).unsqueeze(2)  # (B, 1, 1)

    # Compute progress along trajectory
    vec_from_start = ee - start_pos_exp  # (B, T, 3)
    progress = torch.sum(vec_from_start * line_dir_exp, dim=2, keepdim=True) / dist_exp  # (B, T, 1)

    # Compute perpendicular offsets (normalized by path length)
    line_vec = (end_pos - start_pos).unsqueeze(1)  # (B, 1, 3)
    point_on_line = start_pos_exp + progress * line_vec  # (B, T, 3)
    offset = ee - point_on_line  # (B, T, 3)

    perp1_offset = torch.sum(offset * perp1_exp, dim=2, keepdim=True) / dist_exp  # (B, T, 1)
    perp2_offset = torch.sum(offset * perp2_exp, dim=2, keepdim=True) / dist_exp  # (B, T, 1)

    ee_normalized = torch.cat([progress, perp1_offset, perp2_offset], dim=2)  # (B, T, 3)

    return ee_normalized, dist.unsqueeze(1)  # (B, T, 3), (B, 1)


def load_control_point_metadata(metadata_path):
    """Load CP metadata for visualization."""
    try:
        if os.path.isfile(metadata_path) and metadata_path.endswith('.npy'):
            meta_file = metadata_path
        else:
            if 'processed' in metadata_path:
                base_dir = os.path.dirname(os.path.dirname(metadata_path))
                meta_file = os.path.join(base_dir, 'train', 'train_metadata.npy')
            elif 'train' in metadata_path:
                meta_file = os.path.join(metadata_path, 'train_metadata.npy')
            else:
                meta_file = os.path.join(metadata_path, 'train', 'train_metadata.npy')

        if not os.path.exists(meta_file):
            print(f"Metadata file not found: {meta_file}")
            return None

        episode_metadata = np.load(meta_file, allow_pickle=True)

        modes = []
        cp_params = []
        phase_types = []

        for ep_meta in episode_metadata:
            # REACH trajectory
            modes.append(ep_meta['mode'])
            cp_params.append(ep_meta['current_cp_params'])
            phase_types.append('reach')

            # CARRY trajectory
            modes.append(ep_meta['mode'])
            cp_params.append(ep_meta['current_cp_params'])
            phase_types.append('carry')

        return {
            'modes': np.array(modes),
            'cp_params': cp_params,
            'phase_types': phase_types,
            'num_trajectories': len(modes)
        }

    except Exception as e:
        print(f"Error loading control point metadata: {e}")
        return None


def collect_embeddings(model, dataset, device):
    """Collect all embeddings."""
    model.eval()
    all_z = []

    with torch.no_grad():
        for i in range(len(dataset)):
            batch = dataset[i]
            states = batch.conditions['state'].unsqueeze(0).to(device)
            ee_normalized, _ = extract_and_normalize_ee(states)
            z = model.encode(ee_normalized)
            all_z.append(z.cpu().numpy()[0])

    return np.array(all_z)


def visualize_phase(z_phase, modes_phase, cp_params_phase, phase_name, unique_modes, mode_color_map, save_path):
    """Visualize embeddings for a single phase (REACH or CARRY)."""
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    # Left: embeddings colored by mode
    ax = axes[0]
    point_colors = [mode_color_map[m] for m in modes_phase]
    ax.scatter(z_phase[:, 0], z_phase[:, 1], c=point_colors,
               s=50, alpha=0.7, edgecolors='black', linewidths=0.5)

    legend_elements = [Patch(facecolor=mode_color_map[m], label=f'Mode {m}') for m in unique_modes]
    ax.legend(handles=legend_elements, loc='upper right', fontsize=8)
    ax.set_title(f'{phase_name.upper()} Embeddings by Mode', fontsize=12)
    ax.set_xlabel('z[0]', fontsize=11)
    ax.set_ylabel('z[1]', fontsize=11)
    ax.grid(True, alpha=0.3)

    # Right: CP parameters
    ax = axes[1]
    angles = [np.degrees(p[0]) for p in cp_params_phase]
    dists = [p[1] for p in cp_params_phase]
    ax.scatter(angles, dists, c=point_colors, s=50, alpha=0.7,
               edgecolors='black', linewidths=0.5)
    ax.set_xlabel('Angle (degrees)', fontsize=11)
    ax.set_ylabel('Distance fraction', fontsize=11)
    ax.set_title(f'{phase_name.upper()} Control Point Parameters', fontsize=12)
    ax.set_xlim(-20, 380)
    ax.set_ylim(0, 1.2)
    ax.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    print(f"Saved: {save_path}")
    plt.close(fig)


def main():
    parser = argparse.ArgumentParser(description='Visualize REACH and CARRY embeddings separately')
    parser.add_argument('--checkpoint', type=str,
                        default='/scratch4/workspace/placeholder-hdp1/dppo/data/stack_blocks/variation0/encoder/z2_normalized_ee/checkpoints/checkpoint_epoch_5000.pt',
                        help='Path to checkpoint')
    parser.add_argument('--dataset_path', type=str,
                        default='/scratch4/workspace/placeholder-hdp1/dppo/data/stack_blocks/variation0/processed/train_raw.npz',
                        help='Path to dataset')
    parser.add_argument('--metadata_path', type=str,
                        default='/scratch4/workspace/placeholder-hdp1/dppo/data/stack_blocks/variation0/train',
                        help='Path to metadata')
    parser.add_argument('--output_dir', type=str,
                        default='/scratch4/workspace/placeholder-hdp1/dppo/data/stack_blocks/variation0/encoder/z2_normalized_ee/visualizations',
                        help='Output directory')
    parser.add_argument('--device', type=str, default='cuda:0', help='Device')

    args = parser.parse_args()

    device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Load checkpoint
    print(f"Loading checkpoint: {args.checkpoint}")
    checkpoint = torch.load(args.checkpoint, map_location=device, weights_only=True)

    # Initialize model
    model = TrajectoryVAENormalizedEE(
        state_dim=22, action_dim=8, hidden_dim=128,
        num_layers=4, num_heads=4, latent_dim=2, horizon=64
    ).to(device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    # Load dataset
    print(f"Loading dataset: {args.dataset_path}")
    dataset = FullTrajectoryDataset(
        dataset_path=args.dataset_path,
        horizon_steps=64,
        device='cpu'
    )
    print(f"Loaded {len(dataset)} trajectories")

    # Collect embeddings
    print("Collecting embeddings...")
    all_z = collect_embeddings(model, dataset, device)
    print(f"z shape: {all_z.shape}")
    print(f"z range: [{all_z.min():.4f}, {all_z.max():.4f}]")

    # Load metadata
    cp_metadata = load_control_point_metadata(args.metadata_path)
    if cp_metadata is None:
        print("Error: Could not load metadata")
        return

    modes = cp_metadata['modes'][:len(all_z)]
    cp_params = cp_metadata['cp_params'][:len(all_z)]
    phase_types = np.array(cp_metadata['phase_types'][:len(all_z)])

    # Separate REACH and CARRY
    reach_mask = phase_types == 'reach'
    carry_mask = phase_types == 'carry'

    z_reach = all_z[reach_mask]
    z_carry = all_z[carry_mask]
    modes_reach = modes[reach_mask]
    modes_carry = modes[carry_mask]
    cp_params_reach = [cp_params[i] for i in range(len(cp_params)) if reach_mask[i]]
    cp_params_carry = [cp_params[i] for i in range(len(cp_params)) if carry_mask[i]]

    # Setup colors
    unique_modes = np.unique(modes)
    mode_colors = plt.cm.tab10(np.linspace(0, 1, len(unique_modes)))
    mode_color_map = {m: mode_colors[i] for i, m in enumerate(unique_modes)}

    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)

    # Generate visualizations
    visualize_phase(z_reach, modes_reach, cp_params_reach, 'reach', unique_modes, mode_color_map,
                    os.path.join(args.output_dir, 'embeddings_reach_final.png'))

    visualize_phase(z_carry, modes_carry, cp_params_carry, 'carry', unique_modes, mode_color_map,
                    os.path.join(args.output_dir, 'embeddings_carry_final.png'))

    print("\nDone!")


if __name__ == '__main__':
    main()
