import argparse
import math
import os.path
import pickle
import time
from collections import defaultdict
from copy import deepcopy
from typing import Sequence, Tuple

import numpy as np
import pandas as pd
import torchvision.models
import torchvision.transforms as transforms
import torch
from torch import nn
from torch.utils.data import DataLoader, RandomSampler, BatchSampler, WeightedRandomSampler

import cole as cl
from evaluator import Evaluator
from tiny_imagenet import get_split_tiny_imagenet

# This is where 'cole', my small Continual Learning library will look for datasets, only important for CIFAR100
# (but it will download it anyway if it doesn't find it)
cl.set_data_path('../data')


def get_task_sequence(sequence_name: str, seed: int, joint: bool = False, select_task: int = None) \
        -> Tuple[Sequence[str], Sequence[Sequence[int]]]:
    """
    Get the task sequence from a file in ./task_sequences. It is defined by its name and seed.
    Returns a sequence with the dataset strings and a sequence of the same length with the class ids.

    Joint makes cumulative sums of all classes upto that task
    select task: only returns a single task rather than the full one.
    """
    path = os.path.join('task_sequences', f"{sequence_name}_{seed}.txt")
    ds, tasks = [], []
    with open(path, 'r') as f:
        for line in f.readlines():
            elems = line.strip().split(',')
            ds.append(elems[0])
            tasks.append([int(e) for e in elems[1:]])

    if joint:
        for i in range(1, len(tasks)):
            # Get all classes including the new tasks (kind of cumulative sum with lists)
            tasks[i] = sum(tasks[i-1:i+1], [])

    if select_task is not None:
        select_task = int(select_task)
        # There needs to be two tasks. For scratch learning it doesn't matter in practice where for which this
        # option is used, but it would be annoying to check for that elsewhere.
        return [ds[select_task], ds[select_task]], [[], tasks[select_task]]
    else:
        return ds, tasks


def get_transforms(exp_settings):
    """
    Creates augmentations. Default is made for CIFAR100 (size = 32 x 32)
    """
    crop_size = (64, 64) if exp_settings['new_data'] == 'tiny' else (32, 32)
    if exp_settings['aug'] == 'cropflip':
        train_aug = transforms.Compose([transforms.ToTensor(),
                                        transforms.Normalize((0.5071, 0.4866, 0.4409), (0.2009, 0.1984, 0.2023)),
                                        transforms.RandomHorizontalFlip(),
                                        transforms.RandomCrop(size=crop_size, padding=crop_size[0] // 8)])
    elif exp_settings['aug'] == 'simclr':
        train_aug = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4866, 0.4409), (0.2009, 0.1984, 0.2023)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomResizedCrop(size=crop_size, scale=(0.2, 1.0), antialias=True),
            transforms.RandomApply(
                [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2,
                                        hue=0.1)], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.RandomApply([
                transforms.GaussianBlur(kernel_size=crop_size[0] // 20 * 2 + 1,
                                        sigma=(0.1, 2.0))], p=0.5)])
    else:
        train_aug = None

    test_aug = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                                               torchvision.transforms.Normalize((0.5071, 0.4866, 0.4409),
                                                                                (0.2009, 0.1984, 0.2023))])
    return train_aug, test_aug


def get_data(datasets: Sequence[str], task_labels: Sequence[Sequence[int]], train_transform=None, test_transform=None):
    """
    Get the data. Returns a sequence of datasets that correspond to the datasets and task labels passed. Initially
    I wrote this to be able to handle multiple datasets (e.g. task one has cifar100, task two has tiny imagenet), but
    I've never implemented that (and also never needed it).
    :return:
    """
    if len(set(datasets)) == 1:
        if datasets[0] == 'cifar10':
            train_ds = cl.get_split_cifar10(task_labels=task_labels).train
            test_ds = cl.get_split_cifar10(task_labels=task_labels, joint=True).test
        elif datasets[0] == 'cifar100':
            train_ds = cl.get_split_cifar100(task_labels=task_labels, train_transform=train_transform,
                                             test_transform=test_transform).train
            test_ds = cl.get_split_cifar100(task_labels=task_labels, train_transform=train_transform,
                                            test_transform=test_transform, joint=True).test
        elif datasets[0] == 'tiny':
            ds = get_split_tiny_imagenet(task_labels=task_labels, train_transform=train_transform,
                                         test_transform=test_transform, joint_test=True)
            train_ds, test_ds = ds.train, ds.test
        else:
            raise NotImplementedError(f'Dataset {datasets[0]} not implemented.')
    else:
        raise NotImplementedError('Multiple datasets not implemented.')

    return train_ds, test_ds


