"""
Distributed training script for a Variational Autoencoder (VAE) with a GAN-style
discriminator on the ARSO dataset.

This script uses `torchrun` for multi-GPU training. It handles the end-to-end
training process for an AutoencoderKL model, configured via a YAML file.
Key features include:
- DistributedDataParallel (DDP) for multi-GPU training.
- Dynamic dataset loading from HDF5 files.
- A training loop that alternates between generator (autoencoder) and discriminator updates.
- A validation loop for monitoring performance, saving checkpoints, and early stopping.
- Logging to Weights & Biases (wandb).
"""

import sys
import os

sys.path.append(os.getcwd())  # Add the current working directory to the path

os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
import datetime
import numpy as np
import wandb
import namegenerator
from matplotlib import pyplot as plt
import h5py
from sklearn.model_selection import train_test_split

import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR, SequentialLR
from common.utils.utils import warmup_lambda
from common.autoencoder.losses.lpips import LPIPSWithDiscriminator

import random
from tqdm import tqdm

from experiments.arso.dataset.arsodataset_autoencoder import (
    DynamicAutoencoderArsoDataset,
    pad_to_multiple_of_16,
    remove_padding,
)
from diffusers.models.autoencoders import AutoencoderKL
from common.autoencoder.utils.early_stopping import EarlyStopping
import argparse
from omegaconf import OmegaConf


def setup_ddp():
    """Initializes the DDP process group."""
    if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
        rank = int(os.environ["RANK"])
        world_size = int(os.environ["WORLD_SIZE"])
        local_rank = int(os.environ["LOCAL_RANK"])
        print(f"Initializing DDP: Rank {rank}/{world_size}, Local Rank {local_rank}")
        # URLs for rendezvous is typically handled by the launcher (torchrun) using env://
        dist.init_process_group(
            backend="nccl", init_method="env://", rank=rank, world_size=world_size
        )
        torch.cuda.set_device(local_rank)
        # Add barrier to ensure all processes have initialized
        dist.barrier()
        return rank, world_size, local_rank, torch.device(f"cuda:{local_rank}")
    else:
        print("Not running in distributed mode. Using single device.")
        return (
            0,
            1,
            0,
            torch.device("cuda" if torch.cuda.is_available() else "cpu"),
        )  # Rank 0, World Size 1, Local Rank 0


def cleanup_ddp():
    """Cleans up the DDP process group."""
    if dist.is_initialized():
        dist.destroy_process_group()
        print("Cleaned up DDP.")


# Helper function to flatten a nested dictionary
def flatten_dict(d, parent_key="", sep="."):
    """
    Flattens a nested dictionary.

    Args:
        d (dict): The dictionary to flatten.
        parent_key (str): The base key for recursive calls.
        sep (str): Separator to use between keys.

    Returns:
        dict: The flattened dictionary.
    """
    items = []
    for k, v in d.items():
        new_key = parent_key + sep + k if parent_key else k
        if isinstance(v, dict):
            items.extend(flatten_dict(v, new_key, sep=sep).items())
        else:
            items.append((new_key, v))
    return dict(items)


# Initialize DDP first
rank, world_size, local_rank, device = setup_ddp()

if rank == 0:
    print(f"Using device: {device}")
    if device.type == "cpu":
        print("CPU is used")

# Create argument parser
parser = argparse.ArgumentParser(
    description="Script for distributed training of ARSO AutoencoderKL."
)
parser.add_argument(
    "--config",
    type=str,
    default="experiments/arso/autoencoder/autoencoder_kl_config.yaml",
    help="Path to the YAML configuration file.",
)
parser.add_argument(
    "--data_file",
    type=str,
    default="datasets/arso/final_sequence_data/sequence_data_ds1.h5",
    help="Path to the data file.",
)

# Parse arguments
args = parser.parse_args()
config = OmegaConf.load(args.config)

# --- Configuration from YAML ---
run_params = config.run_params
training_params = config.training_params
optimizer_params = config.optimizer_params
scheduler_params = config.scheduler_params
model_params = config.model_params
loss_params = config.loss_params
DATA_FILE = args.data_file

# Assign arguments to variables
DEBUG_MODE = run_params.debug_mode
RUN_STRING = run_params.run_string
FINAL_TRAINING = training_params.final_training

