# PyTorch GALIP: https://github.com/tobran/GALIP
# The MIT License (MIT)
# See license file or visit https://github.com/tobran/GALIP for details

# replaced with code/src/train.py for SONA training

import os, sys
import os.path as osp
import time
import random
import argparse
import numpy as np
from PIL import Image
import pprint
import datetime
from datetime import timedelta

import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.backends.cudnn as cudnn
from torchvision.utils import save_image,make_grid
from torch.utils.tensorboard import SummaryWriter
import wandb
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.utils.data.distributed import DistributedSampler
from accelerate import Accelerator
from accelerate import DistributedDataParallelKwargs, InitProcessGroupKwargs
import multiprocessing as mp
# import open_clip

ROOT_PATH_ORIG = osp.abspath(osp.join(osp.dirname(osp.abspath(__file__)),  ".."))
ROOT_PATH = "/root/path"
sys.path.insert(0, ROOT_PATH_ORIG)
from lib.utils import mkdir_p,get_rank,merge_args_yaml,get_time_stamp,save_args
from lib.utils import load_models_opt,save_models_opt,save_models_opt_hug,save_models,load_npz,params_count, save_all
from lib.perpare import prepare_dataloaders,prepare_models
from lib.modules_sona import sample_one_batch as sample, test as test, train as train
from lib.datasets import get_fix_data

import importlib
aux_models = importlib.import_module("models.GALIP-SONA")

def parse_args():
    # Training settings
    parser = argparse.ArgumentParser(description='Text2Img')
    parser.add_argument('--cfg', dest='cfg_file', type=str, default='../cfg/coco.yml',
                        help='optional config file')
    parser.add_argument('--num_workers', type=int, default=4,
                        help='number of workers(default: {0})'.format(mp.cpu_count() - 1))
    parser.add_argument('--stamp', type=str, default='normal',
                        help='the stamp of model')
    parser.add_argument('--pretrained_model_path', type=str, default='model',
                        help='the model for training')
    parser.add_argument('--log_dir', type=str, default='new',
                        help='file path to log directory')
    parser.add_argument('--model', type=str, default='GALIP-SONA', 
                        help='the model for training')
    parser.add_argument('--state_epoch', type=int, default=100,
                        help='state epoch')
    parser.add_argument('--batch_size', type=int, default=1024,
                        help='batch size')
    parser.add_argument('--num_san_layers', type=int, default=3,
                        help='The number of layers in san head')
    parser.add_argument('--c_h_dim', type=int, default=128,
                        help='hidden dim of conv layers in netC')
    parser.add_argument('--noise_until', type=int, default=0,
                        help='hidden dim of conv layers in netC')
    parser.add_argument('--k_gp', type=float, default=2.0,
                        help='coef k in MAGP')
    parser.add_argument('--p_gp', type=float, default=6.0,
                        help='power p in MAGP')
    parser.add_argument('--mi_g', type=float, default=0.0,
                        help='MI reg on the Generator')
    parser.add_argument('--mi_d', type=float, default=1.0,
                        help='MI reg on the Discriminator')
    parser.add_argument("--diffaug", type=str, default="False")    
    parser.add_argument("--spectral", type=str, default="False")
    parser.add_argument("--anneal", type=str, default="False")
    parser.add_argument("--img_img_sim", type=str, default="False")
    parser.add_argument('--train', type=str, default='True',
                        help='if train model')
    parser.add_argument('--mixed_precision', type=str, default='False',
                        help='if use multi-gpu')
    parser.add_argument('--tf32', type=str, default='True',
                        help='option for tf32 (only when mixed precision is not used)')
    parser.add_argument('--multi_gpus', type=str, default='False',
                        help='if use multi-gpu')
    parser.add_argument('--accelerator', type=str, default='hug', help='vanilla or ddp or hug (HuggingFace)')
    parser.add_argument('--gpu_id', type=int, default=1,
                        help='gpu id')
    parser.add_argument('--local_rank', default=-1, type=int,
                        help='node rank for distributed training')
    parser.add_argument('--random_sample', action='store_true',default=True, 
                        help='whether to sample the dataset with random sampler')
    args = parser.parse_args()
    return args