class IndexXYDataset(cl.XYDataset):
    """
    Dataset class that also returns the index of the label rather than just x and y, this is to be able to log the
    predictions during training.
    """

    def __init__(self, x, y, **kwargs):
        super().__init__(x, y, **kwargs)

    def __getitem__(self, item):
        x, y = super().__getitem__(item)
        return x, y, item


def to_indexed_dataset(ds: cl.XYDataset) -> IndexXYDataset:
    return IndexXYDataset(ds.x, ds.y, transform=ds.transform)


def get_num_classes(exp_settings):
    if exp_settings['new_data'] == 'cifar10':
        return 10
    elif exp_settings['new_data'] == 'cifar100':
        return 100
    elif exp_settings['new_data'] == 'tiny':
        return 200
    else:
        raise ValueError(f'Unknown data option {exp_settings["new_data"]}.')


def get_model(model_name: str, exp_settings, device='cuda', num_classes: int = 10):
    """
    Gets the base model. My resnet is the typically reduced version of resnet18 that's used on cifar100 data.
    """
    if model_name == 'resnet18':
        model = torchvision.models.resnet18().to(device)
    elif model_name == 'my_resnet18':
        input_size = [3, 64, 64] if exp_settings['new_data'] == 'tiny' else [3, 32, 32]
        model = cl.get_resnet18(num_classes, input_size=input_size).to(device)
    else:
        raise NotImplementedError(f"Model {model_name} not implemented")

    return model


def get_initialization(model, train_data, test_data, exp_settings, log_dir, device='cuda'):
    """
    This loads the initialization of the model, based on the 'init_model' hyperparameter.
    'scratch' means to train from random initialization.
    'base' means that this experiment trains a model on the first task. This is mainly for efficiency and consistency,
    as we're not really interested in how good exactly the model performs on the first task. After the first run,
    the experiment will store the model in its result folder. If the same experiment is re-run, it will look in this
    folder to see if there is already a model trained on the first task, and then load that model. This method also
    stores the predictions during training of the firs task, to use later in the sampling methods.
    others: you can also pass an experiment_id here, which will then try to load the model that was trained on the first
    task of that experiment. This means that different experiments don't have noise of different models trained on the
    first task.
    """
    model_name = f"{exp_settings['model']}_{exp_settings['tasks_file']}_{exp_settings['seed']}.pth"

    if exp_settings['init_model'] == 'scratch':
        pass
    elif exp_settings['init_model'] == 'base':
        model_path = os.path.join('results', str(exp_settings.name), 'models', model_name)

        if os.path.exists(model_path):
            model.load_state_dict(torch.load(model_path, map_location=device))
        else:
            # Store random initialization before training, might be necessary in l2_init experiments.
            torch.save(model.state_dict(), os.path.join('results', str(exp_settings.name), 'models',
                                                        model_name[:-4] + "_init.pth"))
            model, train_pred = train_init_model(train_data, test_data, model, exp_settings, device=device)
            # Store predictions of the trained model. Not really neat to have it hear, but it works.
            torch.save(model.state_dict(), os.path.join('results', str(exp_settings.name), 'models', model_name))
            with (open(os.path.join('results', str(exp_settings.name), 'models', f"{model_name[:-4]}_pred.pkl"), 'wb')
                  as f):
                pickle.dump(train_pred, f)
    else:
        model_path = os.path.join('results', exp_settings['init_model'], 'models', model_name)
        model.load_state_dict(torch.load(model_path))

    return model


def get_original_init(model, exp_settings, device):
    """
    This loads the randomly initialized weights of a model, if they were stored. Currently, they are always stored,
    but for earlier experiments they weren't.
    """
    model_name = f"{exp_settings['model']}_{exp_settings['tasks_file']}_{exp_settings['seed']}_init.pth"
    init_model_exp = str(exp_settings.name) if exp_settings['init_model'] in ['base', 'scratch'] \
        else exp_settings['init_model']
    init_path = os.path.join('results', init_model_exp, 'models', model_name)
    model.load_state_dict(torch.load(init_path, map_location=device))
    return model


