"""
Training script for pick_place trajectory encoder with endpoint conditioning.

Key changes from EE-only encoder:
1. Encodes full state (22D) + actions (8D) trajectory
2. Uses s0-relative coordinates
3. Decoder conditioned on z + endpoint_direction (3D) + time
4. This forces z to encode trajectory shape (HOW), while endpoint_direction provides WHERE

The endpoint_direction is a normalized vector from start_ee to end_ee, giving minimal
context about trajectory destination without allowing decoder to overfit to full s0.
"""
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import wandb

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

from tqdm import tqdm
import argparse

from trajectory_encoder_endpoint import TrajectoryVAEWithEndpoint
from soft_dtw import SoftDTW
from full_trajectory_dataset import FullTrajectoryDataset


def make_trajectory_relative(states, actions):
    """
    Convert trajectories to relative coordinates w.r.t. s0.

    This removes absolute positioning information, forcing the encoder to learn
    trajectory shape/mode rather than absolute location.

    State format: joint_pos(7) + joint_vel(7) + gripper_open(1) + gripper_pose(7)
    Action format: joint_pos(7) + gripper_open(1)

    Args:
        states: (B, T, 22) - absolute states
        actions: (B, T, 8) - absolute actions

    Returns:
        states_rel: (B, T, 22) - relative states
        actions_rel: (B, T, 8) - relative actions
    """
    states_rel = states.clone()
    actions_rel = actions.clone()

    # Extract s0 reference values
    s0_joint_pos = states[:, 0:1, 0:7]  # (B, 1, 7)
    s0_joint_vel = states[:, 0:1, 7:14]  # (B, 1, 7)
    s0_gripper_xyz = states[:, 0:1, 15:18]  # (B, 1, 3)

    # Make joint positions relative to s0
    states_rel[:, :, 0:7] = states[:, :, 0:7] - s0_joint_pos

    # Make joint velocities relative to s0 (optional, helps with normalization)
    states_rel[:, :, 7:14] = states[:, :, 7:14] - s0_joint_vel

    # Make gripper xyz position relative to s0
    states_rel[:, :, 15:18] = states[:, :, 15:18] - s0_gripper_xyz

    # Keep gripper_open and quaternion as-is (indices 14 and 18:22)

    # Make actions (joint positions) relative to s0
    actions_rel[:, :, 0:7] = actions[:, :, 0:7] - s0_joint_pos

    # Keep gripper_open in actions as-is (index 7)

    return states_rel, actions_rel


def compute_endpoint_direction(states):
    """
    Compute normalized direction from start EE to end EE.

    This provides minimal context about WHERE the trajectory goes,
    without giving full s0 information that could cause overfitting.

    Args:
        states: (B, T, 22) - absolute states (NOT relative!)

    Returns:
        direction: (B, 3) - normalized direction vector
    """
    start_ee = states[:, 0, 15:18]   # (B, 3)
    end_ee = states[:, -1, 15:18]    # (B, 3)

    direction = end_ee - start_ee
    direction = direction / (torch.norm(direction, dim=1, keepdim=True) + 1e-8)

    return direction


def extract_relative_ee(states):
    """
    Extract relative EE positions from s0-relative states.
    Used for DTW loss computation.

    states: (B, T, 22) - s0-relative states
    Returns: (B, T, 3) - relative EE positions (start at origin)
    """
    return states[:, :, 15:18]


def compute_reconstruction_loss(pred_states, pred_actions, gt_states, gt_actions):
    """
    Reconstruction loss for predicted sequences.

    pred_states: (B, T-1, state_dim) - predicted s1, s2, ..., s_{T-1}
    pred_actions: (B, T, action_dim) - predicted a0, a1, ..., a_{T-1}
    gt_states: (B, T, state_dim) - ground truth states (s0-relative)
    gt_actions: (B, T, action_dim) - ground truth actions (s0-relative)
    """
    # Actions: predict a0, a1, ..., a_{T-1}
    action_loss = F.mse_loss(pred_actions, gt_actions)

    # States: predict s1, s2, ..., s_{T-1} (skip s0 which is always zero)
    if pred_states is not None:
        state_loss = F.mse_loss(pred_states, gt_states[:, 1:, :])
    else:
        state_loss = torch.tensor(0.0, device=pred_actions.device)

    return action_loss + state_loss


