#!/usr/bin/env python3
"""
decode_robot_latent.py

Load a latent vector (512,) from .npy file (produced by encode_robot_config.py)
and decode it using VAE to reconstruct the robot voxel grid.
"""

import argparse
import os
import sys
import numpy as np
import torch as t
from matplotlib import pyplot as plt

# Add project root to path
sys.path.insert(
    0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
)

from model.vae_old.star_vae import StarVAE
from utils.plot import plot_binary, plot_colored_rigid_by_id


def onehot_to_grid(one_hot: t.Tensor) -> np.ndarray:
    """Convert one-hot tensor to integer grid"""
    if one_hot.dim() == 5:
        one_hot = one_hot.squeeze(0)
    return t.argmax(one_hot, dim=0).cpu().numpy()


def hide_axes_and_grid(ax):
    """Remove grid, labels, ticks, and axes from 3D axis"""
    ax.grid(False)
    ax.set_xlabel("")
    ax.set_ylabel("")
    ax.set_zlabel("")
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_zticks([])
    # Remove axis lines and panes
    ax.xaxis.pane.fill = False
    ax.yaxis.pane.fill = False
    ax.zaxis.pane.fill = False
    ax.xaxis.pane.set_edgecolor("none")
    ax.yaxis.pane.set_edgecolor("none")
    ax.zaxis.pane.set_edgecolor("none")
    # Hide the axes lines completely
    ax.xaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
    ax.yaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
    ax.zaxis.line.set_color((1.0, 1.0, 1.0, 0.0))
    # Turn off the axis
    ax.set_axis_off()
    # Set equal aspect ratio to prevent deformation
    ax.set_box_aspect([1, 1, 1])


def visualize_reconstructed(
    reconstructed_grid: np.ndarray,
    N: int,
    output_path: str = None,
    title: str = "VAE Decoded Robot",
):
    """Visualize the decoded robot"""
    segment_colors = [
        (0.8, 0.2, 0.2, 1.0),
        (0.2, 0.8, 0.2, 1.0),
        (0.2, 0.2, 0.8, 1.0),
        (0.8, 0.8, 0.2, 1.0),
        (0.8, 0.2, 0.8, 1.0),
        (0.2, 0.8, 0.8, 1.0),
        (0.9, 0.5, 0.1, 1.0),
        (0.5, 0.1, 0.9, 1.0),
    ]
    soft_color = (0.5, 0.8, 1.0, 0.8)

    fig = plt.figure(figsize=(14, 6))
    fig.suptitle(title, fontsize=14)

    # Soft voxels
    ax1 = fig.add_subplot(1, 2, 1, projection="3d")
    ax1.set_title("Soft Voxels (ID=1)")
    soft_mask = reconstructed_grid == 1
    plot_binary(ax1, soft_mask, color=soft_color)
    ax1.set_box_aspect([1, 1, 1])  # Equal aspect ratio
    ax1.set_xlim(0, N)
    ax1.set_ylim(0, N)
    ax1.set_zlim(0, N)

    # Rigid segments
    ax2 = fig.add_subplot(1, 2, 2, projection="3d")
    ax2.set_title("Rigid Segments (ID=2 to M)")
    is_rigid = reconstructed_grid >= 2
    segment_id = np.where(reconstructed_grid >= 2, reconstructed_grid - 1, 0)
    plot_colored_rigid_by_id(ax2, is_rigid, segment_id, segment_colors)
    ax2.set_box_aspect([1, 1, 1])  # Equal aspect ratio
    ax2.set_xlim(0, N)
    ax2.set_ylim(0, N)
    ax2.set_zlim(0, N)

    plt.tight_layout()

    # Save main comparison figure
    if output_path:
        os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
        plt.savefig(output_path, dpi=300, bbox_inches="tight", transparent=True)
        print(f"Saved visualization to: {output_path}")

        # Save reconstructed figures separately
        base_path = os.path.splitext(output_path)[0]

        # Save soft voxels
        fig_soft = plt.figure(figsize=(8, 8))
        ax_soft = fig_soft.add_subplot(111, projection="3d")
        soft_mask = reconstructed_grid == 1
        plot_binary(ax_soft, soft_mask, color=soft_color)
        hide_axes_and_grid(ax_soft)
        ax_soft.set_box_aspect([1, 1, 1])  # Equal aspect ratio
        ax_soft.set_xlim(0, N)
        ax_soft.set_ylim(0, N)
        ax_soft.set_zlim(0, N)
        soft_path = f"{base_path}_soft.png"
        plt.savefig(soft_path, dpi=300, bbox_inches="tight", transparent=True)
        print(f"Saved soft voxels to: {soft_path}")
        plt.close(fig_soft)

        # Save rigid segments
        fig_rigid = plt.figure(figsize=(8, 8))
        ax_rigid = fig_rigid.add_subplot(111, projection="3d")
        is_rigid_recon = reconstructed_grid >= 2
        segment_id_recon = np.where(reconstructed_grid >= 2, reconstructed_grid - 1, 0)
        plot_colored_rigid_by_id(
            ax_rigid, is_rigid_recon, segment_id_recon, segment_colors
        )
        hide_axes_and_grid(ax_rigid)
        ax_rigid.set_box_aspect([1, 1, 1])  # Equal aspect ratio
        ax_rigid.set_xlim(0, N)
        ax_rigid.set_ylim(0, N)
        ax_rigid.set_zlim(0, N)
        rigid_path = f"{base_path}_rigid.png"
        plt.savefig(rigid_path, dpi=300, bbox_inches="tight", transparent=True)
        print(f"Saved rigid segments to: {rigid_path}")
        plt.close(fig_rigid)

    # Show interactive window
    plt.show()
    plt.close(fig)