def get_num_epochs(exp_settings, train_set, current_task=0):
    """
    Convenience method to convert from number of iterations to epochs for learning rate scheduler mostly.
    """
    _, batch_size_new = get_batch_sizes(exp_settings, current_task)
    if exp_settings['epochs'] != 'na':
        return int(exp_settings['epochs'])
    else:
        return math.ceil(int(exp_settings['max_iters']) / math.ceil(len(train_set) / batch_size_new))


def get_classes_per_task(exp_settings):
    return [int(exp_settings['num_classes_base']),
            *[int(c) for c in exp_settings['num_classes_new'].split(',')]]


def get_batch_sizes(exp_settings, current_task=None, full_batch_size=128):
    """
    Get the batch sizes during training of the new task. It can be either explicitly the number of new and old
    samples defined in the exp_settings file, or if either of those values are equal to '-1', it will be calculated
    such that the batch is balanced.
    """
    bso, bsn = exp_settings['batch_size_old'], exp_settings['batch_size_new']

    if bso == -1 or bsn == -1:
        classes_per_task = get_classes_per_task(exp_settings)
        alpha = sum(classes_per_task[:current_task + 1]) / sum(classes_per_task[:current_task + 2])
        bso, bsn = math.floor(alpha * full_batch_size), math.ceil((1 - alpha) * full_batch_size)

    return int(bso), int(bsn)


def train_init_model(train_data, test_data, model, exp_settings, device):
    """
    This trains the model on the first task of a sequence. It also records the predictions during training and
    returns them.

    Results are only printed to standard out (condor .out file), didn't really need them so far.
    """
    model.to(device)

    train_loader = DataLoader(to_indexed_dataset(train_data), batch_size=128, shuffle=True, num_workers=4)
    test_loader = DataLoader(test_data, batch_size=128, num_workers=4)
    learning_rate = 0.001 if exp_settings['init_optim'] == 'adam' else 0.1
    optim, sched = get_optim_and_scheduler(model, exp_settings['init_optim'], exp_settings['init_sched'],
                                           int(exp_settings['init_epochs']), learning_rate)
    regularizer = Regularizer(exp_settings['init_regularizer'], exp_settings['init_reg_strength'], exp_settings, device)
    train_acc_track = defaultdict(list)

    for e in range(int(exp_settings['init_epochs'])):
        print(e, flush=True)
        for data, target, idx in train_loader:
            data, target = data.to(device), target.to(device)
            optim.zero_grad()
            output = model(data)
            loss = torch.nn.functional.cross_entropy(output, target)
            loss = loss + regularizer(model)
            loss.backward()
            optim.step()

            pred = torch.argmax(output, dim=1)
            for i, p, t in zip(idx, pred, target):
                train_acc_track[i.item()].append(int(p.item() == t.item()))

        # Remake train_loader, test does not work with a indexed dataset.
        temp_train_loader = DataLoader(train_data, batch_size=128, shuffle=True, num_workers=4)
        cl.test(model, [temp_train_loader], avg=False, print_result=True, device=device)
        cl.test(model, [test_loader], avg=False, print_result=True, device=device)
        del temp_train_loader

        sched.step()

    return model, train_acc_track


def init_logging(exp_settings):
    """
    Create log dir with exp_id if it does not exist and run
    :param exp_settings:
    :return:
    """
    timestamp = time.strftime('%m%d_%H%M%S')
    dir_name = f'{exp_settings["seed"]}_{timestamp}'
    full_dir = os.path.join('results', str(exp_settings.name), dir_name)
    os.makedirs(full_dir, exist_ok=True)
    os.makedirs(os.path.join('results', str(exp_settings.name), 'models'), exist_ok=True)

    # This creates an indication that the current experiment is not finished (or crashed).
    with open(os.path.join(full_dir, 'running.txt'), 'w'):
        pass

    return full_dir


