"""
Training script for grasp trajectory encoder/decoder with reconstruction, KL, and DTW losses.

Adapted from close_drawer encoder for the pick_up_cup grasp task.
The encoder learns to map trajectories to a 2D latent space representing (approach_angle, grasp_height).
"""
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import wandb

# Set matplotlib backend to Agg (non-interactive) before importing pyplot
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

from tqdm import tqdm
import argparse

# Local imports
from trajectory_encoder import TrajectoryVAE
from soft_dtw import SoftDTW
from full_trajectory_dataset import FullTrajectoryDataset


def compute_reconstruction_loss(pred_states, pred_actions, gt_states, gt_actions):
    """
    Reconstruction loss for predicted sequences
    """
    action_loss = F.mse_loss(pred_actions, gt_actions)

    if pred_states is not None:
        state_loss = F.mse_loss(pred_states, gt_states[:, 1:, :])
    else:
        state_loss = torch.tensor(0.0, device=pred_actions.device)

    return action_loss + state_loss


def compute_kl_loss(z_mean, z_logvar):
    """
    KL divergence loss: KL(q(z|x) || p(z)) where p(z) = N(0, I)
    """
    kl_loss = -0.5 * torch.sum(1 + z_logvar - z_mean.pow(2) - z_logvar.exp(), dim=1)
    return kl_loss.mean()


def make_trajectory_relative(states, actions):
    """
    Convert trajectories to relative coordinates w.r.t. s0.
    """
    states_rel = states.clone()
    actions_rel = actions.clone()

    s0_joint_pos = states[:, 0:1, 0:7]
    s0_joint_vel = states[:, 0:1, 7:14]
    s0_gripper_xyz = states[:, 0:1, 15:18]

    states_rel[:, :, 0:7] = states[:, :, 0:7] - s0_joint_pos
    states_rel[:, :, 7:14] = states[:, :, 7:14] - s0_joint_vel
    states_rel[:, :, 15:18] = states[:, :, 15:18] - s0_gripper_xyz

    actions_rel[:, :, 0:7] = actions[:, :, 0:7] - s0_joint_pos

    return states_rel, actions_rel


def extract_endeffector_position(states, relative=False):
    """
    Extract end-effector position from states
    """
    ee_pos = states[:, :, 15:18]

    if relative:
        ee_pos_relative = ee_pos - ee_pos[:, 0:1, :]
        return ee_pos_relative

    return ee_pos


def compute_dtw_loss(z, states, gamma=0.00001):
    """
    DTW loss: distances in z-space should match end-effector trajectory distances.
    """
    B = z.shape[0]
    if B < 2:
        return torch.tensor(0.0, device=z.device, requires_grad=True)

    ee_traj = extract_endeffector_position(states, relative=True)

    z_1 = z
    z_2 = torch.roll(z, 1, 0)
    ee_traj_1 = ee_traj
    ee_traj_2 = torch.roll(ee_traj, 1, 0)

    dz = torch.sqrt(torch.sum((z_1 - z_2) ** 2, dim=1) + 1e-8)

    sdtw = SoftDTW(gamma=gamma, normalize=True)
    traj_dists = sdtw(ee_traj_1, ee_traj_2)

    dz_normalized = dz / (torch.max(dz) + 1e-8)
    traj_dists_normalized = traj_dists / (torch.max(traj_dists) + 1e-8)

    loss = F.mse_loss(dz_normalized, traj_dists_normalized)
    return loss


def load_grasp_metadata(data_path):
    """
    Load grasp mode metadata from dataset.

    Args:
        data_path: Path to episodes directory

    Returns:
        dict with grasp info per episode, or None if not found
    """
    try:
        episodes_dir = data_path
        if not os.path.exists(episodes_dir):
            if 'processed' in data_path:
                base_dir = os.path.dirname(os.path.dirname(data_path))
                episodes_dir = os.path.join(base_dir, 'train', 'episodes')

        if not os.path.exists(episodes_dir):
            print(f"Episodes directory not found: {episodes_dir}")
            return None

        episode_folders = sorted([d for d in os.listdir(episodes_dir) if d.startswith('episode')])

        modes = []
        approach_angles = []
        grasp_heights = []

        for ep_folder in episode_folders:
            metadata_path = os.path.join(episodes_dir, ep_folder, 'metadata.npy')
            if os.path.exists(metadata_path):
                metadata = np.load(metadata_path, allow_pickle=True).item()

                if 'mode' in metadata:
                    modes.append(metadata['mode'])

                if 'approach_angle' in metadata:
                    approach_angles.append(metadata['approach_angle'])

                if 'grasp_height' in metadata:
                    grasp_heights.append(metadata['grasp_height'])

        if len(approach_angles) == 0:
            print("No grasp mode data found in episode metadata")
            return None

        return {
            'modes': np.array(modes),
            'approach_angles': np.array(approach_angles),
            'grasp_heights': np.array(grasp_heights),
            'num_episodes': len(episode_folders)
        }

    except Exception as e:
        print(f"Error loading grasp metadata: {e}")
        import traceback
        traceback.print_exc()
        return None


