# from my_train_cifar10_ddp.py to read checkpoint and continue MLE training

import copy
import math
import os
import json
import gc
import numpy as np
import matplotlib.pyplot as plt
import shutil

import torch
from absl import app, flags
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DistributedSampler
from torchdyn.core import NeuralODE
from torchvision import datasets, transforms
from tqdm import trange
from utils_cifar import ema, generate_samples, infiniteloop, setup
from torch.utils.tensorboard import SummaryWriter

from torchcfm.conditional_flow_matching import (
    ConditionalFlowMatcher,
    ExactOptimalTransportConditionalFlowMatcher,
    TargetConditionalFlowMatcher,
    VariancePreservingConditionalFlowMatcher,
)
from torchcfm.models.unet.unet import UNetModelWrapper
from cleanfid import fid
# from torchdiffeq import odeint
from torchdiffeq import odeint_adjoint as odeint

import torch.nn as nn
class VectorFieldWrapper(nn.Module):
    """Wrapper to ensure model receives both t and x arguments when used with DDP."""
    def __init__(self, model):
        super().__init__()
        self.model = model
        
    def forward(self, t, x):
        return self.model(t, x)

FLAGS = flags.FLAGS

flags.DEFINE_string("model", "otcfm", help="flow matching model type")
flags.DEFINE_string("output_dir", "./results_MLE_4GPU/", help="output_directory")
# flags.DEFINE_string("checkpoint_path", "./results3/otcfm/otcfm_cifar10_weights_step_42780000.pt", "path to checkpoint")
flags.DEFINE_string("checkpoint_path", "./results_1GPU/otcfm/otcfm_cifar10_weights_step_20000.pt", "path to checkpoint")

# UNet
flags.DEFINE_integer("num_channel", 128, help="base channel of UNet")

# Training
flags.DEFINE_float("lr", 5e-5, help="target learning rate")  # TRY 2e-4
flags.DEFINE_float("grad_clip", 1.0, help="gradient norm clipping")
flags.DEFINE_integer(
    "num_epochs", 3000001, help="total training epochs" # 400001
)
flags.DEFINE_integer("warmup", 10, help="learning rate warmup")
flags.DEFINE_integer("batch_size", 128, help="batch size")  # Lipman et al uses 128
flags.DEFINE_integer("num_workers", 24, help="workers of Dataloader")
flags.DEFINE_float("ema_decay", 0.9999, help="ema decay rate")
flags.DEFINE_bool("parallel", True, help="multi gpu training")
flags.DEFINE_string(
    "master_addr", "localhost", help="master address for Distributed Data Parallel"
)
flags.DEFINE_string("master_port", "29500", help="master port for Distributed Data Parallel")

# Evaluation
flags.DEFINE_integer(
    "save_step",
    10, # 10 20000
    help="frequency of saving checkpoints, 0 to disable during training",
)
flags.DEFINE_integer(
    "save_step_loss",
    10, # frequency of saving loss values
    help="frequency of saving loss values, 0 to disable during training",
)
flags.DEFINE_integer(
    "fid_step",
    0,
    help="frequency of computing FID (in epochs), 0 to disable during training",
)
# FID computation parameters
flags.DEFINE_integer("integration_steps", 100, help="number of inference steps for FID")
flags.DEFINE_string("integration_method", "dopri5", help="integration method to use for FID")
flags.DEFINE_integer("num_gen", 50000, help="number of samples to generate for FID")
flags.DEFINE_float("tol", 1e-5, help="Integrator tolerance for FID (absolute and relative)")
flags.DEFINE_integer("batch_size_fid", 1024, help="Batch size to compute FID")
# -----------------------------------------------------------------------------
# import argparse
# parser = argparse.ArgumentParser()
# parser.add_argument("--local_rank", default=-1)
# FLAGS = parser.parse_args()
# local_rank = FLAGS.local_rank

# -----------------------------------------------------------------------------
def warmup_lr(step):
    return min(step, FLAGS.warmup) / FLAGS.warmup


