import datetime
import os
import os.path as osp
import sys
import warnings
warnings.filterwarnings("ignore")

import argparse
import pytorch_lightning as pl
import torch
from pytorch_lightning.trainer import Trainer
import pytorch_lightning.callbacks as plc
import pytorch_lightning.loggers as plog
from model_interface import MInterface
from data_interface import DInterface
from src.tools.logger import SetupCallback, BackupCodeCallback, BestCheckpointCallback
import math
from shutil import ignore_patterns


def create_parser():
    parser = argparse.ArgumentParser()
    # set-up parameters
    parser.add_argument('--res_dir', default='./results', type=str)
    parser.add_argument('--ex_name', default='structgnn_afdb', type=str)
    parser.add_argument('--check_val_every_n_epoch', default=1, type=int)
    parser.add_argument('--dataset', default='SYNC')
    parser.add_argument('--model_name', default='StructGNN', choices=['StructGNN', 'GraphTrans', 'GVP', 'ESMIF', 'PiFold', 'ProteinMPNN'])
    parser.add_argument('--lr', default=1e-3, type=float, help='Learning rate')
    parser.add_argument('--lr_scheduler', default='onecycle')
    parser.add_argument('--offline', default=1, type=int)
    parser.add_argument('--seed', default=111, type=int)
    
    # dataset parameters
    parser.add_argument('--batch_size', default=32, type=int)
    parser.add_argument('--num_workers', default=12, type=int)
    parser.add_argument('--pad', default=1024, type=int)
    parser.add_argument('--min_length', default=40, type=int)
    parser.add_argument('--data_root', default='./dataset')
    
    # Training parameters
    parser.add_argument('--epoch', default=200, type=int, help='end epoch')
    parser.add_argument('--augment_eps', default=0.0, type=float, help='noise level')
    parser.add_argument('--mask_ratio', default=0.1, type=float)

    # Model parameters
    parser.add_argument('--use_dist', default=1, type=int)
    parser.add_argument('--use_product', default=0, type=int)
    parser.add_argument('--sync_data', default='select-0907-2')
    args = parser.parse_args()
    return args


def find_best_checkpoint(ckpt_dir):
    for file_name in os.listdir(ckpt_dir):
        if 'best' in file_name and file_name.endswith('.ckpt'):
            return os.path.join(ckpt_dir, file_name)
    raise FileNotFoundError(f"No best checkpoint found in {ckpt_dir}")


if __name__ == "__main__":
    args = create_parser()

    pl.seed_everything(args.seed)

    data_module = DInterface(**vars(args))
    data_module.setup()

    model = MInterface(**vars(args))

    ckpt_dir = os.path.join(args.res_dir, args.ex_name, 'checkpoints')
    best_ckpt_path = find_best_checkpoint(ckpt_dir)
    print(f"Loading checkpoint from: {best_ckpt_path}")

    model = MInterface.load_from_checkpoint(best_ckpt_path, **vars(args))

    trainer = Trainer(accelerator='gpu', devices=1)

    trainer.test(model, datamodule=data_module)