def visualize_embeddings(z, epoch, grasp_metadata=None, save_path=None):
    """
    Visualize embeddings colored by approach angle (contact point degree).

    For grasp task, z should form a ring-like structure where:
    - Angular position on ring corresponds to approach angle
    - Points at same approach angle should cluster together

    z: (N, latent_dim) numpy array of embeddings
    grasp_metadata: dict with 'modes', 'approach_angles', 'grasp_heights'
    """
    if z.shape[1] > 2:
        from sklearn.manifold import TSNE
        print(f"Applying t-SNE to reduce {z.shape[1]}D embeddings to 2D...")
        z_vis = TSNE(n_components=2, random_state=42, perplexity=min(30, len(z) - 1)).fit_transform(z)
        title_suffix = f' (t-SNE from {z.shape[1]}D)'
    else:
        z_vis = z
        title_suffix = ''

    # Create figure with 3 subplots: z-space, polar z-space, and angle distribution
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    ax_main = axes[0]
    ax_polar = axes[1]
    ax_angle = axes[2]

    if grasp_metadata is not None and len(grasp_metadata['approach_angles']) >= len(z_vis):
        # Use approach angle (contact point degree) for coloring
        angles_rad = grasp_metadata['approach_angles'][:len(z_vis)]
        angles_deg = np.degrees(angles_rad)
        modes = grasp_metadata['modes'][:len(z_vis)]

        # Color by approach angle using cyclic colormap (hsv)
        # Normalize angles to [0, 1] for colormap
        angle_normalized = (angles_deg % 360) / 360.0
        point_colors = plt.cm.hsv(angle_normalized)

        # === Plot 1: z-space (Cartesian) ===
        scatter1 = ax_main.scatter(z_vis[:, 0], z_vis[:, 1], c=angles_deg,
                                   cmap='hsv', s=60, alpha=0.8,
                                   edgecolors='black', linewidths=0.5,
                                   vmin=0, vmax=360)
        cbar1 = plt.colorbar(scatter1, ax=ax_main)
        cbar1.set_label('Approach Angle (degrees)', fontsize=10)

        ax_main.set_xlabel('z[0]', fontsize=11)
        ax_main.set_ylabel('z[1]', fontsize=11)
        ax_main.set_title(f'z-Space Embeddings (Epoch {epoch}){title_suffix}', fontsize=12)
        ax_main.grid(True, alpha=0.3)
        ax_main.set_aspect('equal')

        # Draw reference circle to show ideal ring structure
        z_radius = np.sqrt(z_vis[:, 0]**2 + z_vis[:, 1]**2)
        avg_radius = np.mean(z_radius)
        theta_ref = np.linspace(0, 2*np.pi, 100)
        ax_main.plot(avg_radius * np.cos(theta_ref), avg_radius * np.sin(theta_ref),
                    'k--', alpha=0.3, linewidth=1, label=f'Avg radius={avg_radius:.2f}')
        ax_main.legend(loc='upper right', fontsize=8)

        # === Plot 2: Polar view of z-space ===
        # Convert z to polar coordinates
        z_angles = np.arctan2(z_vis[:, 1], z_vis[:, 0])  # angle in z-space
        z_angles_deg = np.degrees(z_angles) % 360

        scatter2 = ax_polar.scatter(z_angles_deg, z_radius, c=angles_deg,
                                    cmap='hsv', s=60, alpha=0.8,
                                    edgecolors='black', linewidths=0.5,
                                    vmin=0, vmax=360)
        ax_polar.set_xlabel('z-Space Angle (degrees)', fontsize=11)
        ax_polar.set_ylabel('z-Space Radius', fontsize=11)
        ax_polar.set_title('Polar View of z-Space', fontsize=12)
        ax_polar.set_xlim(-10, 370)
        ax_polar.grid(True, alpha=0.3)

        # === Plot 3: z-angle vs approach angle (should be linear for good embedding) ===
        scatter3 = ax_angle.scatter(angles_deg, z_angles_deg, c=angles_deg,
                                    cmap='hsv', s=60, alpha=0.8,
                                    edgecolors='black', linewidths=0.5,
                                    vmin=0, vmax=360)

        # Draw ideal 1:1 line (or with offset)
        ax_angle.plot([0, 360], [0, 360], 'k--', alpha=0.5, label='Ideal 1:1')
        ax_angle.set_xlabel('Approach Angle (degrees)', fontsize=11)
        ax_angle.set_ylabel('z-Space Angle (degrees)', fontsize=11)
        ax_angle.set_title('Angle Correspondence', fontsize=12)
        ax_angle.set_xlim(-10, 370)
        ax_angle.set_ylim(-10, 370)
        ax_angle.grid(True, alpha=0.3)
        ax_angle.legend(loc='upper left', fontsize=8)
        ax_angle.set_aspect('equal')

    else:
        # Default coloring by index
        scatter = ax_main.scatter(z_vis[:, 0], z_vis[:, 1], c=np.arange(len(z_vis)),
                                  cmap='viridis', s=50, alpha=0.7,
                                  edgecolors='black', linewidths=0.5)
        plt.colorbar(scatter, ax=ax_main, label='Trajectory Index')
        ax_main.set_xlabel('z[0]', fontsize=11)
        ax_main.set_ylabel('z[1]', fontsize=11)
        ax_main.set_title(f'Trajectory Embeddings (Epoch {epoch})', fontsize=12)
        ax_main.grid(True, alpha=0.3)

        ax_polar.text(0.5, 0.5, 'No grasp metadata available',
                     transform=ax_polar.transAxes, ha='center', va='center')
        ax_angle.text(0.5, 0.5, 'No grasp metadata available',
                     transform=ax_angle.transAxes, ha='center', va='center')

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')

    return fig