# RUN_ID Generation - Ensure consistency across ranks
run_id_timestamp_string = (
    datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + "_" + RUN_STRING
)
random_name_part_list = [None]
if rank == 0:
    random_name_part_list[0] = namegenerator.gen()

if world_size > 1:  # Only broadcast if distributed
    dist.broadcast_object_list(random_name_part_list, src=0)
random_name_part = random_name_part_list[0]
if random_name_part is None and rank != 0:  # Should not happen if broadcast worked
    # Fallback or error, though broadcast_object_list should handle it
    # For safety, generate one if it's still None, but flag it
    print(
        f"[WARNING] Rank {rank} did not receive random_name_part, generating locally. RUN_ID might be inconsistent."
    )
    random_name_part = namegenerator.gen()


RUN_ID = run_id_timestamp_string + "_" + random_name_part


DEBUG_PRINT_PREFIX = f"[DEBUG Rank {rank}] " if DEBUG_MODE else f"[Rank {rank}] "

ENABLE_WANDB = run_params.enable_wandb
NORMALIZE_DATASET = training_params.normalize_dataset
PRELOAD_MODEL = training_params.preload_model
BATCH_SIZE = training_params.micro_batch_size
NUM_EPOCHS = training_params.num_epochs
NUM_WORKERS = training_params.num_workers
EARLY_STOPPING_PATIENCE = training_params.early_stopping_patience
WARMUP_GENERATOR_EPOCHS = training_params.warmup_generator_epochs
# Optimizer
LEARNING_RATE = optimizer_params.learning_rate
OPTIMIZER_TYPE = optimizer_params.optimizer_type
WEIGHT_DECAY = optimizer_params.weight_decay

# Scheduler
SCHEDULER_TYPE = scheduler_params.scheduler_type
LR_PLATEAU_FACTOR = scheduler_params.lr_plateau_factor
LR_PLATEAU_PATIENCE = scheduler_params.lr_plateau_patience
LR_COSINE_MIN_LR_RATIO = scheduler_params.lr_cosine_min_lr_ratio
LR_COSINE_WARMUP_ITER_PERCENTAGE = scheduler_params.lr_cosine_warmup_iter_percentage
LR_COSINE_MIN_WARMUP_LR_RATIO = scheduler_params.lr_cosine_min_warmup_lr_ratio

# Gradient Clipping
GRADIENT_CLIP_VAL = training_params.gradient_clip_val

# Autoencoder
LATENT_CHANNELS = model_params.latent_channels
NORM_NUM_GROUPS = model_params.norm_num_groups
LAYERS_PER_BLOCK = model_params.layers_per_block
ACT_FN = model_params.act_fn
BLOCK_OUT_CHANNELS = model_params.block_out_channels
DOWN_BLOCK_TYPES = model_params.down_block_types
UP_BLOCK_TYPES = model_params.up_block_types

# Loss
KL_WEIGHT = loss_params.kl_weight
DISC_WEIGHT = loss_params.disc_weight


