import datetime
import sys
import warnings
warnings.filterwarnings("ignore")
import torch
import argparse
import os.path as osp
import pytorch_lightning as pl
from pytorch_lightning.trainer import Trainer
import pytorch_lightning.callbacks as plc
import pytorch_lightning.loggers as plog
from interface import RefinerPR_ITF, DInterface
from src.tools.logger import SetupCallback, BackupCodeCallback, BestCheckpointCallback
import math
from shutil import ignore_patterns

### DEBUG MDOE ###
import os
# import random
# os.environ['MASTER_PORT'] = str(random.randint(10000, 60000))
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
### DEBUG MDOE ###

def create_parser():
    parser = argparse.ArgumentParser()
    # set-up parameters
    parser.add_argument('--res_dir', default='./results', type=str)
    parser.add_argument('--ex_name', default='debug', type=str)
    parser.add_argument('--check_val_every_n_epoch', default=1, type=int)
    parser.add_argument('--offline', default=1, type=int)
    parser.add_argument('--seed', default=111, type=int)
    parser.add_argument('--log_epoch', default=1, type=int)
    parser.add_argument('--test_path', default=None, type=str)
    parser.add_argument('--resume_training', default=1, type=int)

    # dump, debug and visualization parameters
    parser.add_argument('--loss_dump', default=None, type=str) # dump loss in this path, None for no dump
    parser.add_argument('--features_dump',default=None,type=str, help='dump features like alpha, beta, gamma angles and dihedral angles')
    parser.add_argument('--predict',default=None, type=str,help="dump predicted results")

    # dataset parameters
    parser.add_argument('--dataset_type', default='MixDataset', choices=['MixDataset', 'PairDataset','AllTestDataset','PDBDataset','FeaturePredictionDataset',"PTMDataset"]) 
    parser.add_argument('--data_path', default='./example_data/pdb_test.jsonl')
    parser.add_argument('--mixAFDB', default=0, type=int)
    parser.add_argument('--batch_size', default=16, type=int)
    parser.add_argument('--num_workers', default=8, type=int)
    parser.add_argument('--pad', default=1024, type=int)
    parser.add_argument('--min_length', default=30, type=int)
    parser.add_argument('--input_source',default='pdb',choices=["pdb",'afdb'])
    # training parameters
    parser.add_argument('--gpus', type=int, nargs='+', default=[0,1],)
    parser.add_argument('--epoch', default=20, type=int, help='end epoch')
    parser.add_argument('--warmup_epoch', default=3, type=int)
    parser.add_argument('--cycle_epoch', default=1, type=int)
    parser.add_argument('--lr', default=1e-3, type=float, help='Learning rate')
    parser.add_argument('--lr_scheduler', default='cosine')
    parser.add_argument('--lr_decay_steps', default=1000, type=int)
    
    # model parameters
    parser.add_argument('--stage', default='refiner', choices=['rewarder', 'refiner'])
    # parser.add_argument('--r_enc_layers', default=1, type=int)
    parser.add_argument('--r_path', default=None, type=str)
    parser.add_argument('--rf_path', default=None, type=str)
    parser.add_argument('--dropout', default=0.0, type=float)
    parser.add_argument('--rew_type', default='if', type=str, choices=['dis', 'if', 'pair'])

    # parser.add_argument('--num_atoms', default=4, type=int)
    parser.add_argument('--enc_layers', default=8, type=int)
    parser.add_argument('--geo_layers', default=3, type=int)
    parser.add_argument('--edge_layers', default=3, type=int)
    parser.add_argument('--dec_layers', default=6, type=int)
    parser.add_argument('--dec_topk', default=50, type=int)
    parser.add_argument('--hidden_dim', default=128, type=int)
    parser.add_argument('--aug_noise_eps', default=0.0, type=float)
    parser.add_argument('--noise_eps', default=0.1, type=float)
    parser.add_argument('--model_type', default=0, type=int)
    parser.add_argument('--af_weight', default=0.1, type=float)
    parser.add_argument('--atom_mask_ratio',default=0.0,type=float)
    parser.add_argument('--oxygen_mask_ratio',default=None,type=float)
    parser.add_argument('--rotation_noise_eps',default=0.0,type=float)
    parser.add_argument('--bond_noise_eps',default=0.0,type=float)

    parser.add_argument('--loss_type', default="refiner", type=str , choices=['refiner','rmsd'])

    args = parser.parse_args()
    return args

