"""
Training script for pick_place EE trajectory encoder.

This encoder ONLY uses the normalized end-effector trajectory (3D: progress, perp1, perp2).
This ensures REACH and CARRY trajectories with the same CP params produce identical embeddings.

Key insight: The 4th subplot in dataset_visualization.py shows that after normalizing to
trajectory-relative coordinates, REACH and CARRY with the same CP have identical shapes.
We want the encoder to learn this - same normalized shape = same z.
"""
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import wandb

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

from tqdm import tqdm
import argparse

from ee_trajectory_encoder import EETrajectoryVAE
from soft_dtw import SoftDTW
from full_trajectory_dataset import FullTrajectoryDataset


def build_local_frame_batched(start_pos, end_pos):
    """
    Build trajectory-relative orthonormal frame for a batch of trajectories.
    """
    line_vec = end_pos - start_pos
    dist = torch.norm(line_vec, dim=1, keepdim=True)
    dist = torch.clamp(dist, min=1e-6)
    line_dir = line_vec / dist

    world_up = torch.tensor([0.0, 0.0, 1.0], device=start_pos.device, dtype=start_pos.dtype)
    world_up = world_up.unsqueeze(0).expand(start_pos.shape[0], -1)

    dot = torch.sum(world_up * line_dir, dim=1, keepdim=True)
    perp1 = world_up - dot * line_dir
    perp1_len = torch.norm(perp1, dim=1, keepdim=True)

    nearly_vertical = perp1_len.squeeze(1) < 1e-6
    if nearly_vertical.any():
        world_forward = torch.tensor([0.0, 1.0, 0.0], device=start_pos.device, dtype=start_pos.dtype)
        world_forward = world_forward.unsqueeze(0).expand(start_pos.shape[0], -1)
        dot_fwd = torch.sum(world_forward * line_dir, dim=1, keepdim=True)
        perp1_alt = world_forward - dot_fwd * line_dir
        perp1_alt_len = torch.norm(perp1_alt, dim=1, keepdim=True)
        perp1_alt = perp1_alt / torch.clamp(perp1_alt_len, min=1e-6)
        perp1 = torch.where(nearly_vertical.unsqueeze(1), perp1_alt, perp1 / torch.clamp(perp1_len, min=1e-6))
    else:
        perp1 = perp1 / perp1_len

    perp2 = torch.cross(line_dir, perp1, dim=1)
    perp2 = perp2 / torch.norm(perp2, dim=1, keepdim=True).clamp(min=1e-6)

    return line_dir, perp1, perp2, dist.squeeze(1)


def normalize_ee_trajectory(ee_pos):
    """
    Normalize EE trajectory relative to start and end positions.

    Transforms to trajectory-relative coordinates:
    - X: progress (0 at start, 1 at end)
    - Y: perp1 offset (vertical), scaled by path length
    - Z: perp2 offset (horizontal), scaled by path length

    This ensures REACH and CARRY with same CP have identical normalized shapes.
    """
    B, T, _ = ee_pos.shape

    start_pos = ee_pos[:, 0, :]
    end_pos = ee_pos[:, -1, :]

    line_dir, perp1, perp2, dist = build_local_frame_batched(start_pos, end_pos)

    start_pos_exp = start_pos.unsqueeze(1)
    line_dir_exp = line_dir.unsqueeze(1)
    perp1_exp = perp1.unsqueeze(1)
    perp2_exp = perp2.unsqueeze(1)
    dist_exp = dist.unsqueeze(1).unsqueeze(2)

    vec_from_start = ee_pos - start_pos_exp
    progress = torch.sum(vec_from_start * line_dir_exp, dim=2, keepdim=True) / dist_exp

    line_vec = (end_pos - start_pos).unsqueeze(1)
    point_on_line = start_pos_exp + progress * line_vec
    offset = ee_pos - point_on_line

    perp1_offset = torch.sum(offset * perp1_exp, dim=2, keepdim=True) / dist_exp
    perp2_offset = torch.sum(offset * perp2_exp, dim=2, keepdim=True) / dist_exp

    ee_normalized = torch.cat([progress, perp1_offset, perp2_offset], dim=2)
    return ee_normalized


