import warnings
import os
import pytorch_lightning as pl
from argparse import ArgumentParser
from pytorch_lightning import Trainer
import pytorch_lightning.callbacks as plc
from pytorch_lightning.loggers import TensorBoardLogger
from model import MInterface
from data import DInterface
from utils import load_model_path_by_args

warnings.filterwarnings("ignore")

def load_callbacks():
    callbacks = []
    callbacks.append(plc.ModelCheckpoint(
        monitor='loss',
        filename='best-{epoch:02d}-{loss:02f}',
        save_top_k=3,
        mode='min',
        save_last=True
    ))
    # callbacks.append(plc.ModelCheckpoint(
    #     monitor='similarity',
    #     filename='best-{epoch:02d}-{similarity:02f}-{cosine:02f}',
    #     save_top_k=3,
    #     mode='max',
    #     save_last=True
    # ))
    callbacks.append(plc.ModelCheckpoint(
        monitor='dci',
        filename='best-{epoch:02d}-{dci:04f}-{cosine:02f}-{MIG_discrete_mig:04f}-{modularity_score:04f}-{explicitness_score_test:04f}',
        save_top_k=3,
        mode='max',
        save_last=True
    ))
    # callbacks.append(plc.ModelCheckpoint(
    #     monitor='cosine',
    #     filename='best-{epoch:02d}-{cosine:02f}',
    #     save_top_k=3,
    #     mode='min',
    #     save_last=True
    # ))

    if args.lr_scheduler:
        callbacks.append(plc.LearningRateMonitor(
            logging_interval='epoch'))
    return callbacks


def main(args):
    pl.seed_everything(args.seed)
    # load_path = load_model_path_by_args(args)
    # load_path = args.ckpt_path
    load_path = None
    data_module = DInterface(**vars(args))

    if load_path is None:
        model = MInterface(**vars(args))
        model = model.to(args.gpu)
    else:
        model = MInterface(**vars(args))
        model = model.to(args.gpu)
        # args.ckpt_path = load_path

    # # If you want to change the logger's saving folder
    logger = TensorBoardLogger(save_dir='./information_logs', name=args.log_dir)
    args.callbacks = load_callbacks()
    args.logger = logger
    trainer = Trainer.from_argparse_args(args,devices=[0],accelerator='cuda')  #select gpu
    # trainer = Trainer.from_argparse_args(
    #     args,
    #     devices=[1],
    #     accelerator='cuda',
    #     resume_from_checkpoint=args.ckpt_path if args.ckpt_path else None
    #     )
    # trainer = Trainer.from_argparse_args(args)
    trainer.fit(model, data_module)


if __name__ == '__main__':
    parser = ArgumentParser()
    
    # Basic Training Control
    parser.add_argument('--batch_size', default=50, type=int) 
    parser.add_argument('--num_workers', default=32, type=int)
    parser.add_argument('--seed', default=3907, type=int)
    parser.add_argument('--lr', default=3e-4, type=float)
    parser.add_argument('--gpu', default='cuda:0', type=str)
    # LR Scheduler
    parser.add_argument('--lr_scheduler', default='step', type=str) #choices=['step', 'cosine']
    parser.add_argument('--lr_decay_steps', default=100, type=int)
    parser.add_argument('--lr_decay_rate', default=0.5, type=float)
    parser.add_argument('--lr_decay_min_lr', default=1e-6, type=float)
    parser.add_argument('--warmup_epochs', default=50, type=int)
    
    # Restart Control
    parser.add_argument('--load_best', action='store_true')
    parser.add_argument('--load_dir', default=None, type=str)
    parser.add_argument('--load_ver', default=None, type=str)
    parser.add_argument('--load_v_num', default=None, type=int)
    
    # Training Info
    parser.add_argument('--dataset', default='standard_data', type=str)
    parser.add_argument('--csv_file', default=' ./3dshapeWhole.csv', type=str,help='train csv')
    parser.add_argument('--val_dir', default='./3dshapeWhole.csv', type=str,help='test csv')
    parser.add_argument('--model_name', default='CDQAE', type=str) # DeFnetV3
    parser.add_argument('--loss', default='ce', type=str)
    parser.add_argument('--weight_decay', default=1e-5, type=float)
    parser.add_argument('--no_augment', action='store_true')
    parser.add_argument('--log_dir', default='CDQAE_shape', type=str)
    # Model Hyperparameters
    parser.add_argument('--hidden_size', default=512, type=int)
    parser.add_argument('--beta', default=0.7, type=float)
    parser.add_argument('--moving_average_decay', default=0.99, type=float)
    parser.add_argument('--nmf_lambda', default=0.4, type=float) #0.3
    parser.add_argument('--orthogonal_lambda', default=0.3, type=float)
    parser.add_argument('--similarity_lambda', default=0.5, type=float)
    parser.add_argument('--vae_lambda', default=0.1, type=float)
    # Other
    parser.add_argument('--aug_prob', default=0.5, type=float)
    parser.add_argument('--npz_path', default='/home/star/Projects/g2/gyh/gyh/information/metric/DefnetV3/test.npz',type=str)



    ## Deprecated, old version
    # parser = Trainer.add_argparse_args(
    #     parser.add_argument_group(title="pl.Trainer args"))

    # Reset Some Default Trainer Arguments' Default Values
    parser.set_defaults(max_epochs=2000)

    args = parser.parse_args()
    
    # List Arguments
    args.mean_sen = [0.5]
    args.std_sen = [0.5]

    main(args)
