"""
Generates a latent-space dataset from ARSO radar data using a pre-trained autoencoder.

This script loads an ARSO HDF5 file, encodes the 'zm_IN' and 'zm_OUT' datasets
into their latent representations using a specified AutoencoderKL model, and saves
the results to a new HDF5 file. All other datasets from the original file are
copied over unchanged. This creates a more compact dataset for efficiently training
downstream models.
"""

import os
import sys

sys.path.append(os.getcwd())

import random
import argparse

import numpy as np
import h5py
import torch
from tqdm import tqdm

from experiments.arso.dataset.arsodataset_autoencoder import pad_to_multiple_of_16
from diffusers.models.autoencoders import AutoencoderKL
from omegaconf import OmegaConf

# ------------------------------------------------------------
# Parser
# ------------------------------------------------------------
parser = argparse.ArgumentParser(
    description="Generate latent HDF5 dataset for ARSO (zm_IN and zm_OUT)."
)
parser.add_argument(
    "--config",
    type=str,
    default="experiments/arso/autoencoder/autoencoder_kl_config.yaml",
    help="Path to the YAML configuration file.",
)
parser.add_argument(
    "--autoencoder_path",
    type=str,
    required=True,
    help="Path to saved autoencoder model",
)
parser.add_argument(
    "--data_file",
    type=str,
    required=True,
    help="Path to the data file.",
)


args = parser.parse_args()
config = OmegaConf.load(args.config)

# --- Configuration ---
run_params = config.run_params
training_params = config.training_params
model_params = config.model_params

DEBUG_MODE = run_params.debug_mode
DEBUG_PRINT_PREFIX = "[DEBUG] " if DEBUG_MODE else ""

DATA_FILE = args.data_file
OUT_FILE = os.path.splitext(DATA_FILE)[0] + "_latent.h5"
os.makedirs(os.path.dirname(OUT_FILE), exist_ok=True)

# ------------------------------------------------------------
# Set up the autoencoder model
# ------------------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"{DEBUG_PRINT_PREFIX}Using device: {device}")

if not os.path.exists(args.autoencoder_path):
    raise FileNotFoundError(f"Model not found at {args.autoencoder_path}")

model = AutoencoderKL(
    in_channels=1,
    out_channels=1,
    down_block_types=model_params.down_block_types,
    up_block_types=model_params.up_block_types,
    block_out_channels=model_params.block_out_channels,
    act_fn=model_params.act_fn,
    latent_channels=model_params.latent_channels,
    norm_num_groups=model_params.norm_num_groups,
    layers_per_block=model_params.layers_per_block,
)
checkpoint = torch.load(args.autoencoder_path, map_location=device)
checkpoint_model_state_dict = {
    k.replace("module.", ""): v for k, v in checkpoint["model_state_dict"].items()
}
model.load_state_dict(checkpoint_model_state_dict)
model = model.to(device)
model.eval()


# ------------------------------------------------------------
# Helper: encode a single event [T, H, W] -> latent [T, C, h, w]
# ------------------------------------------------------------
def encode_event(event_3d, model, normalize):
    """
    Encodes a single event sequence into its latent representation.

    Args:
        event_3d (np.ndarray): The event data, shaped (T, H, W).
        model (torch.nn.Module): The pre-trained autoencoder.
        normalize (bool): If True, normalizes the input frames by dividing by 57.0.

    Returns:
        np.ndarray: The latent representation of the event, shaped (T, C, h, w).
    """
    T, H, W = event_3d.shape
    frames = torch.from_numpy(event_3d).float().unsqueeze(1)  # [T, 1, H, W]
    if normalize:
        frames = frames / 57.0
    # Pad to multiple of 16
    original_H_pixel, original_W_pixel = (
        frames.shape[-2],
        frames.shape[-1],
    )  # Get original H, W
    padded_frames, padding_info = pad_to_multiple_of_16(frames)

    # --- BEGIN MODIFICATION: Print shapes and padding_info ---
    if T > 0:  # Print for the first frame of the event
        print(
            f"[DEBUG generate_static_dataset] Original Pixel HxW: {original_H_pixel}x{original_W_pixel}, Padding Info: {padding_info}, Padded Pixel HxW: {padded_frames.shape[-2]}x{padded_frames.shape[-1]}"
        )
    # --- END MODIFICATION ---

    # Move to GPU
    padded_frames = padded_frames.to(device)
    with torch.no_grad():
        latent = model.encode(padded_frames).latent_dist.mode()  # [T, C, h, w]
    return latent.cpu().numpy().astype(np.float32)


# ------------------------------------------------------------
# create_latent_h5: store latents in a single dataset "zm_IN_latent"
# ------------------------------------------------------------
def encode_and_write(
    in_data, out_h5, key, latent_key, model, normalize, debug_mode=False
):
    """
    Iterates through a dataset, encodes each sample, and writes to a new HDF5 file.

    Args:
        in_data (h5py.Dataset): The input HDF5 dataset to read from.
        out_h5 (h5py.File): The output HDF5 file to write to.
        key (str): The name of the input dataset (for logging).
        latent_key (str): The name of the new latent dataset to create.
        model (torch.nn.Module): The pre-trained autoencoder.
        normalize (bool): Flag for normalizing the data before encoding.
        debug_mode (bool): If True, processes only a small subset of events.
    """
    N, T, H, W = in_data.shape
    dset = None
    for i in tqdm(range(N), desc=f"Encoding {key}"):
        event_3d = in_data[i]  # [T, H, W]
        latents_4d = encode_event(event_3d, model, normalize)  # [T, C, h, w]
        if np.isnan(latents_4d).any():
            print(f"Warning: NaN found in encoded event {i} of {key}. Skipping.")
            continue
        if dset is None:
            T_new, C, h, w = latents_4d.shape
            dset = out_h5.create_dataset(
                latent_key,
                shape=(N, T_new, C, h, w),
                dtype=latents_4d.dtype,
                chunks=(1, T_new, C, h, w),
                compression="gzip",
                compression_opts=4,
            )
        dset[i, ...] = latents_4d
        if debug_mode and i >= 50:
            print(f"{DEBUG_PRINT_PREFIX}Stopping early after 50 events (debug_mode).")
            break


with h5py.File(DATA_FILE, "r") as in_h5, h5py.File(OUT_FILE, "w") as out_h5:
    for key in in_h5.keys():
        if key == "zm_IN":
            encode_and_write(
                in_h5[key],
                out_h5,
                key,
                "zm_IN_latent",
                model,
                training_params.normalize_dataset,
                DEBUG_MODE,
            )
        elif key == "zm_OUT":
            encode_and_write(
                in_h5[key],
                out_h5,
                key,
                "zm_OUT_latent",
                model,
                training_params.normalize_dataset,
                DEBUG_MODE,
            )
        else:
            # Copy all other datasets as-is
            in_h5.copy(key, out_h5)
    print(f"Finished writing latent file: {OUT_FILE}")

print("Done generating ARSO latent HDF5 file with all keys preserved.")