def extract_and_normalize_ee(states):
    """
    Extract EE positions from states and normalize to trajectory-relative coordinates.

    states: (B, T, 22) - full state with EE at indices 15:18
    Returns: (B, T, 3) - normalized EE trajectory [progress, perp1, perp2]
    """
    ee_pos = states[:, :, 15:18]  # (B, T, 3)
    ee_normalized = normalize_ee_trajectory(ee_pos)
    return ee_normalized


def compute_reconstruction_loss(pred_ee, gt_ee):
    """Reconstruction loss for EE trajectory"""
    return F.mse_loss(pred_ee, gt_ee)


def compute_kl_loss(z_mean, z_logvar):
    """KL divergence loss"""
    kl_loss = -0.5 * torch.sum(1 + z_logvar - z_mean.pow(2) - z_logvar.exp(), dim=1)
    return kl_loss.mean()


def compute_dtw_loss(z, ee_traj, gamma=0.00001):
    """
    DTW loss: distances in z-space should match normalized EE trajectory distances.
    """
    B = z.shape[0]
    if B < 2:
        return torch.tensor(0.0, device=z.device, requires_grad=True)

    z_1 = z
    z_2 = torch.roll(z, 1, 0)
    ee_traj_1 = ee_traj
    ee_traj_2 = torch.roll(ee_traj, 1, 0)

    dz = torch.sqrt(torch.sum((z_1 - z_2) ** 2, dim=1) + 1e-8)

    sdtw = SoftDTW(gamma=gamma, normalize=True)
    traj_dists = sdtw(ee_traj_1, ee_traj_2)

    dz_normalized = dz / (torch.max(dz) + 1e-8)
    traj_dists_normalized = traj_dists / (torch.max(traj_dists) + 1e-8)

    loss = F.mse_loss(dz_normalized, traj_dists_normalized)
    return loss


def load_control_point_metadata(metadata_path):
    """Load CP metadata for visualization."""
    try:
        if os.path.isfile(metadata_path) and metadata_path.endswith('.npy'):
            meta_file = metadata_path
        else:
            if 'processed' in metadata_path:
                base_dir = os.path.dirname(os.path.dirname(metadata_path))
                meta_file = os.path.join(base_dir, 'train', 'train_metadata.npy')
            elif 'train' in metadata_path:
                meta_file = os.path.join(metadata_path, 'train_metadata.npy')
            else:
                meta_file = os.path.join(metadata_path, 'train', 'train_metadata.npy')

        if not os.path.exists(meta_file):
            print(f"Metadata file not found: {meta_file}")
            return None

        episode_metadata = np.load(meta_file, allow_pickle=True)

        modes = []
        cp_params = []
        phase_types = []

        for ep_meta in episode_metadata:
            # REACH trajectory
            modes.append(ep_meta['mode'])
            cp_params.append(ep_meta['current_cp_params'])
            phase_types.append('reach')

            # CARRY trajectory
            modes.append(ep_meta['mode'])
            cp_params.append(ep_meta['current_cp_params'])
            phase_types.append('carry')

        return {
            'modes': np.array(modes),
            'cp_params': cp_params,
            'phase_types': phase_types,
            'num_trajectories': len(modes)
        }

    except Exception as e:
        print(f"Error loading control point metadata: {e}")
        return None


