"""
Training script for close_drawer trajectory encoder/decoder with reconstruction, KL, and DTW losses.

Adapted from RLBench/train_relative_trajectory_encoder.py for the close_drawer task.
"""
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import wandb

# Set matplotlib backend to Agg (non-interactive) before importing pyplot
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

from tqdm import tqdm
import argparse

# Local imports
from trajectory_encoder import TrajectoryVAE
from soft_dtw import SoftDTW
from full_trajectory_dataset import FullTrajectoryDataset


def compute_reconstruction_loss(pred_states, pred_actions, gt_states, gt_actions):
    """
    Reconstruction loss for predicted sequences
    pred_states: (B, horizon-1, state_dim) - predicted s1, s2, s3
    pred_actions: (B, horizon, action_dim) - predicted a0, a1, a2, a3
    gt_states: (B, horizon, state_dim) - ground truth states
    gt_actions: (B, horizon, action_dim) - ground truth actions
    """
    # Actions: predict a0, a1, a2, a3
    action_loss = F.mse_loss(pred_actions, gt_actions)

    # States: predict s1, s2, s3 (skip s0 as it's given)
    if pred_states is not None:
        state_loss = F.mse_loss(pred_states, gt_states[:, 1:, :])
    else:
        state_loss = torch.tensor(0.0, device=pred_actions.device)

    return action_loss + state_loss


def compute_kl_loss(z_mean, z_logvar):
    """
    KL divergence loss: KL(q(z|x) || p(z)) where p(z) = N(0, I)

    This regularizes the latent space to be close to a standard Gaussian,
    preventing z values from exploding while allowing enough spread for
    different trajectory modes to be distinguishable.

    z_mean: (B, latent_dim) - unbounded mean (no tanh)
    z_logvar: (B, latent_dim) - log variance
    """
    # KL = -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    kl_loss = -0.5 * torch.sum(1 + z_logvar - z_mean.pow(2) - z_logvar.exp(), dim=1)
    return kl_loss.mean()


def make_trajectory_relative(states, actions):
    """
    Convert trajectories to relative coordinates w.r.t. s0.

    This removes absolute positioning information, forcing the encoder to learn
    trajectory shape/mode rather than absolute location. The decoder will then
    reconstruct from z alone, without needing s0 as input.

    State format: joint_pos(7) + joint_vel(7) + gripper_open(1) + gripper_pose(7)
    Action format: joint_pos(7) + gripper_open(1)

    Args:
        states: (B, T, 22) - absolute states
        actions: (B, T, 8) - absolute actions

    Returns:
        states_rel: (B, T, 22) - relative states
        actions_rel: (B, T, 8) - relative actions
    """
    states_rel = states.clone()
    actions_rel = actions.clone()

    # Extract s0 reference values
    s0_joint_pos = states[:, 0:1, 0:7]  # (B, 1, 7)
    s0_joint_vel = states[:, 0:1, 7:14]  # (B, 1, 7)
    s0_gripper_xyz = states[:, 0:1, 15:18]  # (B, 1, 3)

    # Make joint positions relative to s0
    states_rel[:, :, 0:7] = states[:, :, 0:7] - s0_joint_pos

    # Make joint velocities relative to s0 (optional, but helps)
    states_rel[:, :, 7:14] = states[:, :, 7:14] - s0_joint_vel

    # Make gripper xyz position relative to s0
    states_rel[:, :, 15:18] = states[:, :, 15:18] - s0_gripper_xyz

    # Keep gripper_open and quaternion as-is (indices 14 and 18:22)
    # Quaternion could also be made relative, but it's complex and may not be necessary

    # Make actions (joint positions) relative to s0
    actions_rel[:, :, 0:7] = actions[:, :, 0:7] - s0_joint_pos

    # Keep gripper_open in actions as-is (index 7)

    return states_rel, actions_rel


def extract_endeffector_position(states, relative=False):
    """
    Extract end-effector position from states
    State format: joint_pos(7) + joint_vel(7) + gripper_open(1) + gripper_pose(7)
    gripper_pose = [x, y, z, qx, qy, qz, qw]

    states: (B, T, 22)
    relative: if True, return positions relative to first timestep (x_0 = [0,0,0])
    Returns: (B, T, 3) - xyz positions
    """
    # End-effector position is at indices 15:18 (after joint_pos(7) + joint_vel(7) + gripper_open(1))
    ee_pos = states[:, :, 15:18]  # (B, T, 3)

    if relative:
        # Make positions relative to first timestep
        ee_pos_relative = ee_pos - ee_pos[:, 0:1, :]  # (B, T, 3), x_0 becomes [0,0,0]
        return ee_pos_relative

    return ee_pos


