#!/usr/bin/env python3
"""
sample_robot_latent_average.py

Pipeline:
1) Load a multi-body robot config (pickle)
2) Revoxelize it 128 times with different seeds (to change rigid body ID/color assignment)
3) Encode each voxel grid with a VAE -> 128 latent vectors
4) Average the latent vectors
5) Save averaged latent vector (and optionally all latents) to disk
"""

import argparse
import os
import sys
from typing import Tuple, Optional, List

import numpy as np
import torch as t

# Add project root to path (same approach as vae_test.py)
sys.path.insert(
    0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
)

from model.vae_old.star_vae import StarVAE  # same as vae_test.py
from scripts.robot_generation.revoxelize_robot_config import (  # same module as revoxelize script
    load_robot_config,
    revoxelize_robot_config,
)

try:
    from scipy.ndimage import zoom as scipy_zoom
except Exception:
    scipy_zoom = None


def grid_to_onehot(grid: np.ndarray, num_classes: int) -> t.Tensor:
    """
    Convert integer label grid to one-hot encoded tensor.

    grid: [X,Y,Z] with values 0..num_classes-1
    returns: [1, C, X, Y, Z]
    """
    grid = np.clip(grid, 0, num_classes - 1)
    one_hot = np.zeros((num_classes,) + grid.shape, dtype=np.float32)
    for c in range(num_classes):
        one_hot[c] = (grid == c).astype(np.float32)
    return t.tensor(one_hot).unsqueeze(0)


def maybe_resize_grid_nearest(grid: np.ndarray, target_size: int) -> np.ndarray:
    """
    Resize a cubic grid [N,N,N] -> [target_size,target_size,target_size] using nearest neighbor.
    Uses scipy if available; otherwise uses a simple integer-stride approach when possible.
    """
    if grid.shape[0] == target_size:
        return grid
    if grid.shape[0] != grid.shape[1] or grid.shape[1] != grid.shape[2]:
        raise ValueError(f"Expected cubic grid, got shape={grid.shape}")

    src = grid.shape[0]
    scale = target_size / float(src)

    if scipy_zoom is not None:
        # order=0 => nearest neighbor
        return scipy_zoom(grid, zoom=(scale, scale, scale), order=0)

    # Fallback: only supports integer upsample/downsample cleanly
    if target_size % src == 0:
        k = target_size // src
        return np.repeat(np.repeat(np.repeat(grid, k, axis=0), k, axis=1), k, axis=2)
    if src % target_size == 0:
        k = src // target_size
        return grid[::k, ::k, ::k]

    raise RuntimeError(
        "Grid resize needs scipy (scipy.ndimage.zoom) for non-integer scale factors. "
        f"src={src}, target={target_size}"
    )


