"""
Test script to pass a revoxelized robot config through a VAE.

This script:
1. Loads a revoxelized robot config (numpy array)
2. Converts it to the format expected by the VAE (one-hot encoding)
3. Encodes it to get the latent representation
4. Decodes it back to get the reconstruction
5. Visualizes the original and reconstructed robot
"""

import argparse
import os
import sys

import numpy as np
import torch as t
from matplotlib import pyplot as plt

# Add project root to path
sys.path.insert(
    0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
)

from model.vae_old.star_vae import StarVAE
from utils.plot import filter_surface_voxels, plot_binary, plot_colored_rigid_by_id


def grid_to_onehot(grid: np.ndarray, num_classes: int) -> t.Tensor:
    """
    Convert integer label grid to one-hot encoded tensor.

    Args:
        grid: Integer array of shape [X, Y, Z] with values 0 to num_classes-1
        num_classes: Number of classes (channels in one-hot encoding)

    Returns:
        Tensor of shape [1, num_classes, X, Y, Z] (batch size 1)
    """
    # Clip values to valid range
    grid = np.clip(grid, 0, num_classes - 1)

    # Create one-hot encoding
    one_hot = np.zeros((num_classes,) + grid.shape, dtype=np.float32)
    for c in range(num_classes):
        one_hot[c] = (grid == c).astype(np.float32)

    # Add batch dimension
    return t.tensor(one_hot).unsqueeze(0)


def onehot_to_grid(one_hot: t.Tensor) -> np.ndarray:
    """
    Convert one-hot encoded tensor back to integer label grid.

    Args:
        one_hot: Tensor of shape [1, num_classes, X, Y, Z] or [num_classes, X, Y, Z]

    Returns:
        Integer array of shape [X, Y, Z]
    """
    if one_hot.dim() == 5:
        one_hot = one_hot.squeeze(0)  # Remove batch dimension

    # Get argmax along channel dimension
    grid = t.argmax(one_hot, dim=0).cpu().numpy()
    return grid


def visualize_comparison(
    original_grid: np.ndarray,
    reconstructed_grid: np.ndarray,
    M: int,
    title: str = "VAE Reconstruction Comparison",
):
    """
    Visualize original and reconstructed robot side by side.

    Args:
        original_grid: Original voxel grid
        reconstructed_grid: Reconstructed voxel grid from VAE
        M: Maximum rigid segment ID
        title: Plot title
    """
    # Define colors for rigid segments
    segment_colors = [
        (0.8, 0.2, 0.2, 1.0),  # Red
        (0.2, 0.8, 0.2, 1.0),  # Green
        (0.2, 0.2, 0.8, 1.0),  # Blue
        (0.8, 0.8, 0.2, 1.0),  # Yellow
        (0.8, 0.2, 0.8, 1.0),  # Magenta
        (0.2, 0.8, 0.8, 1.0),  # Cyan
        (0.9, 0.5, 0.1, 1.0),  # Orange
        (0.5, 0.1, 0.9, 1.0),  # Purple
    ]

    soft_color = (0.5, 0.8, 1.0, 0.8)

    fig = plt.figure(figsize=(16, 12))
    fig.suptitle(title, fontsize=14)

    # Row 1: Original
    # Original soft voxels
    ax1 = fig.add_subplot(2, 2, 1, projection="3d")
    ax1.set_title("Original - Soft Voxels")
    soft_mask = original_grid == 1
    plot_binary(ax1, soft_mask, color=soft_color)

    # Original rigid segments
    ax2 = fig.add_subplot(2, 2, 2, projection="3d")
    ax2.set_title("Original - Rigid Segments")
    is_rigid = original_grid >= 2
    segment_id = np.where(original_grid >= 2, original_grid - 1, 0)
    plot_colored_rigid_by_id(ax2, is_rigid, segment_id, segment_colors)

    # Row 2: Reconstructed
    # Reconstructed soft voxels
    ax3 = fig.add_subplot(2, 2, 3, projection="3d")
    ax3.set_title("Reconstructed - Soft Voxels")
    soft_mask_recon = reconstructed_grid == 1
    plot_binary(ax3, soft_mask_recon, color=soft_color)

    # Reconstructed rigid segments
    ax4 = fig.add_subplot(2, 2, 4, projection="3d")
    ax4.set_title("Reconstructed - Rigid Segments")
    is_rigid_recon = reconstructed_grid >= 2
    segment_id_recon = np.where(reconstructed_grid >= 2, reconstructed_grid - 1, 0)
    plot_colored_rigid_by_id(ax4, is_rigid_recon, segment_id_recon, segment_colors)

    plt.tight_layout()
    plt.show()


