# argparser:

from models.EWC import EWC
from models.SGD import SGD
from models.NashMTL import NashMTLModel
from models.IMTL import IMTLModel
from models.PCGradMTL import PCGradModel
from models.MTAN import MTAN
from models.MTL import MTL
from models.DER import DER
from models.Co2L import Co2L
from models.Single import Single
from models.base import ContinualLearning, MultitaskLearning
from models.DERPP import DERPP
from models.ER import ER
from models.FDR import FDR
from models.GSS import GSS
from models.LWP import LwP
from models.DynamicDistMTL import DynamicDistMTL
from models.LWF import LwF
from models.SI import SI
import argparse
import os
import sys
from typing import Tuple
import torch


sys.path.append(os.path.dirname(
    os.path.dirname(os.path.abspath('baseline_models'))))


def get_args():
    """
    z_dim = 512
    z_prime_dim = 512
    epochs = 1
    batch_size = 256
    model = co2l, der, ...
    """
    parser = argparse.ArgumentParser(
        description='This is a simple argument parser')
    # only cl or mtl:
    parser.add_argument('-ns', '--num_seed', help='num_seed', type=int, default=1)
    parser.add_argument('-j', '--job', help='Continual Learning or MTL',
                        type=str, default='cl', choices=['mtl', 'cl'])
    parser.add_argument('-v', '--verbose',
                        help='Verbose output', action='store_true')
    parser.add_argument('-z', '--z_dim', help='z_dim', type=int, default=512)
    parser.add_argument('-z_prime', '--z_prime_dim',
                        help='z_prime_dim', type=int, default=512)
    parser.add_argument('-e', '--epochs', help='epochs', type=int, default=20)
    parser.add_argument('-b', '--batch_size',
                        help='batch_size', type=int, default=256)
    parser.add_argument('-buffer', '--buffer_size',
                        help='buffer_size', type=int, default=512)
    parser.add_argument('-m', '--model', help='model',
                        type=str, default='co2l')  # co2l
    parser.add_argument('-pretrain', '--pretrain',
                        help='pretrain', type=bool, default=False)
    parser.add_argument('-l', '--lr', help='lr', type=float, default=0.0001)
    parser.add_argument('-a', '--augment', help='augment',
                        type=str, default='default')
    parser.add_argument('-c', '--cls_output_dim',
                        help='cls_output_dim', type=int, default=2)
    parser.add_argument('-d', '--dataset', help='dataset',
                        type=str, default='celeba')
    parser.add_argument('-es', '--early_stopping_tolerance',
                        help='early_stopping_tolerance', type=int, default=5)
    parser.add_argument('-tsr', '--train_subsample_ratio',
                        help='train_subsample_ratio', type=float, default=1)
    parser.add_argument('-is', '--input_size', help='input_size',
                        type=int, default=64)
    parser.add_argument('-ab', '--all_binary', help='convert to binary cls',
                        type=bool, default=True)
    parser.add_argument('-st', '--split_task', help='activate train data splitting based on task',
                        type=bool, default=True)
    parser.add_argument('-ep', '--eval_period', help='(for speed eval) evaluation after this iteration',
                        type=int, default=5)
    parser.add_argument('-sn', '--save_name', help='save_name',
                        type=str, default='')
    
    # lwp specific arguments
    parser.add_argument('--disable_dynamic', help='disable dynamic masking', action='store_true')
    parser.add_argument('--dist_method', help='distance method',
                        choices=['orig', 'co2l', 'rkd', 'cos', 'rbf'], type=str, default='orig')
    
    res = vars(parser.parse_args())
    # assert res['buffer_size'] >= res['batch_size'],\
    #       f"Buffer size {res['buffer_size']} must be greater than batch size {res['batch_size']}"
    return res


def get_models(args: dict, *other_args, **kwargs) -> ContinualLearning:
    if args['model'] == 'co2l':
        raise ValueError("Model not working")
        model = Co2L(*other_args, **kwargs)
    elif args['model'] == 'der':
        model = DER(*other_args, **kwargs)
    elif args['model'] == 'derpp':
        model = DERPP(*other_args, **kwargs)
    elif args['model'] == 'er':
        model = ER(*other_args, **kwargs)
    elif args['model'] == 'fdr':
        model = FDR(*other_args, **kwargs)
    elif args['model'] == 'gss':
        model = GSS(*other_args, **kwargs)
    elif args['model'] == 'lwp':
        model = LwP(*other_args, **kwargs)
    elif args['model'] == 'lwf':
        model = LwF(*other_args, **kwargs)
    elif args['model'] == 'single':
        model = Single(*other_args, **kwargs)
    elif args['model'] == 'si':
        model = SI(*other_args, **kwargs)
    elif args['model'] == 'sgd':
        model = SGD(*other_args, **kwargs)
    elif args['model'] == 'ewc':
        model = EWC(*other_args, **kwargs)
    else:
        raise ValueError(f"Model not found, {args['model']}")
    return model


def get_models_MTL(args: dict, *other_args, **kwargs) -> MultitaskLearning:
    if args['model'] == 'mtl':
        model = MTL(*other_args, **kwargs)
    elif args['model'] == 'mtan':
        raise ValueError("Model not available, debatable?") 
        model = MTAN(*other_args, **kwargs)
    elif args['model'] == 'pcgrad':
        model = PCGradModel(*other_args, **kwargs)
    elif args['model'] == 'imtl':
        model = IMTLModel(*other_args, **kwargs)
    elif args['model'] == 'nashmtl':
        model = NashMTLModel(*other_args, **kwargs)
    elif args['model'] == 'dynamicdistmtl':
        raise ValueError("Model not working, need to do the rename to LwP as well as fix the usability dynamic size of input")
        model = DynamicDistMTL(*other_args, **kwargs)
    else:
        raise ValueError(f"Model not found, {args['model']}")
    return model