def main():
    parser = argparse.ArgumentParser(
        description="Decode latent vector from .npy file using VAE"
    )
    parser.add_argument(
        "latent_path", type=str, help="Path to latent .npy file (shape: (512,))"
    )
    parser.add_argument(
        "--vae-checkpoint", type=str, required=True, help="VAE checkpoint path"
    )
    parser.add_argument(
        "-o",
        "--output-dir",
        type=str,
        default=None,
        help="Output directory for decoded grid and visualizations (default: same as latent file)",
    )
    parser.add_argument(
        "--device", type=str, default="cpu", choices=["cpu", "cuda"], help="Device"
    )
    parser.add_argument(
        "--no-visualize", action="store_true", help="Skip visualization"
    )

    args = parser.parse_args()

    print("=" * 70)
    print("Decode Robot Latent Vector")
    print("=" * 70)
    print(f"Latent:      {args.latent_path}")
    print(f"VAE:         {args.vae_checkpoint}")
    print(
        f"Output dir:  {args.output_dir if args.output_dir else 'same as latent file'}"
    )
    print(f"Device:      {args.device}")

    # Load latent vector
    latent = np.load(args.latent_path)
    print(f"\nLoaded latent: shape={latent.shape}, dtype={latent.dtype}")
    print(f"  Range: [{latent.min():.4f}, {latent.max():.4f}]")

    if latent.ndim != 1:
        raise ValueError(f"Expected 1D latent vector, got shape {latent.shape}")

    # Device setup
    device = args.device
    if device == "cuda" and not t.cuda.is_available():
        print("CUDA not available, using CPU")
        device = "cpu"

    # Load VAE
    vae = StarVAE.load_from_checkpoint(args.vae_checkpoint, map_location=device)
    vae.eval()

    e_dim = int(vae.hparams.e_dim)
    grid_size = int(vae.hparams.grid_size)
    max_nodes = int(vae.hparams.max_num_nodes)

    print(f"\nVAE parameters:")
    print(f"  Latent dim (e_dim):      {e_dim}")
    print(f"  Grid size:               {grid_size}")
    print(f"  Max nodes:               {max_nodes}")

    if latent.shape[0] != e_dim:
        raise ValueError(
            f"Latent dimension mismatch: loaded {latent.shape[0]}, VAE expects {e_dim}"
        )

    # Decode latent to grid
    z = t.tensor(latent, dtype=t.float32).unsqueeze(0).to(device)  # [1, e_dim]
    print(f"\nDecoding latent: {z.shape}")

    with t.no_grad():
        output_tensor = vae.decode(z)  # [1, C, N, N, N]
        print(f"Decoded tensor: {output_tensor.shape}")

    reconstructed_grid = onehot_to_grid(output_tensor)
    print(f"Reconstructed grid: {reconstructed_grid.shape}")
    print(f"  Range: [{reconstructed_grid.min()}, {reconstructed_grid.max()}]")
    print(f"  Unique values: {np.unique(reconstructed_grid)}")

    # Statistics
    occupied = np.sum(reconstructed_grid > 0)
    soft = np.sum(reconstructed_grid == 1)
    rigid = np.sum(reconstructed_grid >= 2)
    print(f"\nVoxel statistics:")
    print(f"  Total voxels:    {reconstructed_grid.size}")
    print(f"  Void (0):        {reconstructed_grid.size - occupied}")
    print(f"  Soft (1):        {soft}")
    print(f"  Rigid (2+):      {rigid}")
    print(f"  Occupied ratio:  {occupied / reconstructed_grid.size * 100:.2f}%")

    # Determine output directory
    base_name = os.path.splitext(os.path.basename(args.latent_path))[0]
    if args.output_dir:
        output_dir = args.output_dir
    else:
        # Default: same directory as latent file
        output_dir = os.path.dirname(args.latent_path) or "."

    os.makedirs(output_dir, exist_ok=True)

    # Save reconstructed grid
    grid_output = os.path.join(output_dir, f"{base_name}_decoded.npy")
    np.save(grid_output, reconstructed_grid)
    print(f"\nSaved reconstructed grid to: {grid_output}")

    # Visualize
    if not args.no_visualize:
        viz_path = os.path.join(output_dir, f"{base_name}_decoded_viz.png")
        visualize_reconstructed(
            reconstructed_grid,
            N=reconstructed_grid.shape[0],
            output_path=viz_path,
            title=f"Decoded: {base_name}",
        )

    print("\nDone!")


if __name__ == "__main__":
    main()
