#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
import random
from collections import defaultdict

import torch
from torch.utils.data import DataLoader

from cl_tools import NCProtocolIterator, TransformationSubset
from utils.cwr_utils import *

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# set cuda device (based on your hardware)
from utils.imagenet_benchmark import make_imagenet_40_25_benchmark

import json
import configparser
import argparse
import sklearn
from pprint import pprint
from torch.utils.tensorboard import SummaryWriter
import PIL
import pandas as pd
from torchvision.transforms import ToTensor
import numpy as np
from models import change_bn_momentum
from models.latent_resnet import resnet18
import math
from models.CVAE_resnet18_layer4 import CVAEResnet18
from models.generator_losses import DistillationLoss, ReconstructionLoss, KLDLoss
import torch.nn.functional as F
from utils import get_expanded_or_shrunken_head_from_model
from torchvision import transforms

# --------------------------------- Setup --------------------------------------

# recover exp configuration name
parser = argparse.ArgumentParser(description='Run CL experiments')
parser.add_argument('--name', dest='exp_name', default='DEFAULT',
                    help='name of the experiment you want to run.')
args = parser.parse_args()

# recover config file for the experiment
config = configparser.ConfigParser()
config.read("params_generative_positive_replay.cfg")
exp_config = config[args.exp_name]
print("Experiment name:", args.exp_name)
pprint(dict(exp_config))

# recover parameters from the cfg file and compute the dependent ones.
exp_name = eval(exp_config['exp_name'])
comment = eval(exp_config['comment'])
use_cuda = eval(exp_config['use_cuda'])
init_lr = eval(exp_config['init_lr'])
inc_lr = eval(exp_config['inc_lr'])
cwr_inc_lr = eval(exp_config['cwr_inc_lr'])
mb_size = eval(exp_config['mb_size'])
init_train_ep_classifier = eval(exp_config['init_train_ep_classifier'])
inc_train_ep_classifier = eval(exp_config['inc_train_ep_classifier'])
rm_classificator_sz = eval(exp_config['rm_classificator_sz'])
momentum = eval(exp_config['momentum'])
l2 = eval(exp_config['l2'])
freeze_below_layer = eval(exp_config['freeze_below_layer'])
latent_layer = eval(exp_config['latent_layer'])
ewc_lambda = eval(exp_config['ewc_lambda'])
wi = eval(exp_config['wi'])

learning_rate_cvae = eval(exp_config["learning_rate_cvae"])
optimizer_cvae = eval(exp_config["optimizer_cvae"])
recon_loss_cvae = eval(exp_config["recon_loss_cvae"])
init_train_ep_generator = eval(exp_config['init_train_ep_generator'])
inc_train_ep_generator = eval(exp_config['inc_train_ep_generator'])
use_kld = eval(exp_config["use_kld"])
use_distillation = eval(exp_config["use_distillation"])
alpha = eval(exp_config["alpha"])
beta = eval(exp_config["beta"])
gamma = eval(exp_config["gamma"])
latent_size = eval(exp_config["latent_size"])
rm_generator_sz = eval(exp_config["rm_generator_sz"])


seed = eval(exp_config['seed'])

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

# setting up log dir for tensorboard
log_dir = 'logs/' + exp_name
writer = SummaryWriter(log_dir)

# Saving params
hyper = json.dumps(dict(exp_config))
writer.add_text("parameters", hyper, 0)
n_runs = 1

dict_test_accuracy = defaultdict(list)
dict_test_loss = defaultdict(list)

# for run in range(n_runs):
run = 0
print("#################### RUN {} ####################".format(run))

# ********************** CLASSIFIER ******************************
protocol, transform, transform_test, classifier_model = make_imagenet_40_25_benchmark(
    class_order_pkl='../data/seed_1993_imagenet_order_run_0.pkl')
print(classifier_model)
# Other variables init
tot_it_step_classifier = 0
tot_it_gen = 0
num_classes = 1000
rm_classifier = None
rm_generator = None

if latent_layer == "":
    latent_layer = None

classifier_model.saved_weights = {}
classifier_model.past_j = {i: 0 for i in range(protocol.n_classes)}
classifier_model.cur_j = {i: 0 for i in range(protocol.n_classes)}

if ewc_lambda != 0:
    ewcData, synData = create_syn_data(classifier_model, classification_layer='fc')

