# Modified from:
#   fast-DiT: https://github.com/chuanyangjin/fast-DiT/blob/main/train.py
#   nanoGPT: https://github.com/karpathy/nanoGPT/blob/master/model.py
import torch
# the first flag below was False when we tested this script but True makes A100 training a lot faster:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torchvision.utils import make_grid
from huggingface_hub import upload_folder

# from torchmetrics.image.fid import FrechetInceptionDistance as FID

import warnings
warnings.filterwarnings('ignore')

from PIL import Image
from tqdm import tqdm
import ruamel.yaml as yaml

import os
import time
import argparse
from glob import glob
from copy import deepcopy
import sys
import math
import numpy as np

current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.abspath(os.path.join(current_dir, '../..'))
sys.path.append(project_root)
from utils.logger import create_logger
from utils.distributed import init_distributed_mode
from utils.ema import update_ema, requires_grad
from dataset.augmentation import random_crop_arr, center_crop_arr
from dataset.build import build_dataset
from tokenizer.tokenizer_image.xqgan_model import VQ_models
from tokenizer.tokenizer_image.vq_loss import VQLoss, VQLoss_no_disc

from timm.scheduler import create_scheduler_v2 as create_scheduler

# from evaluator import Evaluator
# import tensorflow.compat.v1 as tf

from torcheval.metrics import FrechetInceptionDistance as FID


try:
    import horovod.torch as hvd
except ImportError:
    hvd = None

import warnings
warnings.filterwarnings('ignore')