def visualize_embeddings(z, epoch, cp_metadata=None, save_path=None):
    """Visualize embeddings colored by mode and phase type."""
    if z.shape[1] > 2:
        from sklearn.manifold import TSNE
        z_vis = TSNE(n_components=2, random_state=42, perplexity=min(30, len(z) - 1)).fit_transform(z)
        title_suffix = f' (t-SNE from {z.shape[1]}D)'
    else:
        z_vis = z
        title_suffix = ''

    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    ax_mode = axes[0]
    ax_phase = axes[1]
    ax_params = axes[2]

    if cp_metadata is not None and len(cp_metadata['modes']) >= len(z_vis):
        modes = cp_metadata['modes'][:len(z_vis)]
        cp_params = cp_metadata['cp_params'][:len(z_vis)]
        phase_types = cp_metadata['phase_types'][:len(z_vis)]

        unique_modes = np.unique(modes)
        mode_colors = plt.cm.tab10(np.linspace(0, 1, len(unique_modes)))
        mode_color_map = {m: mode_colors[i] for i, m in enumerate(unique_modes)}
        point_colors_mode = [mode_color_map[m] for m in modes]

        ax_mode.scatter(z_vis[:, 0], z_vis[:, 1], c=point_colors_mode,
                        s=50, alpha=0.7, edgecolors='black', linewidths=0.5)

        from matplotlib.patches import Patch
        legend_elements = [Patch(facecolor=mode_color_map[m], label=f'Mode {m}') for m in unique_modes]
        ax_mode.legend(handles=legend_elements, loc='upper right', fontsize=8)
        ax_mode.set_title(f'Embeddings by Mode (Epoch {epoch}){title_suffix}', fontsize=11)

        # Phase type coloring - REACH and CARRY should overlap!
        phase_color_map = {'reach': 'blue', 'carry': 'red'}
        point_colors_phase = [phase_color_map[p] for p in phase_types]

        ax_phase.scatter(z_vis[:, 0], z_vis[:, 1], c=point_colors_phase,
                         s=50, alpha=0.7, edgecolors='black', linewidths=0.5)

        legend_elements_phase = [Patch(facecolor='blue', label='REACH'),
                                 Patch(facecolor='red', label='CARRY')]
        ax_phase.legend(handles=legend_elements_phase, loc='upper right', fontsize=10)
        ax_phase.set_title(f'Embeddings by Phase Type{title_suffix}\n(REACH & CARRY should overlap!)', fontsize=11)

        angles = [np.degrees(p[0]) for p in cp_params]
        dists = [p[1] for p in cp_params]

        ax_params.scatter(angles, dists, c=point_colors_mode, s=50, alpha=0.7,
                          edgecolors='black', linewidths=0.5)
        ax_params.set_xlabel('Angle (degrees)', fontsize=11)
        ax_params.set_ylabel('Distance fraction', fontsize=11)
        ax_params.set_title('Control Point Parameters', fontsize=11)
        ax_params.set_xlim(-20, 380)
        ax_params.set_ylim(0, 1.2)
        ax_params.grid(True, alpha=0.3)
    else:
        scatter = ax_mode.scatter(z_vis[:, 0], z_vis[:, 1], c=np.arange(len(z_vis)),
                                  cmap='viridis', s=50, alpha=0.7)
        plt.colorbar(scatter, ax=ax_mode, label='Trajectory Index')

    ax_mode.set_xlabel('z[0]', fontsize=11)
    ax_mode.set_ylabel('z[1]', fontsize=11)
    ax_mode.grid(True, alpha=0.3)

    ax_phase.set_xlabel('z[0]', fontsize=11)
    ax_phase.set_ylabel('z[1]', fontsize=11)
    ax_phase.grid(True, alpha=0.3)

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')

    return fig