# Print the configuration
if rank == 0:
    print(f"{DEBUG_PRINT_PREFIX}Run ID: {RUN_ID}")
    print(f"{DEBUG_PRINT_PREFIX}Debug Mode: {DEBUG_MODE}")
    print(f"{DEBUG_PRINT_PREFIX}Data File: {DATA_FILE}")
    print(f"{DEBUG_PRINT_PREFIX}Preload Model: {PRELOAD_MODEL}")
    print(f"{DEBUG_PRINT_PREFIX}Batch Size: {BATCH_SIZE}")
    print(f"{DEBUG_PRINT_PREFIX}Learning Rate: {LEARNING_RATE}")
    print(f"{DEBUG_PRINT_PREFIX}Number of Epochs: {NUM_EPOCHS}")
    print(f"{DEBUG_PRINT_PREFIX}Number of Workers: {NUM_WORKERS}")
    print(f"{DEBUG_PRINT_PREFIX}Early Stopping Patience: {EARLY_STOPPING_PATIENCE}")
    print(f"{DEBUG_PRINT_PREFIX}Warmup Generator Epochs: {WARMUP_GENERATOR_EPOCHS}")
    print(f"{DEBUG_PRINT_PREFIX}Optimizer Type: {OPTIMIZER_TYPE}")
    print(f"{DEBUG_PRINT_PREFIX}Scheduler Type: {SCHEDULER_TYPE}")
    print(f"{DEBUG_PRINT_PREFIX}LR Plateau Factor: {LR_PLATEAU_FACTOR}")
    print(f"{DEBUG_PRINT_PREFIX}LR Plateau Patience: {LR_PLATEAU_PATIENCE}")
    print(f"{DEBUG_PRINT_PREFIX}Weight Decay: {WEIGHT_DECAY}")
    print(
        f"{DEBUG_PRINT_PREFIX}LR Cosine Warmup Iter Percentage: {LR_COSINE_WARMUP_ITER_PERCENTAGE}"
    )
    print(
        f"{DEBUG_PRINT_PREFIX}LR Cosine Min Warmup LR Ratio: {LR_COSINE_MIN_WARMUP_LR_RATIO}"
    )
    print(f"{DEBUG_PRINT_PREFIX}LR Cosine Min LR Ratio: {LR_COSINE_MIN_LR_RATIO}")
    print(f"{DEBUG_PRINT_PREFIX}Gradient Clip Value: {GRADIENT_CLIP_VAL}")
    print(f"{DEBUG_PRINT_PREFIX}Latent Channels: {LATENT_CHANNELS}")
    print(f"{DEBUG_PRINT_PREFIX}Norm Num Groups: {NORM_NUM_GROUPS}")
    print(f"{DEBUG_PRINT_PREFIX}Layers Per Block: {LAYERS_PER_BLOCK}")
    print(f"{DEBUG_PRINT_PREFIX}Activation Function: {ACT_FN}")
    print(f"{DEBUG_PRINT_PREFIX}Block Out Channels: {BLOCK_OUT_CHANNELS}")
    print(f"{DEBUG_PRINT_PREFIX}Down Block Types: {DOWN_BLOCK_TYPES}")
    print(f"{DEBUG_PRINT_PREFIX}Up Block Types: {UP_BLOCK_TYPES}")
    print(f"{DEBUG_PRINT_PREFIX}KL Weight: {KL_WEIGHT}")
    print(f"{DEBUG_PRINT_PREFIX}Discriminator Weight: {DISC_WEIGHT}")


# Initialize wandb at the very beginning
if ENABLE_WANDB and rank == 0:
    primitive_config = OmegaConf.to_container(config, resolve=True)
    flat_config = flatten_dict(primitive_config)
    wandb.init(
        project="arso-autoencoder-kl",
        config=flat_config,
        name=RUN_ID,
    )


# All artifacts will be saved here
ARTIFACTS_FOLDER = "artifacts/arso/autoencoder_kl/" + RUN_ID
# Plots inside the artifacts folder
PLOTS_FOLDER = ARTIFACTS_FOLDER + "/plots"
# Make Animations folder inside the plots folder
ANIMATIONS_FOLDER = PLOTS_FOLDER + "/animations"
# Make Metrics folder inside the plots folder
METRICS_FOLDER = PLOTS_FOLDER + "/metrics"

# Set random seeds for reproducibility
random.seed(42)

# Create a directory for saving models inside the artifacts folder
if rank == 0:
    os.makedirs(PLOTS_FOLDER, exist_ok=True)
    os.makedirs(ANIMATIONS_FOLDER, exist_ok=True)
    os.makedirs(METRICS_FOLDER, exist_ok=True)
    MODEL_SAVE_DIR = ARTIFACTS_FOLDER + "/models"
    os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
    MODEL_SAVE_PATH = os.path.join(MODEL_SAVE_DIR, "early_stopping_model" + ".pt")
else:
    MODEL_SAVE_PATH = None

# Define the number of samples read, -1 being all samples
TRAIN_VAL_TEST_FRAC = [0.6, 0.2, 0.2]