def train_epoch(model, dataloader, optimizer, device, epoch, log_freq=50, recon_weight=1.0, kl_weight=0.001, dtw_weight=0.1):
    """Train for one epoch"""
    model.train()

    total_loss = 0.0
    total_recon_loss = 0.0
    total_kl_loss = 0.0
    total_dtw_loss = 0.0
    num_batches = 0

    pbar = tqdm(dataloader, desc=f"Epoch {epoch}", leave=False)

    for batch_idx, batch in enumerate(pbar):
        actions = batch.actions
        states = batch.conditions['state']

        actions = actions.to(device)
        states = states.to(device)

        if states.shape[1] < actions.shape[1]:
            continue

        states_rel, actions_rel = make_trajectory_relative(states, actions)

        outputs = model(states_rel, actions_rel)

        z_mean = outputs['z_mean']
        z_logvar = outputs['z_logvar']
        z = outputs['z']
        pred_states = outputs['pred_states']
        pred_actions = outputs['pred_actions']

        recon_loss = compute_reconstruction_loss(pred_states, pred_actions, states_rel, actions_rel)
        kl_loss = compute_kl_loss(z_mean, z_logvar)
        dtw_loss = compute_dtw_loss(z_mean, states_rel, gamma=0.00001)

        loss = recon_weight * recon_loss + kl_weight * kl_loss + dtw_weight * dtw_loss

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        total_loss += loss.item()
        total_recon_loss += recon_loss.item()
        total_kl_loss += kl_loss.item()
        total_dtw_loss += dtw_loss.item()
        num_batches += 1

        pbar.set_postfix({
            'total': f'{loss.item():.4f}',
            'recon': f'{recon_loss.item():.4f}',
            'kl': f'{kl_loss.item():.6f}',
            'dtw': f'{dtw_loss.item():.4f}'
        })

        if batch_idx % log_freq == 0:
            wandb.log({
                'batch_total_loss': loss.item(),
                'batch_recon_loss': recon_loss.item(),
                'batch_kl_loss': kl_loss.item(),
                'batch_dtw_loss': dtw_loss.item(),
                'epoch': epoch,
                'batch': batch_idx
            })

    avg_loss = total_loss / num_batches if num_batches > 0 else 0
    avg_recon_loss = total_recon_loss / num_batches if num_batches > 0 else 0
    avg_kl_loss = total_kl_loss / num_batches if num_batches > 0 else 0
    avg_dtw_loss = total_dtw_loss / num_batches if num_batches > 0 else 0

    return avg_loss, avg_recon_loss, avg_kl_loss, avg_dtw_loss