def compute_dtw_loss(z, states, gamma=0.00001):
    """
    DTW loss: distances in z-space should match end-effector trajectory distances.
    For two trajectories τ1, τ2:
        distance(z1, z2) should ≈ DTW_distance(ee_traj1, ee_traj2)

    Uses relative positions (x_0 = [0,0,0]) to focus on trajectory shape rather than absolute location.

    z: (B, latent_dim) - latent embeddings
    states: (B, T, state_dim) - trajectory states containing end-effector positions
    gamma: smoothness parameter for soft-DTW
    """
    B = z.shape[0]
    if B < 2:
        return torch.tensor(0.0, device=z.device, requires_grad=True)

    # Extract end-effector positions (3D trajectories) in relative coordinates
    ee_traj = extract_endeffector_position(states, relative=True)  # (B, T, 3), x_0 = [0,0,0]

    # Create shifted pairs by rolling
    z_1 = z  # (B, latent_dim)
    z_2 = torch.roll(z, 1, 0)
    ee_traj_1 = ee_traj  # (B, T, 3)
    ee_traj_2 = torch.roll(ee_traj, 1, 0)

    # Compute z-space distances (L2 distance)
    dz = torch.sqrt(torch.sum((z_1 - z_2) ** 2, dim=1) + 1e-8)  # (B,)

    # Compute soft-DTW distances between 3D end-effector trajectories
    sdtw = SoftDTW(gamma=gamma, normalize=True)
    traj_dists = sdtw(ee_traj_1, ee_traj_2)  # (B,)

    # Normalize both to [0, 1] range for comparison
    dz_normalized = dz / (torch.max(dz) + 1e-8)
    traj_dists_normalized = traj_dists / (torch.max(traj_dists) + 1e-8)

    # MSE loss between normalized distances
    loss = F.mse_loss(dz_normalized, traj_dists_normalized)
    return loss


def load_control_point_metadata(data_path):
    """
    Load control point metadata from close_drawer dataset.

    Args:
        data_path: Path to episodes directory

    Returns:
        dict with control point info per episode, or None if not found
    """
    try:
        episodes_dir = data_path
        if not os.path.exists(episodes_dir):
            # Try to find episodes directory
            if 'processed' in data_path:
                # Go from processed dir to episodes
                base_dir = os.path.dirname(os.path.dirname(data_path))
                episodes_dir = os.path.join(base_dir, 'train', 'episodes')

        if not os.path.exists(episodes_dir):
            print(f"Episodes directory not found: {episodes_dir}")
            return None

        # Load metadata from each episode
        episode_folders = sorted([d for d in os.listdir(episodes_dir) if d.startswith('episode')])

        control_points = []
        modes = []
        cp_params = []

        for ep_folder in episode_folders:
            metadata_path = os.path.join(episodes_dir, ep_folder, 'metadata.npy')
            if os.path.exists(metadata_path):
                metadata = np.load(metadata_path, allow_pickle=True).item()

                # Extract control point position (3D)
                if 'cp_reach' in metadata:
                    control_points.append(metadata['cp_reach'])

                # Extract mode
                if 'mode' in metadata:
                    modes.append(metadata['mode'])

                # Extract CP params (angle, dist, pos)
                if 'canonical_cp_params' in metadata:
                    cp_params.append(metadata['canonical_cp_params'])

        if len(control_points) == 0:
            print("No control point data found in episode metadata")
            return None

        return {
            'control_points': np.array(control_points),
            'modes': np.array(modes),
            'cp_params': cp_params,
            'num_episodes': len(episode_folders)
        }

    except Exception as e:
        print(f"Error loading control point metadata: {e}")
        import traceback
        traceback.print_exc()
        return None