def compute_fid_score(model, device, integration_steps, integration_method, tol, batch_size_fid, num_gen):
    """Compute FID score for the given model."""
    # Handle DDP model by getting the underlying module
    if hasattr(model, 'module'):
        net = model.module
    else:
        net = model
        
    net.eval()

    integration_method = "dopri5"
    
    # Define the integration method if euler is used
    if integration_method == "euler":
        node = NeuralODE(net, solver=integration_method)

    def gen_1_img(unused_latent):
        with torch.no_grad():
            x = torch.randn(batch_size_fid, 3, 32, 32, device=device)
            if integration_method == "euler":
                t_span = torch.linspace(0, 1, integration_steps + 1, device=device)
                traj = node.trajectory(x, t_span=t_span)
            else:
                t_span = torch.linspace(0, 1, 2, device=device)
                traj = odeint(
                    net, x, t_span, rtol=tol, atol=tol, method=integration_method
                )
        traj = traj[-1, :]  # .view([-1, 3, 32, 32]).clip(-1, 1)
        img = (traj * 127.5 + 128).clip(0, 255).to(torch.uint8)  # .permute(1, 2, 0)
        return img

    print("Start computing FID")
    score = fid.compute_fid(
        gen=gen_1_img,
        dataset_name="cifar10",
        batch_size=batch_size_fid,
        dataset_res=32,
        num_gen=num_gen,
        dataset_split="train",
        mode="legacy_tensorflow",
    )
    print(f"FID: {score}")
    
    # Set model back to training mode
    net.train()
    return score


