#!/usr/bin/env python
# -*- coding: utf-8 -*-

# Python 2-3 compatible
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import

import os

# set cuda device (based on your hardware)
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

from data.core50_data_loader import CORE50
import copy
import json
import random
from models.mobilenet import MyMobilenetV1
from models.pretrainedCVAE import PretrainedCVAE
from models import replace_bn_with_brn, change_brn_pars
from models.classifier import ClassifierBN
from models.generator_losses import DistillationLossWithClassifier, KLDLoss, ReconstructionLoss
from models.pretrained_classifier import get_pretrained_classifier_bn
import torch.nn.functional as F
from utils.utils import *
import configparser
import argparse
from pprint import pprint
from torch.utils.tensorboard import SummaryWriter
from tsne.t_sne import T_SNE
import pandas as pd
import PIL.Image
from torchvision.transforms import ToTensor

# --------------------------------- Setup --------------------------------------

# recover exp configuration name
parser = argparse.ArgumentParser(description='Run CL experiments')
parser.add_argument('--name', dest='exp_name',  default='DEFAULT',
                    help='name of the experiment you want to run.')
args = parser.parse_args()

# recover config file for the experiment
config = configparser.ConfigParser()
config.read("params_ar1_lgr_nic.cfg")
exp_config = config[args.exp_name]
print("Experiment name:", args.exp_name)
pprint(dict(exp_config))

# recover parameters from the cfg file and compute the dependent ones.
exp_name = eval(exp_config['exp_name'])
comment = eval(exp_config['comment'])
use_cuda = eval(exp_config['use_cuda'])
init_lr = eval(exp_config['init_lr'])
inc_lr = eval(exp_config['inc_lr'])
mb_size = eval(exp_config['mb_size'])
init_train_ep_classifier = eval(exp_config['init_train_ep_classifier'])
inc_train_ep_classifier = eval(exp_config['inc_train_ep_classifier'])
init_update_rate = eval(exp_config['init_update_rate'])
inc_update_rate = eval(exp_config['inc_update_rate'])
max_r_max = eval(exp_config['max_r_max'])
max_d_max = eval(exp_config['max_d_max'])
inc_step = eval(exp_config['inc_step'])
rm_classificator_sz = eval(exp_config['rm_classificator_sz'])
momentum = eval(exp_config['momentum'])
l2 = eval(exp_config['l2'])
freeze_below_layer = eval(exp_config['freeze_below_layer'])
latent_layer_num = eval(exp_config['latent_layer_num'])
ewc_lambda = eval(exp_config['ewc_lambda'])

learning_rate_cvae = eval(exp_config["learning_rate_cvae"])
optimizer_cvae = eval(exp_config["optimizer_cvae"])
recon_loss_cvae = eval(exp_config["recon_loss_cvae"])
init_train_ep_generator = eval(exp_config['init_train_ep_generator'])
inc_train_ep_generator = eval(exp_config['inc_train_ep_generator'])
use_kld = eval(exp_config["use_kld"])
use_distillation = eval(exp_config["use_distillation"])
alpha = eval(exp_config["alpha"])
beta = eval(exp_config["beta"])
gamma = eval(exp_config["gamma"])
latent_size = eval(exp_config["latent_size"])
rm_generator_sz = eval(exp_config["rm_generator_sz"])

# setting up log dir for tensorboard
log_dir = 'logs_NIC/' + exp_name
writer = SummaryWriter(log_dir)

# Saving params
hyper = json.dumps(dict(exp_config))
writer.add_text("parameters", hyper, 0)

# Other variables init
tot_it_step_classifier = 0
rm_classifier = None
rm_generator = None

# Create the dataset object
dataset = CORE50(root='your/path/here', scenario="nicv2_391", run=0)
preproc = preprocess_imgs

# Get the fixed test set
test_x_all, test_y_all = dataset.get_test_set()
test_y_all = test_y_all.astype(np.long)

# ********************** CLASSIFIER ******************************
classifier_model = MyMobilenetV1(pretrained=True, latent_layer_num=latent_layer_num)
replace_bn_with_brn(
        classifier_model, momentum=init_update_rate, r_d_max_inc_step=inc_step,
        max_r_max=max_r_max, max_d_max=max_d_max
)