def visualize_embeddings(z, epoch, cp_metadata=None, save_path=None):
    """
    Visualize embeddings colored by control point mode.

    z: (N, latent_dim) numpy array of embeddings
    cp_metadata: dict with 'modes', 'cp_params' from load_control_point_metadata
    """
    # Apply t-SNE if z is higher than 2D
    if z.shape[1] > 2:
        from sklearn.manifold import TSNE
        print(f"Applying t-SNE to reduce {z.shape[1]}D embeddings to 2D...")
        z_vis = TSNE(n_components=2, random_state=42, perplexity=min(30, len(z) - 1)).fit_transform(z)
        title_suffix = f' (t-SNE from {z.shape[1]}D)'
    else:
        z_vis = z
        title_suffix = ''

    # Create figure
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    ax_main = axes[0]
    ax_params = axes[1]

    if cp_metadata is not None and len(cp_metadata['modes']) >= len(z_vis):
        modes = cp_metadata['modes'][:len(z_vis)]
        cp_params = cp_metadata['cp_params'][:len(z_vis)]

        # Color by mode
        unique_modes = np.unique(modes)
        colors = plt.cm.tab10(np.linspace(0, 1, len(unique_modes)))
        color_map = {m: colors[i] for i, m in enumerate(unique_modes)}

        point_colors = [color_map[m] for m in modes]

        scatter = ax_main.scatter(z_vis[:, 0], z_vis[:, 1], c=point_colors,
                                  s=50, alpha=0.7, edgecolors='black', linewidths=0.5)

        # Create legend
        from matplotlib.patches import Patch
        legend_elements = [Patch(facecolor=color_map[m], label=f'Mode {m}') for m in unique_modes]
        ax_main.legend(handles=legend_elements, loc='upper right', fontsize=8)

        # Plot CP params (angle vs distance) on right subplot
        angles = [np.degrees(p[0]) for p in cp_params]
        dists = [p[1] for p in cp_params]

        ax_params.scatter(angles, dists, c=point_colors, s=50, alpha=0.7,
                         edgecolors='black', linewidths=0.5)
        ax_params.set_xlabel('Angle (degrees)', fontsize=11)
        ax_params.set_ylabel('Distance fraction', fontsize=11)
        ax_params.set_title('Control Point Parameters', fontsize=12)
        ax_params.set_xlim(-20, 380)
        ax_params.set_ylim(0, 1.2)
        ax_params.grid(True, alpha=0.3)
    else:
        # Default coloring by index
        scatter = ax_main.scatter(z_vis[:, 0], z_vis[:, 1], c=np.arange(len(z_vis)),
                                  cmap='viridis', s=50, alpha=0.7,
                                  edgecolors='black', linewidths=0.5)
        plt.colorbar(scatter, ax=ax_main, label='Trajectory Index')
        ax_params.text(0.5, 0.5, 'No CP metadata available',
                      transform=ax_params.transAxes, ha='center', va='center')

    ax_main.set_xlabel('z[0]' if z.shape[1] <= 2 else 'Component 1', fontsize=11)
    ax_main.set_ylabel('z[1]' if z.shape[1] <= 2 else 'Component 2', fontsize=11)
    ax_main.set_title(f'Trajectory Embeddings (Epoch {epoch}){title_suffix}', fontsize=12)
    ax_main.grid(True, alpha=0.3)

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')

    return fig


def train_epoch(model, dataloader, optimizer, device, epoch, log_freq=50, kl_weight=0.001, dtw_weight=0.1):
    """Train for one epoch"""
    model.train()

    total_loss = 0.0
    total_recon_loss = 0.0
    total_kl_loss = 0.0
    total_dtw_loss = 0.0
    num_batches = 0

    pbar = tqdm(dataloader, desc=f"Epoch {epoch}", leave=False)

    for batch_idx, batch in enumerate(pbar):
        # Extract states and actions from batch
        actions = batch.actions  # (B, horizon, action_dim)
        states = batch.conditions['state']  # (B, horizon, state_dim)

        # Move data to device
        actions = actions.to(device)
        states = states.to(device)

        # If states only has 1 timestep, skip
        if states.shape[1] < actions.shape[1]:
            continue

        # Convert to relative coordinates (removes absolute position info)
        states_rel, actions_rel = make_trajectory_relative(states, actions)

        # Forward pass with relative coordinates
        outputs = model(states_rel, actions_rel)

        z_mean = outputs['z_mean']
        z_logvar = outputs['z_logvar']
        z = outputs['z']
        pred_states = outputs['pred_states']
        pred_actions = outputs['pred_actions']

        # Compute losses (use relative coordinates for reconstruction)
        recon_loss = compute_reconstruction_loss(pred_states, pred_actions, states_rel, actions_rel)
        kl_loss = compute_kl_loss(z_mean, z_logvar)
        dtw_loss = compute_dtw_loss(z_mean, states_rel, gamma=0.00001)

        # Total loss
        loss = recon_loss + kl_weight * kl_loss + dtw_weight * dtw_loss

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        # Accumulate losses
        total_loss += loss.item()
        total_recon_loss += recon_loss.item()
        total_kl_loss += kl_loss.item()
        total_dtw_loss += dtw_loss.item()
        num_batches += 1

        # Update progress bar
        pbar.set_postfix({
            'total': f'{loss.item():.4f}',
            'recon': f'{recon_loss.item():.4f}',
            'kl': f'{kl_loss.item():.6f}',
            'dtw': f'{dtw_loss.item():.4f}'
        })

        # Log to wandb
        if batch_idx % log_freq == 0:
            wandb.log({
                'batch_total_loss': loss.item(),
                'batch_recon_loss': recon_loss.item(),
                'batch_kl_loss': kl_loss.item(),
                'batch_dtw_loss': dtw_loss.item(),
                'epoch': epoch,
                'batch': batch_idx
            })

    # Average losses
    avg_loss = total_loss / num_batches if num_batches > 0 else 0
    avg_recon_loss = total_recon_loss / num_batches if num_batches > 0 else 0
    avg_kl_loss = total_kl_loss / num_batches if num_batches > 0 else 0
    avg_dtw_loss = total_dtw_loss / num_batches if num_batches > 0 else 0

    return avg_loss, avg_recon_loss, avg_kl_loss, avg_dtw_loss