def generate_loaders():
    """
    Splits data indices and creates distributed DataLoaders for train and validation sets.
    """
    # Determine the total number of samples by opening the file once.
    with h5py.File(DATA_FILE, "r") as hf:
        num_samples = len(hf["zm_IN"])
    indices = np.arange(num_samples)

    # Split indices into train and temporary (val + test) splits.
    train_frac, val_frac, test_frac = TRAIN_VAL_TEST_FRAC
    X_train_idx, X_temp_idx = train_test_split(
        indices, train_size=train_frac, random_state=42, shuffle=False
    )

    # Split the temporary set into validation and test sets.
    val_size = val_frac / (val_frac + test_frac)
    X_val_idx, X_test_idx = train_test_split(
        X_temp_idx, train_size=val_size, random_state=42, shuffle=False
    )

    # Create dataset instances that read directly from the HDF5 file.
    train_dataset = DynamicAutoencoderArsoDataset(
        DATA_FILE,
        indices=X_train_idx,
        transform=None,
        channel_last=False,
        normalize=NORMALIZE_DATASET,
    )
    val_dataset = DynamicAutoencoderArsoDataset(
        DATA_FILE,
        indices=X_val_idx,
        transform=None,
        channel_last=False,
        normalize=NORMALIZE_DATASET,
    )

    # --- 6) Create DataLoaders ---
    train_sampler = DistributedSampler(
        train_dataset, num_replicas=world_size, rank=rank, shuffle=True
    )
    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        sampler=train_sampler,
        shuffle=False,  # must be False when using a sampler
        num_workers=NUM_WORKERS,
        pin_memory=True,
        drop_last=True,  # recommended for multi-GPU
    )
    val_sampler = DistributedSampler(
        val_dataset, num_replicas=world_size, rank=rank, shuffle=False
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        sampler=val_sampler,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=True,
        drop_last=False,
    )

    return train_loader, val_loader, train_sampler, val_sampler


train_loader, val_loader, train_sampler, val_sampler = generate_loaders()

# Check the shape of the first batch
input_shape_list = [None]  # Use a list for mutable broadcast
if rank == 0:
    # Iterate briefly to get shapes on rank 0
    temp_input_shape = None
    # Ensure train_loader has data. If it could be empty, add a check.
    if len(train_loader) > 0:
        for batch_data in train_loader:
            inputs, _ = batch_data
            print(f"{DEBUG_PRINT_PREFIX}Inputs shape (from rank 0): {inputs.shape}")
            temp_input_shape = inputs.shape
            break  # Got the shape
    else:
        print(
            f"{DEBUG_PRINT_PREFIX}Train loader is empty on rank 0, cannot determine input shape."
        )
    input_shape_list[0] = temp_input_shape

if world_size > 1:
    dist.broadcast_object_list(input_shape_list, src=0)

input_shape = input_shape_list[0]

if input_shape is None:
    # This is a critical error, raise it on all ranks to stop execution.
    error_message = f"Rank {rank}: Could not determine input_shape. Rank 0 value: {input_shape_list[0]}."
    # It might be good to synchronize before raising to avoid partial errors if possible
    if world_size > 1:
        dist.barrier()
    raise RuntimeError(error_message)


# Preload model information from checkpoint
preload_model_state_dict = None
preload_global_step = None
preload_best_val_loss = None

# Data structure to hold loaded info for broadcasting
loaded_checkpoint_info = [
    None,
    None,
    None,
]  # Corresponds to state_dict, global_step, best_val_loss

if rank == 0:
    if PRELOAD_MODEL is not None:
        print(f"{DEBUG_PRINT_PREFIX}Attempting to load checkpoint: {PRELOAD_MODEL}")
        try:
            # Load on CPU first to avoid GPU mem issues on rank 0 for DDP setup
            model_info_cpu = torch.load(PRELOAD_MODEL, map_location="cpu")
            loaded_checkpoint_info[0] = model_info_cpu.get("model_state_dict")
            loaded_checkpoint_info[1] = model_info_cpu.get("global_step")
            # Ensure key matches what's saved by EarlyStopping ('val_loss_min' or a more general key)
            # Assuming EarlyStopping saves 'best_val_loss' or similar that maps to val_loss_min
            loaded_checkpoint_info[2] = model_info_cpu.get(
                "best_val_loss", model_info_cpu.get("val_loss_min")
            )
            if loaded_checkpoint_info[0] is not None:
                print(
                    f"{DEBUG_PRINT_PREFIX}Successfully loaded model information from checkpoint (on rank 0 CPU)."
                )
            else:
                print(
                    f"{DEBUG_PRINT_PREFIX}Checkpoint loaded, but model_state_dict not found. Starting model from scratch."
                )
        except FileNotFoundError:
            print(
                f"{DEBUG_PRINT_PREFIX}Preload model file not found: {PRELOAD_MODEL}. Starting model from scratch."
            )
        except Exception as e:
            print(
                f"{DEBUG_PRINT_PREFIX}Error loading checkpoint: {e}. Starting model from scratch."
            )
    else:
        print(f"{DEBUG_PRINT_PREFIX}No preload model specified on rank 0.")

