# from my_train_cifar10_ddp.py to read checkpoint and continue MLE training
# This is for residual fune-tuning

import copy
import math
import os
import json
import gc
import numpy as np
import matplotlib.pyplot as plt
import shutil

import torch
from absl import app, flags
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DistributedSampler
from torchdyn.core import NeuralODE
from torchvision import datasets, transforms
from tqdm import trange
from utils_cifar import ema, generate_samples, infiniteloop, setup
from torch.utils.tensorboard import SummaryWriter

from torchcfm.conditional_flow_matching import (
    ConditionalFlowMatcher,
    ExactOptimalTransportConditionalFlowMatcher,
    TargetConditionalFlowMatcher,
    VariancePreservingConditionalFlowMatcher,
)
from torchcfm.models.unet.unet import UNetModelWrapper
from cleanfid import fid
# from torchdiffeq import odeint
from torchdiffeq import odeint_adjoint as odeint

import torch.nn as nn
class VectorFieldWrapper(nn.Module):
    """Wrapper to ensure model receives both t and x arguments when used with DDP."""
    def __init__(self, model):
        super().__init__()
        self.model = model
        
    def forward(self, t, x):
        return self.model(t, x)

FLAGS = flags.FLAGS

flags.DEFINE_string("model", "otcfm", help="flow matching model type")
flags.DEFINE_string("output_dir", "./results_MLE_Res/", help="output_directory")
# flags.DEFINE_string("checkpoint_path", "./results3/otcfm/otcfm_cifar10_weights_step_42780000.pt", "path to checkpoint")
flags.DEFINE_string("checkpoint_path", "", "path to checkpoint")
flags.DEFINE_string("pretrained_path", "./results_1GPU/otcfm/otcfm_cifar10_weights_step_400000.pt", "path to pretrained model")

# UNet
flags.DEFINE_integer("num_channel", 128, help="base channel of UNet")
# Residual model architecture
flags.DEFINE_string("res_model_type", "unet", help="residual model architecture: mlp, unet, or unet_simple")

# Training
flags.DEFINE_float("lr", 1e-5, help="target learning rate")  # TRY 2e-4 1e-5
flags.DEFINE_float("grad_clip", 1.0, help="gradient norm clipping")
flags.DEFINE_integer(
    "num_epochs", 100, help="total training epochs" # 400001
)
flags.DEFINE_integer("warmup", 10, help="learning rate warmup")
flags.DEFINE_integer("batch_size", 128, help="batch size")  # Lipman et al uses #  512
flags.DEFINE_integer("num_workers", 24, help="workers of Dataloader")
flags.DEFINE_float("ema_decay", 0.9999, help="ema decay rate")
flags.DEFINE_bool("parallel", True, help="multi gpu training")
flags.DEFINE_string(
    "master_addr", "localhost", help="master address for Distributed Data Parallel"
)
flags.DEFINE_string("master_port", "29500", help="master port for Distributed Data Parallel")

# Evaluation
flags.DEFINE_integer(
    "save_step",
    10, # 20000
    help="frequency of saving checkpoints, 0 to disable during training",
)
flags.DEFINE_integer(
    "fid_step",
    0,
    help="frequency of computing FID (in epochs), 0 to disable during training",
)
# FID computation parameters
flags.DEFINE_integer("integration_steps", 100, help="number of inference steps for FID")
flags.DEFINE_string("integration_method", "dopri5", help="integration method to use for FID")
flags.DEFINE_integer("num_gen", 50000, help="number of samples to generate for FID")
flags.DEFINE_float("tol", 1e-5, help="Integrator tolerance for FID (absolute and relative)")
flags.DEFINE_integer("batch_size_fid", 1024, help="Batch size to compute FID")
# -----------------------------------------------------------------------------
# import argparse
# parser = argparse.ArgumentParser()
# parser.add_argument("--local_rank", default=-1)
# FLAGS = parser.parse_args()
# local_rank = FLAGS.local_rank

# -----------------------------------------------------------------------------
def warmup_lr(step):
    return min(step, FLAGS.warmup) / FLAGS.warmup