def main(args):
    time_stamp = get_time_stamp()
    stamp = '_'.join([str(args.model),'nf'+str(args.nf),str(args.stamp),str(args.CONFIG_NAME),str(args.imsize),time_stamp])
    args.model_save_file = osp.join(ROOT_PATH, 'saved_models', str(args.CONFIG_NAME), stamp)
    log_dir = args.log_dir
    if log_dir == 'new':
        log_dir = osp.join(ROOT_PATH, 'logs/{0}'.format(osp.join(str(args.CONFIG_NAME), 'train', stamp)))
    args.img_save_dir = osp.join(ROOT_PATH, 'imgs/{0}'.format(osp.join(str(args.CONFIG_NAME), 'train', stamp)))

    args.best_fid = float('inf')

    if args.accelerator == "hug":    # Huggingface accelerator
        ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
        init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=1800))
        if args.mixed_precision:
            # mp = "fp16"
            mp = "bf16"    # not worked well with this model
        else:
            mp = "no"
            ## If you want to use TF32
            torch.backends.cuda.matmul.allow_tf32 = args.tf32
            torch.backends.cudnn.allow_tf32 = args.tf32
        
        args.mixed_precision = False    # This is a flag for default pytorch AMP. Huggingface acc. does not need this one.
        accelerator = Accelerator(mixed_precision=mp, kwargs_handlers=[ddp_kwargs, init_kwargs])
        if accelerator.is_main_process:
            print(f"Mixed Precision: {mp}")
        args.device = accelerator.device
        if accelerator.is_main_process:
            accelerator.print("-----Using Huggingface Accelerator-----")
    elif args.accelerator == "ddp":
        accelerator = None
        pass
    else:
        accelerator = None

    if args.accelerator == "hug":
        if accelerator.is_main_process:
            mkdir_p(osp.join(ROOT_PATH, 'logs'))
            mkdir_p(args.model_save_file)
            mkdir_p(args.img_save_dir)
    else:
        if (args.multi_gpus==True) and (get_rank() != 0):
            pass
        else:
            mkdir_p(osp.join(ROOT_PATH, 'logs'))
            mkdir_p(args.model_save_file)
            mkdir_p(args.img_save_dir)

    if args.accelerator == "ddp":
        args.gpus = torch.distributed.get_world_size()
        if "cc12m" in args.dataset_name:
            args.local_batch_size = args.batch_size // args.gpus
        else:
            args.local_batch_size = args.batch_size
    # prepare TensorBoard
    if args.accelerator == "hug":
        if accelerator.is_main_process:
            wandb.init(project='GALIP-crusoe', name=f"{args.CONFIG_NAME}{stamp}", sync_tensorboard=True, config=args)
            writer = SummaryWriter(log_dir)
        else:
            writer = None
    else:
        if (args.multi_gpus==True) and (get_rank() != 0):
            writer = None
        else:
            wandb.init(project='GALIP-crusoe', name=f"{args.CONFIG_NAME}{stamp}", sync_tensorboard=True, config=args)
            writer = SummaryWriter(log_dir)

    # prepare dataloader, models, data
    train_dl, valid_dl ,train_ds, valid_ds, sampler = prepare_dataloaders(args, caption=True)
    CLIP4trn, CLIP4evl, image_encoder, text_encoder, netG, netD, netC = prepare_models(args)
    fixed_img, fixed_sent, fixed_words, fixed_z, fixed_caption = get_fix_data(train_dl, valid_dl, text_encoder, args)
    if args.diffaug:
        image_encoder.diffaug = True

    train_dl, valid_dl ,train_ds, valid_ds, sampler = prepare_dataloaders(args, caption=False)
    if args.accelerator == "hug":
        train_dl, valid_dl, CLIP4trn, image_encoder, text_encoder, netG, netD, netC = accelerator.prepare(
            train_dl, valid_dl, CLIP4trn, image_encoder, text_encoder, netG, netD, netC
        )
        accelerator.print('**************G_paras: ',params_count(netG))
        accelerator.print('**************D_paras: ',params_count(netD)+params_count(netC))
        accelerator.print(args)
        accelerator.print(f'dataloader size: {len(train_dl)}, {len(valid_dl)}')
    elif args.accelerator == "ddp":
        if get_rank() == 0:
            print('**************G_paras: ',params_count(netG))
            print('**************D_paras: ',params_count(netD)+params_count(netC))
            print(args)
            print(f'dataloader size: {len(train_dl)}, {len(valid_dl)}')
    if args.accelerator in ["ddp", "vannila"]:
        if (args.multi_gpus==True) and (get_rank() != 0):
            pass
        else:
            fixed_grid = make_grid(fixed_img.cpu(), nrow=8, normalize=True)
            img_name = 'gt.png'
            img_save_path = osp.join(args.img_save_dir, img_name)
            vutils.save_image(fixed_img.data, img_save_path, nrow=8, normalize=True)
            with open(osp.join(args.img_save_dir, "prompt.txt"), "w") as f:
                f.writelines([_c+"\n"for _c in fixed_caption])
    else:
        if (accelerator is not None) and accelerator.is_main_process:
            fixed_grid = make_grid(fixed_img.cpu(), nrow=8, normalize=True)
            img_name = 'gt.png'
            img_save_path = osp.join(args.img_save_dir, img_name)
            vutils.save_image(fixed_img.data, img_save_path, nrow=8, normalize=True)
            with open(osp.join(args.img_save_dir, "prompt.txt"), "w") as f:
                f.writelines([_c+"\n"for _c in fixed_caption])
    # prepare optimizer
    D_params = list(netD.parameters()) + list(netC.parameters())
    if "lora" in args.model:
        D_params = D_params + list(image_encoder.get_lora_parameters())
    optimizerD = torch.optim.Adam(D_params, lr=args.lr_d, betas=(0.0, 0.9), eps=1e-6)
    optimizerG = torch.optim.Adam(netG.parameters(), lr=args.lr_g, betas=(0.0, 0.9), eps=1e-6)

    if args.accelerator == "hug":
        optimizerD, optimizerG = accelerator.prepare(optimizerD, optimizerG)

    if args.mixed_precision==True:
        scaler_D = torch.cuda.amp.GradScaler(growth_interval=args.growth_interval)
        scaler_G = torch.cuda.amp.GradScaler(growth_interval=args.growth_interval)
    else:
        scaler_D = None
        scaler_G = None
    m1, s1 = load_npz(args.npz_path)
    start_epoch = 1
    # load from checkpoint
    if args.state_epoch!=1:
        start_epoch = args.state_epoch + 1
        path = osp.join(args.pretrained_model_path, 'state_epoch_%03d.pth'%(args.state_epoch))
        if args.accelerator == "hug":
            _ie = image_encoder if 'lora' in args.model else None
            netG, netD, netC, optimizerG, optimizerD = load_models_opt(accelerator.unwrap_model(netG), accelerator.unwrap_model(netD), accelerator.unwrap_model(netC), optimizerG, optimizerD, path, args.multi_gpus, image_encoder=_ie)
            train_dl, valid_dl, CLIP4trn, image_encoder, text_encoder, netG, netD, netC = accelerator.prepare(
                train_dl, valid_dl, CLIP4trn, image_encoder, text_encoder, netG, netD, netC
            )
        else:            
            netG, netD, netC, optimizerG, optimizerD = load_models_opt(netG, netD, netC, optimizerG, optimizerD, path, args.multi_gpus)

    if args.accelerator in ["ddp", "vanilla"]:
        if (args.multi_gpus==True) and (get_rank() != 0):
            pass
        else:
            pprint.pprint(args)
            arg_save_path = osp.join(log_dir, 'args.yaml')
            save_args(arg_save_path, args)
            print("Start Training")
    else:
        if accelerator.is_main_process:
            pprint.pprint(args)
            arg_save_path = osp.join(log_dir, 'args.yaml')
            save_args(arg_save_path, args)
            print("Start Training")
    # Start training
    test_interval,gen_interval,save_interval = args.test_interval,args.gen_interval,args.save_interval
    for epoch in range(start_epoch, args.max_epoch+1, 1):
        if args.multi_gpus==True:
            # sampler.set_epoch(epoch)
            if sampler is not None:
                train_dl.sampler.set_epoch(epoch)
        if get_rank() == 0:
            print(f"start epoch: {epoch}")

        if args.multi_gpus == True: torch.distributed.barrier()
        
        start_t = time.time()
        # training
        args.current_epoch = epoch
        torch.cuda.empty_cache()
        if args.multi_gpus == True: torch.distributed.barrier()
        train(train_dl, netG, netD, netC, text_encoder, image_encoder, optimizerG, optimizerD, scaler_G, scaler_D, args, writer, accelerator)
        torch.cuda.empty_cache()
        # save
        if args.multi_gpus == True: torch.distributed.barrier()
        if epoch%save_interval==0: 
            _ie = image_encoder if "lora" in args.model else None
            save_all(netG, netD, netC, optimizerG, optimizerD, epoch, args.multi_gpus, args.model_save_file, image_encoder=_ie, acc=accelerator)
        if epoch%save_interval==0:
            accelerator.save_state(osp.join(ROOT_PATH, 'saved_models', str(args.CONFIG_NAME), stamp, f"states_ep{epoch}"))
        if args.multi_gpus == True: torch.distributed.barrier()
        # sample
        if epoch%gen_interval==0:
            if accelerator.is_main_process:
                print("Start sampling images...")
            sample(fixed_z, fixed_sent, netG, args.multi_gpus, epoch, args.img_save_dir, args)
        if args.multi_gpus == True: torch.distributed.barrier()
        # test
        if epoch%test_interval==0:
            if accelerator.is_main_process:
                print("Start Evaluation")
            bs = args.local_batch_size if "local_batch_size" in args else args.batch_size            
            FID, TI_score = test(valid_dl, text_encoder, netG, CLIP4evl, args.device, m1, s1, epoch, args.max_epoch, args.sample_times, args.z_dim, bs, args)

        if (args.multi_gpus==True) and (get_rank() != 0):
            pass
        else:
            if args.accelerator == "hug":
                _skip = True if not accelerator.is_main_process else False
            else:
                _skip = False    # Single GPU

            if not _skip:
                writer.add_scalar('sim_w', args.current_sim_w, epoch)
                print(f"current sim_w: {args.current_sim_w}")
                if epoch%test_interval==0:
                    writer.add_scalar('FID', FID, epoch)
                    writer.add_scalar('CLIP_Score', TI_score, epoch)
                    # writer.flush()
                    print('The %d epoch FID: %.2f, CLIP_Score: %.2f' % (epoch,FID,TI_score*100))
                end_t = time.time()
                print('The epoch %d costs %.2fs'%(epoch, end_t-start_t))
                print('*'*40)

        if args.multi_gpus == True: torch.distributed.barrier()

    if args.accelerator == "ddp":
        if get_rank() == 0:
            wandb.finish()
    elif args.accelerator == "hug":
        if accelerator.is_main_process:
            wandb.finish()
    else:
        wandb.finish()


if __name__ == "__main__":
    args = merge_args_yaml(parse_args())
    # set seed
    if args.manual_seed is None:
        args.manual_seed = 100
        #args.manualSeed = random.randint(1, 10000)
    random.seed(args.manual_seed)
    np.random.seed(args.manual_seed)
    torch.manual_seed(args.manual_seed)

    if args.accelerator == "hug":    # Huggingface accelerator
        args.multi_gpus = False

    if args.cuda:
        if args.multi_gpus:
            torch.cuda.manual_seed_all(args.manual_seed)
            torch.distributed.init_process_group(backend="nccl", timeout=datetime.timedelta(seconds=300))
            local_rank = int(os.environ["LOCAL_RANK"])
            torch.cuda.set_device(local_rank)
            args.device = torch.device("cuda", local_rank)
            args.local_rank = local_rank
        else:
            torch.cuda.manual_seed_all(args.manual_seed)
            torch.cuda.set_device(args.gpu_id)
            args.device = torch.device("cuda")
    else:
        args.device = torch.device('cpu')
    main(args)