freeze_up_to(classifier_model, freeze_below_layer, only_conv=False)

classifier_model.saved_weights = {}
classifier_model.past_j = {i: 0 for i in range(50)}
classifier_model.cur_j = {i: 0 for i in range(50)}
if ewc_lambda != 0:
    ewcData, synData = create_syn_data(classifier_model)

classifier_optimizer = torch.optim.SGD(
    classifier_model.parameters(), lr=init_lr, momentum=momentum, weight_decay=l2
)

# ********************** GENERATOR ******************************
generative_model = PretrainedCVAE(latent_size=latent_size, num_classes=50, pretrained_encoder=True)
generative_model.load_encoder("../pretrained_models/pretrained_AE/pretrained_encoder_imagenet_imagenet_encoder.pth")
generative_model.load_decoder("../pretrained_models/pretrained_AE/pretrained_decoder_imagenet_imagenet_encoder.pth")

# add to the optimizer only the parameters that require gradient.
generator_optimizer = torch.optim.Adam(filter(lambda par: par.requires_grad, generative_model.parameters()),
                                       lr=learning_rate_cvae)

# define generative model losses
reconstruction_loss = ReconstructionLoss(recon_loss_cvae)
kld_loss = KLDLoss()
distillation_loss = DistillationLossWithClassifier()

# ********************************** UTILITIES ************************************************
ce_loss = torch.nn.CrossEntropyLoss()

seen_classes = set()
nr_seen_classes = 0
recalculate_test_set = False
tsne_image_set_x = []
tsne_image_set_y = []

cf_labels_inc = []
cf_count_inc = 0
for _ in range(50):
    cf_labels_inc.append(-1)

# *************************************************************************************************************
# *************************************************************************************************************
# ***************************************  TRAINING PROCEDURE  ************************************************
# *************************************************************************************************************
# *************************************************************************************************************