def compute_fid_score(model, device, integration_steps, integration_method, tol, batch_size_fid, num_gen):
    """Compute FID score for the given model."""
    # Handle DDP model by getting the underlying module
    if hasattr(model, 'module'):
        res_net = model.module
    else:
        res_net = model
        
    res_net.eval()

    # Also get pretrained model for two-stage generation
    if hasattr(pretrained_model, 'module'):
        pretrained_net = pretrained_model.module
    else:
        pretrained_net = pretrained_model
        
    pretrained_net.eval()

    integration_method = "dopri5"
    
    def gen_1_img(unused_latent):
        with torch.no_grad():
            x = torch.randn(batch_size_fid, 3, 32, 32, device=device)
            
            # First stage: use pretrained_model to generate intermediate representation from noise
            if integration_method == "euler":
                pretrained_node = NeuralODE(pretrained_net, solver=integration_method)
                t_span = torch.linspace(0, 1, integration_steps + 1, device=device)
                pretrained_traj = pretrained_node.trajectory(x, t_span=t_span)
                intermediate_x = pretrained_traj[-1]
            else:
                t_span = torch.linspace(0, 1, 2, device=device)
                pretrained_traj = odeint(
                    pretrained_net, x, t_span, rtol=tol, atol=tol, method=integration_method
                )
                intermediate_x = pretrained_traj[-1, :]
            
            # Second stage: use res_model to generate final image from intermediate representation
            if integration_method == "euler":
                res_node = NeuralODE(res_net, solver=integration_method)
                t_span = torch.linspace(0, 1, integration_steps + 1, device=device)
                res_traj = res_node.trajectory(intermediate_x, t_span=t_span)
                final_x = res_traj[-1]
            else:
                t_span = torch.linspace(0, 1, 2, device=device)
                res_traj = odeint(
                    res_net, intermediate_x, t_span, rtol=tol, atol=tol, method=integration_method
                )
                final_x = res_traj[-1, :]
            
            img = (final_x * 127.5 + 128).clip(0, 255).to(torch.uint8)
            return img

    print("Start computing FID")
    score = fid.compute_fid(
        gen=gen_1_img,
        dataset_name="cifar10",
        batch_size=batch_size_fid,
        dataset_res=32,
        num_gen=num_gen,
        dataset_split="train",
        mode="legacy_tensorflow",
    )
    print(f"FID: {score}")
    
    # Set res_model back to training mode (pretrained_model remains frozen)
    res_net.train()
    return score


