import os.path
import sys
import argparse
import datetime
import random
import numpy as np
import time
import torch
import torch.backends.cudnn as cudnn

from pathlib import Path

from timm.models import create_model
from timm.scheduler import create_scheduler
from timm.optim import create_optimizer

from datasets import build_continual_dataloader

import utils
import warnings

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
warnings.filterwarnings('ignore', 'Argument interpolation should be of type InterpolationMode instead of int')


def get_args():
    parser = argparse.ArgumentParser('Training and evaluation configs')
    config = parser.parse_known_args()[-1][0]
    subparser = parser.add_subparsers(dest='subparser_name')
    # import pdb; pdb.set_trace()
    if config == 'cifar100_hideprompt_5e':
        from configs.cifar100_hideprompt_5e import get_args_parser
        config_parser = subparser.add_parser('cifar100_hideprompt_5e', help='Split-CIFAR100 HiDe-Prompt configs')
    elif config == 'imr_hideprompt_5e':
        from configs.imr_hideprompt_5e import get_args_parser
    elif config == 'esc_hideprompt_5e':
        from configs.imr_hideprompt_5e import get_args_parser
        config_parser = subparser.add_parser('esc_hideprompt_5e', help='Split-ImageNet-R HiDe-Prompt configs')
    elif config == 'five_datasets_hideprompt_5e':
        from configs.five_datasets_hideprompt_5e import get_args_parser
        config_parser = subparser.add_parser('five_datasets_hideprompt_5e', help='five datasets HiDe-Prompt configs')
    elif config == 'cub_hideprompt_5e':
        from configs.cub_hideprompt_5e import get_args_parser
        config_parser = subparser.add_parser('cub_hideprompt_5e', help='Split-CUB HiDe-Prompt configs')
    elif config == 'cifar100_dualprompt':
        from configs.cifar100_dualprompt import get_args_parser
        config_parser = subparser.add_parser('cifar100_dualprompt', help='Split-CIFAR100 dual-prompt configs')
    elif config == 'imr_dualprompt':
        from configs.imr_dualprompt import get_args_parser
        config_parser = subparser.add_parser('imr_dualprompt', help='Split-ImageNet-R dual-prompt configs')
    elif config == 'five_datasets_dualprompt':
        from configs.five_datasets_dualprompt import get_args_parser
        config_parser = subparser.add_parser('five_datasets_dualprompt', help='five datasets dual-prompt configs')
    elif config == 'cub_dualprompt':
        from configs.cub_dualprompt import get_args_parser
        config_parser = subparser.add_parser('cub_dualprompt', help='Split-CUB dual-prompt configs')
    elif config == 'cifar100_sprompt_5e':
        from configs.cifar100_sprompt_5e import get_args_parser
        config_parser = subparser.add_parser('cifar100_sprompt_5e', help='Split-CIFAR100 s-prompt configs')
    elif config == 'imr_sprompt_5e':
        from configs.imr_sprompt_5e import get_args_parser
        config_parser = subparser.add_parser('imr_sprompt_5e', help='Split-ImageNet-R s-prompt configs')
    elif config == 'five_datasets_sprompt_5e':
        from configs.five_datasets_sprompt_5e import get_args_parser
        config_parser = subparser.add_parser('five_datasets_sprompt_5e', help='five datasets s-prompt configs')
    elif config == 'cub_sprompt_5e':
        from configs.cub_sprompt_5e import get_args_parser
        config_parser = subparser.add_parser('cub_sprompt_5e', help='Split-CUB s-prompt configs')
    elif config == 'cifar100_l2p':
        from configs.cifar100_l2p import get_args_parser
        config_parser = subparser.add_parser('cifar100_l2p', help='Split-CIFAR100 l2p configs')
    elif config == 'imr_l2p':
        from configs.imr_l2p import get_args_parser
        config_parser = subparser.add_parser('imr_l2p', help='Split-ImageNet-R l2p configs')
    elif config == 'five_datasets_l2p':
        from configs.five_datasets_l2p import get_args_parser
        config_parser = subparser.add_parser('five_datasets_l2p', help='five datasets l2p configs')
    elif config == 'cub_l2p':
        from configs.cub_l2p import get_args_parser
        config_parser = subparser.add_parser('cub_l2p', help='Split-CUB l2p configs')
    elif config == 'cifar100_hidelora':
        from configs.cifar100_hidelora import get_args_parser
        config_parser = subparser.add_parser('cifar100_hidelora', help='Split-CIFAR100 hidelora configs')
    elif config == 'imr_hidelora':
        from configs.imr_hidelora import get_args_parser
        config_parser = subparser.add_parser('imr_hidelora', help='Split-ImageNet-R hidelora configs')
    elif config == 'cifar100_continual_lora':
        from configs.cifar100_continual_lora import get_args_parser
        config_parser = subparser.add_parser('cifar100_continual_lora', help='Split-CIFAR100 continual lora configs')
    elif config == 'imr_continual_lora':
        from configs.imr_continual_lora import get_args_parser
        config_parser = subparser.add_parser('imr_continual_lora', help='Split-ImageNet-R continual lora configs')
    elif config == 'cifar100_hideadapter':
        from configs.cifar100_hideadapter import get_args_parser
        config_parser = subparser.add_parser('cifar100_hideadapter', help='Split-CIFAR100 hideadapter configs')
    elif config == 'imr_hideadapter':
        from configs.imr_hideadapter import get_args_parser
        config_parser = subparser.add_parser('imr_hideadapter', help='Split-ImageNet-R hideadapter configs')
    elif config == 'esc_continual_prompt':
        from configs.imr_continual_prompt import get_args_parser
        config_parser = subparser.add_parser('esc_continual_prompt', help='Split-ImageNet-R continual prompt config')
    elif config == 'imr_continual_adapter':
        from configs.imr_continual_adapter import get_args_parser
        config_parser = subparser.add_parser('imr_continual_adapter', help='Split-ImageNet-R continual adapter config')
    elif config == 'esc_naive_ft':
        from configs.esc_naive_ft import get_args_parser
        config_parser = subparser.add_parser('esc_naive_ft', help='Split-ESC-50 naive fine-tuning config')
    elif config == 'esc_lae':
        from configs.esc_naive_ft import get_args_parser
        config_parser = subparser.add_parser('esc_lae', help='Split-ESC-50 naive fine-tuning config')
    elif config == 'esc_acl':
        from configs.esc_naive_ft import get_args_parser
        config_parser = subparser.add_parser('esc_acl', help='Split-ESC-50 naive fine-tuning config')
    elif config == 'esc_dualprompt':
        # from configs.esc_dualprompt import get_args_parser
        from configs.esc_naive_ft import get_args_parser
        config_parser = subparser.add_parser('esc_dualprompt', help='Split-ESC-50 dualprompt config')
    elif config == 'esc_l2p':
        from configs.esc_naive_ft import get_args_parser
        config_parser = subparser.add_parser('esc_l2p', help='Split-ESC-50 dualprompt config')
    elif config == 'esc_sprompt':
        from configs.esc_naive_ft import get_args_parser
        config_parser = subparser.add_parser('esc_sprompt', help='Split-ESC-50 dualprompt config')
    elif config == 'esc_ranpac':
        from configs.esc_naive_ft import get_args_parser
        config_parser = subparser.add_parser('esc_ranpac', help='Split-ESC-50 dualprompt config')
    elif config == 'esc_joint':
        from configs.esc_naive_ft import get_args_parser
        config_parser = subparser.add_parser('esc_joint', help='Split-ESC-50 dualprompt config')
    else:
        raise NotImplementedError

    get_args_parser(config_parser)
    args = parser.parse_args()
    args.config = config
    return args