def create_buffer(dataset, exp_setting, batch_size):
    """
    Create an infinite (very long) sampler such that the buffer sampler is never exhausted
    before the actual training sampler (can't set it to infinity sadly).
    """

    if int(batch_size) == 0:
        return None

    # More advanced sampling of old data, with easy and hard cutoff.
    if exp_setting['easy_cutoff'] != 'na':
        sample_order = get_sample_order(exp_setting)

        num_samples = len(dataset)
        sampler_weight = torch.zeros(num_samples)
        lower_limit = float(exp_setting['easy_cutoff']) * num_samples
        upper_limit = (1 - float(exp_setting['hard_cutoff'])) * num_samples

        print(lower_limit, upper_limit, num_samples)

        for i, so in enumerate(sample_order):
            if i <= lower_limit or i >= upper_limit:
                weight = float(exp_setting['cutoff_prob'])
            else:
                weight = 1.0
            sampler_weight[so] = weight

        # WeightedSampler creates the full tensor. 2^17 samples should be enough for cifar100, but maybe not for others.
        sampler = WeightedRandomSampler(weights=sampler_weight, replacement=True, num_samples=2 ** 17)
    else:
        # Uniform sampling over all data.
        sampler = RandomSampler(dataset, replacement=True, num_samples=2 ** 31)

    batch_sampler = BatchSampler(sampler, batch_size=int(batch_size), drop_last=False)
    buffer_loader = DataLoader(dataset, batch_sampler=batch_sampler, num_workers=4)
    return buffer_loader


def get_sample_order(exp_settings):
    """
    Get the sample order that goes with the current 'init' model, it should be stored in the result folder.
    :param exp_settings:
    :return: np.Array, with indices of the samples ordered from easy to hard.
    """
    base_exp = exp_settings.name if exp_settings['init_model'] == 'base' else exp_settings['init_model']
    score_name = f"{exp_settings['model']}_{exp_settings['tasks_file']}_{exp_settings['seed']}_pred.pkl"
    score_path = os.path.join('results', str(base_exp), 'models', score_name)

    print(score_path)

    with open(score_path, 'rb') as f:
        # Default dict with {sample_idx: [0, 0, 1, ... 1]}, 1 meaning classified correctly
        sample_scores = pickle.load(f)

    sample_scores = np.array([np.mean(sample_scores[s]) for s in sorted(sample_scores)])
    sample_order = np.argsort(sample_scores)[::-1]  # Flip such that order is easy -> hard, high score = easy
    return sample_order


def before_training(model: torch.nn.Module, exp_settings, device):
    """
    Should be called before training on the new data. Depending on the init_method parameter it will either do
    shrink and perturb, interpolate, or reset the last layer.
    """
    if exp_settings['init_method'] == 'shrink_perturb':
        random_model = get_model(exp_settings['model'], exp_settings, device, get_num_classes(exp_settings))
        model = shrink_and_perturb(model, random_model, float(exp_settings['shrink']), float(exp_settings['perturb']))
    elif exp_settings['init_method'] == 'interpolate':
        random_model = get_model(exp_settings['model'], exp_settings, device, get_num_classes(exp_settings))
        get_original_init(random_model, exp_settings)
        alpha = float(exp_settings['shrink'])
        model = shrink_and_perturb(model, random_model, shrink=alpha, perturb=(1 - alpha))
    elif exp_settings['init_method'] == 'reinit_last':
        model.linear.reset_parameters()
    elif exp_settings['init_method'] == 'scratch':
        model = get_model(exp_settings['model'], exp_settings, device, get_num_classes(exp_settings))
    return model


def shrink_and_perturb(train_model: torch.nn.Module, random_model: torch.nn.Module, shrink=0.4, perturb=10e-4):
    """
    Implements the shrink and perturb algorithm.
    """
    for (name, real_param), (_, rand_param) in zip(train_model.named_parameters(), random_model.named_parameters()):
        if 'bn' in name:
            continue
        else:
            real_param.data = shrink * real_param + perturb * rand_param
    for mod in train_model.modules():
        # This works quite well and is principled, so I think it is best to leave it like this.
        if isinstance(mod, torch.nn.BatchNorm2d):
            mod.running_mean = shrink * mod.running_mean
            mod.running_var = shrink * shrink * mod.running_var
    return train_model