def train_epoch(model, dataloader, optimizer, device, epoch, log_freq=50, kl_weight=0.001, dtw_weight=0.1):
    """Train for one epoch using only normalized EE trajectories."""
    model.train()

    total_loss = 0.0
    total_recon_loss = 0.0
    total_kl_loss = 0.0
    total_dtw_loss = 0.0
    num_batches = 0

    pbar = tqdm(dataloader, desc=f"Epoch {epoch}", leave=False)

    for batch_idx, batch in enumerate(pbar):
        states = batch.conditions['state'].to(device)  # (B, T, 22)

        # Extract and normalize EE trajectory
        ee_normalized = extract_and_normalize_ee(states)  # (B, T, 3)

        # Forward pass with ONLY the normalized EE trajectory
        outputs = model(ee_normalized)

        z_mean = outputs['z_mean']
        z_logvar = outputs['z_logvar']
        z = outputs['z']
        pred_ee = outputs['pred_ee_traj']

        # Compute losses
        recon_loss = compute_reconstruction_loss(pred_ee, ee_normalized)
        kl_loss = compute_kl_loss(z_mean, z_logvar)
        dtw_loss = compute_dtw_loss(z_mean, ee_normalized, gamma=0.00001)

        loss = recon_loss + kl_weight * kl_loss + dtw_weight * dtw_loss

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        total_loss += loss.item()
        total_recon_loss += recon_loss.item()
        total_kl_loss += kl_loss.item()
        total_dtw_loss += dtw_loss.item()
        num_batches += 1

        pbar.set_postfix({
            'total': f'{loss.item():.4f}',
            'recon': f'{recon_loss.item():.4f}',
            'kl': f'{kl_loss.item():.6f}',
            'dtw': f'{dtw_loss.item():.4f}'
        })

        if batch_idx % log_freq == 0:
            wandb.log({
                'batch_total_loss': loss.item(),
                'batch_recon_loss': recon_loss.item(),
                'batch_kl_loss': kl_loss.item(),
                'batch_dtw_loss': dtw_loss.item(),
                'epoch': epoch,
                'batch': batch_idx
            })

    avg_loss = total_loss / num_batches if num_batches > 0 else 0
    avg_recon_loss = total_recon_loss / num_batches if num_batches > 0 else 0
    avg_kl_loss = total_kl_loss / num_batches if num_batches > 0 else 0
    avg_dtw_loss = total_dtw_loss / num_batches if num_batches > 0 else 0

    return avg_loss, avg_recon_loss, avg_kl_loss, avg_dtw_loss


def collect_embeddings(model, dataset, device, batch_size=256):
    """Collect all embeddings for visualization."""
    model.eval()

    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True
    )

    all_z = []

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Collecting embeddings", leave=False):
            states = batch.conditions['state'].to(device)
            ee_normalized = extract_and_normalize_ee(states)
            z = model.encode(ee_normalized)
            all_z.append(z.cpu().numpy())

    return np.concatenate(all_z, axis=0)


