# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

"""
A minimal training script for DiT using PyTorch DDP.
"""
import torch
import wandb
# the first flag below was False when we tested this script but True makes A100 training a lot faster:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import numpy as np
from collections import OrderedDict
from PIL import Image
from copy import deepcopy
from glob import glob
from time import time
import logging
import os

from utility.parse_args import parse_args
from dataset.md_dataset import get_dataloader
from model.train_epoch import train_epoch, test
from model.get_model import get_model, get_optim

from utility import functions as uf
torch.multiprocessing.set_sharing_strategy('file_system')
#################################################################################
#                             Training Helper Functions                         #
#################################################################################

@torch.no_grad()
def update_ema(ema_model, model, decay=0.9999):
    """
    Step the EMA model towards the current model.
    """
    ema_params = OrderedDict(ema_model.named_parameters())
    model_params = OrderedDict(model.named_parameters())

    for name, param in model_params.items():
        # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
        ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)


def requires_grad(model, flag=True):
    """
    Set requires_grad flag for all parameters in a model.
    """
    for p in model.parameters():
        p.requires_grad = flag


def cleanup():
    """
    End DDP training.
    """
    dist.destroy_process_group()


def create_logger(logging_dir):
    """
    Create a logger that writes to a log file and stdout.
    """
    if dist.get_rank() == 0:  # real logger
        logging.basicConfig(
            level=logging.INFO,
            format='[\033[34m%(asctime)s\033[0m] %(message)s',
            datefmt='%Y-%m-%d %H:%M:%S',
            handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
        )
        logger = logging.getLogger(__name__)
    else:  # dummy logger (does nothing)
        logger = logging.getLogger(__name__)
        logger.addHandler(logging.NullHandler())
    return logger

#################################################################################
#                                  Training Loop                                #
#################################################################################