if world_size > 1:
    dist.broadcast_object_list(loaded_checkpoint_info, src=0)

# All ranks assign from the broadcasted list
preload_model_state_dict = loaded_checkpoint_info[0]
preload_global_step = loaded_checkpoint_info[1]
preload_best_val_loss = loaded_checkpoint_info[2]

# Create the model (on CPU initially, then move to device, then DDP)
model = AutoencoderKL(
    in_channels=input_shape[1],
    out_channels=input_shape[1],
    down_block_types=DOWN_BLOCK_TYPES,
    up_block_types=UP_BLOCK_TYPES,
    block_out_channels=BLOCK_OUT_CHANNELS,
    act_fn=ACT_FN,
    latent_channels=LATENT_CHANNELS,
    norm_num_groups=NORM_NUM_GROUPS,
    layers_per_block=LAYERS_PER_BLOCK,
)

# Load the model state dict (all ranks do this with the (potentially) broadcasted state_dict)
if preload_model_state_dict is not None:
    model.load_state_dict(preload_model_state_dict)
    if rank == 0:  # Print only on rank 0
        print(
            DEBUG_PRINT_PREFIX
            + "Successfully loaded model state dict to model instance."
        )
else:
    if rank == 0:
        print(
            DEBUG_PRINT_PREFIX
            + "No pre-trained model state_dict to load, or PRELOAD_MODEL not specified/found."
        )


# Move the model to the device
model = model.to(device)  # Each rank moves its model instance to its assigned device
# Wrap with DDP. Ensure local_rank is defined correctly from setup_ddp.
model = DDP(model, device_ids=[local_rank], output_device=local_rank)


# Remove the preload to save VRAM
preload_model_state_dict = None
model_info = None
# Clear the cache
torch.cuda.empty_cache()


# Initialize loss function
criterion = LPIPSWithDiscriminator(
    kl_weight=KL_WEIGHT,
    disc_weight=DISC_WEIGHT,
    disc_in_channels=1,
).to(device)

if OPTIMIZER_TYPE == "adam":
    gen_optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    disc_optimizer = torch.optim.Adam(
        criterion.discriminator.parameters(), lr=LEARNING_RATE
    )
elif OPTIMIZER_TYPE == "adamw":
    gen_optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY,
        betas=(0.9, 0.999),
    )
    disc_optimizer = torch.optim.AdamW(
        criterion.discriminator.parameters(),
        lr=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY,
        betas=(0.9, 0.999),
    )
else:
    raise ValueError(f"Invalid optimizer type: {OPTIMIZER_TYPE}")

# You can keep your existing scheduler (or define separate schedulers if desired)
# For this example, we assume a single scheduler is used for the generator optimizer.
total_num_steps = len(train_loader) * NUM_EPOCHS
warmup_iter = int(np.round(LR_COSINE_WARMUP_ITER_PERCENTAGE * total_num_steps))
if SCHEDULER_TYPE == "cosine":
    gen_warmup_scheduler = LambdaLR(
        gen_optimizer,
        lr_lambda=warmup_lambda(
            warmup_steps=warmup_iter, min_lr_ratio=LR_COSINE_MIN_WARMUP_LR_RATIO
        ),
    )
    gen_cosine_scheduler = CosineAnnealingLR(
        gen_optimizer,
        T_max=(total_num_steps - warmup_iter),
        eta_min=LR_COSINE_MIN_LR_RATIO * LEARNING_RATE,
    )
    gen_scheduler = SequentialLR(
        gen_optimizer,
        schedulers=[gen_warmup_scheduler, gen_cosine_scheduler],
        milestones=[warmup_iter],
    )
    disc_warmup_scheduler = LambdaLR(
        disc_optimizer,
        lr_lambda=warmup_lambda(
            warmup_steps=warmup_iter, min_lr_ratio=LR_COSINE_MIN_WARMUP_LR_RATIO
        ),
    )
    disc_cosine_scheduler = CosineAnnealingLR(
        disc_optimizer,
        T_max=(total_num_steps - warmup_iter),
        eta_min=LR_COSINE_MIN_LR_RATIO * LEARNING_RATE,
    )
    disc_scheduler = SequentialLR(
        disc_optimizer,
        schedulers=[disc_warmup_scheduler, disc_cosine_scheduler],
        milestones=[warmup_iter],
    )