classifier_optimizer = torch.optim.SGD(
    classifier_model.parameters(), lr=init_lr, momentum=momentum, weight_decay=l2
)

change_bn_momentum(classifier_model, momentum=0.001)

# ********************** GENERATOR ******************************
generative_model = CVAEResnet18(latent_size=latent_size, num_classes=num_classes, pretrained_encoder=True,
                                use_shortcut_in_decoder=True, use_upsampling=True)
generative_model.load_decoder("../data/pretrained_ae/decoder_upsampling_skip_pretrained_places_ep4.pt")

# add to the optimizer only the parameters that require gradient.
if optimizer_cvae == "adam":
    generator_optimizer = torch.optim.Adam(filter(lambda par: par.requires_grad, generative_model.parameters()),
                                           lr=learning_rate_cvae)
elif optimizer_cvae == "sgd":
    generator_optimizer = torch.optim.SGD(filter(lambda par: par.requires_grad, generative_model.parameters()),
                                          lr=learning_rate_cvae)
else:
    raise NameError("{} is not a known optimizer".format(optimizer_cvae))

# define generative model losses
reconstruction_loss = ReconstructionLoss(recon_loss_cvae)
kld_loss = KLDLoss()
distillation_loss = DistillationLoss()

# ********************************** UTILITIES ************************************************
seen_classes = set()
nr_seen_classes = 0
ce_loss = torch.nn.CrossEntropyLoss()

ce_loss_no_reduction = torch.nn.CrossEntropyLoss(reduction="none")

lsm = torch.nn.LogSoftmax(dim=1)
nll = torch.nn.NLLLoss(reduction="none")

train_transform_classifier = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

train_transform_generator = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

# *************************************************************************************************************
# *************************************************************************************************************
# ***************************************  TRAINING PROCEDURE  ************************************************
# *************************************************************************************************************
# *************************************************************************************************************

classifier_model.train()