import wandb
#################################################################################
#                                  Training Loop                                #
#################################################################################
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--data-path", type=str, default='/mnt/localssd/ImageNet2012/train')
    parser.add_argument("--data-face-path", type=str, default=None, help="face datasets to improve vq model")
    parser.add_argument("--cloud-save-path", type=str, default='output/debug', help='please specify a cloud disk path, if not, local path')
    parser.add_argument("--no-local-save", action='store_true', help='no save checkpoints to local path for limited disk volume')
    parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
    parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for resume training")
    parser.add_argument("--finetune", action='store_true', help="finetune a pre-trained vq model")
    parser.add_argument("--ema", action='store_true', help="whether using ema training")
    parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
    parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
    parser.add_argument("--codebook-l2-norm", action='store_true', default=True, help="l2 norm codebook")
    parser.add_argument("--codebook-weight", type=float, default=1.0, help="codebook loss weight for vector quantization")
    parser.add_argument("--entropy-loss-ratio", type=float, default=0.0, help="entropy loss ratio in codebook loss")
    parser.add_argument("--commit-loss-beta", type=float, default=0.25, help="commit loss beta in codebook loss")
    parser.add_argument("--reconstruction-weight", type=float, default=1.0, help="reconstruction loss weight of image pixel")
    parser.add_argument("--reconstruction-loss", type=str, default='l2', help="reconstruction loss type of image pixel")
    parser.add_argument("--perceptual-weight", type=float, default=1.0, help="perceptual loss weight of LPIPS")
    parser.add_argument("--disc-weight", type=float, default=0.5, help="discriminator loss weight for gan training")
    parser.add_argument("--disc-epoch-start", type=int, default=0, help="iteration to start discriminator training and loss")
    parser.add_argument("--disc-start", type=int, default=0, help="iteration to start discriminator training and loss")  # autoset
    parser.add_argument("--disc-type", type=str, choices=['patchgan', 'stylegan', 'dinodisc', 'none'], default='patchgan', help="discriminator type")
    parser.add_argument("--disc-loss", type=str, choices=['hinge', 'vanilla', 'non-saturating'], default='hinge', help="discriminator loss")
    parser.add_argument("--gen-loss", type=str, choices=['hinge', 'non-saturating'], default='hinge', help="generator loss for gan training")
    parser.add_argument("--compile", action='store_true', default=False)
    
    parser.add_argument("--dropout-p", type=float, default=0.0, help="dropout_p") # ??????? drop out for encoder and decoder
    
    parser.add_argument("--results-dir", type=str, default="/ImageFolder/results_tokenizer_image")
    parser.add_argument("--dataset", type=str, default='imagenet')
    parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
    parser.add_argument("--epochs", type=int, default=40)
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--disc_lr", type=float, default=1e-4)
    parser.add_argument("--max_grad_norm", type=float, default=0.0)
    parser.add_argument("--lr_scheduler", type=str, default='none')
    parser.add_argument("--weight-decay", type=float, default=0.0, help="Weight decay to use.") # weight decay for optimizer
    parser.add_argument("--disc-weight-decay", type=float, default=0.0, help="Weight decay to use.") # weight decay for discriminator optimizer

    parser.add_argument("--beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
    parser.add_argument("--beta2", type=float, default=0.95, help="The beta2 parameter for the Adam optimizer.")
    parser.add_argument("--max-grad-norm", default=1.0, type=float, help="Max gradient norm.")
    
    parser.add_argument("--global-batch-size", type=int, default=128)
    parser.add_argument("--global-seed", type=int, default=0)
    parser.add_argument("--num-workers", type=int, default=16)
    parser.add_argument("--log-every", type=int, default=100)
    parser.add_argument("--vis-every", type=int, default=5000)
    
    # parser.add_argument("--ckpt-every", type=int, default=10000)
    
    parser.add_argument("--gradient-accumulation-steps", type=int, default=1)
    parser.add_argument("--mixed-precision", type=str, default='bf16', choices=["none", "fp16", "bf16"])
    parser.add_argument("--save_best",action='store_true', default=False)
    parser.add_argument("--val_data_path", type=str, default="/mnt/localssd/ImageNet2012/val")
    parser.add_argument("--sample_folder_dir", type=str, default='/ImageFolder/results_tokenizer_image/samples')
    parser.add_argument("--reconstruction_folder_dir", type=str, default='/ImageFolder/results_tokenizer_image/reconstruction')
    parser.add_argument("--v-patch-nums", type=int, default=[1, 2, 3, 4, 5, 6, 8, 10, 13, 16], nargs='+',
                        help="number of patch numbers of each scale")
    parser.add_argument("--enc_type", type=str, default="cnn")
    parser.add_argument("--dec_type", type=str, default="cnn")
    parser.add_argument("--semantic_guide", type=str, default="none")
    parser.add_argument("--detail_guide", type=str, default="none")
    parser.add_argument("--num_latent_tokens", type=int, default=256)
    parser.add_argument("--encoder_model", type=str, default='vit_small_patch14_dinov2.lvd142m',
                        help='encoder model name')
    parser.add_argument("--decoder_model", type=str, default='vit_small_patch14_dinov2.lvd142m',
                        help='encoder model name')
    parser.add_argument("--disc_adaptive_weight", type=bool, default=False) # whether use adaptive weight for discriminative loss
    parser.add_argument("--abs_pos_embed", type=bool, default=False) # whether use absolute position embedding (dinov2)
    parser.add_argument("--product_quant", type=int, default=1)
    parser.add_argument("--share_quant_resi", type=int, default=4)
    parser.add_argument("--codebook_drop", type=float, default=0.0)
    parser.add_argument("--half_sem", type=bool, default=False)
    parser.add_argument("--start_drop", type=int, default=1)
    parser.add_argument("--lecam_loss_weight", type=float, default=None)
    parser.add_argument("--sem_loss_weight", type=float, default=0.1)
    parser.add_argument("--detail_loss_weight", type=float, default=0.1)
    parser.add_argument("--enc_tuning_method", type=str, default='none')
    parser.add_argument("--dec_tuning_method", type=str, default='none')
    parser.add_argument("--clip_norm", type=bool, default=False)
    parser.add_argument("--sem_loss_scale", type=float, default=1.0)
    parser.add_argument("--detail_loss_scale", type=float, default=1.0)
    parser.add_argument("--config", type=str, default=None)
    parser.add_argument("--norm_type", type=str, default='bn')

    parser.add_argument("--aug_prob", type=float, default=1.0) # augmentation for disc loss
    parser.add_argument("--aug_fade_steps", type=int, default=0) # fade in steps for disc loss
    parser.add_argument("--disc_reinit", type=int, default=0)
    parser.add_argument("--debug_disc", type=bool, default=False)

    parser.add_argument("--guide_type_1", type=str, default='class', choices=["patch", "class"])  # semantic guide type (ImageFolder)
    parser.add_argument("--guide_type_2", type=str, default='class', choices=["patch", "class"]) # detail guide type (ImageFolder)
    parser.add_argument("--lfq", action='store_true', default=False, help="if use LFQ") # lfq
    
    parser.add_argument("--kmeans_init", type=bool, default=False)
    parser.add_argument("--ae_training", type=bool, default=False)
    parser.add_argument("--ae_load", type=bool, default=False)
    parser.add_argument("--ae_path", type=str, default=None)
    
    parser.add_argument("--ckpt_per_epoch", type=int, default=5)
    parser.add_argument("--test_data_path", type=str, default=None)
    parser.add_argument("--codebook_l2_norm", type=bool, default=False)
    parser.add_argument("--cosine_min_ratio", type=float, default=0.01)
    parser.add_argument("--min_lr", type=float, default=1e-6)
    
    parser.add_argument("--patience_epochs", type=int, default=20)

    args = parser.parse_args()
    if args.config is not None:
        with open(args.config, 'r', encoding='utf-8') as f:
            file_yaml = yaml.YAML()
            config_args = file_yaml.load(f)
            parser.set_defaults(**config_args)

        # re-parse command-line args to overwrite with any command-line inputs
    args = parser.parse_args()
    return args

def main(args):
    """
    Trains a new model.
    """
    assert torch.cuda.is_available(), "Training currently requires at least one GPU."

    # Setup DDP:
    init_distributed_mode(args)
    assert args.global_batch_size % dist.get_world_size() == 0, f"Batch size must be divisible by world size."
    rank = dist.get_rank()
    device = rank % torch.cuda.device_count()
    seed = args.global_seed * dist.get_world_size() + rank
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    np.random.seed(seed)

    torch.cuda.set_device(device)

    # Setup an experiment folder:
    if rank == 0:
        os.makedirs(args.results_dir, exist_ok=True)  # Make results folder (holds all experiment subfolders)
        experiment_index = len(glob(f"{args.results_dir}/*"))
        model_string_name = args.vq_model.replace("/", "-")
        experiment_dir = f"{args.results_dir}/{experiment_index:03d}-{model_string_name}"  # Create an experiment folder
        checkpoint_dir = f"{experiment_dir}/checkpoints"  # Stores saved model checkpoints
        os.makedirs(checkpoint_dir, exist_ok=True)
        logger = create_logger(experiment_dir)
        logger.info(f"Experiment directory created at {experiment_dir}")

        cloud_results_dir = f"{args.cloud_save_path}"
        cloud_checkpoint_dir = f"{cloud_results_dir}"
        os.makedirs(cloud_checkpoint_dir, exist_ok=True)
        logger.info(f"Experiment directory created in cloud at {cloud_checkpoint_dir}")

        experiment_config = vars(args)
        with open(os.path.join(cloud_checkpoint_dir, 'config.yaml'), 'w', encoding='utf-8') as f:
            # Use the round_trip_dump method to preserve the order and style
            file_yaml = yaml.YAML()
            file_yaml.dump(experiment_config, f)
    
    else:
        logger = create_logger(None)

    # training args
    logger.info(f"{args}")

    # training env
    logger.info(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")



    # Setup data:
    transform = transforms.Compose([
        transforms.Lambda(lambda pil_image: random_crop_arr(pil_image, args.image_size)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
    ])
    dataset = build_dataset(args, transform=transform)
    sampler = DistributedSampler(
        dataset,
        num_replicas=dist.get_world_size(),
        rank=rank,
        shuffle=True,
        seed=args.global_seed
    )
    loader = DataLoader(
        dataset,
        batch_size=int(args.global_batch_size // dist.get_world_size()),
        shuffle=False,
        sampler=sampler,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=True
    )
    logger.info(f"Dataset contains {len(dataset):,} images ({args.data_path})")

    if args.save_best:
        transform = transforms.Compose([
        transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
        ])
        args.data_path = args.val_data_path
        val_dataset = build_dataset(args, transform=transform)
        val_sampler = DistributedSampler(
            val_dataset,
            num_replicas=dist.get_world_size(),
            rank=rank,
            shuffle=False,
            seed=args.global_seed
        )
        val_loader = DataLoader(
            val_dataset,
            batch_size=int(args.global_batch_size // dist.get_world_size()),
            shuffle=False,
            sampler=val_sampler,
            num_workers=args.num_workers,
            pin_memory=True,
            drop_last=True
        )
        if rank % torch.cuda.device_count() == 0:
            os.makedirs(args.sample_folder_dir, exist_ok=True)
            os.makedirs(args.reconstruction_folder_dir, exist_ok=True)
            logger.info(f"Saving .png samples at {args.sample_folder_dir}")
            logger.info(f"Saving .png reconstruction at {args.reconstruction_folder_dir}")

    num_update_steps_per_epoch = len(loader)
    max_train_steps = args.epochs * num_update_steps_per_epoch
    args.disc_start = args.disc_epoch_start * num_update_steps_per_epoch

    # Initialize FID
    # fid = FID().to(device)


    # create and load model
    vq_model = VQ_models[args.vq_model](
        ae_training=args.ae_training,
        
        # vq 
        codebook_size=args.codebook_size,
        codebook_embed_dim=args.codebook_embed_dim,
        commit_loss_beta=args.commit_loss_beta,
        entropy_loss_ratio=args.entropy_loss_ratio,
        dropout_p=args.dropout_p,
        v_patch_nums=args.v_patch_nums,
        enc_type=args.enc_type,
        encoder_model=args.encoder_model,
        dec_type=args.dec_type,
        decoder_model=args.decoder_model,
        semantic_guide=args.semantic_guide,
        detail_guide=args.detail_guide,
        num_latent_tokens=args.num_latent_tokens,
        abs_pos_embed=args.abs_pos_embed,
        share_quant_resi=args.share_quant_resi,
        product_quant=args.product_quant,
        codebook_drop=args.codebook_drop,
        half_sem=args.half_sem,
        start_drop=args.start_drop,
        sem_loss_weight=args.sem_loss_weight,
        detail_loss_weight=args.detail_loss_weight,
        clip_norm=args.clip_norm,
        sem_loss_scale=args.sem_loss_scale,
        detail_loss_scale=args.detail_loss_scale,
        guide_type_1=args.guide_type_1,
        guide_type_2=args.guide_type_2,
        lfq=args.lfq,

        # XQGAN's configuration
        codebook_l2_norm=args.codebook_l2_norm
    )
    logger.info(f"VQ Model Configs: {vq_model.config}")


    logger.info(f"VQ Model Parameters: {sum(p.numel() for p in vq_model.parameters()):,}")
    if args.ema:
        ema = deepcopy(vq_model).to(device)  # Create an EMA of the model for use after training
        requires_grad(ema, False)
        logger.info(f"VQ Model EMA Parameters: {sum(p.numel() for p in ema.parameters()):,}")
    vq_model = vq_model.to(device)

    if args.ae_load:
        if args.ae_path is None:
            raise ValueError("AE path is not provided")
        ae_model = torch.load(args.ae_path, map_location="cpu")
        ae_model = ae_model["model"]
        missing, unexpected = vq_model.load_state_dict(ae_model, strict=False)
        logger.info(f"Missing keys: {missing}")
        logger.info(f"Unexpected keys: {unexpected}")
        assert len(unexpected) == 0, "Unexpected keys found in AE model"

    # K-means initialization
    
    with torch.no_grad():
        if args.kmeans_init:
            vq_model.init_embedding(loader, device)
    
    if args.disc_type != 'none':
        vq_loss = VQLoss(
            disc_start=args.disc_start, 
            disc_weight=args.disc_weight,
            disc_type=args.disc_type,
            disc_loss=args.disc_loss,
            gen_adv_loss=args.gen_loss,
            image_size=args.image_size,
            perceptual_weight=args.perceptual_weight,
            reconstruction_weight=args.reconstruction_weight,
            reconstruction_loss=args.reconstruction_loss,
            codebook_weight=args.codebook_weight,
            lecam_loss_weight=args.lecam_loss_weight,
            disc_adaptive_weight=args.disc_adaptive_weight,
            norm_type=args.norm_type,
            aug_prob=args.aug_prob,
        ).to(device)
    else:
        vq_loss = VQLoss_no_disc(
            image_size=args.image_size,
            reconstruction_weight=args.reconstruction_weight,
            reconstruction_loss=args.reconstruction_loss,
            codebook_weight=args.codebook_weight,
            perceptual_weight=args.perceptual_weight,
            norm_type=args.norm_type,
        ).to(device)
    
    if args.disc_type != 'none':
        logger.info(f"Discriminator Parameters: {sum(p.numel() for p in vq_loss.discriminator.parameters()):,}")

    args.lr = args.lr * args.global_batch_size / 128
    args.disc_lr = args.disc_lr * args.global_batch_size / 128
    # initialize a GradScaler. If enabled=False scaler is a no-op
    scaler = torch.cuda.amp.GradScaler(enabled=(args.mixed_precision =='fp16'))
    scaler_disc = torch.cuda.amp.GradScaler(enabled=(args.mixed_precision =='fp16'))
    # Setup optimizer
    optimizer = torch.optim.AdamW(vq_model.parameters(), lr=args.lr, betas=(args.beta1, args.beta2),
                                  weight_decay=args.weight_decay, )
    if args.disc_type != 'none':
        optimizer_disc = torch.optim.AdamW(vq_loss.discriminator.parameters(), lr=args.disc_lr,
                                       betas=(args.beta1, args.beta2), weight_decay=args.disc_weight_decay, )
    else:
        optimizer_disc = None
        
        
    reduce_lr_scheduler_config = {
        'sched': 'plateau',             
        'num_epochs': args.epochs,
        'warmup_epochs': 20,            
        'warmup_lr': 1e-6,         
        'min_lr': args.min_lr,
        'patience_epochs': args.patience_epochs,
        'cooldown_epochs': 0,
        'decay_rate': 0.2,
        'plateau_mode': 'min',
        'step_on_epochs': True,
    }



    # create lr scheduler
    if args.lr_scheduler == 'none':
        vqvae_lr_scheduler = None
        disc_lr_scheduler = None
    elif args.lr_scheduler == 'ReduceLROnPlateau':
        vqvae_lr_scheduler, _ = create_scheduler(optimizer=optimizer, **reduce_lr_scheduler_config)
    else:
        vqvae_lr_scheduler, _ = create_scheduler(
            sched=args.lr_scheduler,
            optimizer=optimizer,
            patience_epochs=0,
            step_on_epochs=True,
            updates_per_epoch=num_update_steps_per_epoch,
            num_epochs=args.epochs,
            warmup_epochs=10, # 1
            min_lr= args.lr * args.cosine_min_ratio, #5e-5 
        )
        
    # print(vqvae_lr_scheduler.lr_scheduler.num_bad_epochs)
        
    if optimizer_disc is not None:
        disc_lr_scheduler, _ = create_scheduler(
            sched=args.lr_scheduler,
            optimizer=optimizer_disc,
            patience_epochs=0,
            step_on_epochs=True,
            updates_per_epoch=num_update_steps_per_epoch,
            num_epochs=args.epochs - args.disc_epoch_start,
            warmup_epochs=5, #5,
            min_lr= args.lr * args.cosine_min_ratio, # 5e-5 # args.lr * args.cosine_min_ratio,
        )
        

    logger.info(f"num_update_steps_per_epoch {num_update_steps_per_epoch:,} max_train_steps ({max_train_steps})")

    # Prepare models for training:
    if args.vq_ckpt:
        checkpoint = torch.load(args.vq_ckpt, map_location="cpu")
        vq_model.load_state_dict(checkpoint["model"])
        if args.ema:
            ema.load_state_dict(checkpoint["ema"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        if not args.debug_disc:
            vq_loss.discriminator.load_state_dict(checkpoint["discriminator"])
            optimizer_disc.load_state_dict(checkpoint["optimizer_disc"])
        else:
            num_step = checkpoint["optimizer_disc"]["state"][next(iter(checkpoint["optimizer_disc"]["state"]))]['step']
            for param_state in optimizer_disc.state.values():
                param_state['step'] = num_step
        if not args.finetune:
            train_steps = checkpoint["steps"] if "steps" in checkpoint else int(args.vq_ckpt.split('/')[-1].split('.')[0])
            start_epoch = int(train_steps / int(len(dataset) / args.global_batch_size)) + 1
            train_steps = int(start_epoch * int(len(dataset) / args.global_batch_size))
        else:
            train_steps = 0
            start_epoch = 0           
        del checkpoint
        vq_model.finetune(args.enc_tuning_method, args.dec_tuning_method)
        logger.info(f"Resume training from checkpoint: {args.vq_ckpt}")
        logger.info(f"Initial state: steps={train_steps}, epochs={start_epoch}")
    else:
        train_steps = 0
        start_epoch = 0
        if args.ema:
            update_ema(ema, vq_model, decay=0)  # Ensure EMA is initialized with synced weights
    

    if args.compile:
        logger.info("compiling the model... (may take several minutes)")
        vq_model = torch.compile(vq_model, mode='max-autotune')  # requires PyTorch 2.0
    
    vq_model = DDP(vq_model.to(device), device_ids=[args.gpu])
    vq_model.train()
    if args.ema:
        ema.eval()  # EMA model should always be in eval mode
    if args.disc_type != 'none':
        vq_loss = DDP(vq_loss.to(device), device_ids=[args.gpu])
        vq_loss.train()

    ptdtype = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.mixed_precision]

    # Variables for monitoring/logging purposes:
    log_steps = 0
    running_loss = 0
    start_time = time.time()
    curr_fid = None

    logger.info(f"Training for {args.epochs} epochs...")
    for epoch in range(start_epoch, args.epochs):
        sampler.set_epoch(epoch)
        logger.info(f"Beginning epoch {epoch}...")
        
        if args.disc_reinit != 0:
            if epoch % args.disc_reinit == 0:
                vq_loss.module.discriminator.reinit()
        for x, y in loader:
            imgs = x.to(device, non_blocking=True)

            if args.aug_fade_steps >= 0:
                fade_blur_schedule = 0 if train_steps < args.disc_start else min(1.0, (train_steps - args.disc_start) / (args.aug_fade_steps + 1))
                fade_blur_schedule = 1 - fade_blur_schedule
            else:
                fade_blur_schedule = 0
            # generator training
            optimizer.zero_grad()
            with torch.cuda.amp.autocast(dtype=ptdtype):  
                recons_imgs, codebook_loss, sem_loss, detail_loss, dependency_loss = vq_model(imgs, epoch)
                loss_gen = vq_loss(codebook_loss, sem_loss, detail_loss, dependency_loss, imgs, recons_imgs, optimizer_idx=0, global_step=train_steps+1,
                                   last_layer=vq_model.module.decoder.last_layer, 
                                   logger=logger, log_every=args.log_every, fade_blur_schedule=fade_blur_schedule)

            scaler.scale(loss_gen).backward()
            if args.max_grad_norm != 0.0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(vq_model.parameters(), args.max_grad_norm)
            scaler.step(optimizer)
            scaler.update()
            if args.ema:
                update_ema(ema, vq_model.module._orig_mod if args.compile else vq_model.module)

            # discriminator training  
            if optimizer_disc is not None:
                optimizer_disc.zero_grad()

                with torch.cuda.amp.autocast(dtype=ptdtype):
                    loss_disc = vq_loss(codebook_loss, sem_loss, detail_loss, dependency_loss, imgs, recons_imgs, optimizer_idx=1, global_step=train_steps+1,
                                        logger=logger, log_every=args.log_every, fade_blur_schedule=fade_blur_schedule)
                scaler_disc.scale(loss_disc).backward()
                if args.max_grad_norm != 0.0:
                    scaler_disc.unscale_(optimizer_disc)
                    torch.nn.utils.clip_grad_norm_(vq_loss.module.discriminator.parameters(), args.max_grad_norm)
                scaler_disc.step(optimizer_disc)
                scaler_disc.update()
            else:
                loss_disc = torch.tensor(0.0, device=device)
            
            # # Log loss values:
            
            
            running_loss += loss_gen.item() + loss_disc.item()
            
            log_steps += 1
            train_steps += 1
            if train_steps % args.log_every == 0:
                # Measure training speed:
                torch.cuda.synchronize()
                end_time = time.time()
                steps_per_sec = log_steps / (end_time - start_time)
                # Reduce loss history over all processes:
                avg_loss = torch.tensor(running_loss / log_steps, device=device)
                dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM)
                avg_loss = avg_loss.item() / dist.get_world_size()
                logger.info(f"(step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}")
                # Reset monitoring variables:
                running_loss = 0
                log_steps = 0
                start_time = time.time()


                log_module = vq_loss.module if args.disc_type != 'none' else vq_loss
                
                if dist.get_rank() == 0:
                    log_module.wandb_tracker.log({
                        "lr": optimizer.param_groups[0]["lr"],
                        "train_loss": avg_loss},
                        step=train_steps
                    )
                    # show images and recon images
                    if train_steps % args.vis_every == 0:
                        with torch.no_grad():
                            recons_with_scale = vq_model.module.img_to_reconstructed_img(imgs[:4], last_one=False)
                        image = torch.cat(recons_with_scale + [imgs[:4]], dim=0)
                        image = torch.clamp(image, min=-1, max=1)
                        image = make_grid((image + 1) / 2, nrow=4, padding=0, pad_value=1.0)
                        image = image.permute(1, 2, 0).mul_(255).cpu().numpy()
                        image = Image.fromarray(image.astype(np.uint8))

                        log_module.wandb_tracker.log({"recon_images": [wandb.Image(image)]}, step=train_steps)







        ########## My modification: ADD validation loss ##########
        
        vq_model.eval()
        stats_sum = torch.zeros(7, device=device)  # rec, p, adv, total, vq, commit, usage
        num_batches_local = len(val_loader)

        with torch.no_grad():
            for x, _ in tqdm(val_loader, disable=(rank != 0),
                            desc=f'eval step {train_steps:07d} epoch {epoch+1}'):
                x = x.to(device, non_blocking=True)
                recons, codebook_loss, sem_loss, detail_loss, dep_loss = vq_model(x, epoch)

                disc = vq_loss.module if args.disc_type != 'none' else vq_loss
                rec, p, adv, total = disc.val_forward(
                    codebook_loss, sem_loss, detail_loss, dep_loss,
                    x, recons, global_step=train_steps+1, fade_blur_schedule=0)

                scalars = [rec, p, adv, total,
                        codebook_loss[0], codebook_loss[1],
                        np.mean(codebook_loss[3])]   # usage
                scalars = [s if torch.is_tensor(s) else torch.tensor(s, device=device)
                        for s in scalars]
                batch_stats = torch.stack(scalars)  # shape (7,)

                stats_sum += batch_stats / num_batches_local  # avg batch

        # all processes sum and avg
        dist.all_reduce(stats_sum)
        stats_mean = stats_sum / dist.get_world_size()  # now global avg

        (final_rec_loss,
        final_p_loss,
        final_generator_adv_loss,
        final_loss,
        final_vq_loss,
        final_commit_loss,
        final_usage) = stats_mean.tolist()

        validation_loss = final_loss if args.disc_type == 'none' else final_loss - final_generator_adv_loss


        vq_model.train()

        dist.barrier()
        if rank == 0:
            log_module = vq_loss.module if args.disc_type != 'none' else vq_loss
            
            log_module.wandb_tracker.log({
                "val_rec_loss": final_rec_loss,
                "val_perceptual_loss": final_p_loss,
                "val_generator_adv_loss": final_generator_adv_loss,
                "val_loss": final_loss,
                "val_vq_loss": final_vq_loss,
                "val_usage": final_usage,
                "val_commit_loss": final_commit_loss,
                "val_p_mse_loss": final_p_loss + final_rec_loss,
                "Measure_loss": validation_loss,
            }, step=train_steps)
    
        dist.barrier()

        # Save checkpoint:
        if (epoch + 1) % args.ckpt_per_epoch == 0:    
            

            
            if args.save_best:
                vq_model.eval()
                total = 0
                samples = []
                gt = []
                for x, _ in tqdm(val_loader, desc=f'evaluation for step {train_steps:07d}, epoch {epoch + 1}', disable=not rank == 0):
                    with torch.no_grad():
                        x = x.to(device, non_blocking=True)
                        sample = vq_model.module.img_to_reconstructed_img(x)
                        # sample = torch.clamp(127.5 * sample + 128.0, 0, 255).permute(0, 2, 3, 1).to(torch.uint8).contiguous()
                        # x = torch.clamp(127.5 * x + 128.0, 0, 255).permute(0, 2, 3, 1).to(torch.uint8).contiguous()
                        
                        ########## My modification ##########
                        #sample = torch.clamp(127.5 * sample + 128.0, 0, 255).to(torch.uint8).contiguous()
                        
                        #x = torch.clamp(127.5 * x + 128.0, 0, 255).to(torch.uint8).contiguous()
                        
                        sample = torch.clamp((sample + 1.0) / 2.0, 0.0, 1.0).contiguous()
                        x = torch.clamp((x + 1.0) / 2.0, 0.0, 1.0).contiguous()
                        
                        
                        ########## My modification ##########

                    sample = torch.cat(dist.nn.all_gather(sample), dim=0)
                    x = torch.cat(dist.nn.all_gather(x), dim=0)
                    # samples.append(sample.to("cpu", dtype=torch.uint8).numpy())
                    # gt.append(x.to("cpu", dtype=torch.uint8).numpy())
                    samples.append(sample)
                    gt.append(x)

                    total += sample.shape[0]
                vq_model.train()
                logger.info(f"Ealuate total {total} files.")
                dist.barrier()

                ############### My FID calculation ###############
                # Modify the FID calculation to use the correct GPU using pytorch-fid (Original code using tensorflow but the Cudacnn seems to be not working)
                if rank == 0:
                #    # samples = np.concatenate(samples, axis=0)
                #    #gt = np.concatenate(gt, axis=0)
                
                    # samples = torch.cat(samples, dim=0)
                    #gt = torch.cat(gt, dim=0)
                    fid = FID(device=device)
                    for i in range(len(samples)):
                        print('samples[i].shape: ', samples[i].shape)
                        fid.update(samples[i], is_real=False)
                        fid.update(gt[i], is_real=True)
                    # logger.info(f"begin compute")
                    fid_value = fid.compute()
                    logger.info(f"FID: {fid_value}")
                    fid.reset()
                    logger.info(f"FID: {fid_value}")
                    
                    if curr_fid == None:
                        curr_fid = [fid_value, train_steps]
                    elif fid_value <= curr_fid[0]:
                        # os.remove(f"{cloud_checkpoint_dir}/{curr_fid[1]:07d}.pt")
                        curr_fid = [fid_value, train_steps]
                    if optimizer_disc is not None:
                        vq_loss.module.wandb_tracker.log({"eval FID": fid_value}, step=train_steps)
                    else:
                        vq_loss.wandb_tracker.log({"eval FID": fid_value}, step=train_steps)
                        
                        # release memory
                        del samples, gt
                        torch.cuda.empty_cache()
                


                ## End of FID calculation
                dist.barrier()
                

            if rank == 0:
                if args.compile:
                    model_weight = vq_model.module._orig_mod.state_dict()
                else:
                    model_weight = vq_model.module.state_dict()  
                    
                optimizer_disc_weight = optimizer_disc.state_dict() if optimizer_disc is not None else None 
                disc_weight = vq_loss.module.discriminator.state_dict() if args.disc_type != 'none' else None
                checkpoint = {
                    "model": model_weight,
                    "optimizer": optimizer.state_dict(),
                    "discriminator": disc_weight,
                    "optimizer_disc": optimizer_disc_weight,
                    "steps": train_steps,
                    "args": args
                }
                if args.ema:
                    checkpoint["ema"] = ema.state_dict()
                if not args.no_local_save:
                    checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}_{epoch + 1:03d}.pt"
                    torch.save(checkpoint, checkpoint_path)
                    logger.info(f"Saved checkpoint to {checkpoint_path}")
                
                # Save checkpoint local
                cloud_checkpoint_path = f"{cloud_checkpoint_dir}/{train_steps:07d}_{epoch + 1:03d}.pt"
                torch.save(checkpoint, cloud_checkpoint_path)
                logger.info(f"Saved checkpoint in cloud to {cloud_checkpoint_path}")

                if args.save_best:
                    last_checkpoint_path = f"{args.cloud_save_path}/last_ckpt.pt"
                    if os.path.exists(last_checkpoint_path):
                        os.remove(last_checkpoint_path)
                    else:
                        os.makedirs(f"{args.cloud_save_path}", exist_ok=True)
                    torch.save(checkpoint, last_checkpoint_path)
                    logger.info(f"Saved checkpoint in cloud to {last_checkpoint_path}")
                    if curr_fid[1] == train_steps:
                        best_checkpoint_path = f"{args.cloud_save_path}/best_ckpt.pt"
                        torch.save(checkpoint, best_checkpoint_path)
                        logger.info(f"Saved checkpoint in cloud to {best_checkpoint_path}")

            dist.barrier()

        if vqvae_lr_scheduler is not None:
            if args.lr_scheduler == 'ReduceLROnPlateau':
                vqvae_lr_scheduler.step(epoch + 1, validation_loss)
            else:
                vqvae_lr_scheduler.step(epoch + 1)
        if optimizer_disc is not None and disc_lr_scheduler is not None and epoch >= args.disc_epoch_start:
            disc_lr_scheduler.step(epoch + 1 - args.disc_epoch_start)

        # Early stopping (My modification for ReduceLROnPlateau)
        if args.lr_scheduler == 'ReduceLROnPlateau':
            current_lr = vqvae_lr_scheduler.optimizer.param_groups[0]['lr']
            num_bad_epochs = vqvae_lr_scheduler.lr_scheduler.num_bad_epochs

            if current_lr <= args.min_lr and num_bad_epochs >= args.patience_epochs:
                print("Early stopping triggered.")
                break

    
    vq_model.eval()  # important! This disables randomized embedding dropout
    # do any sampling/FID calculation/etc. with ema (or model) in eval mode ...

    logger.info("Done!")
    dist.destroy_process_group()



if __name__ == "__main__":
    args = parse_args()
    main(args)