def compute_kl_loss(z_mean, z_logvar):
    """KL divergence loss: KL(q(z|x) || p(z)) where p(z) = N(0, I)"""
    kl_loss = -0.5 * torch.sum(1 + z_logvar - z_mean.pow(2) - z_logvar.exp(), dim=1)
    return kl_loss.mean()


def compute_dtw_loss(z, states_rel, gamma=0.00001):
    """
    DTW loss: distances in z-space should match EE trajectory distances.

    Uses relative EE positions (start at origin) to focus on trajectory shape.

    z: (B, latent_dim) - latent embeddings
    states_rel: (B, T, state_dim) - s0-relative trajectory states
    """
    B = z.shape[0]
    if B < 2:
        return torch.tensor(0.0, device=z.device, requires_grad=True)

    # Extract relative EE positions
    ee_traj = extract_relative_ee(states_rel)  # (B, T, 3)

    # Create shifted pairs by rolling
    z_1 = z
    z_2 = torch.roll(z, 1, 0)
    ee_traj_1 = ee_traj
    ee_traj_2 = torch.roll(ee_traj, 1, 0)

    # Compute z-space distances (L2 distance)
    dz = torch.sqrt(torch.sum((z_1 - z_2) ** 2, dim=1) + 1e-8)

    # Compute soft-DTW distances between EE trajectories
    sdtw = SoftDTW(gamma=gamma, normalize=True)
    traj_dists = sdtw(ee_traj_1, ee_traj_2)

    # Normalize both to [0, 1] range for comparison
    dz_normalized = dz / (torch.max(dz) + 1e-8)
    traj_dists_normalized = traj_dists / (torch.max(traj_dists) + 1e-8)

    # MSE loss between normalized distances
    loss = F.mse_loss(dz_normalized, traj_dists_normalized)
    return loss


def load_control_point_metadata(metadata_path):
    """Load CP metadata for visualization."""
    try:
        if os.path.isfile(metadata_path) and metadata_path.endswith('.npy'):
            meta_file = metadata_path
        else:
            if 'processed' in metadata_path:
                base_dir = os.path.dirname(os.path.dirname(metadata_path))
                meta_file = os.path.join(base_dir, 'train', 'train_metadata.npy')
            elif 'train' in metadata_path:
                meta_file = os.path.join(metadata_path, 'train_metadata.npy')
            else:
                meta_file = os.path.join(metadata_path, 'train', 'train_metadata.npy')

        if not os.path.exists(meta_file):
            print(f"Metadata file not found: {meta_file}")
            return None

        episode_metadata = np.load(meta_file, allow_pickle=True)

        modes = []
        cp_params = []
        phase_types = []

        for ep_meta in episode_metadata:
            # REACH trajectory
            modes.append(ep_meta['mode'])
            cp_params.append(ep_meta['current_cp_params'])
            phase_types.append('reach')

            # CARRY trajectory
            modes.append(ep_meta['mode'])
            cp_params.append(ep_meta['current_cp_params'])
            phase_types.append('carry')

        return {
            'modes': np.array(modes),
            'cp_params': cp_params,
            'phase_types': phase_types,
            'num_trajectories': len(modes)
        }

    except Exception as e:
        print(f"Error loading control point metadata: {e}")
        import traceback
        traceback.print_exc()
        return None