elif SCHEDULER_TYPE == "plateau":
    gen_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        gen_optimizer,
        mode="min",
        factor=LR_PLATEAU_FACTOR,
        patience=LR_PLATEAU_PATIENCE,
    )
    disc_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        disc_optimizer,
        mode="min",
        factor=LR_PLATEAU_FACTOR,
        patience=LR_PLATEAU_PATIENCE,
    )
else:
    raise ValueError(f"Invalid scheduler type: {SCHEDULER_TYPE}")

# Initialize best validation loss
best_val_loss = float("inf") if preload_best_val_loss is None else preload_best_val_loss

# Initialize EarlyStopping instance
if rank == 0:
    early_stopping = EarlyStopping(
        patience=EARLY_STOPPING_PATIENCE,
        verbose=True,
        path=MODEL_SAVE_PATH,
        val_loss_min=best_val_loss,
    )
else:
    early_stopping = None

if rank == 0:  # Guard this print statement
    print(DEBUG_PRINT_PREFIX + "Starting training, run id: " + RUN_ID)

# Global step counter
global_step = 0 if preload_global_step is None else preload_global_step


def train_epoch(
    train_bar,
    global_step,
    model,
    criterion,
    device,
    gen_optimizer,
    disc_optimizer,
    gen_scheduler,
    disc_scheduler,
    scheduler_type,
):
    """
    Runs one full training epoch.

    Iterates through the training data, performing two updates per batch:
    1. Generator (Autoencoder) Update: Updates the autoencoder's weights based on
       reconstruction, KL, and adversarial losses.
    2. Discriminator Update: Updates the discriminator's weights based on its
       ability to distinguish real vs. reconstructed images.

    Args:
        train_bar: tqdm progress bar for the training loader.
        global_step (int): Current global step for logging.
        model (DDP): The wrapped AutoencoderKL model.
        criterion: The loss function (LPIPSWithDiscriminator).
        device: The computation device for the current rank.
        gen_optimizer: Optimizer for the autoencoder.
        disc_optimizer: Optimizer for the discriminator.
        gen_scheduler: Learning rate scheduler for the autoencoder.
        disc_scheduler: Learning rate scheduler for the discriminator.
        scheduler_type (str): The type of scheduler ('cosine' or 'plateau').

    Returns:
        A tuple with the updated global step, and average total, generator,
        and discriminator losses.
    """
    global_step_counter = global_step
    model.train()
    criterion.train()
    total_gen_loss = 0.0
    total_disc_loss = 0.0
    count = 0

    for batch in train_bar:
        input_frame, _ = batch
        input_frame = input_frame.to(device)

        # Pad input to multiple of 16
        padded_input, padding_info = pad_to_multiple_of_16(input_frame)

        disc_optimizer.zero_grad()
        gen_optimizer.zero_grad()

        # -------------------------
        # 1) Generator update
        # -------------------------

        # Forward pass: get reconstruction and latent distribution.
        outputs_dict = model(padded_input)
        recon = outputs_dict.sample  # reconstructed images
        # Remove padding from reconstruction
        recon = remove_padding(recon, padding_info)
        # Obtain posterior (latent distribution) from encoding
        posterior = model.module.encode(padded_input)

        # Compute generator loss (optimizer_idx = 0)
        loss_gen, log_dict_gen = criterion(
            input_frame,
            recon,
            posterior,
            optimizer_idx=0,
            mask=None,
            last_layer=model.module.decoder.conv_out.weight,
            split="train",
        )
        loss_gen.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), GRADIENT_CLIP_VAL)
        gen_optimizer.step()

        # -------------------------
        # 2) Discriminator update
        # -------------------------
        loss_disc, log_dict_disc = criterion(
            input_frame,
            recon,
            posterior,
            optimizer_idx=1,
            mask=None,
            last_layer=model.module.decoder.conv_out.weight,
            split="train",
        )
        loss_disc.backward()
        torch.nn.utils.clip_grad_norm_(
            criterion.discriminator.parameters(), GRADIENT_CLIP_VAL
        )
        disc_optimizer.step()

        total_gen_loss += loss_gen.item()
        total_disc_loss += loss_disc.item()
        if rank == 0:
            train_bar.set_postfix(
                {
                    "train_gen_loss": f"{loss_gen.item():.3f}",
                    "train_disc_loss": f"{loss_disc.item():.3f}",
                }
            )

        global_step_counter += input_frame.shape[0]
        count += 1

        if ENABLE_WANDB and rank == 0:
            wandb.log(
                {
                    "train_gen_loss": loss_gen.item(),
                    "train_disc_loss": loss_disc.item(),
                },
                step=global_step_counter,
            )

        if DEBUG_MODE and count > 20:
            if rank == 0:  # Guard this print statement
                print(f"{DEBUG_PRINT_PREFIX}Breaking early due to DEBUG_MODE")
            break

        if scheduler_type == "cosine":
            gen_scheduler.step()
            disc_scheduler.step()

    avg_gen_loss = total_gen_loss / count
    avg_disc_loss = total_disc_loss / count
    avg_total_loss = (total_gen_loss + total_disc_loss) / (2 * count)
    return global_step_counter, avg_total_loss, avg_gen_loss, avg_disc_loss