# loop over the training incremental batches
for i, train_batch in enumerate(dataset):

    if ewc_lambda != 0:
        init_batch(classifier_model, ewcData, synData)

    # freeze below conv5_4/dw
    freeze_up_to(classifier_model, freeze_below_layer, only_conv=False)

    if i == 1:
        change_brn_pars(classifier_model, momentum=inc_update_rate, r_d_max_inc_step=0,
                        r_max=max_r_max, d_max=max_d_max)
        base_params = list(map(lambda p: p[1], filter(lambda kv: kv[0] not in ['output.weight'],
                                                      classifier_model.named_parameters())))
        classifier_optimizer = torch.optim.SGD([
            {'params': base_params},
            {'params': classifier_model.output.parameters(), 'lr': inc_lr * 10.0}],
            lr=inc_lr, momentum=momentum, weight_decay=l2)

    train_x, train_y = train_batch
    train_x = preproc(train_x)
    train_y = train_y.astype(np.long)

    seen_classes.update(train_y)
    if len(seen_classes) > nr_seen_classes:
        print("New classes here!")
        recalculate_test_set = True
        nr_seen_classes = len(seen_classes)

    if i == 0:
        cur_class = [int(o) for o in set(train_y)]
        cur_class_batch = [int(o) for o in set(train_y)]
        classifier_model.cur_j = examples_per_class(train_y)
    else:
        cur_class = [int(o) for o in set(train_y).union(set(rm_classifier[1].tolist()))]
        cur_class_batch = [int(o) for o in set(train_y)]
        classifier_model.cur_j = examples_per_class(list(train_y) + list(rm_classifier[1]))

    for cc in cur_class:
        if cf_labels_inc[cc] == -1:
            cf_labels_inc[cc] = cf_count_inc
            cf_count_inc += 1

    print("----------- batch {0} -------------".format(i))
    print("train_x shape: {}, train_y shape: {}"
          .format(train_x.shape, train_y.shape))

    # Transform the datas
    (train_x, train_y), it_x_ep = pad_data([train_x, train_y], mb_size)
    shuffle_in_unison([train_x, train_y], in_place=True)
    train_x = torch.from_numpy(train_x).type(torch.FloatTensor)
    train_y = torch.from_numpy(train_y).type(torch.LongTensor)

    tsne_image_set_x.extend(train_x)
    tsne_image_set_y.extend(train_y)

    # *************************************************************************************************************
    # *************************************** CLASSIFIER TRAINING *************************************************
    # *************************************************************************************************************
    print("************************ CLASSIFIER TRAINING ****************************************************")

    generative_model.eval()
    classifier_model.train()
    classifier_model.lat_features.eval()

    reset_weights(classifier_model, cur_class_batch)
    classifier_model = maybe_cuda(classifier_model, use_cuda=use_cuda)

    cur_ep_classifier = 0

    acc_classifier = None
    ave_loss_classifier = 0

    if i == 0:
        train_ep_classifier = init_train_ep_classifier
    else:
        train_ep_classifier = inc_train_ep_classifier

    for ep in range(train_ep_classifier):

        print("training ep: ", ep)
        correct_cnt, ave_loss_classifier = 0, 0

        if i > 0:
            cur_sz = 64
            it_x_ep = train_x.size(0) // cur_sz
            n2inject = max(0, mb_size - cur_sz)
        else:
            n2inject = 0
        print("total sz:", train_x.size(0) + rm_classificator_sz)
        print("n2inject", n2inject)
        print("it x ep: ", it_x_ep)

        for it in range(it_x_ep):

            if ewc_lambda !=0:
                pre_update(classifier_model, synData)

            start = it * (mb_size - n2inject)
            end = (it + 1) * (mb_size - n2inject)

            classifier_optimizer.zero_grad()

            x_mb = maybe_cuda(train_x[start:end], use_cuda=use_cuda)

            if i == 0:
                lat_mb_x = None
                y_mb = maybe_cuda(train_y[start:end], use_cuda=use_cuda)
            else:
                lat_mb_x = np.take(rm_classifier[0], range(it * n2inject, (it + 1) * n2inject), axis=0, mode='wrap')
                lat_mb_y = np.take(rm_classifier[1], range(it * n2inject, (it + 1) * n2inject), axis=0, mode='wrap')
                y_mb = maybe_cuda(
                    torch.cat((train_y[start:end], lat_mb_y), 0),
                    use_cuda=use_cuda)
                lat_mb_x = maybe_cuda(lat_mb_x, use_cuda=use_cuda)

            logits = classifier_model(x_mb, latent_input=lat_mb_x, return_lat_acts=False)

            _, pred_label = torch.max(logits, 1)
            correct_cnt += (pred_label == y_mb).sum()

            classifier_loss = ce_loss(logits, y_mb)
            if ewc_lambda !=0:
                classifier_loss += compute_ewc_loss(classifier_model, ewcData, lambd=ewc_lambda)
            ave_loss_classifier += classifier_loss.item()

            classifier_loss.backward()
            classifier_optimizer.step()

            if ewc_lambda != 0:
                post_update(classifier_model, synData)

            acc_classifier = correct_cnt.item() / \
                             ((it + 1) * y_mb.size(0))
            ave_loss_classifier /= ((it + 1) * y_mb.size(0))

            if it % 10 == 0:
                print(
                    '==>>> it: {}, avg. loss: {:.6f}, '
                    'running train acc: {:.3f}'
                        .format(it, ave_loss_classifier, acc_classifier)
                )

            # Log scalar values (scalar summary) to TB
            tot_it_step_classifier += 1
            writer.add_scalar('classifier_train_loss', ave_loss_classifier, tot_it_step_classifier)
            writer.add_scalar('classifier_train_accuracy', acc_classifier, tot_it_step_classifier)

        cur_ep_classifier += 1

    w_changes = consolidate_weights_nic(classifier_model, cur_class, cur_class)
    if ewc_lambda != 0:
        update_ewc_data(classifier_model, ewcData, synData, clip_to=0.001, c=0.00002)

    # update number examples encountered over time
    for c, n in classifier_model.cur_j.items():
        classifier_model.past_j[c] += n
    set_consolidate_weights(classifier_model)

    ave_loss_classifier, acc_classifier, accs, cf = get_accuracy_conf_matrix_nic(
        classifier_model, ce_loss, mb_size, test_x_all, test_y_all, preproc=preproc,
        cf_labels_inc=cf_labels_inc
    )

    # ------------------ Statistics classifier -------------------------------------
    writer.add_scalar('classifier_test_loss', ave_loss_classifier, i)
    writer.add_scalar('classifier_test_accuracy', acc_classifier, i)

    df_cm = pd.DataFrame(cf, range(50), range(50))
    plot_buf = gen_plot(df_cm)
    image_cf = PIL.Image.open(plot_buf)
    image_cf = ToTensor()(image_cf)
    writer.add_image('classifier_confusion_matrix', image_cf, i)

    print("---------------------------------")
    print("Accuracy classifier: ", acc_classifier)
    print("---------------------------------")

    if ewc_lambda != 0:
        weight_stats(classifier_model, ewcData, clip_to=0.001)

    # *************************************************************************************************************
    # *************************************** GENERATOR TRAINING **************************************************
    # *************************************************************************************************************

    print("************************ GENERATOR TRAINING ****************************************************")
    if i == 0:
        train_ep_generator = init_train_ep_generator
    else:
        train_ep_generator = inc_train_ep_generator

    cur_ep_generator = 0

    generative_model = maybe_cuda(generative_model, use_cuda=use_cuda)

    for ep in range(train_ep_generator):

        print("training ep: ", ep)
        ave_loss = 0

        ave_recon_loss_real = 0
        ave_recon_loss_replay = 0
        ave_kld_loss = 0
        ave_dist_loss = 0

        classifier_model.eval()
        generative_model.train()
        # generative_model.encoder.eval()

        if i > 0:
            cur_sz = train_x.size(0) // ((train_x.size(0) + rm_generator_sz) // mb_size)
            # print("cur size = {}".format(cur_sz))
            cur_sz = 64
            it_x_ep = train_x.size(0) // cur_sz
            n2inject = max(0, mb_size - cur_sz)
        else:
            cur_sz = mb_size
            n2inject = 0
        print("total sz:", train_x.size(0) + rm_generator_sz)
        print("n2inject", n2inject)
        print("it x ep: ", it_x_ep)

        for it in range(it_x_ep):
            start = it * (mb_size - n2inject)
            end = (it + 1) * (mb_size - n2inject)

            generator_optimizer.zero_grad()

            x_mb = maybe_cuda(train_x[start:end], use_cuda=use_cuda)

            with torch.no_grad():
                _, features = classifier_model(x_mb, return_lat_acts=True)
                features = features.detach()

            if i == 0:
                replay_mb_x = None
                y_mb = maybe_cuda(train_y[start:end], use_cuda=use_cuda)

            else:
                replay_mb_x = maybe_cuda(rm_generator[0][it * n2inject: (it + 1) * n2inject], use_cuda=use_cuda)
                features = torch.cat((features, replay_mb_x), dim=0)
                replay_mb_y = rm_generator[1][it * n2inject: (it + 1) * n2inject]
                y_mb = maybe_cuda(torch.cat((train_y[start:end], replay_mb_y), 0), use_cuda=use_cuda)

            recon_batch, mu, logvar = generative_model(features, F.one_hot(y_mb, 50).float())

            if i > 0:
                recon_loss_real = reconstruction_loss(recon_batch[0:cur_sz], features[0:cur_sz])
                recon_loss_replay = reconstruction_loss(recon_batch[cur_sz:mb_size], features[cur_sz:mb_size])
            else:
                recon_loss_real = reconstruction_loss(recon_batch, features)
                recon_loss_replay = maybe_cuda(torch.tensor([0.0]), use_cuda=use_cuda)

            if use_kld:
                kld = maybe_cuda(kld_loss(mu, logvar), use_cuda=use_cuda)
            else:
                kld = maybe_cuda(torch.tensor([0.0]), use_cuda=use_cuda)
            if use_distillation:
                dist_loss = maybe_cuda(distillation_loss(classifier_model, recon_batch, y_mb), use_cuda=use_cuda)
            else:
                dist_loss = maybe_cuda(torch.tensor([0.0]), use_cuda=use_cuda)
            generator_loss = (alpha * ((recon_loss_real * (cur_sz / mb_size)) + (recon_loss_replay * (n2inject / mb_size)))) \
                   + (beta * kld) + (gamma * dist_loss)

            ave_loss += generator_loss.item()
            ave_recon_loss_real += recon_loss_real.item()
            ave_recon_loss_replay += recon_loss_replay.item()
            ave_dist_loss += dist_loss.item()
            ave_kld_loss += kld.item()

            generator_loss.backward()
            generator_optimizer.step()

            ave_loss_mean = ave_loss / ((it + 1) * y_mb.size(0))

            if it % 10 == 0:
                print("==>>> it: {}, avg. loss: {:.6f}".format(it, ave_loss_mean))
                print("    recon loss real: {}".format(recon_loss_real.item()))
                print("    recon loss replay: {}".format(recon_loss_replay.item()))
                print("    KLD loss: {}".format(kld.item()))
                print("    distillation loss: {}".format(dist_loss.item()))

        cur_ep_generator = cur_ep_generator + 1

        # **************************** CALCULATE METRICS GENERATOR *************************************************
        generative_model.eval()
        if cur_ep_generator == train_ep_generator and i == 10000:  # never happens (no stats generator (useless)
            ave_kld_loss /= it_x_ep
            ave_dist_loss /= it_x_ep
            ave_recon_loss_real /= it_x_ep
            ave_recon_loss_replay /= it_x_ep

            writer.add_scalar('generator recon loss/train/real', ave_recon_loss_real, i)
            writer.add_scalar('generator recon loss/train/replay', ave_recon_loss_replay, i)
            writer.add_scalar('generator KLD loss/train', ave_kld_loss, i)
            writer.add_scalar('generator distillation loss/train', ave_dist_loss, i)

            # produce only batch 0 test set (only in epoch 0)
            if recalculate_test_set:
                test_x_tmp = []
                test_y_tmp = []
                for ti in range(len(test_x_all)):
                    if test_y_all[ti].item() in seen_classes:
                        test_y_tmp.append(test_y_all[ti])
                        test_x_tmp.append(test_x_all[ti])
                test_x_generator = np.array(test_x_tmp)
                test_y_generator = np.array(test_y_tmp)
                test_x_generator = torch.from_numpy(test_x_generator).type(torch.FloatTensor)
                test_y_generator = torch.from_numpy(test_y_generator).type(torch.LongTensor)
                test_x_generator = preproc(test_x_generator)

                print("test set lengths")
                print(len(test_y_generator))
                print(len(test_x_generator))
                recalculate_test_set = False

            # --------------- Seen classes generation (sampling) accuracy + confusion matrix ---------------------

            tot_correct = 0
            generative_model.eval()
            for c in seen_classes:
                # 128 images per class
                samples = np.random.multivariate_normal(torch.zeros(latent_size),
                                                        torch.eye(latent_size), 128)
                z = maybe_cuda(torch.from_numpy(samples), use_cuda=use_cuda)
                z = z.view(-1, latent_size).float()
                y_one_hot = maybe_cuda(F.one_hot(torch.full((128,), c, dtype=torch.long), 50).float(),
                                       use_cuda=use_cuda)
                z = torch.cat((z, y_one_hot), dim=1)
                with torch.no_grad():
                    recon = generative_model.decode(z).detach()
                    output = classifier_model(x=None, latent_input=recon, return_lat_acts=False)
                _, pred = output.max(1)
                tot_correct += (pred == c).sum().item()

            acc_sampling = tot_correct / (128 * len(seen_classes))
            print("---------------------------------")
            print("Seen classes generation accuracy: ", acc_sampling)
            print("---------------------------------")
            writer.add_scalar('generator sampling accuracy', acc_sampling, i)

            # ------------------------------- Test set reconstruction -------------------------------------

            generative_model.eval()
            tot_recon_correct = 0

            test_it = test_y_generator.shape[0] // 128
            if test_y_generator.shape[0] % 128 > 0:
                test_it = test_it + 1

            test_avg_recon_loss_test = 0
            for ti in range(test_it):
                # indexing
                start = ti * 128
                end = (ti + 1) * 128

                x = maybe_cuda(test_x_generator[start:end], use_cuda=use_cuda)
                y = maybe_cuda(test_y_generator[start:end], use_cuda=use_cuda)

                classifier_model.eval()
                with torch.no_grad():
                    _, x = classifier_model(x, return_lat_acts=True)
                    x = x.detach()
                    recon_batch, _, _ = generative_model(x, torch.nn.functional.one_hot(y, 50).float())

                test_avg_recon_loss_test += reconstruction_loss(recon_batch, x).item()

                with torch.no_grad():
                    logits = classifier_model(x=None, latent_input=recon_batch, return_lat_acts=False)

                _, pred_label = logits.max(1)
                tot_recon_correct += (pred_label == y).sum().item()

            test_avg_recon_loss_test /= test_it

            writer.add_scalar('generator recon loss/test', test_avg_recon_loss_test, i)

            acc_recon_test = tot_recon_correct * 1.0 / test_y_generator.size(0)

            print("All test set reconstruction accuracy: ", acc_recon_test)
            writer.add_scalar('generator test reconstruction accuracy', acc_recon_test, i)

            # ------------------- classification retrain ----------------

            # create a new classifier
            retrained_classifier = maybe_cuda(ClassifierBN(n_classes=50, weights_file=None), use_cuda=use_cuda)

            classificator_retrained_optimizer = torch.optim.SGD(params=retrained_classifier.parameters(), lr=0.005,
                                                                momentum=0.9, weight_decay=0.0005)

            confusion_matrix_retrain = np.zeros((50, 50), dtype=int)

            generative_model.eval()
            retrained_classifier.train()

            # calculate dynamically the number of patterns used to train the classificator (same number of the pattern
            # seen so far).
            features_for_training = 3000 + (300 * i)
            nr_batches_retrain = (features_for_training // 128) + 1

            for _ in range(nr_batches_retrain):
                samples = np.random.multivariate_normal(torch.zeros(latent_size),
                                                        torch.eye(latent_size), 128)
                z = maybe_cuda(torch.from_numpy(samples).view(-1, latent_size).float(), use_cuda=use_cuda)
                lab = maybe_cuda(torch.from_numpy(np.random.choice(list(seen_classes), 128)), use_cuda=use_cuda)
                y_one_hot = maybe_cuda(F.one_hot(lab, 50).float(), use_cuda=use_cuda)
                z = torch.cat((z, y_one_hot), dim=1)
                with torch.no_grad():
                    recon = generative_model.decode(z).detach()
                classificator_retrained_optimizer.zero_grad()
                output = retrained_classifier(recon)
                loss = ce_loss(output, lab)
                loss.backward()
                classificator_retrained_optimizer.step()
                # class_lr_scheduler.step()

            retrained_classifier.eval()
            tot_recon_correct = 0
            test_it = test_y_generator.shape[0] // 128
            if test_y_generator.shape[0] % 128 > 0:
                test_it = test_it + 1

            for ti in range(test_it):
                # indexing
                start = ti * 128
                end = (ti + 1) * 128

                x = maybe_cuda(test_x_generator[start:end], use_cuda=use_cuda)
                y = maybe_cuda(test_y_generator[start:end], use_cuda=use_cuda)

                classifier_model.eval()
                with torch.no_grad():
                    _, x = classifier_model(x, return_lat_acts=True)
                    x = x.detach()

                logits = retrained_classifier(x)

                _, pred_label = logits.max(1)
                tot_recon_correct += (pred_label == y).sum()

                for j in range(x.shape[0]):
                    confusion_matrix_retrain[pred_label[j]][y[j].item()] += 1

            acc_recon = tot_recon_correct.item() * 1.0 / test_y_generator.size(0)
            print("Test retrain classifier accuracy: {}".format(acc_recon))
            writer.add_scalar('generator retraining classificator test accuracy', acc_recon, i)

            df_cm_retrain = pd.DataFrame(confusion_matrix_retrain, range(50), range(50))
            plot_buf = gen_plot(df_cm_retrain)
            image_retrain = PIL.Image.open(plot_buf)
            image_retrain = ToTensor()(image_retrain)
            writer.add_image('generator confusion matrix classificator retrain', image_retrain, i)

    # *************************************************************************************************************
    # *************************************** REPLAY MEMORY CALCULATION *******************************************
    # *************************************************************************************************************

    print("************************ MEMORY CALCULATION ****************************************************")

    # --------------------------- Generator replay memory ---------------------------------

    insert_per_class, offset = divmod(rm_generator_sz, nr_seen_classes)
    new_labels = []
    # calculate the labels (balanced) for the generator memory
    for c in seen_classes:
        for _ in range(insert_per_class):
            new_labels.append(c)
    if offset > 0:
        offset_classes = random.sample(seen_classes, offset)
        new_labels.extend(offset_classes)

    random.shuffle(new_labels)
    new_labels = torch.from_numpy(np.asarray(new_labels, dtype=np.long)).type(torch.long)

    # EVAL MODE
    generative_model.eval()

    list_memory = []
    # generate latent features

    nr_batches_memory = rm_generator_sz // 128
    if rm_generator_sz % 128 != 0:
        nr_batches_memory = nr_batches_memory + 1
    for y in range(nr_batches_memory):
        batch_dim = min(128, (rm_generator_sz - (y * 128)))
        samples = np.random.multivariate_normal(torch.zeros(generative_model.latent_size),
                                                torch.eye(generative_model.latent_size), batch_dim)
        z = maybe_cuda(torch.from_numpy(samples).view(-1, generative_model.latent_size).float(), use_cuda=use_cuda)
        y_one_hot = maybe_cuda(F.one_hot(new_labels[y * 128:(y * 128) + batch_dim], 50).type(torch.float),
                               use_cuda=use_cuda)
        z = torch.cat((z, y_one_hot), dim=1)
        list_memory.append(generative_model.decode(z).cpu().detach())
    generated_features = torch.cat(list_memory, dim=0)

    rm_add = [generated_features, new_labels]
    rm_generator = copy.deepcopy(rm_add)

    # --------------------------- Classificator replay memory ---------------------------------

    generated_features = None
    insert_per_class, offset = divmod(rm_classificator_sz, nr_seen_classes)

    new_labels = []
    # calculate the labels (balanced) for the generator memory
    for c in seen_classes:
        for _ in range(insert_per_class):
            new_labels.append(c)
    if offset > 0:
        offset_classes = random.sample(seen_classes, offset)
        new_labels.extend(offset_classes)

    random.shuffle(new_labels)
    new_labels = torch.from_numpy(np.asarray(new_labels, dtype=np.long)).type(torch.long)

    generative_model.eval()

    list_memory = []
    # generate latent features
    nr_batches_memory = rm_classificator_sz // 128
    if rm_classificator_sz % 128 != 0:
        nr_batches_memory = nr_batches_memory + 1
    for y in range(nr_batches_memory):
        batch_dim = min(128, (rm_classificator_sz - (y * 128)))
        samples = np.random.multivariate_normal(torch.zeros(generative_model.latent_size),
                                                torch.eye(generative_model.latent_size), batch_dim)
        z = maybe_cuda(torch.from_numpy(samples).view(-1, generative_model.latent_size).float(), True)
        y_one_hot = maybe_cuda(torch.nn.functional.one_hot(new_labels[y*128:(y*128)+batch_dim], 50), True) \
            .type(torch.float)
        z = torch.cat((z, y_one_hot), dim=1)
        list_memory.append(generative_model.decode(z).cpu().detach())
    generated_features = torch.cat(list_memory, dim=0)

    rm_add = [generated_features, new_labels]
    rm_classifier = copy.deepcopy(rm_add)

    # ########################  END OF BATCH  #################################
writer.close()