def train(rank, total_num_gpus, argv):
    # Determine the integer rank for logging purposes
    # For single GPU, rank is a device, so we set rank_int to 0
    # For multi-GPU, rank is an integer, so we use it directly
    if isinstance(rank, int):
        rank_int = rank
    else:
        rank_int = 0  # Single GPU case
        
    if FLAGS.parallel and total_num_gpus > 1:
        # Each GPU uses the same batch size, so the global batch size scales with the number of GPUs.
        batch_size_per_gpu = FLAGS.batch_size
        setup(rank, total_num_gpus, FLAGS.master_addr, FLAGS.master_port)
    else:
        batch_size_per_gpu = FLAGS.batch_size
        # For single GPU with parallel flag, don't use distributed sampler
        if FLAGS.parallel:
            FLAGS.parallel = False

    # DATASETS/DATALOADER
    dataset = datasets.CIFAR10(
        root="./data",
        train=True,
        download=True,
        transform=transforms.Compose(
            [
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        ),
    )
    # import torch.distributed as dist
    # dist.init_process_group(backend='nccl')

    sampler = DistributedSampler(dataset) if FLAGS.parallel else None
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size_per_gpu,
        sampler=sampler,
        shuffle=False if FLAGS.parallel else True,
        num_workers=FLAGS.num_workers,
        drop_last=True,
    )

    datalooper = infiniteloop(dataloader)

    # Calculate number of steps per epoch for each GPU
    # The dataset is split across GPUs by DistributedSampler
    steps_per_epoch = math.ceil(len(dataset) / (FLAGS.batch_size * total_num_gpus))
    num_epochs = FLAGS.num_epochs
    total_steps = num_epochs * steps_per_epoch
    if rank_int == 0:
        print(
            "lr, total_steps, ema decay, save_step:",
            FLAGS.lr,
            total_steps,
            FLAGS.ema_decay,
            FLAGS.save_step,
        )

    # MODELS
    net_model = UNetModelWrapper(
        dim=(3, 32, 32),
        num_res_blocks=2,
        num_channels=FLAGS.num_channel,
        channel_mult=[1, 2, 2, 2],
        num_heads=4,
        num_head_channels=64,
        attention_resolutions="16",
        dropout=0.1,
    ).to(
        rank
    )  # new dropout + bs of 128

    ema_model = copy.deepcopy(net_model)
    # Original Adam optimizer (commented out)
    optim = torch.optim.Adam(net_model.parameters(), lr=FLAGS.lr)
    
    # # Try SGD optimizer instead
    # optim = torch.optim.SGD(net_model.parameters(), lr=0, momentum=0.9)
    sched = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=warmup_lr)
    if FLAGS.parallel:
        net_model = DistributedDataParallel(net_model, device_ids=[rank])
        ema_model = DistributedDataParallel(ema_model, device_ids=[rank])

    # show model size
    model_size = 0
    for param in net_model.parameters():
        model_size += param.data.nelement()
    print("Model params: %.2f M" % (model_size / 1024 / 1024))

    #################################
    #            OT-CFM
    #################################

    sigma = 0.0
    if FLAGS.model == "otcfm":
        FM = ExactOptimalTransportConditionalFlowMatcher(sigma=sigma)
    elif FLAGS.model == "icfm":
        FM = ConditionalFlowMatcher(sigma=sigma)
    elif FLAGS.model == "fm":
        FM = TargetConditionalFlowMatcher(sigma=sigma)
    elif FLAGS.model == "si":
        FM = VariancePreservingConditionalFlowMatcher(sigma=sigma)
    else:
        raise NotImplementedError(
            f"Unknown model {FLAGS.model}, must be one of ['otcfm', 'icfm', 'fm', 'si']"
        )

    savedir = FLAGS.output_dir + FLAGS.model + "/"
    if rank_int == 0:
        if os.path.exists(savedir):
            shutil.rmtree(savedir)
        os.makedirs(savedir)
    
    # Initialize TensorBoard writer
    if rank_int == 0:
        log_dir = savedir + "tensorboard_logs/"
        os.makedirs(log_dir, exist_ok=True)
        writer = SummaryWriter(log_dir=log_dir)
        print(f"TensorBoard logs will be saved to: {log_dir}")

    if FLAGS.checkpoint_path:
        if isinstance(rank, int):
            map_location = f"cuda:{rank}"
        else:
            map_location = rank
        checkpoint = torch.load(FLAGS.checkpoint_path, map_location=map_location)
        
        # Handle DDP state_dict loading (add "module." prefix if needed)
        from collections import OrderedDict
        
        # Load net_model state_dict
        net_state_dict = checkpoint['net_model']
        try:
            net_model.load_state_dict(net_state_dict)
        except RuntimeError:
            new_net_state_dict = OrderedDict()
            for k, v in net_state_dict.items():
                if not k.startswith('module.'):
                    k = 'module.' + k  # add 'module.' prefix
                new_net_state_dict[k] = v
            net_model.load_state_dict(new_net_state_dict)
        
        # Load ema_model state_dict
        ema_state_dict = checkpoint['ema_model']
        try:
            ema_model.load_state_dict(ema_state_dict)
        except RuntimeError:
            new_ema_state_dict = OrderedDict()
            for k, v in ema_state_dict.items():
                if not k.startswith('module.'):
                    k = 'module.' + k  # add 'module.' prefix
                new_ema_state_dict[k] = v
            ema_model.load_state_dict(new_ema_state_dict)
        
        # Skip loading optimizer state since we changed optimizer type
        # optim.load_state_dict(checkpoint['optim'])
        # sched.load_state_dict(checkpoint['sched'])
        global_step = checkpoint['step']
        # global_step=global_step/4
        
        # Calculate and print the current and remaining steps/epochs
        current_epoch = global_step // steps_per_epoch
        remaining_steps = total_steps - global_step
        remaining_epochs = num_epochs - current_epoch
        
        if rank_int == 0:
            print(f"Resuming training from step {global_step} (Epoch {current_epoch})")
            print(f"Total steps to run: {total_steps} (for {num_epochs} epochs)")
            print(f"Remaining steps: {remaining_steps}")
            print(f"Remaining epochs: {remaining_epochs}")
        
    else:
        global_step = 0  # to keep track of the global step in training loop
        if rank_int == 0:
            print(f"Starting training from scratch. Total steps to run: {total_steps} (for {num_epochs} epochs)")
    # Initialize list to store training losses (with periodic saving to avoid unbounded growth)
    training_losses = []
    # Initialize list to store FID scores (with immediate saving to avoid unbounded growth)
    fid_scores = []
    # Initialize list to store global steps where FID was computed (with immediate saving)
    fid_steps = []

    # Generate samples and save model before training starts
    if rank_int == 0:
        print("Generating samples and saving model before training starts...")
        generate_samples(
            net_model, FLAGS.parallel, savedir, global_step, net_="normal"
        )
        generate_samples(
            ema_model, FLAGS.parallel, savedir, global_step, net_="ema"
        )
        torch.save(
            {
                "net_model": net_model.state_dict(),
                "ema_model": ema_model.state_dict(),
                "sched": sched.state_dict(),
                "optim": optim.state_dict(),
                "step": global_step,
            },
            savedir + f"{FLAGS.model}_cifar10_weights_step_{global_step}.pt",
        )
        print("Initial samples generated and model saved before training.")

    with trange(num_epochs, dynamic_ncols=True) as epoch_pbar:
        for epoch in epoch_pbar:
            epoch_pbar.set_description(f"Epoch {epoch + 1}/{num_epochs}")
            if sampler is not None:
                sampler.set_epoch(epoch)

            with trange(steps_per_epoch, dynamic_ncols=True) as step_pbar:
                for step in step_pbar:
                    if global_step >= total_steps:
                        break
                    global_step += 1

                    optim.zero_grad()
                    t_span = torch.linspace(0, 1, 2).to(rank)
                    
                    x1 = next(datalooper).to(rank)
                    x0 = torch.randn_like(x1)
                    if FLAGS.model != "icfm":
                        x0, x1 = FM.ot_sampler.sample_plan(x0, x1)
                    
                    # MLE training: use NeuralODE to integrate from x0 to x1
                    # Wrap the model to ensure it receives both t and x arguments
                    # wrapped_model = VectorFieldWrapper(net_model)
                    # node = NeuralODE(
                    #     net_model,
                    #     solver=FLAGS.integration_method,
                    #     sensitivity="adjoint",
                    # )
                    # # Forward solve to get predicted x1
                    # _, traj = node(x0, t_span)

                    # traj = odeint(
                    #     net_model,
                    #     x0,
                    #     t_span,
                    #     method="euler"  # or FLAGS.integration_method
                    #     # sensitivity="adjoint"
                    #     # rtol=1e-5,
                    #     # atol=1e-5
                    # )
                    # pred_x1 = traj[-1]

                    # Split into 'n_chunks' chunks to reduce memory usage
                    n_chunks=1
                    x0_chunks = torch.chunk(x0, n_chunks, dim=0)
                    x1_chunks = torch.chunk(x1, n_chunks, dim=0)

                    

                    total_loss = 0.0
                    
                    # Process each chunk separately and update weights after each chunk
                    for i, (x0_chunk, x1_chunk) in enumerate(zip(x0_chunks, x1_chunks)):
                        optim.zero_grad()  # Clear gradients before each chunk
                        
                        # Use fixed-step Euler method to save memory
                        # traj_chunk = odeint(
                        #     net_model,
                        #     x0_chunk,
                        #     t_span,
                        #     method="euler"  # Use fixed-step Euler method instead of adaptive
                        # )

                        # Extract underlying module from DDP wrapper for NeuralODE
                        if hasattr(net_model, 'module'):
                            actual_model = net_model.module
                        else:
                            actual_model = net_model
                            
                        node = NeuralODE(
                            actual_model,
                            # solver=self.hparams.mle_finetune["solver"],
                            sensitivity="adjoint"
                        )
                        _, traj_chunk = node(x0_chunk, t_span)


                        pred_x1_chunk = traj_chunk[-1]
                        loss_chunk = torch.mean((pred_x1_chunk - x1_chunk) ** 2)
                        loss_chunk.backward()  # compute gradients
                        
                        torch.nn.utils.clip_grad_norm_(net_model.parameters(), FLAGS.grad_clip)
                        optim.step()  # update weights after each chunk
                        # sched.step()
                        ema(net_model, ema_model, FLAGS.ema_decay)
                        
                        total_loss += loss_chunk.item()
                        
                        # Explicitly delete intermediate tensors to free memory
                        del traj_chunk, pred_x1_chunk, loss_chunk
                        if torch.cuda.is_available():
                            torch.cuda.empty_cache()
                        
                        # Force garbage collection after each chunk
                        gc.collect()
                    
                    loss = total_loss / n_chunks  # average loss for logging
                    
                    # Record training loss and update progress bar
                    if rank_int == 0:
                        training_losses.append(loss)
                        # Log loss to TensorBoard
                        writer.add_scalar('Loss/train', loss, global_step)
                        # Update progress bar with current loss value
                        step_pbar.set_postfix(loss=f"{loss:.4f}")
                        
                        # Save training losses incrementally based on save_step_loss and manage memory
                        if FLAGS.save_step_loss > 0 and global_step % FLAGS.save_step_loss == 0:
                            losses_file = savedir + f"{FLAGS.model}_training_losses.json"
                            
                            # First, read existing losses if file exists to maintain complete history
                            all_losses = []
                            if os.path.exists(losses_file):
                                try:
                                    with open(losses_file, 'r') as f:
                                        all_losses = json.load(f)
                                except:
                                    all_losses = []
                            
                            # Only append new losses since last save to avoid duplication
                            if all_losses:
                                # Find the index where new losses start
                                start_index = len(all_losses)
                                new_losses = training_losses[start_index:]
                                all_losses.extend(new_losses)
                            else:
                                # First save, include all losses
                                all_losses = training_losses.copy()
                            
                            # Save the complete history
                            with open(losses_file, 'w') as f:
                                json.dump(all_losses, f)
                            print(f"Training losses saved incrementally to {losses_file}")
                            
                            # Limit the size of in-memory training_losses list to prevent unbounded growth
                            # But keep enough for the next save cycle (1000 steps)
                            if len(training_losses) > 60000:
                                training_losses = training_losses[-60000:]
                            
                            # Force garbage collection and clear CUDA cache
                            gc.collect()
                            if torch.cuda.is_available():
                                torch.cuda.empty_cache()
                    
                    if rank_int == 0:
                        # sample and Saving the weights
                        if FLAGS.save_step > 0 and global_step % FLAGS.save_step == 0:
                            generate_samples(
                                net_model, FLAGS.parallel, savedir, global_step, net_="normal"
                            )
                            generate_samples(
                                ema_model, FLAGS.parallel, savedir, global_step, net_="ema"
                            )
                            torch.save(
                                {
                                    "net_model": net_model.state_dict(),
                                    "ema_model": ema_model.state_dict(),
                                    "sched": sched.state_dict(),
                                    "optim": optim.state_dict(),
                                    "step": global_step,
                                },
                                savedir + f"{FLAGS.model}_cifar10_weights_step_{global_step}.pt",
                            )

            if global_step >= total_steps:
                break
            
            # Compute FID at the end of epoch if enabled
            if rank_int == 0 and FLAGS.fid_step > 0 and (epoch + 1) % FLAGS.fid_step == 0:
                print(f"Computing FID at epoch {epoch + 1}")
                try:
                    # Get device for FID computation
                    if isinstance(rank, int):
                        device = torch.device(f"cuda:{rank}")
                    else:
                        device = rank
                    
                    # Compute FID for EMA model
                    fid_score = compute_fid_score(
                        ema_model, device, FLAGS.integration_steps, 
                        FLAGS.integration_method, FLAGS.tol, 
                        FLAGS.batch_size_fid, FLAGS.num_gen
                    )
                    
                    # Store FID score and global step
                    fid_scores.append(fid_score)
                    fid_steps.append(global_step)
                    
                    # Save FID scores after each computation
                    fid_data = {
                        "steps": fid_steps,
                        "fid_scores": fid_scores
                    }
                    fid_file = savedir + f"{FLAGS.model}_fid_scores.json"
                    with open(fid_file, 'w') as f:
                        json.dump(fid_data, f)
                    print(f"FID scores saved to {fid_file}")
                    
                    # Log FID score to TensorBoard
                    writer.add_scalar('FID/score', fid_score, epoch + 1)
                    writer.add_scalar('FID/score_vs_steps', fid_score, global_step)
                    print(f"FID score {fid_score} logged to TensorBoard at epoch {epoch + 1}")
                    
                    # Clear memory after FID computation to prevent ODE integration leaks
                    gc.collect()
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
                    print("Memory cleared after FID computation")
                    
                except Exception as e:
                    print(f"Error computing FID at epoch {epoch + 1}: {e}")
                    # Continue training even if FID computation fails
                    # Still clear memory in case of error
                    gc.collect()
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
    
    # After training completes, save training losses and plot curve
    if rank_int == 0 and training_losses:
        # Save training losses to JSON file
        losses_file = savedir + f"{FLAGS.model}_training_losses.json"
        with open(losses_file, 'w') as f:
            json.dump(training_losses, f)
        print(f"Training losses saved to {losses_file}")
        
        # Plot training loss curve
        plt.figure(figsize=(10, 6))
        plt.plot(training_losses)
        plt.title(f'Training Loss Curve - {FLAGS.model}')
        plt.xlabel('Training Steps')
        plt.ylabel('Loss')
        plt.grid(True)
        
        # Save the plot
        plot_file = savedir + f"{FLAGS.model}_training_loss_curve.png"
        plt.savefig(plot_file, dpi=300, bbox_inches='tight')
        print(f"Training loss plot saved to {plot_file}")
        
        # Close the figure to free memory
        plt.close()
    
    # Close TensorBoard writer
    if rank_int == 0:
        writer.close()
        print("TensorBoard writer closed.")
        
        # Also save losses as numpy array for easier loading
        np_file = savedir + f"{FLAGS.model}_training_losses.npy"
        np.save(np_file, np.array(training_losses))
        print(f"Training losses saved as numpy array to {np_file}")

    # After training completes, save and plot FID scores if computed
    if rank_int == 0 and fid_scores:
        # Save FID scores to JSON file
        fid_data = {
            "steps": fid_steps,
            "fid_scores": fid_scores
        }
        fid_file = savedir + f"{FLAGS.model}_fid_scores.json"
        with open(fid_file, 'w') as f:
            json.dump(fid_data, f)
        print(f"FID scores saved to {fid_file}")
        
        # Save FID scores as numpy array for easier loading
        fid_np_file = savedir + f"{FLAGS.model}_fid_scores.npy"
        np.save(fid_np_file, np.array(fid_scores))
        print(f"FID scores saved as numpy array to {fid_np_file}")
        
        # Save steps as numpy array
        steps_np_file = savedir + f"{FLAGS.model}_fid_steps.npy"
        np.save(steps_np_file, np.array(fid_steps))
        print(f"FID steps saved as numpy array to {steps_np_file}")
        
        # Plot FID curve
        plt.figure(figsize=(10, 6))
        plt.plot(fid_steps, fid_scores, 'o-')
        plt.title(f'FID Score Curve - {FLAGS.model}')
        plt.xlabel('Training Steps')
        plt.ylabel('FID Score')
        plt.grid(True)
        
        # Save the plot
        fid_plot_file = savedir + f"{FLAGS.model}_fid_curve.png"
        plt.savefig(fid_plot_file, dpi=300, bbox_inches='tight')
        print(f"FID curve plot saved to {fid_plot_file}")
        
        # Close the figure to free memory
        plt.close()


def main(argv):
    # get world size (number of GPUs)
    total_num_gpus = int(os.getenv("WORLD_SIZE", 1))

    if FLAGS.parallel and total_num_gpus > 1:
        train(rank=int(os.getenv("RANK", 0)), total_num_gpus=total_num_gpus, argv=argv)
    else:
        use_cuda = torch.cuda.is_available()
        device = torch.device("cuda" if use_cuda else "cpu")
        train(rank=device, total_num_gpus=total_num_gpus, argv=argv)


if __name__ == "__main__":
    app.run(main)
