import os
import argparse
from omegaconf import OmegaConf
import pytorch_lightning as pl
from pytorch_lightning.strategies import DDPStrategy, DeepSpeedStrategy
from pytorch_lightning.callbacks import ModelCheckpoint, ModelSummary
from torch.utils.data import DataLoader, Dataset
import setproctitle
from models.data.data import webvid,msrvtt,ucf,text,jdb
import torch
from models.trainer import StableDiffusionTrainer

def get_parser():
    parser = argparse.ArgumentParser()
    """ Base args """
    parser.add_argument('--name', type=str, default='ETC500', help='experiment identifier')
    parser.add_argument('--savedir', type=str, default='logs', help='path to save checkpoints and logs')
    parser.add_argument('--savevideo', type=str, default='videos', help='path to save videos')
    parser.add_argument('--seed', type=int, default=42, help='random seed')
    parser.add_argument('--mode', type=str, default='train', choices=['train', 'test'], help='experiment mode to run')
    # parser.add_argument('--stage', type=str, default='1', choices=['1', '2'], help='stage 1 is vae training, and stage 2 is diffusion training')

    """ Date args """
    parser.add_argument('--dataset', type=str, default='datasets/WebVid')
    parser.add_argument('--batch_size', type=int, default=6)
    

    """ Model args """
    parser.add_argument("--config", type=str, default="config/base.yaml")
    parser.add_argument('--resume', type=str, default='logs/checkpoints/baseline+FPS')
    parser.add_argument('--H', type=int, default=256)
    parser.add_argument('--W', type=int, default=256)
    parser.add_argument('--T', type=int, default=16)

    """ Args about Training """
    parser.add_argument('--nodes', type=int, default=1, help='nodes')
    parser.add_argument('--devices', type=int, default=8, help='e.g., gpu number')

    return parser.parse_args()

def main():
    args = get_parser()
    pl.seed_everything(args.seed, workers=True)
    config = OmegaConf.load(args.config)
    config.name = args.name
    config.savedir = args.savedir
    config.mode = args.mode
    config.datasets = args.dataset
    config.batch_size = args.batch_size
    config.ddconfig.savevideo = args.savevideo
    config.ddconfig.name = args.name
    config.ddconfig.H = args.H
    config.ddconfig.W = args.W
    config.ddconfig.T = args.T
    setproctitle.setproctitle(args.name)
    
    ### Define trainer
    checkpoint_callback = ModelCheckpoint(
        dirpath                   =     os.path.join(config.savedir, 'checkpoints'),
        filename                  =     '{step}', # -{epoch:02d}
        monitor                   =     'step',
        save_last                 =     False,
        save_top_k                =     -1,
        verbose                   =     True,
        every_n_train_steps       =     2000,
        save_on_train_epoch_end   =     True,
    )

    strategy = DeepSpeedStrategy(
        stage                     =     2, 
        offload_optimizer         =     True, 
        # load_full_weights         =     True,
        # cpu_checkpointing         =     True,
    )

    trainer = pl.Trainer(
        default_root_dir          =     config.savedir,
        callbacks                 =     [checkpoint_callback, ], # ModelSummary(2)
        accelerator               =     'cuda',
        #accumulate_grad_batches   =     config.gradient_accumulation_steps,
        benchmark                 =     True,
        num_nodes                 =     args.nodes,
        devices                   =     args.devices,
        #gradient_clip_val         =     config.max_grad_norm,
        log_every_n_steps         =     1,
        precision                 =     16,
        max_epochs                =     config.num_train_epochs,
        strategy                  =     strategy,
        sync_batchnorm            =     True,
        val_check_interval        =     100,
        max_steps                 =     5000,
        check_val_every_n_epoch   =     1,
    )
    trainer_model = StableDiffusionTrainer(config.ddconfig)

    
    if args.mode == 'train':
    ### training
        train_dataset = webvid(args.dataset, args.H, args.W, args.T, True)
        test_dataset = webvid(args.dataset, args.H, args.W, args.T, False, 23)
        train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=config.batch_size, num_workers=1)
        test_dataloader = DataLoader(test_dataset, shuffle=True, batch_size=config.batch_size, num_workers=1)
        trainer.fit(
            model                     =     trainer_model,
            train_dataloaders         =     train_dataloader,
            val_dataloaders           =     test_dataloader,
            ckpt_path                 =     None if not os.path.exists(args.resume) else args.resume,
        )
    elif args.mode == 'test':
        test_dataset = msrvtt(args.dataset, args.H, args.W, args.T, False, 10000)
        test_dataloader = DataLoader(test_dataset, shuffle=True, batch_size=config.batch_size, num_workers=6)

        # assert os.path.exists(args.resume), "resume path does not exist"
        # d = torch.load(args.resume, map_location='cpu')
        # d_con = {}
        # for t in d['module']:
        #     d_con[t[t.find('.') + 1:]] = d['module'][t]
        # trainer_model.load_state_dict(d_con, strict=True)
        trainer.test(
            model                     =     trainer_model,
            dataloaders               =     test_dataloader,
            ckpt_path                 =     args.resume,
        )


if __name__ == "__main__":
    main()
    