import EWC
import ShiftingWindowSetting
import AGEM
import PackNet
import PrevL2
import DeepCCG
import ER
import EntropySS
import GSS
import StandardheadDeepCCG
import ReservoirDeepCCG
import argparse


def pair(arg):
    # For simplicity, assume arg is a pair of integers
    # separated by a comma. If you want to do more
    # validation, raise argparse.ArgumentError if you
    # encounter a problem.
    return tuple(arg)


def parse_arguments():
    parser = argparse.ArgumentParser(description='Argument parser')
    parser.add_argument('--setting', default='shifting_window', type=str, help='Cl Setting')
    parser.add_argument('--algo', default='SGD', type=str, help='The algo to run')
    parser.add_argument('--dataset', default='CIFAR100', type=str, help='The dataset to be used')
    parser.add_argument('--mem_size', default=1000, type=int,
                        help='the number of samples allowed to be stored in memory')
    parser.add_argument('--batch_size', default=10, type=int, help='the size of data batches given to the method')
    parser.add_argument('--replay_batch_size', default=10, type=int,
                        help='the number of samples to be replayed (for replay methods)')
    parser.add_argument('--reg_coef', default=1, type=int, help='the regularisation coefficient to use')
    parser.add_argument('--mask_rate', default=0.3, type=float, help='the mask rate if PackNet is the selected algo')
    parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
    parser.add_argument('--seed', default=(42, 1), type=int, nargs=2,
                        help='the seed to be used, a pair of ints')
    parser.add_argument('--mem_samples_per_task', default=11, type=int, help='the number samples per task (for AGEM)')
    parser.add_argument('--epochs', default=1, type=int, help='the number of epochs per task')
    return parser.parse_args()


def run_experiment(args):
    if args.algo == 'ER_reservoir':
        learning_algo = ER.ER_reservoir(args=args, mem_size=args.mem_size, replay_batch_size=args.replay_batch_size)
    elif args.algo == 'SGD':
        learning_algo = ShiftingWindowSetting.CLLearningAlgo(args=args)
    elif args.algo == 'EntropySS':
        learning_algo = EntropySS.EntropySS(args=args, mem_size=args.mem_size, replay_batch_size=args.replay_batch_size)
    elif args.algo == 'GSS':
        learning_algo = GSS.GSS(args=args, num_of_mem_samples=args.mem_size)
    elif args.algo == 'EWC_reg':
        learning_algo = EWC.EWC_reg(args=args, reg_coef=args.reg_coef)
    elif args.algo == 'EWC_constMemReg':
        learning_algo = EWC.EWC_constMemReg(args=args, reg_coef=args.reg_coef)
    elif args.algo == 'PackNet':
        learning_algo = PackNet.PackNet(args=args, mask_rate=args.mask_rate)
    elif args.algo == 'AGEM':
        learning_algo = AGEM.AGEM(args=args, mem_samples_per_task=args.mem_samples_per_task)
    elif args.algo == 'DeepCCG':
        learning_algo = DeepCCG.DeepCCG(args=args, mem_size=args.mem_size, mem_batch_size=args.replay_batch_size)
    elif args.algo == 'DeepCCG_reservoir':
        learning_algo = ReservoirDeepCCG.DeepCCG_reservoir(args=args, mem_size=args.mem_size,
                                                           mem_batch_size=args.replay_batch_size)
    elif args.algo == 'DeepCCG_SH':
        learning_algo = StandardheadDeepCCG.DeepCCG_SH(args=args, mem_size=args.mem_size,
                                                       mem_batch_size=args.replay_batch_size)
    else:
        raise ValueError('Algo given is not supported')

    if args.setting == 'disjoint_tasks':
        learning_algo.run_disjoint_tasks_setting()
    elif args.setting == 'shifting_window':
        learning_algo.run_shifting_window_setting()
    elif args.setting == 'dummy_window':
        learning_algo.run_dummy_window_setting()
    elif args.setting == 'dummy_disjoint_tasks':
        learning_algo.run_dummy_disjoint_tasks_setting()
    else:
        raise ValueError('Setting chosen is not supported')


if __name__ == '__main__':
    args = parse_arguments()
    run_experiment(args)