def main(args):
    """
    Trains a new DiT model.
    """
    assert torch.cuda.is_available(), "Training currently requires at least one GPU."
    dtype = torch.float32
    # Setup DDP:
    dist.init_process_group("nccl")
    assert args.global_batch_size % dist.get_world_size() == 0, f"Batch size must be divisible by world size."
    rank = dist.get_rank()
    device = rank % torch.cuda.device_count()
    seed = args.global_seed * dist.get_world_size() + rank
    torch.manual_seed(seed)
    torch.cuda.set_device(device)
    print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")

    if args.no_wandb:
        mode = 'disabled'
    else:
        mode = 'online' if args.online else 'offline'
    kwargs = {'entity': args.wandb_usr, 'name': args.exp_name, 'project': 'ood_diffusion_mg', 'config': args,
              'settings': wandb.Settings(_disable_stats=True), 'reinit': True, 'mode': mode}
    wandb.init(**kwargs)
    wandb.save('*.txt')
    # Setup an experiment folder:
    if rank == 0:
        os.makedirs(args.results_dir, exist_ok=True)  # Make results folder (holds all experiment subfolders)
        experiment_index = len(glob(f"{args.results_dir}/*"))
        # model_string_name = args.model.replace("/", "-")  # e.g., DiT-XL/2 --> DiT-XL-2 (for naming folders)
        model_string_name = args.exp_name
        experiment_dir = f"{args.results_dir}/{experiment_index:03d}-{model_string_name}" if args.experiment_dir is None else args.experiment_dir
        checkpoint_dir = f"{experiment_dir}/checkpoints"  # Stores saved model checkpoints
        os.makedirs(checkpoint_dir, exist_ok=True)
        args.checkpoint_dir = checkpoint_dir
        logger = create_logger(experiment_dir)
        logger.info(f"Experiment directory created at {experiment_dir}")
    else:
        logger = create_logger(None)

    # Retrieve QM9 dataloaders

    dataloaders, samplers = get_dataloader(args, dist, rank)
    model = get_model(args, device)

    total_params = sum(p.numel() for p in model.parameters())
    total_size = total_params * 4 / (1024 ** 2)  # Assuming float32 (4 bytes)
    print(f'Total parameters: {total_params}')
    print(f'Model size: {total_size:.2f} MB')

    model = model.to(device)
    optim = get_optim(args, model)
    gradnorm_queue = uf.Queue()
    gradnorm_queue.add(3000)

    model_ema = deepcopy(model).to(device)  # Create an EMA of the model for use after training
    ema = uf.EMA(args.ema_decay)

    if rank == 0:
        # Print model info
        print(model)

        # Iterate over each parameter in the model
        for name, param in model.named_parameters():
            # Check if the parameter requires gradient
            requires_graded = param.requires_grad

            # Print the layer name and whether it requires gradient
            if not requires_graded:
                print(f"Layer: {name}, Requires Gradient: {requires_graded}")

    begin_epoch = args.start_epoch
    if args.resume is not None:
        checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage)
        model_check_point = checkpoint["model"]
        model_ema_cp = checkpoint["model_ema"]
        optim_cp = checkpoint["opt"]
        model.load_state_dict(model_check_point)
        model_ema.load_state_dict(model_ema_cp)
        optim.load_state_dict(optim_cp)
        print(f'load from {args.resume}')
        begin_epoch = checkpoint["args"].current_epoch

    requires_grad(model_ema, False)
    model = DDP(model.to(device), device_ids=[rank])

    # Variables for monitoring/logging purposes:
    best_nll_val = 1e8
    best_nll_test = 1e8
    if rank == 0:
        logger.info(f"Training for {args.epochs} epochs...")
    dist.barrier()
    for epoch in range(begin_epoch, args.epochs):
        samplers['train'].set_epoch(epoch)
        if rank == 0:
            logger.info(f"Beginning epoch {epoch}...")
        start_epoch = time()
        if epoch != begin_epoch or args.resume is None:
            train_epoch(args=args, loader=dataloaders['train'], epoch=epoch, model=model,
                        model_ema=model_ema, ema=ema, device=device, dtype=dtype,
                        gradnorm_queue=gradnorm_queue, optim=optim, rank=rank, dist=dist)
        print(f'{rank} finished epoch {epoch}')
        dist.barrier()
        if rank == 0:
            logger.info(f"Epoch {epoch} took {time() - start_epoch:.1f} seconds.")

        if epoch % args.test_epochs == 0:
            samplers['valid'].set_epoch(epoch)
            nll_val = test(args=args, loader=dataloaders['valid'], epoch=epoch, eval_model=model,
                           partition='valid', device=device, dtype=dtype, rank=rank)
            samplers['test'].set_epoch(epoch)
            nll_test = test(args=args, loader=dataloaders['test'], epoch=epoch, eval_model=model,
                            partition='Test', device=device, dtype=dtype, rank=rank)
            dist.barrier()
            if nll_val < best_nll_val:
                best_nll_val = nll_val
                best_nll_test = nll_test
                if args.save_model:
                    args.current_epoch = epoch + 1
                    if rank == 0:
                        checkpoint = {
                            "model": model.module.state_dict(),
                            "model_ema": model_ema.state_dict(),
                            "opt": optim.state_dict(),
                            "args": args
                        }
                        checkpoint_path = f"{checkpoint_dir}/{args.exp_name}.pt"
                        torch.save(checkpoint, checkpoint_path)
                        logger.info(f"Saved checkpoint to {checkpoint_path}")
            if rank == 0:
                logger.info('E: %d Val loss: %.9f \t Test loss:  %.9f' % (epoch, nll_val, nll_test))
                logger.info('E: %d Best val loss: %.9f \t Best test loss:  %.9f (%.5f)' % (epoch, best_nll_val, best_nll_test, best_nll_test * 1000))
                wandb.log({"Val loss ": nll_val}, commit=True)
                wandb.log({"Test loss ": nll_test}, commit=True)
                wandb.log({"Best cross-validated test loss ": best_nll_test}, commit=True)
            dist.barrier()
    logger.info("Done!")
    cleanup()


if __name__ == "__main__":
    args = parse_args()
    main(args)