def load_callbacks(args):
    callbacks = []
    args.savedir = savedir = str(osp.join(args.res_dir, args.ex_name))
    ckptdir = osp.join(savedir, "checkpoints")
    callbacks.append(BackupCodeCallback(osp.dirname(args.res_dir), savedir, ignore_patterns=ignore_patterns('__pycache__', 'results*', 'raw_data*', 'example_data*', 'checkpoint*', 'evaluation*', 'lightning_logs*', 'wandb*', '*.jsonl')))

    metric = "val_loss"
    sv_filename = 'best-{epoch:02d}-{val_loss:.4f}'
    callbacks.append(BestCheckpointCallback(
        monitor=metric,
        filename=sv_filename,
        save_top_k=15,
        mode='min',
        save_last=True,
        dirpath=ckptdir,
        verbose=True,
        every_n_epochs=args.check_val_every_n_epoch,
    ))

    now = datetime.datetime.now().strftime("%m-%dT%H-%M-%S")
    cfgdir = osp.join(savedir, "configs")
    callbacks.append(
        SetupCallback(
            now=now,
            logdir=savedir,
            ckptdir=ckptdir,
            cfgdir=cfgdir,
            config=args.__dict__,
            argv_content=sys.argv + ["gpus: {}".format(torch.cuda.device_count())])
    )
    
    if args.lr_scheduler:
        callbacks.append(plc.LearningRateMonitor(logging_interval=None))
    return callbacks

if __name__ == "__main__":
    args = create_parser()
    pl.seed_everything(args.seed)
    callbacks = load_callbacks(args)
    
    gpu_count = torch.cuda.device_count()
    remove_pdb = True if args.stage == 'selector' else False
    data_module = DInterface(data_path=args.data_path, batch_size=args.batch_size, num_workers=args.num_workers, dataset_type=args.dataset_type, remove_pdb=remove_pdb)
    data_module.setup()
    args.steps_per_epoch = math.ceil(len(data_module.trainset) / args.batch_size)
    
    trainer_config = {
        'gpus': args.gpus,  # Use all available GPUs
        'max_epochs': args.epoch,  # Maximum number of epochs to train for
        'num_nodes': 1,  # Number of nodes to use for distributed training
        "strategy": 'auto', # 'ddp', 'deepspeed_stage_2'
        'accelerator': 'gpu',  # Use distributed data parallel
        'callbacks': callbacks,
        '''
        'logger': plog.WandbLogger(
                    project='AF2DB-Elite',
                    name=args.ex_name,
                    save_dir=args.savedir,
                    offline=args.offline,
                    id="_".join(args.ex_name.split("/")),
                    entity="chengtan9907"),
        '''
        'logger':plog.CSVLogger(str(os.path.join(args.res_dir, args.ex_name))),
        'strict_loading': False,
        'reload_dataloaders_every_n_epochs': 1
    }

    trainer_opt = argparse.Namespace(**trainer_config)
    trainer = Trainer.from_argparse_args(trainer_opt)

    if args.stage == 'rewarder':
        model_itf = Rewarder_ITF
    elif args.stage == 'refiner':
        if args.rew_type == 'dis':
            model_itf = Refiner_ITF
        elif args.rew_type == 'if':
            model_itf = RefinerIF_ITF
        elif args.rew_type == 'pair':
            model_itf = RefinerPR_ITF
        
    if args.test_path is None:
        model = model_itf(**vars(args))
        trainer.fit(model, data_module)
    elif args.resume_training:
        model = model_itf.load_from_checkpoint(args.test_path,savedir=args.savedir,lr=args.lr)
        trainer.fit(model, data_module)
    else:
        model = model_itf.load_from_checkpoint(args.test_path,loss_dump=args.loss_dump,loss_type=args.loss_type,savedir=args.savedir,features_dump=args.features_dump,predict=args.predict,input_source=args.input_source,aug_noise_eps=args.aug_noise_eps,atom_mask_ratio=args.atom_mask_ratio,oxygen_mask_ratio=args.oxygen_mask_ratio,rotation_noise_eps=args.rotation_noise_eps,bond_noise_eps=args.bond_noise_eps)

    if args.predict:
        trainer.predict(model,data_module)
    else:
        trainer.test(model, data_module)
    model.cal_metric()