def visualize_embeddings(z, epoch, cp_metadata=None, save_path=None):
    """Visualize embeddings colored by mode and phase type."""
    if z.shape[1] > 2:
        from sklearn.manifold import TSNE
        z_vis = TSNE(n_components=2, random_state=42, perplexity=min(30, len(z) - 1)).fit_transform(z)
        title_suffix = f' (t-SNE from {z.shape[1]}D)'
    else:
        z_vis = z
        title_suffix = ''

    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    ax_mode = axes[0]
    ax_phase = axes[1]
    ax_params = axes[2]

    if cp_metadata is not None and len(cp_metadata['modes']) >= len(z_vis):
        modes = cp_metadata['modes'][:len(z_vis)]
        cp_params = cp_metadata['cp_params'][:len(z_vis)]
        phase_types = cp_metadata['phase_types'][:len(z_vis)]

        unique_modes = np.unique(modes)
        mode_colors = plt.cm.tab10(np.linspace(0, 1, len(unique_modes)))
        mode_color_map = {m: mode_colors[i] for i, m in enumerate(unique_modes)}
        point_colors_mode = [mode_color_map[m] for m in modes]

        ax_mode.scatter(z_vis[:, 0], z_vis[:, 1], c=point_colors_mode,
                        s=50, alpha=0.7, edgecolors='black', linewidths=0.5)

        from matplotlib.patches import Patch
        legend_elements = [Patch(facecolor=mode_color_map[m], label=f'Mode {m}') for m in unique_modes]
        ax_mode.legend(handles=legend_elements, loc='upper right', fontsize=8)
        ax_mode.set_title(f'Embeddings by Mode (Epoch {epoch}){title_suffix}', fontsize=11)

        # Phase type coloring - REACH and CARRY should overlap!
        phase_color_map = {'reach': 'blue', 'carry': 'red'}
        point_colors_phase = [phase_color_map[p] for p in phase_types]

        ax_phase.scatter(z_vis[:, 0], z_vis[:, 1], c=point_colors_phase,
                         s=50, alpha=0.7, edgecolors='black', linewidths=0.5)

        legend_elements_phase = [Patch(facecolor='blue', label='REACH'),
                                 Patch(facecolor='red', label='CARRY')]
        ax_phase.legend(handles=legend_elements_phase, loc='upper right', fontsize=10)
        ax_phase.set_title(f'Embeddings by Phase Type{title_suffix}\n(REACH & CARRY should overlap!)', fontsize=11)

        angles = [np.degrees(p[0]) for p in cp_params]
        dists = [p[1] for p in cp_params]

        ax_params.scatter(angles, dists, c=point_colors_mode, s=50, alpha=0.7,
                          edgecolors='black', linewidths=0.5)
        ax_params.set_xlabel('Angle (degrees)', fontsize=11)
        ax_params.set_ylabel('Distance fraction', fontsize=11)
        ax_params.set_title('Control Point Parameters', fontsize=11)
        ax_params.set_xlim(-20, 380)
        ax_params.set_ylim(0, 1.2)
        ax_params.grid(True, alpha=0.3)
    else:
        scatter = ax_mode.scatter(z_vis[:, 0], z_vis[:, 1], c=np.arange(len(z_vis)),
                                  cmap='viridis', s=50, alpha=0.7)
        plt.colorbar(scatter, ax=ax_mode, label='Trajectory Index')

    ax_mode.set_xlabel('z[0]', fontsize=11)
    ax_mode.set_ylabel('z[1]', fontsize=11)
    ax_mode.grid(True, alpha=0.3)

    ax_phase.set_xlabel('z[0]', fontsize=11)
    ax_phase.set_ylabel('z[1]', fontsize=11)
    ax_phase.grid(True, alpha=0.3)

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')

    return fig


def train_epoch(model, dataloader, optimizer, device, epoch, log_freq=50, kl_weight=0.001, dtw_weight=0.1):
    """Train for one epoch using full state+action with endpoint conditioning."""
    model.train()

    total_loss = 0.0
    total_recon_loss = 0.0
    total_kl_loss = 0.0
    total_dtw_loss = 0.0
    num_batches = 0

    pbar = tqdm(dataloader, desc=f"Epoch {epoch}", leave=False)

    for batch_idx, batch in enumerate(pbar):
        states = batch.conditions['state'].to(device)  # (B, T, 22)
        actions = batch.actions.to(device)  # (B, T, 8)

        # Compute endpoint direction BEFORE making relative (needs absolute positions)
        endpoint_direction = compute_endpoint_direction(states)  # (B, 3)

        # Convert to s0-relative coordinates
        states_rel, actions_rel = make_trajectory_relative(states, actions)

        # Forward pass
        outputs = model(states_rel, actions_rel, endpoint_direction)

        z_mean = outputs['z_mean']
        z_logvar = outputs['z_logvar']
        z = outputs['z']
        pred_states = outputs['pred_states']
        pred_actions = outputs['pred_actions']

        # Compute losses (use relative coordinates for reconstruction)
        recon_loss = compute_reconstruction_loss(pred_states, pred_actions, states_rel, actions_rel)
        kl_loss = compute_kl_loss(z_mean, z_logvar)
        dtw_loss = compute_dtw_loss(z_mean, states_rel, gamma=0.00001)

        loss = recon_loss + kl_weight * kl_loss + dtw_weight * dtw_loss

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        total_loss += loss.item()
        total_recon_loss += recon_loss.item()
        total_kl_loss += kl_loss.item()
        total_dtw_loss += dtw_loss.item()
        num_batches += 1

        pbar.set_postfix({
            'total': f'{loss.item():.4f}',
            'recon': f'{recon_loss.item():.4f}',
            'kl': f'{kl_loss.item():.6f}',
            'dtw': f'{dtw_loss.item():.4f}'
        })

        if batch_idx % log_freq == 0:
            wandb.log({
                'batch_total_loss': loss.item(),
                'batch_recon_loss': recon_loss.item(),
                'batch_kl_loss': kl_loss.item(),
                'batch_dtw_loss': dtw_loss.item(),
                'epoch': epoch,
                'batch': batch_idx
            })

    avg_loss = total_loss / num_batches if num_batches > 0 else 0
    avg_recon_loss = total_recon_loss / num_batches if num_batches > 0 else 0
    avg_kl_loss = total_kl_loss / num_batches if num_batches > 0 else 0
    avg_dtw_loss = total_dtw_loss / num_batches if num_batches > 0 else 0

    return avg_loss, avg_recon_loss, avg_kl_loss, avg_dtw_loss