def train(rank, total_num_gpus, argv):
    # Determine the integer rank for logging purposes
    # For single GPU, rank is a device, so we set rank_int to 0
    # For multi-GPU, rank is an integer, so we use it directly
    if isinstance(rank, int):
        rank_int = rank
    else:
        rank_int = 0  # Single GPU case
        
    if FLAGS.parallel and total_num_gpus > 1:
        # Each GPU uses the same batch size, so the global batch size scales with the number of GPUs.
        batch_size_per_gpu = FLAGS.batch_size
        setup(rank, total_num_gpus, FLAGS.master_addr, FLAGS.master_port)
    else:
        batch_size_per_gpu = FLAGS.batch_size
        # For single GPU with parallel flag, don't use distributed sampler
        if FLAGS.parallel:
            FLAGS.parallel = False

    # DATASETS/DATALOADER
    dataset = datasets.CIFAR10(
        root="./data",
        train=True,
        download=True,
        transform=transforms.Compose(
            [
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        ),
    )
    # import torch.distributed as dist
    # dist.init_process_group(backend='nccl')

    sampler = DistributedSampler(dataset) if FLAGS.parallel else None
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size_per_gpu,
        sampler=sampler,
        shuffle=False if FLAGS.parallel else True,
        num_workers=FLAGS.num_workers,
        drop_last=True,
    )

    datalooper = infiniteloop(dataloader)

    # Calculate number of steps per epoch for each GPU
    # The dataset is split across GPUs by DistributedSampler
    steps_per_epoch = math.ceil(len(dataset) / (FLAGS.batch_size * total_num_gpus))
    num_epochs = FLAGS.num_epochs
    total_steps = num_epochs * steps_per_epoch
    if rank_int == 0:
        print(
            "lr, total_steps, ema decay, save_step:",
            FLAGS.lr,
            total_steps,
            FLAGS.ema_decay,
            FLAGS.save_step,
        )

    # MODELS
    # Load pretrained model (using original complex UNet for pretrained)
    pretrained_model = UNetModelWrapper(
        dim=(3, 32, 32),
        num_res_blocks=2,
        num_channels=FLAGS.num_channel,
        channel_mult=[1, 2, 2, 2],
        num_heads=4,
        num_head_channels=64,
        attention_resolutions="16",
        dropout=0.1,
    ).to(rank)

    # Load pretrained weights
    if isinstance(rank, int):
        map_location = f"cuda:{rank}"
    else:
        map_location = rank
    pretrained_checkpoint = torch.load(FLAGS.pretrained_path, map_location=map_location)
    
    # Handle DDP state_dict loading (strip "module." prefix if needed)
    from collections import OrderedDict
    pretrained_state_dict = pretrained_checkpoint['net_model']
    try:
        pretrained_model.load_state_dict(pretrained_state_dict)
    except RuntimeError:
        new_pretrained_state_dict = OrderedDict()
        for k, v in pretrained_state_dict.items():
            if k.startswith('module.'):
                k = k[7:]  # remove 'module.' prefix
            new_pretrained_state_dict[k] = v
        pretrained_model.load_state_dict(new_pretrained_state_dict)
    
    # Freeze pretrained model parameters
    for param in pretrained_model.parameters():
        param.requires_grad = False

        # Import simplified UNet models
    # from torchcfm.models.unet.minimal_unet import MinimalUNetModelWrapper
    # from torchcfm.models.unet.simplified_unet import SimplifiedUNetModelWrapper
    
    # Create residual model based on architecture type
    if FLAGS.res_model_type == "mlp":
        # Simple MLP model for residual learning
        class MLPResidualModel(nn.Module):
            def __init__(self, input_dim=3*32*32, hidden_dim=4*512, output_dim=3*32*32):
                super().__init__()
                self.net = nn.Sequential(
                    nn.Linear(input_dim, hidden_dim),
                    nn.ReLU(),
                    nn.Linear(hidden_dim, hidden_dim),
                    nn.ReLU(),
                    nn.Linear(hidden_dim, output_dim)
                )
                
            def forward(self, t, x, **kwargs):
                # t is time, x is the input tensor
                batch_size = x.shape[0]
                x_flat = x.view(batch_size, -1)
                output = self.net(x_flat)
                return output.view_as(x)
        
        res_model = MLPResidualModel().to(rank)
    # elif FLAGS.res_model_type == "unet_simplified":
    #     # Use simplified UNet for residual learning (even more minimal)
    #     res_model = SimplifiedUNetModelWrapper(
    #         dim=(3, 32, 32),
    #         num_res_blocks=1,  # Single residual block per level
    #         num_channels=32,   # Further reduced from 64
    #         channel_mult=[1, 2],  # Only 2 levels total
    #         dropout=0.1,
    #     ).to(rank)
    else:  # default to original complex unet


        # Original complex UNet
        # res_model = UNetModelWrapper(
        #     dim=(3, 32, 32),
        #     num_res_blocks=2,
        #     num_channels=FLAGS.num_channel,
        #     channel_mult=[1, 2, 2, 2],
        #     num_heads=4,
        #     num_head_channels=64,
        #     attention_resolutions="16",
        #     dropout=0.1,
        # ).to(rank)


        res_model = UNetModelWrapper(
            dim=(3, 32, 32),
            num_res_blocks=1,
            num_channels=FLAGS.num_channel,
            channel_mult=[1, 2],
            num_heads=4,
            num_head_channels=64,
            # attention_resolutions="",
            dropout=0.1,
        ).to(rank)

    ema_res_model = copy.deepcopy(res_model)
    optim = torch.optim.Adam(res_model.parameters(), lr=FLAGS.lr)
    # sched = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=warmup_lr)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=total_steps)
    # sched = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optim, T_0=1000, T_mult=2)

    
    if FLAGS.parallel:
        # Only wrap models that have trainable parameters with DDP
        res_model = DistributedDataParallel(res_model, device_ids=[rank])
        ema_res_model = DistributedDataParallel(ema_res_model, device_ids=[rank])
        # pretrained_model is frozen and doesn't need DDP wrapping

    # show model size
    model_size = 0
    for param in res_model.parameters():
        model_size += param.data.nelement()
    print("Model params: %.2f M" % (model_size / 1024 / 1024))

    #################################
    #            OT-CFM
    #################################

    sigma = 0.0
    if FLAGS.model == "otcfm":
        FM = ExactOptimalTransportConditionalFlowMatcher(sigma=sigma)
    elif FLAGS.model == "icfm":
        FM = ConditionalFlowMatcher(sigma=sigma)
    elif FLAGS.model == "fm":
        FM = TargetConditionalFlowMatcher(sigma=sigma)
    elif FLAGS.model == "si":
        FM = VariancePreservingConditionalFlowMatcher(sigma=sigma)
    else:
        raise NotImplementedError(
            f"Unknown model {FLAGS.model}, must be one of ['otcfm', 'icfm', 'fm', 'si']"
        )

    savedir = FLAGS.output_dir + FLAGS.model + "/"
    if rank_int == 0:
        if os.path.exists(savedir):
            shutil.rmtree(savedir)
        os.makedirs(savedir)
    
    # Initialize TensorBoard writer
    if rank_int == 0:
        log_dir = savedir + "tensorboard_logs/"
        os.makedirs(log_dir, exist_ok=True)
        writer = SummaryWriter(log_dir=log_dir)
        print(f"TensorBoard logs will be saved to: {log_dir}")

    if FLAGS.checkpoint_path:
        if isinstance(rank, int):
            map_location = f"cuda:{rank}"
        else:
            map_location = rank
        checkpoint = torch.load(FLAGS.checkpoint_path, map_location=map_location)
        
        # Handle DDP state_dict loading (strip "module." prefix if needed)
        from collections import OrderedDict
        
        # Load res_model state_dict
        res_state_dict = checkpoint['res_model']
        try:
            res_model.load_state_dict(res_state_dict)
        except RuntimeError:
            new_res_state_dict = OrderedDict()
            for k, v in res_state_dict.items():
                if k.startswith('module.'):
                    k = k[7:]  # remove 'module.' prefix
                new_res_state_dict[k] = v
            res_model.load_state_dict(new_res_state_dict)
        
        # Load ema_res_model state_dict
        ema_state_dict = checkpoint['ema_res_model']
        try:
            ema_res_model.load_state_dict(ema_state_dict)
        except RuntimeError:
            new_ema_state_dict = OrderedDict()
            for k, v in ema_state_dict.items():
                if k.startswith('module.'):
                    k = k[7:]  # remove 'module.' prefix
                new_ema_state_dict[k] = v
            ema_res_model.load_state_dict(new_ema_state_dict)
        
        optim.load_state_dict(checkpoint['optim'])
        sched.load_state_dict(checkpoint['sched'])
        global_step = checkpoint['step']
        global_step=global_step/4
        
        # Calculate and print the current and remaining steps/epochs
        current_epoch = global_step // steps_per_epoch
        remaining_steps = total_steps - global_step
        remaining_epochs = num_epochs - current_epoch
        
        if rank_int == 0:
            print(f"Resuming training from step {global_step} (Epoch {current_epoch})")
            print(f"Total steps to run: {total_steps} (for {num_epochs} epochs)")
            print(f"Remaining steps: {remaining_steps}")
            print(f"Remaining epochs: {remaining_epochs}")
        
    else:
        global_step = 0  # to keep track of the global step in training loop
        if rank_int == 0:
            print(f"Starting training from scratch. Total steps to run: {total_steps} (for {num_epochs} epochs)")
    # Initialize list to store training losses (with periodic saving to avoid unbounded growth)
    training_losses = []
    # Initialize list to store FID scores (with immediate saving to avoid unbounded growth)
    fid_scores = []
    # Initialize list to store global steps where FID was computed (with immediate saving)
    fid_steps = []

    with trange(num_epochs, dynamic_ncols=True) as epoch_pbar:
        for epoch in epoch_pbar:
            epoch_pbar.set_description(f"Epoch {epoch + 1}/{num_epochs}")
            if sampler is not None:
                sampler.set_epoch(epoch)

            with trange(steps_per_epoch, dynamic_ncols=True) as step_pbar:
                for step in step_pbar:

                    # import pdb
                    # pdb.set_trace()

                    if global_step >= total_steps:
                        break
                    global_step += 1

                    optim.zero_grad()
                    t_span = torch.linspace(0, 1, 2).to(rank)
                    
                    x1 = next(datalooper).to(rank)
                    x0 = torch.randn_like(x1)
                    x0, x1 = FM.ot_sampler.sample_plan(x0, x1)
                    
                    # Two-stage generation: pretrained_model -> res_model
                    with torch.no_grad():  # Don't track gradients for pretrained model
                        # Extract underlying module from DDP wrapper for NeuralODE
                        if hasattr(pretrained_model, 'module'):
                            pretrained_actual_model = pretrained_model.module
                        else:
                            pretrained_actual_model = pretrained_model
                            
                        pretrained_node = NeuralODE(
                            pretrained_actual_model,
                            sensitivity="adjoint"
                        )
                        _, pretrained_traj = pretrained_node(x0, t_span)
                        intermediate_x1 = pretrained_traj[-1]
                    
                    # Use residual model to generate final solution
                    # Extract underlying module from DDP wrapper for NeuralODE
                    if hasattr(res_model, 'module'):
                        res_actual_model = res_model.module
                    else:
                        res_actual_model = res_model
                        
                    res_node = NeuralODE(
                        res_actual_model,
                        sensitivity="adjoint"
                    )
                    _, res_traj = res_node(intermediate_x1, t_span)
                    pred_x1 = res_traj[-1]
                    
                    # Compute loss between final prediction and target
                    loss = torch.mean((pred_x1 - x1) ** 2)
                    loss.backward()  # compute gradients
                    
                    torch.nn.utils.clip_grad_norm_(res_model.parameters(), FLAGS.grad_clip)
                    optim.step()  # update weights
                    sched.step()
                    ema(res_model, ema_res_model, FLAGS.ema_decay)
                    
                    # Explicitly delete intermediate tensors to free memory
                    del pretrained_traj, intermediate_x1, res_traj, pred_x1
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
                    
                    # Force garbage collection
                    gc.collect()
                    
                    # Record training loss and update progress bar
                    if rank_int == 0:
                        training_losses.append(loss.item())
                        # Log loss to TensorBoard
                        writer.add_scalar('Loss/train', loss, global_step)
                        # Update progress bar with current loss value
                        step_pbar.set_postfix(loss=f"{loss:.4f}")
                        
                        # Save training losses incrementally every 1000 steps and manage memory
                        if global_step % 10 == 0:
                            losses_file = savedir + f"{FLAGS.model}_training_losses.json"
                            
                            # First, read existing losses if file exists to maintain complete history
                            all_losses = []
                            if os.path.exists(losses_file):
                                try:
                                    with open(losses_file, 'r') as f:
                                        all_losses = json.load(f)
                                except:
                                    all_losses = []
                            
                            # Only append new losses since last save to avoid duplication
                            if all_losses:
                                # Find the index where new losses start
                                start_index = len(all_losses)
                                new_losses = training_losses[start_index:]
                                all_losses.extend(new_losses)
                            else:
                                # First save, include all losses
                                all_losses = training_losses.copy()
                            
                            # Save the complete history
                            with open(losses_file, 'w') as f:
                                json.dump(all_losses, f)
                            print(f"Training losses saved incrementally to {losses_file}")
                            
                            # Limit the size of in-memory training_losses list to prevent unbounded growth
                            # But keep enough for the next save cycle (1000 steps)
                            if len(training_losses) > 60000:
                                training_losses = training_losses[-60000:]
                            
                            # Force garbage collection and clear CUDA cache
                            gc.collect()
                            if torch.cuda.is_available():
                                torch.cuda.empty_cache()
                    
                    if rank_int == 0:
                        # sample and Saving the weights
                        if FLAGS.save_step > 0 and global_step % FLAGS.save_step == 0:
                            generate_samples(
                                res_model, FLAGS.parallel, savedir, global_step, net_="normal"
                            )
                            generate_samples(
                                ema_res_model, FLAGS.parallel, savedir, global_step, net_="ema"
                            )
                            torch.save(
                                {
                                    "res_model": res_model.state_dict(),
                                    "ema_res_model": ema_res_model.state_dict(),
                                    "pretrained_model": pretrained_model.state_dict(),
                                    "sched": sched.state_dict(),
                                    "optim": optim.state_dict(),
                                    "step": global_step,
                                    "res_model_type": FLAGS.res_model_type,  # Save residual model architecture type
                                },
                                savedir + f"{FLAGS.model}_cifar10_weights_step_{global_step}.pt",
                            )

            if global_step >= total_steps:
                break
            
            # Compute FID at the end of epoch if enabled
            if rank_int == 0 and FLAGS.fid_step > 0 and (epoch + 1) % FLAGS.fid_step == 0:
                print(f"Computing FID at epoch {epoch + 1}")
                try:
                    # Get device for FID computation
                    if isinstance(rank, int):
                        device = torch.device(f"cuda:{rank}")
                    else:
                        device = rank
                    
                    # Compute FID for EMA model
                    fid_score = compute_fid_score(
                        ema_res_model, device, FLAGS.integration_steps, 
                        FLAGS.integration_method, FLAGS.tol, 
                        FLAGS.batch_size_fid, FLAGS.num_gen
                    )
                    
                    # Store FID score and global step
                    fid_scores.append(fid_score)
                    fid_steps.append(global_step)
                    
                    # Save FID scores after each computation
                    fid_data = {
                        "steps": fid_steps,
                        "fid_scores": fid_scores
                    }
                    fid_file = savedir + f"{FLAGS.model}_fid_scores.json"
                    with open(fid_file, 'w') as f:
                        json.dump(fid_data, f)
                    print(f"FID scores saved to {fid_file}")
                    
                    # Log FID score to TensorBoard
                    writer.add_scalar('FID/score', fid_score, epoch + 1)
                    writer.add_scalar('FID/score_vs_steps', fid_score, global_step)
                    print(f"FID score {fid_score} logged to TensorBoard at epoch {epoch + 1}")
                    
                    # Clear memory after FID computation to prevent ODE integration leaks
                    gc.collect()
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
                    print("Memory cleared after FID computation")
                    
                except Exception as e:
                    print(f"Error computing FID at epoch {epoch + 1}: {e}")
                    # Continue training even if FID computation fails
                    # Still clear memory in case of error
                    gc.collect()
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
    
    # After training completes, save training losses and plot curve
    if rank_int == 0 and training_losses:
        # Save training losses to JSON file
        losses_file = savedir + f"{FLAGS.model}_training_losses.json"
        with open(losses_file, 'w') as f:
            json.dump(training_losses, f)
        print(f"Training losses saved to {losses_file}")
        
        # Plot training loss curve
        plt.figure(figsize=(10, 6))
        plt.plot(training_losses)
        plt.title(f'Training Loss Curve - {FLAGS.model}')
        plt.xlabel('Training Steps')
        plt.ylabel('Loss')
        plt.grid(True)
        
        # Save the plot
        plot_file = savedir + f"{FLAGS.model}_training_loss_curve.png"
        plt.savefig(plot_file, dpi=300, bbox_inches='tight')
        print(f"Training loss plot saved to {plot_file}")
        
        # Close the figure to free memory
        plt.close()
    
    # Close TensorBoard writer
    if rank_int == 0:
        writer.close()
        print("TensorBoard writer closed.")
        
        # Also save losses as numpy array for easier loading
        np_file = savedir + f"{FLAGS.model}_training_losses.npy"
        np.save(np_file, np.array(training_losses))
        print(f"Training losses saved as numpy array to {np_file}")

    # After training completes, save and plot FID scores if computed
    if rank_int == 0 and fid_scores:
        # Save FID scores to JSON file
        fid_data = {
            "steps": fid_steps,
            "fid_scores": fid_scores
        }
        fid_file = savedir + f"{FLAGS.model}_fid_scores.json"
        with open(fid_file, 'w') as f:
            json.dump(fid_data, f)
        print(f"FID scores saved to {fid_file}")
        
        # Save FID scores as numpy array for easier loading
        fid_np_file = savedir + f"{FLAGS.model}_fid_scores.npy"
        np.save(fid_np_file, np.array(fid_scores))
        print(f"FID scores saved as numpy array to {fid_np_file}")
        
        # Save steps as numpy array
        steps_np_file = savedir + f"{FLAGS.model}_fid_steps.npy"
        np.save(steps_np_file, np.array(fid_steps))
        print(f"FID steps saved as numpy array to {steps_np_file}")
        
        # Plot FID curve
        plt.figure(figsize=(10, 6))
        plt.plot(fid_steps, fid_scores, 'o-')
        plt.title(f'FID Score Curve - {FLAGS.model}')
        plt.xlabel('Training Steps')
        plt.ylabel('FID Score')
        plt.grid(True)
        
        # Save the plot
        fid_plot_file = savedir + f"{FLAGS.model}_fid_curve.png"
        plt.savefig(fid_plot_file, dpi=300, bbox_inches='tight')
        print(f"FID curve plot saved to {fid_plot_file}")
        
        # Close the figure to free memory
        plt.close()


def main(argv):
    # get world size (number of GPUs)
    total_num_gpus = int(os.getenv("WORLD_SIZE", 1))

    if FLAGS.parallel and total_num_gpus > 1:
        train(rank=int(os.getenv("RANK", 0)), total_num_gpus=total_num_gpus, argv=argv)
    else:
        use_cuda = torch.cuda.is_available()
        device = torch.device("cuda" if use_cuda else "cpu")
        train(rank=device, total_num_gpus=total_num_gpus, argv=argv)


if __name__ == "__main__":
    app.run(main)