def collect_embeddings(model, dataset, device, batch_size=256):
    """Collect all embeddings for visualization."""
    model.eval()

    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=True
    )

    all_z = []

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Collecting embeddings", leave=False):
            actions = batch.actions
            states = batch.conditions['state']

            actions = actions.to(device)
            states = states.to(device)

            if states.shape[1] < actions.shape[1]:
                continue

            states_rel, actions_rel = make_trajectory_relative(states, actions)

            z = model.encode(states_rel, actions_rel)
            all_z.append(z.cpu().numpy())

    all_z = np.concatenate(all_z, axis=0)
    return all_z


def main():
    parser = argparse.ArgumentParser(description='Train Grasp Trajectory Encoder VAE')
    parser.add_argument('--horizon', type=int, default=108, help='Horizon length for trajectories (108 for grasp task)')
    parser.add_argument('--latent_dim', type=int, default=2, help='Latent dimension')
    parser.add_argument('--hidden_dim', type=int, default=128, help='Hidden dimension')
    parser.add_argument('--num_layers', type=int, default=4, help='Number of transformer layers')
    parser.add_argument('--num_heads', type=int, default=4, help='Number of attention heads')
    parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
    parser.add_argument('--num_epochs', type=int, default=5000, help='Number of epochs')
    parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate')
    parser.add_argument('--recon_weight', type=float, default=1.0, help='Reconstruction loss weight')
    parser.add_argument('--kl_weight', type=float, default=0.001, help='KL loss weight')
    parser.add_argument('--dtw_weight', type=float, default=1.0, help='DTW loss weight')
    parser.add_argument('--vis_freq', type=int, default=100, help='Visualization frequency (epochs)')
    parser.add_argument('--save_freq', type=int, default=500, help='Checkpoint save frequency (epochs)')
    parser.add_argument('--dataset_path', type=str,
                        default='/scratch4/workspace/placeholder-hdp1/dppo/data/grasp/variation0/processed/train_normalized.npz',
                        help='Path to training dataset')
    parser.add_argument('--episodes_path', type=str,
                        default='/scratch4/workspace/placeholder-hdp1/dppo/data/grasp/variation0/train/episodes',
                        help='Path to episodes directory (for grasp metadata)')
    parser.add_argument('--save_dir', type=str,
                        default='/scratch4/workspace/placeholder-hdp1/dppo/data/grasp/variation0/encoder',
                        help='Directory to save checkpoints and visualizations')
    parser.add_argument('--device', type=str, default='cuda:0', help='Device to use')
    parser.add_argument('--wandb_project', type=str, default='grasp-encoder', help='Wandb project name')

    args = parser.parse_args()

    config = {
        'dataset_path': args.dataset_path,
        'episodes_path': args.episodes_path,
        'state_dim': 22,
        'action_dim': 8,
        'horizon': args.horizon,
        'latent_dim': args.latent_dim,
        'hidden_dim': args.hidden_dim,
        'num_layers': args.num_layers,
        'num_heads': args.num_heads,
        'batch_size': args.batch_size,
        'num_epochs': args.num_epochs,
        'learning_rate': args.lr,
        'recon_weight': args.recon_weight,
        'kl_weight': args.kl_weight,
        'dtw_weight': args.dtw_weight,
        'device': args.device if torch.cuda.is_available() else 'cpu',
        'log_freq': 10,
        'vis_freq': args.vis_freq,
        'save_freq': args.save_freq,
        'save_dir': args.save_dir
    }

    os.makedirs(config['save_dir'], exist_ok=True)
    latent_dir = os.path.join(config['save_dir'], f'z{config["latent_dim"]}')
    checkpoint_dir = os.path.join(latent_dir, 'checkpoints')
    visualization_dir = os.path.join(latent_dir, 'visualizations')
    os.makedirs(checkpoint_dir, exist_ok=True)
    os.makedirs(visualization_dir, exist_ok=True)

    wandb.init(
        project=args.wandb_project,
        name=f'grasp_vae_z{config["latent_dim"]}_h{config["hidden_dim"]}',
        config=config
    )

    device = torch.device(config['device'])
    print(f"Using device: {device}")

    print(f"Loading dataset from {config['dataset_path']}")
    dataset = FullTrajectoryDataset(
        dataset_path=config['dataset_path'],
        horizon_steps=config['horizon'],
        device='cpu',
        max_n_episodes=None
    )

    print(f"Dataset size: {len(dataset)} trajectories")

    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=0,
        pin_memory=True
    )

    model = TrajectoryVAE(
        state_dim=config['state_dim'],
        action_dim=config['action_dim'],
        hidden_dim=config['hidden_dim'],
        num_layers=config['num_layers'],
        num_heads=config['num_heads'],
        latent_dim=config['latent_dim'],
        horizon=config['horizon']
    ).to(device)

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config['learning_rate'],
        weight_decay=1e-5
    )

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        T_max=config['num_epochs'],
        eta_min=1e-6
    )

    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

    # Load grasp metadata for visualization
    grasp_metadata = load_grasp_metadata(config['episodes_path'])
    if grasp_metadata is not None:
        print(f"Loaded grasp metadata for {grasp_metadata['num_episodes']} episodes")
        print(f"  Unique angles: {np.unique(np.degrees(grasp_metadata['approach_angles'])).astype(int)} degrees")
        print(f"  Unique heights: {np.unique(grasp_metadata['grasp_heights'] * 1000).astype(int)} mm")
    else:
        print("Grasp metadata not available - using default visualization")

    best_loss = float('inf')
    epoch_pbar = tqdm(range(1, config['num_epochs'] + 1), desc="Training", position=0)

    for epoch in epoch_pbar:
        avg_loss, avg_recon_loss, avg_kl_loss, avg_dtw_loss = train_epoch(
            model, dataloader, optimizer, device, epoch,
            log_freq=config['log_freq'],
            recon_weight=config['recon_weight'],
            kl_weight=config['kl_weight'],
            dtw_weight=config['dtw_weight']
        )

        wandb.log({
            'epoch': epoch,
            'avg_total_loss': avg_loss,
            'avg_recon_loss': avg_recon_loss,
            'avg_kl_loss': avg_kl_loss,
            'avg_dtw_loss': avg_dtw_loss,
            'learning_rate': optimizer.param_groups[0]['lr']
        })

        epoch_pbar.set_postfix({
            'total': f'{avg_loss:.4f}',
            'recon': f'{avg_recon_loss:.4f}',
            'kl': f'{avg_kl_loss:.6f}',
            'dtw': f'{avg_dtw_loss:.4f}'
        })

        if epoch % config['vis_freq'] == 0 or epoch == 1:
            all_z = collect_embeddings(model, dataset, device, batch_size=config['batch_size'])

            fig = visualize_embeddings(
                all_z,
                epoch,
                grasp_metadata=grasp_metadata,
                save_path=os.path.join(visualization_dir, f'embeddings_epoch_{epoch}.png')
            )

            wandb.log({
                'embeddings': wandb.Image(fig),
                'epoch': epoch
            })
            plt.close(fig)

        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_loss,
                'config': config
            }, os.path.join(checkpoint_dir, 'best_model.pt'))

        if epoch % config['save_freq'] == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_loss,
                'config': config
            }, os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pt'))

        scheduler.step()

    epoch_pbar.close()

    print("\nGenerating final embedding visualization...")
    all_z = collect_embeddings(model, dataset, device, batch_size=config['batch_size'])
    fig = visualize_embeddings(
        all_z,
        config['num_epochs'],
        grasp_metadata=grasp_metadata,
        save_path=os.path.join(visualization_dir, 'embeddings_final.png')
    )
    wandb.log({'final_embeddings': wandb.Image(fig)})
    plt.close(fig)

    wandb.finish()
    print("Training complete!")


if __name__ == '__main__':
    main()