def get_optim_and_scheduler(model, optim, sched, total_epochs, lr):
    if optim == 'sgd':
        # Added weight decay in SGD only after experiment 200, but most of the time I use AdamW anyway which
        # always has weight decay.
        optim = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
    elif optim == 'adam':
        optim = torch.optim.AdamW(model.parameters(), lr=lr)
        # optim = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=5e-4)
    else:
        raise ValueError(f"Unknown option {optim} for optimizer")

    if sched == 'none':
        # Dummy scheduler
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optim, milestones=[total_epochs], gamma=1.0)
    elif sched == 'multistep':
        milestones = [int(frac * total_epochs) for frac in [0.6, 0.8, 0.9]]
        print(f"Approx. epochs {total_epochs} \t"
              f"Milestones scheduled at epoch {','.join([str(s) for s in milestones])}.")
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optim, milestones=milestones, gamma=0.2)
    elif sched == 'debug':
        # Scheduler to try out stuff without having to make a different option all the time.
        milestones = [60, 120, 160]
        print(f"Approx. epochs {total_epochs} \t"
              f"Milestones scheduled at epoch {','.join([str(s) for s in milestones])}.")
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optim, milestones=milestones, gamma=0.2)
    elif sched == 'restart':
        # I'm not using this much, so I will not make this a parameter. It's 4 and if it changes it'll become one.
        restart_epoch = total_epochs // 4
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optim, T_0=restart_epoch, eta_min=1e-6)
    elif sched == 'cos':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=total_epochs,
                                                               eta_min=1e-6)
    else:
        raise ValueError(f"Unknown option {sched} for optimizer")

    return optim, scheduler


def reweigh_loss(loss, exp_settings, current_task):
    """
    This makes sure that even if there are relatively more or less new samples in a batch, the final loss is still
    balanced (if alpha in exp_settings == 'balanced'). In the past I also used to set alpha manually with a float,
    but it is never better than balanced, so should just be kept like that.
    """
    alpha = exp_settings['loss_ratio_old_new']
    _, idx = get_batch_sizes(exp_settings, current_task)

    if alpha == 'balanced':
        classes_per_task = get_classes_per_task(exp_settings)
        alpha = sum(classes_per_task[:current_task + 1]) / sum(classes_per_task[:current_task + 2])
    else:
        alpha = float(alpha)
    return (1 - alpha) * torch.mean(loss[:idx]) + alpha * torch.mean(loss[idx:])


class Regularizer:
    """
    Implements the regularization: l2-loss and l2-init-loss. Because the l2-init loss needs to old model weights,
    it was easier to make this a class and call it. If 'regularizer' is none, it just returns zero upon calling, which
    is why it is always called during training, even when there is no regularization.
    """

    def __init__(self, regularizer, strength, exp_settings, device=None, model=None):
        self.regularizer = regularizer
        self.strength = float(strength) if self.regularizer != 'na' else 0.0

        if self.regularizer == 'l2_init':
            if model is not None:
                self.model_zero = model
            else:
                model = get_model(exp_settings['model'], exp_settings, device, get_num_classes(exp_settings))
                self.model_zero = get_original_init(model, exp_settings, device)
        else:
            self.model_zero = None

    def __call__(self, model: nn.Module):
        if self.regularizer == 'na':
            reg_loss = 0.0
        elif self.regularizer == 'l2':
            reg_loss = torch.norm(torch.cat([p.view(-1) for p in model.parameters()]))
        elif self.regularizer == 'l2_init':
            reg_loss = torch.norm(torch.cat([(p - p_init).view(-1) for p, p_init
                                             in zip(model.parameters(), self.model_zero.parameters())]))
        else:
            raise ValueError(f'Unknown regularization option {self.regularizer}')

        return self.strength * reg_loss


def train_new_data(new_data, old_loader, model, optim, scheduler, batch_size, exp_settings, evaluator: Evaluator,
                   test_ds, device, current_task=0):
    """
    The heart of the code, here the new data gets trained.
    """
    new_loader = DataLoader(new_data, batch_size=batch_size, shuffle=True, num_workers=4)
    # This is a bit of a hack, if training from scratch the initialization isn't stored, so we need to get it here.
    zero_model = deepcopy(model) if exp_settings['init_model'] == 'scratch' else None
    regularizer = Regularizer(exp_settings['regularizer'], exp_settings['reg_strength'], exp_settings, device,
                              zero_model)
    model.train()
    # Old code worked with epochs, set to something large so the number of iterations is smaller than the iterations
    # it would require to run this many epochs.
    num_epochs = 1_000 if exp_settings['epochs'] == 'na' else int(exp_settings['epochs'])
    stop_training = False

    # Be careful that one of the two loaders is not exhausted before the other
    for e in range(num_epochs):
        print(e, flush=True)
        # Use iterator over old data if there is old data (i.e. not training from scratch).
        old_data_iter = iter(old_loader) if old_loader is not None else None

        for data, target in new_loader:
            if old_data_iter is not None:
                buf_data, buf_target = next(old_data_iter)
                data = torch.cat([data, buf_data])
                target = torch.cat([target, buf_target])

            data, target = data.to(device), target.to(device)
            optim.zero_grad()

            output = model(data)

            if exp_settings['balance']:
                loss = torch.nn.functional.cross_entropy(output, target, reduction='none')
                loss = reweigh_loss(loss, exp_settings, current_task)
            else:
                loss = torch.nn.functional.cross_entropy(output, target)

            loss = loss + regularizer(model)

            loss.backward()
            optim.step()

            train_dataset = [new_data] if exp_settings['init_model'] == 'scratch' else [new_data, old_loader.dataset]

            # The evaluator keeps track of the current iteration, so it is important that it is called every iteration.
            # It will not do anything if there is not testing scheduled for this iteration, but it does increase the
            # iteration count.
            evaluator.evaluate(model, test_ds, train_dataset, device=device)
            if not exp_settings['max_iters'] == 'na' and evaluator.current_iter >= int(exp_settings['max_iters']):
                stop_training = True
                break

        scheduler.step()
        evaluator.dump_results()

        if stop_training:
            break