def main():
    parser = argparse.ArgumentParser(description='Train Pick-and-Place EE Trajectory Encoder')
    parser.add_argument('--horizon', type=int, default=64)
    parser.add_argument('--latent_dim', type=int, default=2)
    parser.add_argument('--hidden_dim', type=int, default=128)
    parser.add_argument('--num_layers', type=int, default=4)
    parser.add_argument('--num_heads', type=int, default=4)
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--num_epochs', type=int, default=5000)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--kl_weight', type=float, default=0.001)
    parser.add_argument('--dtw_weight', type=float, default=1.0)
    parser.add_argument('--vis_freq', type=int, default=100)
    parser.add_argument('--save_freq', type=int, default=500)
    parser.add_argument('--dataset_path', type=str,
                        default='/scratch4/workspace/placeholder-hdp1/dppo/data/stack_blocks/variation0/processed/train_raw.npz')
    parser.add_argument('--metadata_path', type=str,
                        default='/scratch4/workspace/placeholder-hdp1/dppo/data/stack_blocks/variation0/train')
    parser.add_argument('--save_dir', type=str,
                        default='/scratch4/workspace/placeholder-hdp1/dppo/data/stack_blocks/variation0/encoder')
    parser.add_argument('--device', type=str, default='cuda:0')
    parser.add_argument('--wandb_project', type=str, default='pick-place-ee-encoder')

    args = parser.parse_args()

    config = {
        'dataset_path': args.dataset_path,
        'metadata_path': args.metadata_path,
        'ee_dim': 3,  # Only 3D normalized EE trajectory
        'horizon': args.horizon,
        'latent_dim': args.latent_dim,
        'hidden_dim': args.hidden_dim,
        'num_layers': args.num_layers,
        'num_heads': args.num_heads,
        'batch_size': args.batch_size,
        'num_epochs': args.num_epochs,
        'learning_rate': args.lr,
        'kl_weight': args.kl_weight,
        'dtw_weight': args.dtw_weight,
        'device': args.device if torch.cuda.is_available() else 'cpu',
        'log_freq': 10,
        'vis_freq': args.vis_freq,
        'save_freq': args.save_freq,
        'save_dir': args.save_dir
    }

    os.makedirs(config['save_dir'], exist_ok=True)
    latent_dir = os.path.join(config['save_dir'], f'z{config["latent_dim"]}')
    checkpoint_dir = os.path.join(latent_dir, 'checkpoints')
    visualization_dir = os.path.join(latent_dir, 'visualizations')
    os.makedirs(checkpoint_dir, exist_ok=True)
    os.makedirs(visualization_dir, exist_ok=True)

    wandb.init(
        project=args.wandb_project,
        name=f'ee_vae_z{config["latent_dim"]}_h{config["hidden_dim"]}',
        config=config
    )

    device = torch.device(config['device'])
    print(f"Using device: {device}")

    print(f"Loading dataset from {config['dataset_path']}")
    dataset = FullTrajectoryDataset(
        dataset_path=config['dataset_path'],
        horizon_steps=config['horizon'],
        device='cpu',
        max_n_episodes=None
    )
    print(f"Dataset size: {len(dataset)} trajectories")

    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=config['batch_size'], shuffle=True, num_workers=0, pin_memory=True
    )

    # Use EE-only VAE (3D input, not 22D)
    model = EETrajectoryVAE(
        ee_dim=config['ee_dim'],
        hidden_dim=config['hidden_dim'],
        num_layers=config['num_layers'],
        num_heads=config['num_heads'],
        latent_dim=config['latent_dim'],
        horizon=config['horizon']
    ).to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=config['learning_rate'], weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['num_epochs'], eta_min=1e-6)

    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

    cp_metadata = load_control_point_metadata(config['metadata_path'])
    if cp_metadata is not None:
        print(f"Loaded CP metadata for {cp_metadata['num_trajectories']} trajectories")

    best_loss = float('inf')
    epoch_pbar = tqdm(range(1, config['num_epochs'] + 1), desc="Training", position=0)

    for epoch in epoch_pbar:
        avg_loss, avg_recon_loss, avg_kl_loss, avg_dtw_loss = train_epoch(
            model, dataloader, optimizer, device, epoch,
            log_freq=config['log_freq'],
            kl_weight=config['kl_weight'],
            dtw_weight=config['dtw_weight']
        )

        wandb.log({
            'epoch': epoch,
            'avg_total_loss': avg_loss,
            'avg_recon_loss': avg_recon_loss,
            'avg_kl_loss': avg_kl_loss,
            'avg_dtw_loss': avg_dtw_loss,
            'learning_rate': optimizer.param_groups[0]['lr']
        })

        epoch_pbar.set_postfix({
            'total': f'{avg_loss:.4f}',
            'recon': f'{avg_recon_loss:.4f}',
            'kl': f'{avg_kl_loss:.6f}',
            'dtw': f'{avg_dtw_loss:.4f}'
        })

        if epoch % config['vis_freq'] == 0 or epoch == 1:
            all_z = collect_embeddings(model, dataset, device, batch_size=config['batch_size'])
            fig = visualize_embeddings(
                all_z, epoch, cp_metadata=cp_metadata,
                save_path=os.path.join(visualization_dir, f'embeddings_epoch_{epoch}.png')
            )
            wandb.log({'embeddings': wandb.Image(fig), 'epoch': epoch})
            plt.close(fig)

        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_loss,
                'config': config
            }, os.path.join(checkpoint_dir, 'best_model.pt'))

        if epoch % config['save_freq'] == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_loss,
                'config': config
            }, os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pt'))

        scheduler.step()

    epoch_pbar.close()

    print("\nGenerating final embedding visualization...")
    all_z = collect_embeddings(model, dataset, device, batch_size=config['batch_size'])
    fig = visualize_embeddings(
        all_z, config['num_epochs'], cp_metadata=cp_metadata,
        save_path=os.path.join(visualization_dir, 'embeddings_final.png')
    )
    wandb.log({'final_embeddings': wandb.Image(fig)})
    plt.close(fig)

    wandb.finish()
    print("Training complete!")


if __name__ == '__main__':
    main()
