import datetime
import os
import os.path as osp
import sys
import warnings
warnings.filterwarnings("ignore")

import argparse
import pytorch_lightning as pl
import torch
from pytorch_lightning.trainer import Trainer
import pytorch_lightning.loggers as plog
from model_interface import MInterface
from data_interface import DInterface
from src.tools.logger import SetupCallback, BackupCodeCallback, BestCheckpointCallback
import math
from shutil import ignore_patterns


def create_parser():
    parser = argparse.ArgumentParser()
    # set-up parameters
    parser.add_argument('--res_dir', default='./results', type=str)
    parser.add_argument('--ex_name', default='default', type=str)
    parser.add_argument('--check_val_every_n_epoch', default=1, type=int)
    parser.add_argument('--dataset', default='SYNC')
    parser.add_argument('--model_name', default='StructGNN', choices=['StructGNN', 'GraphTrans', 'GVP', 'ESMIF', 'PiFold', 'ProteinMPNN',"StructGNN_Plus"])
    parser.add_argument('--lr', default=1e-3, type=float, help='Learning rate')
    parser.add_argument('--lr_scheduler', default='onecycle')
    parser.add_argument('--offline', default=1, type=int)
    parser.add_argument('--seed', default=111, type=int)
    parser.add_argument('--gpus', type=int, nargs='+', default=[0])
    parser.add_argument('--test_path',type=str,default=None)

    # dataset parameters
    parser.add_argument('--batch_size', default=32, type=int)
    parser.add_argument('--num_workers', default=12, type=int)
    parser.add_argument('--pad', default=1024, type=int)
    parser.add_argument('--min_length', default=40, type=int)
    parser.add_argument('--data_root', default='./dataset')
    # parser.add_argument('--input_source',default='pdb',choices=["pdb",'afdb'])
    
    # Training parameters
    parser.add_argument('--epoch', default=200, type=int, help='end epoch')
    parser.add_argument('--augment_eps', default=0.01, type=float, help='noise level')
    parser.add_argument('--mask_ratio', default=0.1, type=float)

    # Model parameters
    parser.add_argument('--use_dist', default=1, type=int)
    parser.add_argument('--use_product', default=0, type=int)
    parser.add_argument('--sync_data', default='select-0920-4')
    parser.add_argument('--use_refine_only', default=False)

    # dataset_specific parameters, if not using, leave it alone
    # af_cath_mix_dataset
    parser.add_argument('--afdb_rmsd_score', default=None,choices=[None,'high','middle','low'],help="choose what kind of afdb sets as train set")
    # pairwise_cath_mix_dataset
    parser.add_argument('--pairwise_source', default=None,choices=[None],help="choose what kind of afdb sets as train set") # nomask here means no masking when being refined, but have mask during training
    parser.add_argument('--test_source',default='default',choices=['default','afdb']) # it means the afdb version of 'CATH4.2' or its original version
    # ts_dataset
    parser.add_argument('--category',default='ts50',choices=['ts50','ts500'])
    args = parser.parse_args()
    return args

def load_callbacks(args):
    callbacks = []
    logdir = str(os.path.join(args.res_dir, args.ex_name))
    ckptdir = os.path.join(logdir, "checkpoints")
    callbacks.append(BackupCodeCallback(os.path.dirname(args.res_dir), logdir, ignore_patterns=ignore_patterns('results*', 'pdb*', 'metadata*', 'vq_dataset*', '.vscode', '__pycache__', '*.ipynb', 'dataset*')))
    
    metric = "recovery"  
    sv_filename = 'best-{epoch:02d}-{recovery:.3f}'
    
    ckpt_callback = BestCheckpointCallback(
        monitor=metric,
        filename=sv_filename,
        save_top_k=1,  
        mode='max',
        save_last=True,  
        dirpath=ckptdir,
        verbose=True,
        every_n_epochs=args.check_val_every_n_epoch,
    )
    callbacks.append(ckpt_callback)

    now = datetime.datetime.now().strftime("%m-%dT%H-%M-%S")
    cfgdir = os.path.join(logdir, "configs")
    callbacks.append(
        SetupCallback(
            now=now,
            logdir=logdir,
            ckptdir=ckptdir,
            cfgdir=cfgdir,
            config=args.__dict__,
            argv_content=sys.argv + ["gpus: {}".format(torch.cuda.device_count())],
        )
    )
    return callbacks


if __name__ == "__main__":
    args = create_parser()
    pl.seed_everything(args.seed)
    
    data_module = DInterface(**vars(args))
    data_module.setup()
    
    gpu_count = torch.cuda.device_count()
    # gpu_count = 1
    args.steps_per_epoch = math.ceil(len(data_module.trainset) / args.batch_size / min(gpu_count,len(args.gpus)))
    model = MInterface(**vars(args))
    if args.test_path:
        print("LOAD EXISTING MODEL")
        model= model.load_from_checkpoint(args.test_path,seed=args.seed,gpus=args.gpus,test_path=args.test_path,test_source=args.test_source,mask_ratio=args.mask_ratio,ex_name=args.ex_name,dataset=args.dataset,category=args.category)
    logger=plog.CSVLogger(str(os.path.join(args.res_dir, args.ex_name)))
    trainer_config = {
        'gpus': args.gpus,  
        'max_epochs': args.epoch,  
        'num_nodes': 1,  
        "strategy": 'auto',  
        'accelerator': 'gpu',  
        'callbacks': load_callbacks(args),
        'logger': logger,
        'gradient_clip_val': 1.0
    }

    trainer = Trainer(**trainer_config)
    if args.test_path == None:
        trainer.fit(model, data_module)
    print("TEST STAGE")
    trainer.test(model,data_module)