# loop over the training incremental batches
batch_info: NCProtocolIterator
for i, (train_batch, batch_info) in enumerate(protocol):
    print("----------- batch {0} -------------".format(i))

    if ewc_lambda != 0:
        init_batch(classifier_model, ewcData, synData, classification_layer='fc')


    train_y_all = np.array(train_batch.targets)
    print("train lab - min: {}".format(min(train_y_all)))
    print("train lab - max: {}".format(max(train_y_all)))
    print("Nr. classes in batch: {}".format(len(set(train_y_all))))

    if i == 1:
        if not freeze_up_to(classifier_model, freeze_below_layer, only_conv=False):  # Sets .eval() too
            raise ValueError('Can\'t find freeze below layer ' + str(freeze_below_layer))
        standard_params_lr, cwr_params_lr, frozen_params_lr = \
            ar1f_get_params_lr(classifier_model, 'fc', freeze_below_layer=freeze_below_layer)

        classifier_optimizer = torch.optim.SGD([
            {'params': standard_params_lr},  # Whole net - fc.weight - fc.bias
            {'params': cwr_params_lr, 'lr': cwr_inc_lr},  # fc.weight
            {'params': frozen_params_lr, 'lr': 0}  # fc.bias
        ], lr=inc_lr, momentum=momentum, weight_decay=l2)


    seen_classes.update(train_y_all.tolist())
    if len(seen_classes) > nr_seen_classes:
        print("New classes here!")
        nr_seen_classes = len(seen_classes)

    if i == 0:
        cur_class = [int(o) for o in set(train_y_all)]
        cur_class_batch = [int(o) for o in set(train_y_all)]
        classifier_model.cur_j = examples_per_class(train_y_all)
    else:
        cur_class = [int(o) for o in set(train_y_all).union(set(rm_classifier[1].tolist()))]
        cur_class_batch = [int(o) for o in set(train_y_all)]
        classifier_model.cur_j = examples_per_class(list(train_y_all) + rm_classifier[1].tolist())

    print("train_y shape: {}".format(train_y_all.shape))

    # *************************************************************************************************************
    # *************************************** CLASSIFIER TRAINING *************************************************
    # *************************************************************************************************************
    print("************************ CLASSIFIER TRAINING ****************************************************")

    if i > 0:
        reset_weights(classifier_model, cur_class_batch)
    classifier_model = maybe_cuda(classifier_model, use_cuda=use_cuda)

    cur_ep_classifier = 0

    acc_classifier = None
    ave_loss_classifier = 0

    if i == 0:
        train_ep_classifier = init_train_ep_classifier
    else:
        train_ep_classifier = inc_train_ep_classifier

    lr_multiplier = -(0.9 / (1 + math.pow(math.e, -((1.5 * i) - 8)))) + 1  # same of random replay
    classifier_optimizer.param_groups[0]['lr'] = inc_lr * lr_multiplier

    h = min(rm_classificator_sz // (i + 1), len(train_batch))
    print("nr. pattern to insert in memory: {}".format(h))

    train_batch.dataset.transform = train_transform_classifier

    for ep in range(train_ep_classifier):
        if i == 0:
            # Skip training for first batch
            print('Skipping training for the first batch')
            break

        if ep == (train_ep_classifier // 2):
            for group in classifier_optimizer.param_groups:
                group['lr'] = group['lr'] * 0.1

        print("training ep: ", ep)
        print("lr = {}".format([group['lr'] for group in classifier_optimizer.param_groups]))
        correct_cnt, ave_loss_classifier = 0, 0

        if i > 0:
            cur_sz = len(train_batch) // ((len(train_batch) + rm_classificator_sz) // mb_size)
            it_x_ep = len(train_batch) // cur_sz
            n2inject = max(0, mb_size - cur_sz)
        else:
            n2inject = 0

        print("total sz:", len(train_batch) + rm_classificator_sz)
        print("n2inject", n2inject)

        train_loader = DataLoader(train_batch, batch_size=(mb_size - n2inject), shuffle=True,
                                  pin_memory=True, drop_last=True, num_workers=8)

        print("it x ep: ", len(train_loader))

        for it, (train_x, train_y) in enumerate(train_loader):
            if ewc_lambda != 0:
                pre_update(classifier_model, synData, classification_layer='fc')

            classifier_optimizer.zero_grad()

            y_mb = train_y

            if n2inject > 0:
                lat_mb_x = np.take(rm_classifier[0], range(it * n2inject, (it + 1) * n2inject), axis=0, mode='wrap')
                lat_mb_y = np.take(rm_classifier[1], range(it * n2inject, (it + 1) * n2inject), axis=0, mode='wrap')
                y_mb = torch.cat((train_y, lat_mb_y), 0)

            if n2inject > 0:
                lat_mb_x = maybe_cuda(lat_mb_x, use_cuda=use_cuda)
            x_mb = maybe_cuda(train_x, use_cuda=use_cuda)
            y_mb = maybe_cuda(y_mb, use_cuda=use_cuda)

            if i > 0:
                classifier_model.train()
                if not freeze_up_to(classifier_model, freeze_below_layer, only_conv=False):  # Sets .eval() too
                    raise ValueError('Can\'t find freeze below layer ' + str(freeze_below_layer))
            else:
                classifier_model.train()
            if n2inject > 0:
                logits = classifier_model(x_mb, lat_mb_x, latent_layer=latent_layer, return_lat_acts=False)
            else:
                logits = classifier_model(x_mb)
            loss = ce_loss(logits, y_mb)
            ave_loss_classifier += loss.mean().item()
            loss.backward()

            _, pred_label = torch.max(logits[0: mb_size - n2inject], 1)
            correct_cnt += (pred_label == y_mb[0: mb_size - n2inject]).sum()

            if ewc_lambda != 0:
                loss += compute_ewc_loss(classifier_model, ewcData, lambd=ewc_lambda,
                                         classification_layer='fc')

            classifier_optimizer.step()

            if ewc_lambda != 0:
                post_update(classifier_model, synData, classification_layer='fc')

            acc_classifier = correct_cnt.item() / \
                             ((it + 1) * y_mb[0: mb_size - n2inject].size(0))

            if it % 10 == 0:
                print(
                    '==>>> it: {}, avg. loss: {:.6f}, '
                    'running train acc: {:.3f}'.format(it, ave_loss_classifier / (it + 1), acc_classifier))
                writer.add_scalar('classifier_train_loss', ave_loss_classifier / (it + 1), tot_it_step_classifier)
                writer.add_scalar('classifier_train_accuracy', acc_classifier, tot_it_step_classifier)
            # Log scalar values (scalar summary) to TB
            tot_it_step_classifier += 1

        # if i >= 1:
        #     optimizer_scheduler.step()
        cur_ep_classifier += 1

    for group in classifier_optimizer.param_groups:
        group['lr'] = group['lr'] * 10

    # replay positive
    w_changes = consolidate_weights(classifier_model, cur_class, cur_class)  # replay positive
    if ewc_lambda != 0:
        update_ewc_data(classifier_model, ewcData, synData, clip_to=0.001, c=wi, classification_layer='fc')

    # update number examples encountered over time
    for c, n in classifier_model.cur_j.items():
        classifier_model.past_j[c] += n
    set_consolidate_weights(classifier_model)

    test_loader = DataLoader(batch_info.get_cumulative_test_set(), batch_size=mb_size, shuffle=False,
                             drop_last=False, num_workers=8, pin_memory=True)

    ave_loss_classifier, acc_classifier, accs, cf = get_accuracy_conf_matrix_from_dataloader(
        classifier_model, ce_loss, test_loader, protocol.n_classes)

    # ------------------ Statistics classifier -------------------------------------
    writer.add_scalar('classifier_test_loss', ave_loss_classifier, i)
    writer.add_scalar('classifier_test_accuracy', acc_classifier, i)

    hit_per_class, tot_test_class = accs

    df_cm = pd.DataFrame(cf, range(protocol.n_classes), range(protocol.n_classes))
    plot_buf = gen_plot(df_cm, max(tot_test_class))
    image_cf = PIL.Image.open(plot_buf)
    image_cf = ToTensor()(image_cf)
    writer.add_image('classifier_confusion_matrix/run_{}'.format(run), image_cf, i)

    print("---------------------------------")
    print("Accuracy classifier: ", acc_classifier)
    print("---------------------------------")

    for past_task in range(i + 1):
        tot_hit = sum(hit_per_class[past_task * 40: (past_task * 40) + 40])
        tot_class_task = sum(tot_test_class[past_task * 40: (past_task * 40) + 40])
        acc_task = tot_hit / tot_class_task
        writer.add_scalar('classifier_task_accuracy/task_{}'.format(past_task), acc_task, i)

    for fut_task in range(i + 1, 25):
        writer.add_scalar('classifier_task_accuracy/task_{}'.format(fut_task), 0.0, i)

    if ewc_lambda != 0:
        weight_stats(classifier_model, ewcData, clip_to=0.001, classification_layer='fc')

    # *************************************************************************************************************
    # *************************************** GENERATOR TRAINING **************************************************
    # *************************************************************************************************************

    print("************************ GENERATOR TRAINING ****************************************************")


    if rm_classificator_sz > 0:

        if i == 0:
            train_ep_generator = init_train_ep_generator
        else:
            train_ep_generator = inc_train_ep_generator

        cur_ep_generator = 0


        train_batch.dataset.transform = train_transform_generator

        generative_model = maybe_cuda(generative_model, use_cuda=use_cuda)
        classifier_model = maybe_cuda(classifier_model, use_cuda=use_cuda)

        lr_multiplier_generator = -(0.9 / (1 + math.pow(math.e, -((1.5 * i) - 8)))) + 1  # last multiplier = 0.1
        generator_optimizer.param_groups[0]['lr'] = learning_rate_cvae * lr_multiplier_generator

        print("generator lr = {}".format(generator_optimizer.param_groups[0]['lr']))

        for ep in range(train_ep_generator):

            print("training ep: ", ep)
            ave_loss = 0

            ave_recon_loss_real = 0
            ave_recon_loss_replay = 0
            ave_kld_loss = 0
            ave_dist_loss = 0

            classifier_model.eval()
            generative_model.train()

            if i > 0:
                cur_sz = len(train_batch) // ((len(train_batch) + rm_generator_sz) // mb_size)
                n2inject = max(0, mb_size - cur_sz)
            else:
                cur_sz = mb_size
                n2inject = 0
            it_x_ep = len(train_batch) // cur_sz
            print("total sz:", len(train_batch) + rm_generator_sz)
            print("n2inject", n2inject)

            train_loader = DataLoader(train_batch, batch_size=(mb_size - n2inject), shuffle=True,
                                      pin_memory=True, drop_last=True, num_workers=8)

            for it, (train_x, train_y) in enumerate(train_loader):
                generator_optimizer.zero_grad()

                train_x = maybe_cuda(train_x, use_cuda=use_cuda)

                with torch.no_grad():
                    _, features = classifier_model(train_x, None, latent_layer=latent_layer, return_lat_acts=True)
                    features = features.detach()

                if i == 0:
                    replay_mb_x = None
                    y_mb = maybe_cuda(train_y, use_cuda=use_cuda)

                else:
                    replay_mb_x = maybe_cuda(rm_generator[0][it * n2inject: (it + 1) * n2inject], use_cuda=use_cuda)
                    features = torch.cat((features, replay_mb_x), dim=0)
                    replay_mb_y = rm_generator[1][it * n2inject: (it + 1) * n2inject]
                    y_mb = maybe_cuda(torch.cat((train_y[0: cur_sz], replay_mb_y), 0), use_cuda=use_cuda)

                recon_batch, mu, logvar = generative_model(features, F.one_hot(y_mb, num_classes).float())

                if i > 0:
                    recon_loss_real = reconstruction_loss(recon_batch[0:cur_sz], features[0:cur_sz])
                    recon_loss_replay = reconstruction_loss(recon_batch[cur_sz:mb_size], features[cur_sz:mb_size])
                else:
                    recon_loss_real = reconstruction_loss(recon_batch, features)
                    recon_loss_replay = maybe_cuda(torch.tensor([0.0]), use_cuda=use_cuda)

                if use_kld:
                    kld = maybe_cuda(kld_loss(mu, logvar), use_cuda=use_cuda)
                else:
                    kld = maybe_cuda(torch.tensor([0.0]), use_cuda=use_cuda)
                if use_distillation:
                    dist_loss = maybe_cuda(distillation_loss(classifier_model, recon_batch, y_mb,
                                                             latent_layer=latent_layer), use_cuda=use_cuda)
                else:
                    dist_loss = maybe_cuda(torch.tensor([0.0]), use_cuda=use_cuda)
                generator_loss = (alpha * ((recon_loss_real * (cur_sz / mb_size)) +
                                           (recon_loss_replay * (n2inject / mb_size)))) + (beta * kld) + (
                                         gamma * dist_loss)

                ave_loss += generator_loss.item()
                ave_recon_loss_real += recon_loss_real.item()
                ave_recon_loss_replay += recon_loss_replay.item()
                ave_dist_loss += dist_loss.item()
                ave_kld_loss += kld.item()

                generator_loss.backward()
                generator_optimizer.step()

                ave_loss_mean = ave_loss / ((it + 1) * y_mb.size(0))

                if it % 10 == 0:
                    print("==>>> it: {}, avg. loss: {:.6f}".format(it, ave_loss_mean))
                    print("    recon loss real: {}".format(recon_loss_real.item()))
                    print("    recon loss replay: {}".format(recon_loss_replay.item()))
                    print("    KLD loss: {}".format(kld.item()))
                    print("    distillation loss: {}".format(dist_loss.item()))

                    writer.add_scalar('generator recon loss/train/real', ave_recon_loss_real / (it + 1), tot_it_gen)
                    writer.add_scalar('generator recon loss/train/replay', ave_recon_loss_replay / (it + 1), tot_it_gen)
                    writer.add_scalar('generator KLD loss/train', ave_kld_loss / (it + 1), tot_it_gen)
                    writer.add_scalar('generator distillation loss/train', ave_dist_loss / (it + 1), tot_it_gen)
                tot_it_gen += 1

            cur_ep_generator = cur_ep_generator + 1

            # **************************** CALCULATE METRICS GENERATOR *************************************************

            generative_model.eval()
            ave_kld_loss /= it_x_ep
            ave_dist_loss /= it_x_ep
            ave_recon_loss_real /= it_x_ep
            ave_recon_loss_replay /= it_x_ep

            # --------------- Seen classes generation (sampling) accuracy + confusion matrix ---------------------

            tot_correct = 0
            generative_model.eval()
            for c in seen_classes:
                # 128 images per class
                samples = np.random.multivariate_normal(torch.zeros(latent_size),
                                                        torch.eye(latent_size), 128)
                z = maybe_cuda(torch.from_numpy(samples), use_cuda=use_cuda)
                z = z.view(-1, latent_size).float()
                y_one_hot = maybe_cuda(F.one_hot(torch.full((128,), c, dtype=torch.long), num_classes).float(),
                                       use_cuda=use_cuda)
                z = torch.cat((z, y_one_hot), dim=1)
                with torch.no_grad():
                    recon = generative_model.decode(z).detach()
                    output = classifier_model(x=None, y=recon, latent_layer=latent_layer, return_lat_acts=False)
                _, pred = output.max(1)
                tot_correct += (pred == c).sum().item()

            acc_sampling = tot_correct / (128 * len(seen_classes))
            print("---------------------------------")
            print("Seen classes generation accuracy: ", acc_sampling)
            print("---------------------------------")
            writer.add_scalar('generator sampling accuracy', acc_sampling, i)

            # ------------------------------- Train set reconstruction ------------------------------------

            generative_model.eval()
            tot_recon_correct = 0

            test_loader = DataLoader(train_batch, batch_size=mb_size, shuffle=False,
                                     drop_last=False, num_workers=8, pin_memory=True)

            for ti, (x, y) in enumerate(test_loader):
                x = maybe_cuda(x, use_cuda=use_cuda)
                y = maybe_cuda(y, use_cuda=use_cuda)

                classifier_model.eval()
                with torch.no_grad():
                    _, x = classifier_model(x=x, y=None, latent_layer=latent_layer, return_lat_acts=True)
                    x = x.detach()
                    recon_batch, _, _ = generative_model(x, torch.nn.functional.one_hot(y, num_classes).float())

                with torch.no_grad():
                    logits = classifier_model(x=None, y=recon_batch, latent_layer=latent_layer, return_lat_acts=False)

                _, pred_label = logits.max(1)
                tot_recon_correct += (pred_label == y).sum().item()

            acc_recon_test = tot_recon_correct * 1.0 / len(test_loader.dataset)

            print("All train set reconstruction accuracy: ", acc_recon_test)
            writer.add_scalar('generator train reconstruction accuracy', acc_recon_test, i)

            # ------------------------------- Test set reconstruction -------------------------------------

            generative_model.eval()
            tot_recon_correct = 0

            test_loader = DataLoader(batch_info.get_cumulative_test_set(), batch_size=mb_size, shuffle=False,
                                     drop_last=False, num_workers=8, pin_memory=True)

            test_avg_recon_loss_test = 0
            for ti, (x, y) in enumerate(test_loader):
                x = maybe_cuda(x, use_cuda=use_cuda)
                y = maybe_cuda(y, use_cuda=use_cuda)

                classifier_model.eval()
                with torch.no_grad():
                    _, x = classifier_model(x=x, y=None, latent_layer=latent_layer, return_lat_acts=True)
                    x = x.detach()
                    recon_batch, _, _ = generative_model(x, torch.nn.functional.one_hot(y, num_classes).float())

                test_avg_recon_loss_test += reconstruction_loss(recon_batch, x).item()

                with torch.no_grad():
                    logits = classifier_model(x=None, y=recon_batch, latent_layer=latent_layer, return_lat_acts=False)

                _, pred_label = logits.max(1)
                tot_recon_correct += (pred_label == y).sum().item()

            test_avg_recon_loss_test /= len(test_loader)

            writer.add_scalar('generator recon loss/test', test_avg_recon_loss_test, i)

            acc_recon_test = tot_recon_correct * 1.0 / len(test_loader.dataset)

            print("All test set reconstruction accuracy: ", acc_recon_test)
            writer.add_scalar('generator test reconstruction accuracy', acc_recon_test, i)

            # create a new classifier
            retrained_classifier = resnet18(pretrained=False, num_classes=40)

            retrained_classifier.load_state_dict(
                torch.load("../data/first_batch_models/seed_827_imagenet_40_batch0_45ep_noBias.pth"))
            setattr(retrained_classifier, 'fc',
                    get_expanded_or_shrunken_head_from_model(retrained_classifier, 'fc', 40, num_classes))

            retrained_classifier = maybe_cuda(retrained_classifier, use_cuda=use_cuda)

            classificator_retrained_optimizer = torch.optim.SGD(params=retrained_classifier.parameters(), lr=0.1,
                                                                momentum=0.9, weight_decay=0.0005)

            confusion_matrix_retrain = np.zeros((num_classes, num_classes), dtype=int)

            generative_model.eval()
            retrained_classifier.train()
            if not freeze_up_to(retrained_classifier, freeze_below_layer=freeze_below_layer, only_conv=False):
                raise ValueError('Can\'t find freeze below layer ' + str(freeze_below_layer))

            # calculate dynamically the number of patterns used to train the classificator (10,000 patterns per
            # training iteration).
            features_for_training = 50000 * (i + 1)
            nr_batches_retrain = (features_for_training // 128) + 1

            for _ in range(nr_batches_retrain):
                samples = np.random.multivariate_normal(torch.zeros(latent_size),
                                                        torch.eye(latent_size), 128)
                z = maybe_cuda(torch.from_numpy(samples).view(-1, latent_size).float(), use_cuda=use_cuda)
                lab = maybe_cuda(torch.from_numpy(np.random.choice(list(seen_classes), 128)), use_cuda=use_cuda)
                y_one_hot = maybe_cuda(F.one_hot(lab, num_classes).float(), use_cuda=use_cuda)
                z = torch.cat((z, y_one_hot), dim=1)
                with torch.no_grad():
                    recon = generative_model.decode(z).detach()
                classificator_retrained_optimizer.zero_grad()
                output = retrained_classifier(x=None, y=recon, latent_layer=latent_layer, return_lat_acts=False)
                loss = ce_loss(output, lab)
                loss.backward()
                classificator_retrained_optimizer.step()

            retrained_classifier.eval()
            tot_recon_correct = 0
            for ti, (x, y) in enumerate(test_loader):
                x = maybe_cuda(x, use_cuda=use_cuda)
                y = maybe_cuda(y, use_cuda=use_cuda)

                logits = retrained_classifier(x)

                _, pred_label = logits.max(1)
                tot_recon_correct += (pred_label == y).sum()

                for j in range(x.shape[0]):
                    confusion_matrix_retrain[pred_label[j]][y[j].item()] += 1

            acc_recon = tot_recon_correct.item() * 1.0 / len(test_loader.dataset)
            print("Test retrain classifier accuracy: {}".format(acc_recon))
            writer.add_scalar('generator retraining classificator test accuracy', acc_recon, i)

            df_cm_retrain = pd.DataFrame(confusion_matrix_retrain, range(num_classes), range(num_classes))
            plot_buf = gen_plot(df_cm_retrain, np.max(confusion_matrix_retrain))
            image_retrain = PIL.Image.open(plot_buf)
            image_retrain = ToTensor()(image_retrain)
            writer.add_image('generator confusion matrix classificator retrain', image_retrain, i)

    # *************************************************************************************************************
    # *************************************** REPLAY MEMORY CALCULATION *******************************************
    # *************************************************************************************************************

    print("************************ MEMORY CALCULATION ****************************************************")

    # --------------------------- Classificator replay memory ---------------------------------
    if rm_classificator_sz > 0:
        generated_features = None
        insert_per_class, offset = divmod(rm_classificator_sz, nr_seen_classes)

        new_labels = []
        # calculate the labels (balanced) for the generator memory
        for c in seen_classes:
            for _ in range(insert_per_class):
                new_labels.append(c)
        if offset > 0:
            offset_classes = random.sample(seen_classes, offset)
            new_labels.extend(offset_classes)

        random.shuffle(new_labels)

        generative_model.eval()
        classifier_model.eval()

        nr_inserted = 0
        nr_attempts = 0
        list_memory = []
        list_labels = []
        # generate latent features

        while nr_inserted < rm_classificator_sz:
            b_size = min(128, len(new_labels))
            samples = np.random.multivariate_normal(torch.zeros(generative_model.latent_size),
                                                    torch.eye(generative_model.latent_size), b_size)
            z = maybe_cuda(torch.from_numpy(samples).view(-1, generative_model.latent_size).float(), True)
            current_labels_list = random.sample(new_labels, b_size)
            current_labels = maybe_cuda(torch.from_numpy(np.asarray(current_labels_list, dtype=np.long)).type(torch.long), True)
            y_one_hot = maybe_cuda(torch.nn.functional.one_hot(current_labels, num_classes), True).type(torch.float)
            z = torch.cat((z, y_one_hot), dim=1)
            gen_features = generative_model.decode(z).detach()

            if nr_attempts < 5000:
                classifier_output = classifier_model(x=None, y=gen_features, latent_layer=latent_layer, return_lat_acts=False).detach()
                _, preds = torch.max(classifier_output.data, 1)

                acc_all = (preds == current_labels)

                for cont, corr in enumerate(acc_all):
                    if corr.item():
                        list_memory.append(gen_features[cont].cpu())
                        list_labels.append(current_labels[cont].cpu())
                        new_labels.remove(current_labels_list[cont])
                        nr_inserted += 1
                nr_attempts += 1
            else:  # after 5k attempts insert everything
                for cont, feat in enumerate(gen_features):
                    list_memory.append(feat.cpu())
                    list_labels.append(current_labels[cont].cpu())
                    new_labels.remove(current_labels_list[cont])
                    nr_inserted += 1

        list_memory, list_labels = sklearn.utils.shuffle(list_memory, list_labels)  # for selection of patterns in memory
        generated_features = torch.stack(list_memory, dim=0)  # for selection of patterns in memory
        new_labels_fin = torch.stack(list_labels, dim=0)  # for selection of patterns in memory
        list_memory = []
        list_labels = []
        print("memory dim")

        rm_classifier = [generated_features, new_labels_fin]

        print(rm_classifier[0].shape)
        print(rm_classifier[1].shape)
    else:
        rm_classifier = [torch.as_tensor([], dtype=torch.float), torch.as_tensor([], dtype=torch.long)]

    # --------------------------- Generator replay memory ---------------------------------

    if rm_generator_sz > 0:
        print("Generator memory")

        generated_features = None
        insert_per_class, offset = divmod(rm_generator_sz, nr_seen_classes)

        new_labels = []
        # calculate the labels (balanced) for the generator memory
        for c in seen_classes:
            for _ in range(insert_per_class):
                new_labels.append(c)
        if offset > 0:
            offset_classes = random.sample(seen_classes, offset)
            new_labels.extend(offset_classes)

        random.shuffle(new_labels)

        generative_model.eval()
        classifier_model.eval()

        nr_inserted = 0
        nr_attempts = 0
        list_memory = []
        list_labels = []
        # generate latent features

        while nr_inserted < rm_generator_sz:
            b_size = min(128, len(new_labels))
            samples = np.random.multivariate_normal(torch.zeros(generative_model.latent_size),
                                                    torch.eye(generative_model.latent_size), b_size)
            z = maybe_cuda(torch.from_numpy(samples).view(-1, generative_model.latent_size).float(), True)
            current_labels_list = random.sample(new_labels, b_size)
            current_labels = maybe_cuda(torch.from_numpy(np.asarray(current_labels_list, dtype=np.long)).type(torch.long),
                                        True)
            y_one_hot = maybe_cuda(torch.nn.functional.one_hot(current_labels, num_classes), True).type(torch.float)
            z = torch.cat((z, y_one_hot), dim=1)
            gen_features = generative_model.decode(z).detach()

            if nr_attempts < 5000:
                classifier_output = classifier_model(x=None, y=gen_features, latent_layer=latent_layer,
                                                     return_lat_acts=False).detach()
                _, preds = torch.max(classifier_output.data, 1)

                acc_all = (preds == current_labels)

                for cont, corr in enumerate(acc_all):
                    if corr.item():
                        list_memory.append(gen_features[cont].cpu())
                        list_labels.append(current_labels[cont].cpu())
                        new_labels.remove(current_labels_list[cont])
                        nr_inserted += 1
                nr_attempts += 1
            else:  # after 5k attempts insert everything
                for cont, feat in enumerate(gen_features):
                    list_memory.append(feat.cpu())
                    list_labels.append(current_labels[cont].cpu())
                    new_labels.remove(current_labels_list[cont])
                    nr_inserted += 1

        list_memory, list_labels = sklearn.utils.shuffle(list_memory, list_labels)  # for selection of patterns in memory
        generated_features = torch.stack(list_memory, dim=0)  # for selection of patterns in memory
        new_labels_fin = torch.stack(list_labels, dim=0)  # for selection of patterns in memory
        list_memory = []
        list_labels = []
        print("memory dim")

        rm_generator = [generated_features, new_labels_fin]

        print(rm_generator[0].shape)
        print(rm_generator[1].shape)
    else:
        rm_generator = [torch.as_tensor([], dtype=torch.float), torch.as_tensor([], dtype=torch.long)]

writer.close()