def run_exp(exp_id: int):
    """
    Start experiment based on a given exp_id and the experiments described in experiments.csv
    :return:
    """
    device = 'cuda'

    exp_file = pd.read_csv('experiments.csv', index_col='exp_id', dtype={'balance': bool})
    exp_settings = exp_file.loc[exp_id]
    log_dir = init_logging(exp_settings)

    print(exp_settings)

    # Prepare data and the old data loader
    select_task = exp_settings['task'] if exp_settings['task'] != 'na' else None
    ds_names, task_labels = get_task_sequence(exp_settings['tasks_file'], exp_settings['seed'], select_task=select_task,
                                              joint=exp_settings['init_model'] == 'scratch')
    # Temporary set joint to false, if not the sampling on the old data does not work properly.
    # ds_names, task_labels = get_task_sequence(exp_settings['tasks_file'], exp_settings['seed'], select_task=select_task,
    #                                           joint=False)
    train_transform, test_transform = get_transforms(exp_settings)
    train_ds, test_ds = get_data(ds_names, task_labels, train_transform, test_transform)

    # Get model initialization
    model = get_model(exp_settings['model'], exp_settings, device, get_num_classes(exp_settings))
    model = get_initialization(model, train_ds[0], test_ds[0], exp_settings, log_dir, device)

    # Init evaluator.
    evaluator = Evaluator(log_dir, result_name='results', num_classes=get_num_classes(exp_settings), frequency='dense',
                          checkpoints='none')

    # Loop for when there are multiple tasks. Only the smart sampling doesn't work yet with multiple tasks because
    # I'm not storing the predictions when training on new data
    for t, (tl, tds) in enumerate(zip(task_labels[1:], train_ds[1:])):
        print(tl)
        model = before_training(model, exp_settings, device)
        evaluator.evaluate(model, test_ds, train_ds, device=device, force_eval=True)

        # Create buffer. Should be after init_model, if not, the training order might not yet exist.
        # TODO: this does not work when we need to sample from more than one task.
        old_batch_size, new_batch_size = get_batch_sizes(exp_settings, current_task=t, full_batch_size=128)
        old_loader = create_buffer(torch.utils.data.ConcatDataset(train_ds[:t + 1]),
                                   exp_settings, batch_size=old_batch_size)

        # Get optim and scheduler. Re-init after each task.
        num_epochs = get_num_epochs(exp_settings, tds, current_task=t)
        optim, scheduler = get_optim_and_scheduler(model, exp_settings['optim'], exp_settings['sched'], num_epochs,
                                                   exp_settings['lr'])

        # Actual training. Pas current task-id for logging.
        train_new_data(tds, old_loader, model, optim, scheduler, new_batch_size, exp_settings, evaluator, test_ds,
                       device, current_task=t)

        # Final Evaluation
        evaluator.evaluate(model, test_ds, train_ds, device=device, force_eval=True)
        evaluator.dump_results()
        evaluator.reset_iter_count()

    # If we get here, the program finished successfully, and we can remove this file.
    os.remove(os.path.join(log_dir, 'running.txt'))


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('exp_id', type=int, help='the exp_id from the experiments.csv file')
    parser.add_argument('-r', '--reps', type=int, default=1, help='number of repetitions to run')
    args = parser.parse_args()

    for _ in range(args.reps):
        run_exp(args.exp_id)


if __name__ == '__main__':
    main()