def collect_embeddings(model, dataset, device, batch_size=256):
    """Collect all embeddings for visualization."""
    model.eval()

    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True
    )

    all_z = []

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Collecting embeddings", leave=False):
            states = batch.conditions['state'].to(device)
            actions = batch.actions.to(device)

            # Convert to relative coordinates
            states_rel, actions_rel = make_trajectory_relative(states, actions)

            # Encode
            z = model.encode(states_rel, actions_rel)
            all_z.append(z.cpu().numpy())

    return np.concatenate(all_z, axis=0)


def main():
    parser = argparse.ArgumentParser(description='Train Pick-and-Place Trajectory Encoder with Endpoint Conditioning')
    parser.add_argument('--horizon', type=int, default=64)
    parser.add_argument('--latent_dim', type=int, default=2)
    parser.add_argument('--hidden_dim', type=int, default=128)
    parser.add_argument('--num_layers', type=int, default=4)
    parser.add_argument('--num_heads', type=int, default=4)
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--num_epochs', type=int, default=5000)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--kl_weight', type=float, default=0.001)
    parser.add_argument('--dtw_weight', type=float, default=1.0)
    parser.add_argument('--vis_freq', type=int, default=100)
    parser.add_argument('--save_freq', type=int, default=500)
    parser.add_argument('--dataset_path', type=str,
                        default='/scratch4/workspace/placeholder-hdp1/dppo/data/stack_blocks/variation0/processed/train_raw.npz')
    parser.add_argument('--metadata_path', type=str,
                        default='/scratch4/workspace/placeholder-hdp1/dppo/data/stack_blocks/variation0/train')
    parser.add_argument('--save_dir', type=str,
                        default='/scratch4/workspace/placeholder-hdp1/dppo/data/stack_blocks/variation0/encoder')
    parser.add_argument('--device', type=str, default='cuda:0')
    parser.add_argument('--wandb_project', type=str, default='pick-place-trajectory-encoder')

    args = parser.parse_args()

    config = {
        'dataset_path': args.dataset_path,
        'metadata_path': args.metadata_path,
        'state_dim': 22,
        'action_dim': 8,
        'horizon': args.horizon,
        'latent_dim': args.latent_dim,
        'hidden_dim': args.hidden_dim,
        'num_layers': args.num_layers,
        'num_heads': args.num_heads,
        'batch_size': args.batch_size,
        'num_epochs': args.num_epochs,
        'learning_rate': args.lr,
        'kl_weight': args.kl_weight,
        'dtw_weight': args.dtw_weight,
        'device': args.device if torch.cuda.is_available() else 'cpu',
        'log_freq': 10,
        'vis_freq': args.vis_freq,
        'save_freq': args.save_freq,
        'save_dir': args.save_dir
    }

    os.makedirs(config['save_dir'], exist_ok=True)
    latent_dir = os.path.join(config['save_dir'], f'z{config["latent_dim"]}_endpoint')
    checkpoint_dir = os.path.join(latent_dir, 'checkpoints')
    visualization_dir = os.path.join(latent_dir, 'visualizations')
    os.makedirs(checkpoint_dir, exist_ok=True)
    os.makedirs(visualization_dir, exist_ok=True)

    wandb.init(
        project=args.wandb_project,
        name=f'traj_vae_endpoint_z{config["latent_dim"]}_h{config["hidden_dim"]}',
        config=config
    )

    device = torch.device(config['device'])
    print(f"Using device: {device}")

    print(f"Loading dataset from {config['dataset_path']}")
    dataset = FullTrajectoryDataset(
        dataset_path=config['dataset_path'],
        horizon_steps=config['horizon'],
        device='cpu',
        max_n_episodes=None
    )
    print(f"Dataset size: {len(dataset)} trajectories")

    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=config['batch_size'], shuffle=True, num_workers=0, pin_memory=True
    )

    # Use TrajectoryVAE with endpoint conditioning
    model = TrajectoryVAEWithEndpoint(
        state_dim=config['state_dim'],
        action_dim=config['action_dim'],
        hidden_dim=config['hidden_dim'],
        num_layers=config['num_layers'],
        num_heads=config['num_heads'],
        latent_dim=config['latent_dim'],
        horizon=config['horizon']
    ).to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=config['learning_rate'], weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['num_epochs'], eta_min=1e-6)

    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

    cp_metadata = load_control_point_metadata(config['metadata_path'])
    if cp_metadata is not None:
        print(f"Loaded CP metadata for {cp_metadata['num_trajectories']} trajectories")

    best_loss = float('inf')
    epoch_pbar = tqdm(range(1, config['num_epochs'] + 1), desc="Training", position=0)

    for epoch in epoch_pbar:
        avg_loss, avg_recon_loss, avg_kl_loss, avg_dtw_loss = train_epoch(
            model, dataloader, optimizer, device, epoch,
            log_freq=config['log_freq'],
            kl_weight=config['kl_weight'],
            dtw_weight=config['dtw_weight']
        )

        wandb.log({
            'epoch': epoch,
            'avg_total_loss': avg_loss,
            'avg_recon_loss': avg_recon_loss,
            'avg_kl_loss': avg_kl_loss,
            'avg_dtw_loss': avg_dtw_loss,
            'learning_rate': optimizer.param_groups[0]['lr']
        })

        epoch_pbar.set_postfix({
            'total': f'{avg_loss:.4f}',
            'recon': f'{avg_recon_loss:.4f}',
            'kl': f'{avg_kl_loss:.6f}',
            'dtw': f'{avg_dtw_loss:.4f}'
        })

        if epoch % config['vis_freq'] == 0 or epoch == 1:
            all_z = collect_embeddings(model, dataset, device, batch_size=config['batch_size'])

            # Print z statistics
            print(f"\n  z stats: mean={all_z.mean(0)}, std={all_z.std(0)}, min={all_z.min(0)}, max={all_z.max(0)}")

            fig = visualize_embeddings(
                all_z, epoch, cp_metadata=cp_metadata,
                save_path=os.path.join(visualization_dir, f'embeddings_epoch_{epoch}.png')
            )
            wandb.log({'embeddings': wandb.Image(fig), 'epoch': epoch})
            plt.close(fig)

        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_loss,
                'config': config
            }, os.path.join(checkpoint_dir, 'best_model.pt'))

        if epoch % config['save_freq'] == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_loss,
                'config': config
            }, os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pt'))

        scheduler.step()

    epoch_pbar.close()

    print("\nGenerating final embedding visualization...")
    all_z = collect_embeddings(model, dataset, device, batch_size=config['batch_size'])
    print(f"Final z stats: mean={all_z.mean(0)}, std={all_z.std(0)}, min={all_z.min(0)}, max={all_z.max(0)}")

    fig = visualize_embeddings(
        all_z, config['num_epochs'], cp_metadata=cp_metadata,
        save_path=os.path.join(visualization_dir, 'embeddings_final.png')
    )
    wandb.log({'final_embeddings': wandb.Image(fig)})
    plt.close(fig)

    wandb.finish()
    print("Training complete!")


if __name__ == '__main__':
    main()
