import logging, argparse, os
import pytorch_lightning as pl
# from data.data_modules import ProteinDataModule
# from models.base_model import BaseProteinModel
from models import build_model
from data.data_modules import build_data_module
from utils.args import str2bool

logging.basicConfig(
    format='%(asctime)s %(levelname)-8s %(message)s',
    level=logging.INFO,
    datefmt='%Y-%m-%d %H:%M:%S')

def parse_args(dm_cls=None, model_cls=None):
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--save-dir', default='models/clusterings_geo', type=str)
    parser.add_argument('--model-type', default='counting', type=str, choices=["counting", "discrete", "continuous"])
    parser.add_argument('--max-mm-num', type=int, default=30)
    parser.add_argument('--max-time-diff', type=float, default=0.01)
    parser.add_argument('--day-intervals', type=float, default=0.005)
    parser.add_argument('--one-step-interval', type=float, default=0.0005)
    # parser.add_argument('--batch-size', type=int, default=50)
    parser.add_argument('--max-subtree-size', type=int, default=100)
    parser.add_argument('--max-backward-step', type=int, default=50)
    parser.add_argument('--seed', type=int, default=1005)
    parser.add_argument('--epoch-num', type=int, default=3)
    parser.add_argument('--use-cuda', action="store_true")
    parser.add_argument('--debug', action="store_true")
    parser.add_argument('--wandb_logger', type=str2bool, default="false")
    parser.add_argument('--clean_output', type=str2bool, default="true")
    
    parser.add_argument('--val_every_n_steps', type=int, default=None, help='Eval every n trianing steps.')
    parser.add_argument('--set_none_check_val_every_n_epoch', action="store_true")
    
    parser.add_argument('--early_stop', action="store_true")
    parser.add_argument('--early_stop_patience', type=int, default=3)
    parser.add_argument('--early_stop_monitor', type=str, default="val_loss")
    parser.add_argument('--early_stop_mode', type=str, default="min")
    parser.add_argument('--model_ckpt_monitor', type=str, default="val_loss")
    parser.add_argument('--model_ckpt_mode', type=str, default="min")
    parser.add_argument('--model_ckpt_save_top_k', type=int, default=1)
    
    parser.add_argument('--model_ckpt_every_n_train_steps', type=int, default=None, help="Save the model checkpoints every N training batches. Should run the validation steps first if we want to save according to val loss.")
    parser.add_argument('--test', action="store_true", help="Testing mode.")
    parser.add_argument('--predict', action="store_true", help="Prediction mode.")
    parser.add_argument('--save_prediction_path', type=str, default=None)
    # parser.add_argument('--strategy', type=str, default="ddp")
    # parser.add_argument('--precision', type=int, default=32)
    parser.add_argument('--cudnn_deterministic', type=str2bool, default="true", help="Prediction mode.")
    # parser.add_argument('--test_mode', type=str, default="elbo", choices=["elbo", "mini_elbo", "importance_sampling", "mc", "mini_mc"])
    # parser.add_argument('--debug_mode', type=str, default=None, choices=["zero_z", "pred_z", "pred_zero"])
    parser.add_argument('--model', type=str, default=None)
    parser.add_argument('-d', '--data_module', type=str, default="lm_weighted")

    # Testing
    parser.add_argument('--max_testing_time', type=int, default=-1)
    parser.add_argument('--min_testing_time', type=int, default=-1)
    
    # basic setting for models
    parser.add_argument('--weight_decay_rate', type=float, default=0.01)
    parser.add_argument('--learning_rate', type=float, default=1e-5)
    parser.add_argument('--scheduler', type=str, default="linear", choices=["cosine", "linear"])

    parser.add_argument('--validate_with_generation', type=str2bool, default="false", help="Generate sequences during validation.")
    # parser.add_argument('--logger', type=str, default="wandb", choices=["wandb", "tensorboard"])

    # add the args from Trainer
    parser = pl.Trainer.add_argparse_args(parser)

    # add the args for DataModule
    if dm_cls is not None:
        parser = dm_cls.add_argparse_args(parser)
    else:
        parsed, unparsed = parser.parse_known_args()
        parser = build_data_module(parsed.data_module).add_argparse_args(parser)

    # add the args for Model
    if model_cls is not None:
        parser = model_cls.add_argparse_args(parser)
    else:
        parsed, unparsed = parser.parse_known_args()
        parser = build_model(parsed.model).add_argparse_args(parser)
    
    # Parsing
    args = parser.parse_args()
    

    # If the vocab type (e.g., ESM, MSA etc...) is not specified, we use the vocab type from model_name_or_path.
    if args.vocab == "" and args.model_name_or_path != "":
        if "esm_msa" in args.model_name_or_path:
            args.vocab = "msa"
        else:
            args.vocab = os.path.split(args.model_name_or_path)[-1].split("_")[0]

    # Namespace(accelerator=None, accumulate_grad_batches=None, amp_backend='native', amp_level=None, auto_lr_find=False, auto_scale_batch_size=False, auto_select_gpus=False, benchmark=None, check_val_every_n_epoch=1, default_root_dir=None, detect_anomaly=False, deterministic=None, devices=None, enable_checkpointing=True, enable_model_summary=True, enable_progress_bar=True, fast_dev_run=False, gpus=None, gradient_clip_algorithm=None, gradient_clip_val=None, ipus=None, limit_predict_batches=None, limit_test_batches=None, limit_train_batches=None, limit_val_batches=None, log_every_n_steps=50, logger=True, max_epochs=None, max_steps=-1, max_time=None, min_epochs=None, min_steps=None, move_metrics_to_cpu=False, multiple_trainloader_mode='max_size_cycle', num_nodes=1, num_processes=None, num_sanity_val_steps=2, overfit_batches=0.0, plugins=None, precision=32, profiler=None, reload_dataloaders_every_n_epochs=0, replace_sampler_ddp=True, resume_from_checkpoint=None, strategy=None, sync_batchnorm=False, tpu_cores=None, track_grad_norm=-1, val_check_interval=None, weights_save_path=None)

    if args.set_none_check_val_every_n_epoch:
        args.check_val_every_n_epoch = None

    return args


