import datetime
import os
import sys
os.environ['CURL_CA_BUNDLE'] = ''
import json
import warnings
warnings.filterwarnings("ignore")

import argparse
import torch
import pytorch_lightning as pl
from pytorch_lightning.trainer import Trainer
import pytorch_lightning.callbacks as plc
from model_interface import MInterface
from data_interface import DInterface
from src.tools.logger import SetupCallback, BestCheckpointCallback, BackupCodeCallback, TempFileCleanupCallback
from shutil import ignore_patterns
import pytorch_lightning.loggers as plog


def create_parser():
    parser = argparse.ArgumentParser()
    # Set-up parameters
    parser.add_argument('--res_dir', default='./results', type=str)
    parser.add_argument('--ex_name', default='debug', type=str)
    parser.add_argument('--check_val_every_n_epoch', default=1, type=int)
    
    parser.add_argument('--dataset', default='PTM')
    parser.add_argument('--model_name', default='MeToken', choices=['MeToken'])
    parser.add_argument('--lr', default=1e-4, type=float, help='Learning rate')
    parser.add_argument('--lr_scheduler', default='onecycle')
    parser.add_argument('--offline', default=1, type=int)
    parser.add_argument('--seed', default=114514, type=int)
    
    # dataset parameters
    parser.add_argument('--batch_size', default=16, type=int)
    parser.add_argument('--num_workers', default=16, type=int)
    parser.add_argument('--pad', default=1024, type=int)
    parser.add_argument('--path', default='./data')
    
    # Training parameters
    parser.add_argument('--epoch', default=200, type=int, help='end epoch')

    # Model parameters
    parser.add_argument('--final_tau', default=1e-4, type=float)    

    # Resume from checkpoints
    parser.add_argument("--pretrain", default=0, type=int)
    parser.add_argument("--test_only", default=0, type=int)
    parser.add_argument("--ckpt_from_deepspeed", default=0, type=int)
    parser.add_argument("--ckpt_path", default=None, type=str) 

    parser.add_argument('--gpus', type=int, nargs='+', default=[0], help='gpu to use, -1 for cpu')
    parser.add_argument('--strategy', type=str, default='auto')
    parser.add_argument('--wandb_offline', type=int, default=1)
    
    args = parser.parse_args()
    return args


def load_callbacks(args):
    callbacks = []
    logdir = str(os.path.join(args.res_dir, args.ex_name))
    ckptdir = os.path.join(logdir, "checkpoints")
    callbacks.append(BackupCodeCallback(os.path.dirname(args.res_dir),logdir, ignore_patterns=ignore_patterns('results*', 'pdb*', 'metadata*', 'vq_dataset*', 'bin*', 'data*', '__pycache__', 'info', 'lib', 'requirements', 'debug', 'wandb')))
    
    metric = "val_f1"
    early_stop_val = "max"
    sv_filename = 'best-{epoch:02d}-{val_f1:.3f}'
    callbacks.append(BestCheckpointCallback(
        monitor=metric,
        filename=sv_filename,
        save_top_k=15,
        mode='max',
        save_last=True,
        dirpath = ckptdir,
        verbose = True,
        every_n_epochs = args.check_val_every_n_epoch,
    ))

    now = datetime.datetime.now().strftime("%m-%dT%H-%M-%S")
    cfgdir = os.path.join(logdir, "configs")
    callbacks.append(
        SetupCallback(
                now = now,
                logdir = logdir,
                ckptdir = ckptdir,
                cfgdir = cfgdir,
                config = args.__dict__,
                argv_content = sys.argv + ["gpus: {}".format(torch.cuda.device_count())],)
    )

    callbacks.append(plc.EarlyStopping(monitor=metric, mode=early_stop_val, patience=20 if args.pretrain else 5))
    callbacks.append(TempFileCleanupCallback())
    return callbacks, ckptdir


if __name__ == "__main__":
    os.chdir(sys.path[0])
    args = create_parser()
    pl.seed_everything(args.seed)
    data_module = DInterface(**vars(args))
    if args.test_only:
        data_module.setup(stage="test")
    else:
        data_module.setup()
        args.steps_per_epoch = len(data_module.train_loader)
        args.total_steps = len(data_module.train_loader) * args.epoch
    gpu_count = torch.cuda.device_count()


    logger = plog.WandbLogger(project='PTM-MeToken', dir='./wandb/', name=args.ex_name, offline=args.wandb_offline, config=args.__dict__)
        
    callbacks, ckptdir = load_callbacks(args)
    trainer_config = {
        'gpus': args.gpus if args.ex_name != 'debug' else [0],
        # 'gpus': -1,
        'max_epochs': args.epoch,  # Maximum number of epochs to train for
        'strategy': args.strategy, # 'ddp', 'deepspeed_stage_2'
        'accelerator': 'gpu',  # Use distributed data parallel
        'callbacks': callbacks,
        'logger': logger,
        'gradient_clip_val': 1.0,
        'resume_from_checkpoint': args.ckpt_path if args.test_only else None,
    }

    trainer_opt = argparse.Namespace(**trainer_config)
    trainer = Trainer.from_argparse_args(trainer_opt)
    model = MInterface(**vars(args))
    trainer.test(model, data_module, ckpt_path=os.path.join(ckptdir, 'best.ckpt'))
    trainer.fit(model, data_module)
    if trainer.global_rank == 0:
        metrics = model.cal_metric(path=args.path)
        with open(os.path.join(args.res_dir, args.ex_name, 'metrics.json'), 'w') as file_obj:
            json.dump(metrics, file_obj)