def main():
    parser = argparse.ArgumentParser(
        description="Test VAE encoding/decoding on revoxelized robot config"
    )
    parser.add_argument(
        "voxel_path",
        type=str,
        help="Path to the revoxelized voxel grid (.npy file)",
    )
    parser.add_argument(
        "--vae-checkpoint",
        type=str,
        required=True,
        help="Path to the VAE checkpoint file",
    )
    parser.add_argument(
        "-M",
        "--max-rigid-id",
        type=int,
        default=7,
        help="Maximum rigid segment ID used in revoxelization",
    )
    parser.add_argument(
        "--device",
        type=str,
        default="cpu",
        choices=["cpu", "cuda"],
        help="Device to run VAE on",
    )
    parser.add_argument(
        "--no-visualize",
        action="store_true",
        help="Skip visualization",
    )

    args = parser.parse_args()

    print("=" * 60)
    print("VAE Test - Robot Config Encoding/Decoding")
    print("=" * 60)
    print(f"Voxel grid: {args.voxel_path}")
    print(f"VAE checkpoint: {args.vae_checkpoint}")
    print(f"Max rigid ID (M): {args.max_rigid_id}")
    print(f"Device: {args.device}")

    # Load the revoxelized grid
    print("\nLoading voxel grid...")
    grid = np.load(args.voxel_path)
    print(f"  Grid shape: {grid.shape}")
    print(f"  Value range: [{grid.min()}, {grid.max()}]")
    print(f"  Unique values: {np.unique(grid)}")

    # Load VAE
    print("\nLoading VAE...")
    device = args.device
    if device == "cuda" and not t.cuda.is_available():
        print("  CUDA not available, falling back to CPU")
        device = "cpu"

    vae = StarVAE.load_from_checkpoint(args.vae_checkpoint, map_location=device)
    vae.eval()
    print(f"  VAE loaded successfully")
    print(f"  Latent dimension: {vae.hparams.e_dim}")
    print(f"  Grid size expected: {vae.hparams.grid_size}")
    print(f"  Max num nodes: {vae.hparams.max_num_nodes}")

    # Check grid size compatibility
    expected_grid_size = vae.hparams.grid_size
    if grid.shape[0] != expected_grid_size:
        print(f"\n  WARNING: Grid size mismatch!")
        print(f"    Input grid: {grid.shape[0]}")
        print(f"    VAE expects: {expected_grid_size}")
        print(f"  Resizing grid to match VAE...")

        # Simple resize using nearest neighbor interpolation
        from scipy.ndimage import zoom

        scale = expected_grid_size / grid.shape[0]
        grid = zoom(grid, scale, order=0)  # order=0 for nearest neighbor
        print(f"  Resized grid shape: {grid.shape}")

    # Number of classes for one-hot encoding
    # VAE expects f_dim = max_num_nodes + 2 channels
    num_classes = vae.hparams.max_num_nodes + 2
    print(f"\nConverting to one-hot encoding ({num_classes} classes)...")

    # Convert to one-hot
    input_tensor = grid_to_onehot(grid, num_classes).to(device)
    print(f"  Input tensor shape: {input_tensor.shape}")

    # Encode
    print("\nEncoding...")
    with t.no_grad():
        mu, logvar = vae.encode(input_tensor)
        print(f"  Latent mu shape: {mu.shape}")
        print(f"  Latent mu range: [{mu.min().item():.4f}, {mu.max().item():.4f}]")

        # Sample from latent distribution
        z = vae.rsample(mu, logvar)
        print(f"  Sampled z shape: {z.shape}")

        # Decode
        print("\nDecoding...")
        output_tensor = vae.decode(z)
        print(f"  Output tensor shape: {output_tensor.shape}")

    # Convert back to grid
    reconstructed_grid = onehot_to_grid(output_tensor)
    print(f"\nReconstructed grid shape: {reconstructed_grid.shape}")
    print(
        f"Reconstructed value range: [{reconstructed_grid.min()}, {reconstructed_grid.max()}]"
    )
    print(f"Reconstructed unique values: {np.unique(reconstructed_grid)}")

    # Compute reconstruction statistics
    original_occupied = np.sum(grid > 0)
    recon_occupied = np.sum(reconstructed_grid > 0)
    matching = np.sum(grid == reconstructed_grid)
    accuracy = matching / grid.size * 100

    print(f"\nReconstruction Statistics:")
    print(f"  Original occupied voxels: {original_occupied}")
    print(f"  Reconstructed occupied voxels: {recon_occupied}")
    print(f"  Matching voxels: {matching} ({accuracy:.2f}%)")

    # Visualize
    if not args.no_visualize:
        print("\nVisualizing...")
        visualize_comparison(grid, reconstructed_grid, args.max_rigid_id)

    print("\nDone!")


if __name__ == "__main__":
    main()