###############################################
# Adapted validate function using LPIPSWithDiscriminator
###############################################
def validate(
    model,
    criterion,
    device,
    val_bar,
    gen_scheduler,
    disc_scheduler,
    scheduler_type,
    epoch,
):
    """
    Runs a full validation pass on the validation dataset.

    Computes generator and discriminator losses without gradient updates. Handles
    LR scheduler steps for 'plateau' schedulers, logs metrics, and calls the
    early stopping handler on the main process.

    Args:
        model (DDP): The wrapped AutoencoderKL model.
        criterion: The loss function.
        device: The computation device.
        val_bar: tqdm progress bar for the validation loader.
        gen_scheduler: Learning rate scheduler for the autoencoder.
        disc_scheduler: Learning rate scheduler for the discriminator.
        scheduler_type (str): The type of scheduler.
        epoch (int): The current epoch number.

    Returns:
        A tuple with the average total, generator, and discriminator
        validation losses.
    """
    model.eval()
    criterion.eval()
    total_gen_loss = 0.0
    total_disc_loss = 0.0
    count = 0

    with torch.no_grad():
        for batch in val_bar:
            input_frame, _ = batch
            input_frame = input_frame.to(device)

            # Pad input to multiple of 16
            padded_input, padding_info = pad_to_multiple_of_16(input_frame)

            outputs_dict = model(padded_input)
            recon = outputs_dict.sample
            # Remove padding from reconstruction
            recon = remove_padding(recon, padding_info)
            posterior = model.module.encode(padded_input)

            # Compute generator loss (optimizer_idx = 0)
            loss_gen, log_dict_gen = criterion(
                input_frame,
                recon,
                posterior,
                optimizer_idx=0,
                mask=None,
                last_layer=model.module.decoder.conv_out.weight,
                split="val",
            )
            # Compute discriminator loss (optimizer_idx = 1)
            loss_disc, log_dict_disc = criterion(
                input_frame,
                recon,
                posterior,
                optimizer_idx=1,
                mask=None,
                last_layer=model.module.decoder.conv_out.weight,
                split="val",
            )

            total_gen_loss += loss_gen.item()
            total_disc_loss += loss_disc.item()
            if rank == 0:
                val_bar.set_postfix(
                    {
                        "val_gen_loss": f"{loss_gen.item():.3f}",
                        "val_disc_loss": f"{loss_disc.item():.3f}",
                    }
                )

            count += 1
            if DEBUG_MODE and count > 20:
                if rank == 0:  # Guard this print statement
                    print(f"{DEBUG_PRINT_PREFIX}Breaking early due to DEBUG_MODE")
                break

    avg_gen_loss = total_gen_loss / count
    avg_disc_loss = total_disc_loss / count
    avg_total_loss = (total_gen_loss + total_disc_loss) / (2 * count)

    if scheduler_type == "plateau":
        gen_scheduler.step(avg_total_loss)
        disc_scheduler.step(avg_total_loss)

    current_lr = gen_optimizer.param_groups[0]["lr"]
    if ENABLE_WANDB and rank == 0:
        wandb.log(
            {
                "val_total_loss": avg_total_loss,
                "val_gen_loss": avg_gen_loss,
                "val_disc_loss": avg_disc_loss,
                "learning_rate": current_lr,
            }
        )

    # Early stopping and Warmup Manager
    if epoch >= WARMUP_GENERATOR_EPOCHS:
        if not criterion.discriminator_active:
            if rank == 0:  # Guard this print statement
                print(DEBUG_PRINT_PREFIX + "Switching to full mode")
            criterion.activate_discriminator()
        else:
            # Early stop on reconstruction loss
            if rank == 0 and early_stopping is not None:
                early_stopping(
                    avg_gen_loss,
                    model,
                    gen_optimizer,
                    disc_optimizer,
                    epoch,
                    global_step,
                )

    return avg_total_loss, avg_gen_loss, avg_disc_loss