@t.no_grad()
def encode_latent(
    vae: StarVAE,
    grid: np.ndarray,
    device: str,
    use_mu: bool = True,
) -> np.ndarray:
    """
    Encode a voxel grid into a latent vector.

    If use_mu=True: return mu (deterministic)
    Else: return z sampled via vae.rsample(mu, logvar)
    """
    # VAE expects f_dim = max_num_nodes + 2 channels (see vae_test.py)
    num_classes = int(vae.hparams.max_num_nodes) + 2

    x = grid_to_onehot(grid, num_classes).to(device)  # [1,C,N,N,N]
    mu, logvar = vae.encode(x)  # typically [1, e_dim]

    if use_mu:
        z = mu
    else:
        z = vae.rsample(mu, logvar)

    return z.squeeze(0).detach().cpu().numpy()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Revoxelize robot with many seeds, encode with VAE, average latents, save."
    )
    parser.add_argument(
        "config_path",
        type=str,
        help="Path to robot config pickle (RS_StructureConfig).",
    )
    parser.add_argument(
        "--vae-checkpoint",
        type=str,
        required=True,
        help="Path to VAE checkpoint.",
    )
    parser.add_argument(
        "-o",
        "--output",
        type=str,
        default=None,
        help="Output path. If endswith .npy => saves averaged latent only; "
        "otherwise saves .npz with avg + all latents + seeds.",
    )
    parser.add_argument(
        "--num-seeds",
        type=int,
        default=1,
        help="Number of seeds (default 1), latent will be averaged.",
    )
    parser.add_argument(
        "--seed-start", type=int, default=0, help="First seed (default 0)."
    )

    # Revoxelization params (match revoxelize_robot_config defaults)
    parser.add_argument(
        "-N", "--grid-size", type=int, default=64, help="Output voxel grid size N."
    )
    parser.add_argument(
        "--n-used", type=int, default=40, help="Used region size N_used."
    )
    parser.add_argument(
        "-M", "--max-rigid-id", type=int, default=7, help="Max rigid segment ID M."
    )

    # VAE / device behavior
    parser.add_argument(
        "--device",
        type=str,
        default="cpu",
        choices=["cpu", "cuda"],
        help="Device for VAE encode.",
    )
    parser.add_argument(
        "--use-mu",
        action="store_true",
        help="Use latent mean mu (deterministic). Recommended for averaging.",
    )
    parser.add_argument(
        "--sample-z",
        action="store_true",
        help="Use sampled z instead of mu (stochastic). Overrides --use-mu.",
    )

    args = parser.parse_args()

    # Output default
    if args.output is None:
        base_name = os.path.splitext(os.path.basename(args.config_path))[0]
        args.output = os.path.join(
            os.path.dirname(args.config_path),
            f"{base_name}_latent_seed_num_{args.num_seeds}_seed_start_{args.seed_start}_n_used_{args.n_used}.npy",
        )

    # Generate std output path
    base_name_out = os.path.splitext(args.output)[0]
    std_output = f"{base_name_out}_std.npy"

    use_mu = True
    if args.sample_z:
        use_mu = False
    elif args.use_mu:
        use_mu = True
    else:
        # default behavior: mu (better for averaging)
        use_mu = True

    print("=" * 70)
    print("Sample Robot Latent Average")
    print("=" * 70)
    print(f"Config:           {args.config_path}")
    print(f"VAE checkpoint:   {args.vae_checkpoint}")
    print(
        f"Seeds:            {args.seed_start}..{args.seed_start + args.num_seeds - 1} (count={args.num_seeds})"
    )
    print(f"Revox N/N_used/M: {args.grid_size}/{args.n_used}/{args.max_rigid_id}")
    print(
        f"Latent mode:      {'mu (deterministic)' if use_mu else 'sampled z (stochastic)'}"
    )
    print(f"Device:           {args.device}")
    print(f"Output (avg):     {args.output}")
    print(f"Output (std):     {std_output}")

    # Device fallback
    device = args.device
    if device == "cuda" and not t.cuda.is_available():
        print("CUDA not available -> falling back to CPU")
        device = "cpu"

    # Load structure config
    structure_config = load_robot_config(args.config_path)

    # Load VAE
    vae = StarVAE.load_from_checkpoint(args.vae_checkpoint, map_location=device)
    vae.eval()

    expected_grid_size = int(vae.hparams.grid_size)
    e_dim = int(vae.hparams.e_dim)
    print("\nVAE:")
    print(f"  expected grid_size: {expected_grid_size}")
    print(f"  latent dim (e_dim): {e_dim}")
    print(f"  max_num_nodes:      {int(vae.hparams.max_num_nodes)}")

    # Encode over seeds
    seeds: List[int] = list(range(args.seed_start, args.seed_start + args.num_seeds))
    latents = np.zeros((args.num_seeds, e_dim), dtype=np.float32)

    for i, s in enumerate(seeds):
        grid = revoxelize_robot_config(
            structure_config,
            N=args.grid_size,
            N_used=args.n_used,
            M=args.max_rigid_id,
            seed=s,
        )

        # Resize if the VAE expects a different grid size (same logic idea as vae_test.py)
        if grid.shape[0] != expected_grid_size:
            grid = maybe_resize_grid_nearest(grid, expected_grid_size)

        z = encode_latent(vae, grid, device=device, use_mu=use_mu)
        if z.shape[0] != e_dim:
            raise RuntimeError(
                f"Unexpected latent dim: got {z.shape}, expected ({e_dim},)"
            )

        latents[i] = z.astype(np.float32)

        if (i + 1) % 16 == 0 or (i + 1) == args.num_seeds:
            print(f"  Encoded {i+1}/{args.num_seeds} seeds...")

    latent_avg = latents.mean(axis=0)
    latent_std = latents.std(axis=0)

    # Save
    out = args.output
    os.makedirs(os.path.dirname(out) or ".", exist_ok=True)

    np.save(out, latent_avg)
    print(f"\nSaved averaged latent to: {out}")

    # Save std
    np.save(std_output, latent_std)
    print(f"Saved latent std to: {std_output}")

    # Quick stats
    print("\nLatent stats:")
    print(f"  latents shape:    {latents.shape}")
    print(f"  avg range:        [{latent_avg.min():.4f}, {latent_avg.max():.4f}]")
    print(f"  std range:        [{latent_std.min():.4f}, {latent_std.max():.4f}]")
    print(f"  std mean:         {latent_std.mean():.4f}")
    print("Done.")