def main(args):
    utils.init_distributed_mode(args)
    if args.output_dir:
        Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    # fix the seed for reproducibility
    seed = args.seed
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    cudnn.benchmark = True
    # import pdb; pdb.set_trace()
    if hasattr(args, 'train_inference_task_only') and args.train_inference_task_only:
        import trainers.tii_trainer as tii_trainer
        tii_trainer.train(args)
    elif 'hideprompt' in args.config and not args.train_inference_task_only:
        import trainers.hideprompt_trainer as hideprompt_trainer
        hideprompt_trainer.train(args)
    elif 'l2p' in args.config or 'dualprompt' in args.config or 'sprompt' in args.config:
        import trainers.dp_trainer as dp_trainer
        dp_trainer.train(args)
    elif 'hidelora' in args.config and not args.train_inference_task_only:
        import trainers.hidelora_trainer as hidelora_trainer
        hidelora_trainer.train(args)
    elif 'continual_lora' in args.config:
        import trainers.continual_lora_trainer as continual_lora_trainer
        continual_lora_trainer.train(args)
    elif 'hideadapter' in args.config and not args.train_inference_task_only:
        import trainers.hideadapter_trainer as hideapater_trainer
        hideapater_trainer.train(args)
    elif 'continual_prompt' in args.config:
        import trainers.continual_prompt_trainer as continual_prompt_trainer
        continual_prompt_trainer.train(args)
    elif 'continual_adapter' in args.config:
        import trainers.continual_adapter_trainer as continual_adapter_trainer
        continual_adapter_trainer.train(args)
    elif 'naive_ft' in args.config:
        import trainers.naive_ft_trainer as naive_ft_trainer
        naive_ft_trainer.train(args)
    elif 'acl' in args.config:
        import trainers.acl_trainer as acl_trainer
        acl_trainer.train(args)
    elif 'ranpac' in args.config:
        import trainers.ranpac_trainer as ranpac_trainer
        ranpac_trainer.train(args)
    elif 'joint' in args.config:
        import trainers.joint_trainer as ranpac_trainer
        ranpac_trainer.train(args)
    elif 'lae' in args.config:
        import trainers.lae_trainer as lae_trainer
        lae_trainer.train(args)
    else:
        raise NotImplementedError

if __name__ == '__main__':
    
    args = get_args()
    print(args)
    main(args)