# Train the model
for epoch in range(NUM_EPOCHS):
    train_sampler.set_epoch(epoch)
    val_sampler.set_epoch(epoch)
    # Generate tqdm bar based on the custom loading and not the loader
    if rank == 0:
        train_bar = tqdm(train_loader, desc=f"Training Epoch {epoch}")
    else:
        train_bar = train_loader

    # Train for one epoch
    global_step, avg_total_train_loss, avg_gen_train_loss, avg_disc_train_loss = (
        train_epoch(
            train_bar,
            global_step,
            model,
            criterion,
            device,
            gen_optimizer,
            disc_optimizer,
            gen_scheduler,
            disc_scheduler,
            SCHEDULER_TYPE,
        )
    )

    if FINAL_TRAINING:
        if rank == 0:
            print(f"{DEBUG_PRINT_PREFIX}Training on validation set for epoch {epoch}")
            val_train_bar = tqdm(
                val_loader, desc=f"Final Training Epoch {epoch} (Val Set)"
            )
        else:
            val_train_bar = val_loader

        global_step, _, _, _ = train_epoch(
            val_train_bar,
            global_step,
            model,
            criterion,
            device,
            gen_optimizer,
            disc_optimizer,
            gen_scheduler,
            disc_scheduler,
            SCHEDULER_TYPE,
        )

    # Validation TQDM
    if rank == 0:
        val_bar = tqdm(val_loader, desc=f"Validation Epoch {epoch}")
    else:
        val_bar = val_loader

    # Validate the model
    avg_total_val_loss, avg_gen_val_loss, avg_disc_val_loss = validate(
        model,
        criterion,
        device,
        val_bar,
        gen_scheduler,
        disc_scheduler,
        SCHEDULER_TYPE,
        epoch,
    )

    if rank == 0:
        print(
            f"Finished Epoch {epoch} - Train Loss: {avg_total_train_loss:.3f}, Val Loss: {avg_total_val_loss:.3f}, Gen Train Loss: {avg_gen_train_loss:.3f}, Gen Val Loss: {avg_gen_val_loss:.3f}, Disc Train Loss: {avg_disc_train_loss:.3f}, Disc Val Loss: {avg_disc_val_loss:.3f}"
        )

    # Check if we need to break the loop
    if rank == 0 and early_stopping is not None and early_stopping.early_stop:
        print(DEBUG_PRINT_PREFIX + "Early stopping")
        break

if rank == 0:
    print(DEBUG_PRINT_PREFIX + "Finished training, run id: " + RUN_ID)

cleanup_ddp()

# Free up RAM by deleting the model, optimizer, and scheduler, and torch.cuda.empty_cache(), and loaders except for test
del (
    model,
    gen_optimizer,
    disc_optimizer,
    gen_scheduler,
    disc_scheduler,
    train_loader,
    val_loader,
    train_bar,
    val_bar,
)
torch.cuda.empty_cache()