def collect_embeddings(model, dataset, device, batch_size=256):
    """
    Collect all embeddings for visualization.
    Creates a new dataloader with shuffle=False to ensure embeddings
    are collected in episode order.
    """
    model.eval()

    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=True
    )

    all_z = []

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Collecting embeddings", leave=False):
            actions = batch.actions
            states = batch.conditions['state']

            actions = actions.to(device)
            states = states.to(device)

            if states.shape[1] < actions.shape[1]:
                continue

            # Convert to relative coordinates
            states_rel, actions_rel = make_trajectory_relative(states, actions)

            z = model.encode(states_rel, actions_rel)
            all_z.append(z.cpu().numpy())

    all_z = np.concatenate(all_z, axis=0)
    return all_z


def main():
    # Parse command-line arguments
    parser = argparse.ArgumentParser(description='Train Close Drawer Trajectory Encoder VAE')
    parser.add_argument('--horizon', type=int, default=87, help='Horizon length for trajectories (87 for close_drawer)')
    parser.add_argument('--latent_dim', type=int, default=2, help='Latent dimension')
    parser.add_argument('--hidden_dim', type=int, default=128, help='Hidden dimension')
    parser.add_argument('--num_layers', type=int, default=4, help='Number of transformer layers')
    parser.add_argument('--num_heads', type=int, default=4, help='Number of attention heads')
    parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
    parser.add_argument('--num_epochs', type=int, default=5000, help='Number of epochs')
    parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate')
    parser.add_argument('--kl_weight', type=float, default=0.001, help='KL loss weight')
    parser.add_argument('--dtw_weight', type=float, default=1.0, help='DTW loss weight')
    parser.add_argument('--vis_freq', type=int, default=100, help='Visualization frequency (epochs)')
    parser.add_argument('--save_freq', type=int, default=500, help='Checkpoint save frequency (epochs)')
    parser.add_argument('--dataset_path', type=str,
                        default='/scratch4/workspace/placeholder-hdp1/dppo/data/close_drawer/variation2/processed/train_normalized.npz',
                        help='Path to training dataset')
    parser.add_argument('--episodes_path', type=str,
                        default='/scratch4/workspace/placeholder-hdp1/dppo/data/close_drawer/variation2/train/episodes',
                        help='Path to episodes directory (for control point metadata)')
    parser.add_argument('--save_dir', type=str,
                        default='/scratch4/workspace/placeholder-hdp1/dppo/data/close_drawer/variation2/encoder',
                        help='Directory to save checkpoints and visualizations')
    parser.add_argument('--device', type=str, default='cuda:0', help='Device to use')
    parser.add_argument('--wandb_project', type=str, default='close-drawer-encoder', help='Wandb project name')

    args = parser.parse_args()

    # Configuration from args
    config = {
        'dataset_path': args.dataset_path,
        'episodes_path': args.episodes_path,
        'state_dim': 22,
        'action_dim': 8,
        'horizon': args.horizon,
        'latent_dim': args.latent_dim,
        'hidden_dim': args.hidden_dim,
        'num_layers': args.num_layers,
        'num_heads': args.num_heads,
        'batch_size': args.batch_size,
        'num_epochs': args.num_epochs,
        'learning_rate': args.lr,
        'kl_weight': args.kl_weight,
        'dtw_weight': args.dtw_weight,
        'device': args.device if torch.cuda.is_available() else 'cpu',
        'log_freq': 10,
        'vis_freq': args.vis_freq,
        'save_freq': args.save_freq,
        'save_dir': args.save_dir
    }

    # Create save directory with subfolders
    os.makedirs(config['save_dir'], exist_ok=True)
    latent_dir = os.path.join(config['save_dir'], f'z{config["latent_dim"]}')
    checkpoint_dir = os.path.join(latent_dir, 'checkpoints')
    visualization_dir = os.path.join(latent_dir, 'visualizations')
    os.makedirs(checkpoint_dir, exist_ok=True)
    os.makedirs(visualization_dir, exist_ok=True)

    # Initialize wandb
    wandb.init(
        project=args.wandb_project,
        name=f'close_drawer_vae_z{config["latent_dim"]}_h{config["hidden_dim"]}',
        config=config
    )

    # Set device
    device = torch.device(config['device'])
    print(f"Using device: {device}")

    # Load dataset
    print(f"Loading dataset from {config['dataset_path']}")
    dataset = FullTrajectoryDataset(
        dataset_path=config['dataset_path'],
        horizon_steps=config['horizon'],
        device='cpu',
        max_n_episodes=None
    )

    print(f"Dataset size: {len(dataset)} trajectories")

    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=0,
        pin_memory=True
    )

    # Initialize model
    model = TrajectoryVAE(
        state_dim=config['state_dim'],
        action_dim=config['action_dim'],
        hidden_dim=config['hidden_dim'],
        num_layers=config['num_layers'],
        num_heads=config['num_heads'],
        latent_dim=config['latent_dim'],
        horizon=config['horizon']
    ).to(device)

    # Optimizer
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config['learning_rate'],
        weight_decay=1e-5
    )

    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        T_max=config['num_epochs'],
        eta_min=1e-6
    )

    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

    # Load control point metadata for visualization
    cp_metadata = load_control_point_metadata(config['episodes_path'])
    if cp_metadata is not None:
        print(f"Loaded control point metadata for {cp_metadata['num_episodes']} episodes")
    else:
        print("Control point metadata not available - using default visualization")

    # Training loop
    best_loss = float('inf')
    epoch_pbar = tqdm(range(1, config['num_epochs'] + 1), desc="Training", position=0)

    for epoch in epoch_pbar:
        # Train
        avg_loss, avg_recon_loss, avg_kl_loss, avg_dtw_loss = train_epoch(
            model, dataloader, optimizer, device, epoch,
            log_freq=config['log_freq'],
            kl_weight=config['kl_weight'],
            dtw_weight=config['dtw_weight']
        )

        # Log epoch metrics
        wandb.log({
            'epoch': epoch,
            'avg_total_loss': avg_loss,
            'avg_recon_loss': avg_recon_loss,
            'avg_kl_loss': avg_kl_loss,
            'avg_dtw_loss': avg_dtw_loss,
            'learning_rate': optimizer.param_groups[0]['lr']
        })

        # Update epoch progress bar
        epoch_pbar.set_postfix({
            'total': f'{avg_loss:.4f}',
            'recon': f'{avg_recon_loss:.4f}',
            'kl': f'{avg_kl_loss:.6f}',
            'dtw': f'{avg_dtw_loss:.4f}'
        })

        # Visualize embeddings
        if epoch % config['vis_freq'] == 0 or epoch == 1:
            all_z = collect_embeddings(model, dataset, device, batch_size=config['batch_size'])

            fig = visualize_embeddings(
                all_z,
                epoch,
                cp_metadata=cp_metadata,
                save_path=os.path.join(visualization_dir, f'embeddings_epoch_{epoch}.png')
            )

            wandb.log({
                'embeddings': wandb.Image(fig),
                'epoch': epoch
            })
            plt.close(fig)

        # Save best model
        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_loss,
                'config': config
            }, os.path.join(checkpoint_dir, 'best_model.pt'))

        # Save checkpoint
        if epoch % config['save_freq'] == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_loss,
                'config': config
            }, os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pt'))

        # Step scheduler
        scheduler.step()

    epoch_pbar.close()

    # Final embedding visualization
    print("\nGenerating final embedding visualization...")
    all_z = collect_embeddings(model, dataset, device, batch_size=config['batch_size'])
    fig = visualize_embeddings(
        all_z,
        config['num_epochs'],
        cp_metadata=cp_metadata,
        save_path=os.path.join(visualization_dir, 'embeddings_final.png')
    )
    wandb.log({'final_embeddings': wandb.Image(fig)})
    plt.close(fig)

    wandb.finish()
    print("Training complete!")


if __name__ == '__main__':
    main()
