import os
import shutil
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import torch
import torch.nn as nn
from torch.utils.data.sampler import Sampler, SubsetRandomSampler
from torchvision import datasets, transforms
from math import e
import torch.optim as optim
import random

from art.attacks.evasion import ProjectedGradientDescent
from art.classifiers import PyTorchClassifier

use_cuda = torch.cuda.is_available()
if not use_cuda:
    os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'


class SoftHebbLayer(nn.Module):
    '''softmax WTA linear layer with Hebbian plasticity à la STDP of Nessler et al
    '''
    def __init__(self,
                 in_channels,
                 n_neurons,
                 learning_rate,
                 softness,
                 softness_change_point,
                 plasticity,
                 bias,
                 learning_rate_b,
                 bias_training_start_point,
                 normalize_weights,
                 activation_fn,
                 weight_init,
                 weight_init_range,
                 delta,
                 from_k,
                 up_to_k,
                 decrease,
                 base,
                 device,
                 min_learning_rate,
                 normalize_inp,
                 learning_rate_decay=None):
        super(SoftHebbLayer, self).__init__()

        self.model_or_layer_flag = 'layer'
        self.in_channels = in_channels
        self.n_neurons = n_neurons
        self.activation_fn = activation_fn

        self.learning_rate = learning_rate
        self.learning_rate_init = self.learning_rate
        self.softness = softness
        self.softness_change_point = softness_change_point
        self.plasticity = plasticity
        self.exp = False
        self.bias = bias
        self.learning_rate_b = torch.tensor(learning_rate_b).to(device)
        self.bias_training_start_point = bias_training_start_point
        self.normalize_weights = normalize_weights
        self.delta = delta
        self.from_k = from_k
        self.up_to_k = up_to_k
        self.decrease = decrease
        if weight_init == 'positive':
            self.weight = weight_init_range*torch.rand((n_neurons, in_channels))+0.1
        self.delta_w = torch.zeros((n_neurons, in_channels))
        self.learning_rate_a = torch.tensor(1.)
        if self.bias:
            self.bias = torch.ones(self.n_neurons) \
                        * torch.log(torch.tensor(1/self.n_neurons)) / torch.log(torch.tensor(base))  # uniform initial priors, and acount for softmax's base
            self.delta_b = torch.zeros(n_neurons)

            self.bias = self.bias.to(device)
            self.delta_b = self.delta_b.to(device)
        self.learning_rate_decay = learning_rate_decay
        self.adapted_steps = 0
        self.preactivations = torch.zeros(self.n_neurons).to(device)
        self.w_times_x = torch.zeros(self.n_neurons).to(device)
        self.wta = torch.zeros(self.n_neurons).to(device)

        self.activity_counter = torch.zeros(self.n_neurons).long()
        self.ensemble = False
        self.test_uses_softmax = False
        self.neuron_labels = None
        self.base = torch.tensor(base).to(device)
        self.weight = self.weight.to(device)
        self.device = device
        self.min_learning_rate = min_learning_rate
        self.normalize_inp = normalize_inp
        
    def get_wta(self, softness=None):
        if softness is None:
            softness = self.softness
        if softness == 'soft':
            wta = softmax(self.preactivations, self.base, self.activation_fn)
        elif softness == 'hard':
            wta = nn.functional.one_hot(self.preactivations.argmax(dim=1), num_classes=self.preactivations.shape[1]).to(torch.float)
        return wta

    def plasticity_fn(self, inp):
        wta_plast = self.wta
        if self.plasticity == 'SoftHebb':
            yu = torch.multiply(wta_plast, self.preactivations)  # =y*u
        pt1 = torch.matmul(wta_plast.t(), inp)  # y*x
        pt2 = torch.sum(torch.multiply(yu.repeat(self.weight.shape[1], 1, 1).permute(1, 2, 0), self.weight.repeat(inp.shape[0], 1, 1)), dim=0)  # =yuw
        ds = pt1 - pt2  # =yx-yuw=y(x-uw)
        nc = 1
        # nc = torch.amax(torch.absolute(ds))
        # prec = 1e-30
        # if nc < prec:
        #     nc = prec
        self.delta_w = (self.learning_rate_a * self.learning_rate) * torch.true_divide(ds, nc)
        self.weight += self.delta_w
        if self.bias is not False:
            if self.activation_fn == 'exp':
                if self.base == e:
                    self.delta_b = self.learning_rate_b * self.learning_rate / torch.exp(
                        self.bias) * torch.sum(self.wta - torch.exp(self.bias).repeat(self.wta.shape[0], 1), dim=0)  # eta/e^w *    (y-e^w)
                else:
                    self.delta_b = self.learning_rate_b * self.learning_rate / torch.pow(self.base,
                        self.bias) * torch.sum(self.wta - torch.pow(self.base, self.bias).repeat(self.wta.shape[0], 1), dim=0)  # eta/e^w *    (y-e^w)
                self.bias += self.delta_b
                self.bias.clip_(-50, 0)  # prevents permanently inactive neurons from getting -inf biases
            elif self.activation_fn == 'relu':
                self.delta_b = self.learning_rate_b * self.learning_rate \
                               * (torch.sum(self.wta, dim=0) - self.wta.shape[0]*self.bias -self.wta.shape[0])  # eta *    (y-w-1)

        if self.normalize_weights:
            for i in range(self.n_neurons):
                self.weight[i] -= self.weight[i].mean()
                fro_norm = torch.norm(self.weight[i])
                if fro_norm != 0:
                    self.weight[i] /= fro_norm

    def forward(self, inp, adapt=True):
        inp = inp.view((inp.shape[0], -1))
        if self.normalize_inp:
            inp = 10 * nn.functional.normalize(inp)
        self.w_times_x = inp @ self.weight.t()
        if self.bias is not False:
            self.preactivations = torch.add(self.w_times_x, self.bias)
        else:
            self.preactivations = self.w_times_x
        self.wta = self.get_wta()
        if adapt:
            self.plasticity_fn(inp)
        return self.wta


class SoftHebbModel2Layers(nn.Module):
    def __init__(self, hparams):
        super(SoftHebbModel2Layers, self).__init__()
        self.isSoftHebb = True
        self.fc_wta = SoftHebbLayer(**hparams)
        self.fc = nn.Linear(hparams['n_neurons'], 10)
        self.sigmoid = nn.Sigmoid()
        self.relu = nn.ReLU()
        self.softmax_out = nn.Softmax()
        self.tanh_out = nn.Tanh()
        self.model_or_layer_flag = 'model'
        self.device = self.fc_wta.device
        if use_cuda:
            self.fc_wta.cuda()
            self.fc.cuda()

    def forward(self, x, adapt, wta_prop=True):
        if wta_prop:
            x = self.fc_wta(x, adapt)
        else:
            x_ = self.fc_wta(x, adapt)
            x = self.fc_wta.preactivations
            if self.fc_wta.activation_fn == 'exp':
                x = torch.pow(self.fc_wta.base, x)
        x = self.fc(x)
        # x = self.softmax_out(x)
        return x


class MLP(nn.Module):
    def __init__(self, device, n_hidden):
        super(MLP, self).__init__()
        self.isSoftHebb = False
        self.fc1 = nn.Linear(28 * 28, n_hidden)
        self.relu1 = nn.ReLU()
        # linear layer (n_hidden -> 10)
        self.fc2 = nn.Linear(n_hidden, 10)
        self.softmax_out = nn.Softmax()
        self.device = device

    def forward(self, x):
        # flatten image input
        x = x.view(-1, 28 * 28)
        # add hidden layer, with relu activation function
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        # x = self.softmax_out(x)
        return x


def init_script(gpu_id=0):
    plt.ion()
    use_cuda = torch.cuda.is_available()
    device = torch.device('cuda:' + str(gpu_id) if use_cuda else 'cpu')
    return device


def set_presets(preset):
    print('presets:', preset)
    min_learning_rate = 0
    if preset == 'softhebb2000normb0.05_base1e6_1epoch':
        base = 1.e6

        seed = 4
        n_neurons = 2000
        learning_rate = 0.05
        softness = 'soft'
        softness_change_point = None
        plasticity = 'SoftHebb'

        bias = True
        bias_training_start_point = 0  # epoch
        learning_rate_b = 1 / n_neurons

        normalize_weights = False
        activation_fn = 'exp'
        weight_init = 'positive'
        delta = 0.
        from_k = 1
        up_to_k = from_k+1
        decrease = 1.1
        weight_init_range = 0.05
        epochs = 1
        batch_size = 1
        normalize_inp = True
    elif preset == 'softhebb2000normb0.03_base1000':
        base = 1000.

        seed = 18
        n_neurons = 2000
        learning_rate = 0.03
        softness = 'soft'
        softness_change_point = None
        plasticity = 'SoftHebb'

        bias = True
        bias_training_start_point = 0  # epoch
        learning_rate_b = 1 / n_neurons

        normalize_weights = False
        activation_fn = 'exp'
        weight_init = 'positive'
        delta = 0.
        from_k = 1
        up_to_k = from_k + 1
        decrease = 1.1
        weight_init_range = 0.05
        epochs = 100
        batch_size = 128
        normalize_inp = True
    else:
        raise Exception("No valid presets defined.")
    preset_var_list = dir()
    preset_dict = dict()
    for i in preset_var_list:
        preset_dict[i] = locals()[i]
    return preset_dict


def seed_init_fn(seed):
    seed = seed % 2**32
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    return


def make_data_loaders(batch_sizes=None, shuffle=False, training_examples_used=60000, dataseed=None):

    if batch_sizes is None:
        batch_sizes = [128, 10000]
    if dataseed is not None:
        seed_init_fn(dataseed)
    indices = list(range(60000))
    train_indices = indices[:training_examples_used]
    if shuffle:
        np.random.shuffle(train_indices)
    train_sampler = SubsetRandomSampler(train_indices)
    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST(MNIST_PATH, train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           # transforms.Normalize((0.1307,), (0.3081,))
                       ])),
        batch_size=batch_sizes[0], sampler=train_sampler, worker_init_fn=seed_init_fn)

    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST(MNIST_PATH, train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            # transforms.Normalize((0.1307,), (0.3081,))
        ])),
        batch_size=batch_sizes[1], shuffle=shuffle, worker_init_fn=seed_init_fn)
    return train_loader, test_loader


def softmax(x, base=torch.tensor(e), activation_fn='exp', dim=1):
    if activation_fn == 'exp':
        if base == e:
            return torch.softmax(x, dim)
        else:
            return torch.softmax(x*torch.log(base), dim)


def lr_scheduler_lin_perbatch(model, n_samples, epochs, batch_size):
    decrease_lin = model.learning_rate_init/(epochs*np.ceil(60000/batch_size))
    if n_samples > 0 and model.learning_rate >= model.min_learning_rate:
        if n_samples % 1000 == 0:
            print('adapted rate ', model.learning_rate - decrease_lin)
        return max(model.learning_rate - decrease_lin, 0)
    else:
        return model.learning_rate


def train_hebb_layer(wta_model, batch_idx_start, n_batches, show_intermediate, fig_counter, subfolder,
                     train_loader, epochs, set_loss_every=500, starting_epoch=0, get_loss=False,
                     save_atevery_batch=False, replayfromdisk=False, also_setloss=False, dataseed=None):
    print('Training SoftHebb layer')
    if replayfromdisk:
        print('Will attempt to replay from disk')
    if save_atevery_batch:
        print('Will save weights before every batch')
    device = wta_model.device
    message_printed = False
    savedir = './models_saved/' + subfolder
    if not os.path.exists(savedir):
        os.mkdir(savedir)
    savedir_fig = './models_saved/' + subfolder + '/figs/'
    if not os.path.exists(savedir_fig):
        os.mkdir(savedir_fig)
    savedir = './models_saved/' + subfolder + '/every_batch/'
    if not os.path.exists(savedir):
        os.mkdir(savedir)
    if wta_model.model_or_layer_flag == 'model':
        wta_layer = wta_model.fc_wta
        if get_loss:
            criterion = nn.CrossEntropyLoss()
            loss_running_history = []
            loss_train_history = []
            loss_test_history = []
            loss_train_trainingbatch_avg, loss_test_trainingbatch_avg = 0, 0
            loss_train_loader, loss_test_loader = make_data_loaders([60000, 10000], dataseed=dataseed)
    else:
        if get_loss:
            raise Exception("Can't get loss if 'wta_model' is only a layer. Needs whole model.")
        wta_layer = wta_model
    wta_layer.learning_rate_b = torch.zeros(1).to(device)
    batch_idx_global = 0

    def replay_fn(message_printed, whole_model=False):
        if not os.path.exists(
                './models_saved/' + subfolder + '/every_batch/batch' + str(batch_idx_global) + '.pt'):
            if not message_printed:
                print("Couldn't replay this batch's training from disk")
                message_printed = True
                if batch_idx_global > 0:
                    w_prev, b_prev = torch.load(
                        './models_saved/' + subfolder + '/every_batch/batch' + str(batch_idx_global - 1) + '.pt',
                        map_location=torch.device(device))
                    wta_model.fc_wta.weight = w_prev
                    wta_model.fc_wta.bias = b_prev
        else:
            w, b = torch.load(
                './models_saved/' + subfolder + '/every_batch/batch' + str(batch_idx_global) + '.pt',
                map_location=torch.device(device))
            wta_model.fc_wta.weight = w
            wta_model.fc_wta.bias = b
        if whole_model:
            output = wta_model(data.to(device), adapt=True, wta_prop=True)
        else:
            output = wta_layer(data.to(device), adapt=True)
        return output, message_printed

    def training_step(message_printed):
        if save_atevery_batch:
            torch.save((wta_layer.weight, wta_layer.bias),
                       savedir + 'batch' + str(batch_idx_global) + '.pt')
        if replayfromdisk:
            output, message_printed = replay_fn(message_printed, whole_model=True)
        else:
            output = wta_model(data.to(device), adapt=True, wta_prop=True)
        return output, message_printed

    def get_running_loss_at_t(message_printed):
        output, message_printed = training_step(message_printed)
        loss_running_trainingbatch_avg = criterion(output, target.to(device))
        return loss_running_trainingbatch_avg, message_printed

    with torch.no_grad():
        for epoch in range(starting_epoch, epochs-starting_epoch):
            loss_running_thistrainingepoch_sum = 0
            loss_test_thistrainingepoch_sum = 0
            loss_train_thistrainingepoch_sum = 0
            if epoch == wta_layer.bias_training_start_point and wta_layer.bias is not False:
                wta_layer.learning_rate_b = torch.tensor(wta_layer.learning_rate_b).to(device)
            for batch_idx, (data, target) in enumerate(train_loader):
                if batch_idx < batch_idx_start:
                    pass
                else:
                    if batch_idx == batch_idx_start + n_batches:
                        break
                    if get_loss:
                        loss_running_trainingbatch_avg, message_printed = get_running_loss_at_t(message_printed)
                        if also_setloss and (batch_idx_global % (set_loss_every/data.shape[0]) == 0):
                            loss_train_trainingbatch_avg, loss_test_trainingbatch_avg \
                                = get_set_loss_at_t(loss_train_loader, loss_test_loader, wta_model, criterion)
                        loss_running_thistrainingepoch_sum, loss_train_thistrainingepoch_sum, loss_test_thistrainingepoch_sum, loss_train_history, loss_test_history = update_loss_history(
                            batch_idx_global, loss_running_thistrainingepoch_sum, loss_train_thistrainingepoch_sum,
                            loss_test_thistrainingepoch_sum, loss_running_trainingbatch_avg,
                            loss_train_trainingbatch_avg, loss_test_trainingbatch_avg, epochs, set_loss_every/data.shape[0],
                            loss_running_history, loss_train_history, loss_test_history, batch_idx, train_loader)
                    else:
                        output, message_printed = training_step(message_printed)
                    print_every = 2048
                    batch_size = target.shape[0]
                    if (batch_idx*batch_size) % print_every == 0 or n_batches <= 100:
                        if len(torch.where(wta_layer.wta[0] == 1)) > 0:
                            print('\r', 'training epoch:', epoch, ', sample:', '{}'.format(batch_idx*batch_size), ', winner:',
                                  torch.argmax(wta_layer.preactivations[0]), end='')
                        else:
                            print('\r', 'training epoch:', epoch, ', sample:', '{}'.format(batch_idx*batch_size),
                                  'no_output', end='')
                    if show_intermediate and (batch_idx*batch_size) % show_intermediate == 0 and batch_idx+epoch > 0:
                        show_weights(wta_layer.weight, wta_layer.exp, base=wta_layer.base, norm_each_digit=True)
                        plt.pause(0.0002)
                        fig_counter += 1
                        plt.savefig(savedir_fig+str(fig_counter))
                        plt.close()
                batch_idx_global += 1
                wta_layer.learning_rate = wta_layer.learning_rate_decay(wta_layer, batch_idx_global+1, epochs,
                                                                        batch_size)
            if (epoch+1) % 10 == 0 and (wta_layer.n_neurons == 2000 or wta_layer.n_neurons == 4000):
                torch.save((wta_layer.weight, wta_layer.bias),
                           savedir+str(epoch)+'.pt')
    print('Training complete.')
    if get_loss:
        return batch_idx, loss_running_history, loss_train_history, loss_test_history
    else:
        return batch_idx, fig_counter


def test_hebb_layer(wta_layer, train_loader, test_loader, fig_counter=0, plots_bool=False):
    device = wta_layer.device
    wta_layer.neuron_labels = None

    def infer_dataset(data_loader, also_test=True):
        wta_layer.adapt = False
        targets = []
        winner_ids = []
        preactivations = []
        predictions = []
        n_samples = data_loader.dataset.data.shape[0]
        for batch_idx, (data, target) in enumerate(data_loader):
            data = data.to(device)
            target = target.to(device)
            if batch_idx == n_samples:
                break
            prediction = wta_layer(data, adapt=False)
            preactivation = wta_layer.preactivations
            if also_test:
                preactivations.append(preactivation)
                predictions.append(prediction)
            targets += target.tolist()
            winner_ids_minibatch = prediction.argmax(dim=1)
            winner_ids += winner_ids_minibatch.tolist()
            if batch_idx % 500 == 0:
                print('\rTesting:[{}]'.format(batch_idx), end='')
        print('Inferred ', str(batch_idx), ' batches')
        winner_ids = torch.FloatTensor(winner_ids).to(torch.int64).to(device)
        targets = torch.FloatTensor(targets).to(torch.int64).to(device)
        preactivations = torch.cat(preactivations).to(device)
        predictions = torch.cat(predictions).to(device)
        if also_test:
            return winner_ids, targets, preactivations, predictions, data_loader
        else:
            return winner_ids, targets

    def get_neuron_labels(winner_ids, targets, preactivations):
        targets_onehot = nn.functional.one_hot(targets, num_classes=preactivations.shape[1]).to(torch.float32)
        winner_ids_onehot = nn.functional.one_hot(winner_ids, num_classes=preactivations.shape[1]).to(torch.float32)
        responses_matrix = torch.matmul(winner_ids_onehot.t(), targets_onehot)

        activations_soft = softmax(preactivations, wta_layer.base, dim=1)
        neuron_outputs_for_label_total = torch.matmul(activations_soft.t(), targets_onehot)

        responses_matrix[responses_matrix.sum(dim=1) == 0] = neuron_outputs_for_label_total[responses_matrix.sum(dim=1) == 0]
        neuron_labels = responses_matrix.argmax(1)
        return neuron_labels

    def get_accuracy(winner_ids, targets, preactivations, neuron_labels):
        n_samples = preactivations.shape[0]
        if not wta_layer.ensemble:
            predlabels = torch.FloatTensor([neuron_labels[i] for i in winner_ids]).to(device)
        else:
            if wta_layer.test_uses_softmax:
                soft_acts = softmax(preactivations, wta_layer.base, wta_layer.activation_fn, dim=1)
                winner_ensembles = [
                    np.argmax([np.sum(np.where(neuron_labels == ensemble, soft_acts[sample], np.asarray(0))) for
                               ensemble in range(10)]) for sample in range(n_samples)]
            else:
                winner_ensembles = [
                    np.argmax([np.sum(np.where(neuron_labels == ensemble, preactivations[sample], np.asarray(0))) for
                               ensemble in range(10)]) for sample in range(n_samples)]
            predlabels = winner_ensembles
        correct_pred = predlabels == targets
        n_correct = correct_pred.sum()
        accuracy = n_correct / len(targets)
        print(accuracy)

        return accuracy

    if wta_layer.neuron_labels is None:
        winner_ids, targets, preactivations, predictions, last_data_loader = infer_dataset(data_loader=train_loader)
        neuron_labels = get_neuron_labels(winner_ids, targets, preactivations)
        wta_layer.neuron_labels = neuron_labels
        print("Training accuracy (single SoftHebb layer):")
        get_accuracy(winner_ids, targets, preactivations, neuron_labels)

    winner_ids, targets, preactivations, predictions, last_data_loader = infer_dataset(data_loader=test_loader)


    def plots_posttraining(wta_layer, fig_counter):
        device = wta_layer.device
        if wta_layer.n_neurons == 2000 or wta_layer.n_neurons == 4000:
            show_weights(wta_layer.weight, wta_layer.exp, norm_each_digit=True, rows=40, col=50)
        else:
            show_weights(wta_layer.weight, wta_layer.exp, norm_each_digit=False)
        fig_counter += 1
        plt.savefig('figs_gpu'+str(device.index)+'/'+str(fig_counter))
        plt.close()
        if wta_layer.bias is not False:
            w = wta_layer.weight.clone().t()*torch.pow(wta_layer.base, wta_layer.bias)
            w = w.t()
            show_weights(w, wta_layer.exp, base=wta_layer.base, norm_each_digit=False)
            plt.pause(0.0001)
            fig_counter += 1
            plt.savefig('figs_gpu'+str(device.index)+'/' + str(fig_counter))
            plt.close()
        if use_cuda:
            plt.close('all')

    if plots_bool:
        print('Drawing figure of trained weights of all neurons, please wait a bit.')
        plots_posttraining(wta_layer, fig_counter)

    print("Testing accuracy (single SoftHebb layer):")
    accuracy_test = get_accuracy(winner_ids, targets, preactivations, wta_layer.neuron_labels)
    return accuracy_test


def train_sgd_model(multilayer_model, epochs, lr, train_loader, optimizer=optim.Adam, also_setloss=False,
                    set_loss_every=500, dataseed=None, adapt=False, wta_prop=True):
    loss_running_history = []
    loss_train_history = []
    loss_test_history = []
    loss_train_trainingbatch_avg, loss_test_trainingbatch_avg = 0, 0
    loss_train_loader, loss_test_loader = make_data_loaders([60000, 10000], dataseed=dataseed)

    criterion = nn.CrossEntropyLoss()
    log_every = 10000
    batch_idx_global = 0
    optimizer = optimizer(multilayer_model.parameters(), lr=lr)
    loss_running_thistrainingepoch_sum = 0
    loss_train_thistrainingepoch_sum = 0
    loss_test_thistrainingepoch_sum = 0
    device = multilayer_model.device
    for epoch in range(epochs):  # loop over the dataset multiple times

        running_loss = 0.0
        for batch_idx, data in enumerate(train_loader, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            if multilayer_model.isSoftHebb:
                outputs = multilayer_model(inputs, adapt, wta_prop)
            else:
                outputs = multilayer_model(inputs)
            loss_running_trainingbatch_avg = criterion(outputs, labels)

            # ########### update loss history log
            if also_setloss and (batch_idx % (set_loss_every/inputs.shape[0]) == 0):  # get loss on whole test set
                loss_train_trainingbatch_avg, loss_test_trainingbatch_avg = get_set_loss_at_t(loss_train_loader,
                                                                                              loss_test_loader, multilayer_model,
                                                                                              criterion)
            loss_running_thistrainingepoch_sum, loss_train_thistrainingepoch_sum, loss_test_thistrainingepoch_sum, loss_train_history, loss_test_history = update_loss_history(
                batch_idx_global, loss_running_thistrainingepoch_sum, loss_train_thistrainingepoch_sum,
                loss_test_thistrainingepoch_sum, loss_running_trainingbatch_avg, loss_train_trainingbatch_avg,
                loss_test_trainingbatch_avg, epochs, set_loss_every/inputs.shape[0], loss_running_history, loss_train_history,
                loss_test_history, batch_idx, train_loader)

            loss_running_trainingbatch_avg.backward()
            optimizer.step()

            # print statistics
            running_loss += loss_running_trainingbatch_avg.item()
            batch_size = labels.shape[0]
            if batch_idx % np.floor(log_every/batch_size) == np.floor(log_every/batch_size)-1:  # print every log_every training examples
                print('[%d, %5d] loss: %.3f' %
                      (epoch + 1, batch_idx + 1, running_loss / np.floor(log_every/batch_size)))
                running_loss = 0.0
            batch_idx_global += 1
    print('Finished Training')
    return loss_running_history, loss_train_history, loss_test_history


def test_multilayer_model(multilayer_model, adapt=False, wta_prop=True, noiseattack=False, dataseed=None):
    train_loader, test_loader = make_data_loaders([10000, 10000], training_examples_used=10000, dataseed=dataseed)
    correct = 0
    total = 0
    device = multilayer_model.device
    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            images = images.to(device)
            if noiseattack:
                images += torch.randn(images.shape).to(device)
            labels = labels.to(device)
            # calculate outputs by running images through the network
            if multilayer_model.isSoftHebb:
                outputs = multilayer_model(images, adapt, wta_prop)
            else:
                outputs = multilayer_model(images)
            # the class with the highest energy is what we choose as prediction
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    print('2-layer accuracy of the network on the 10000 test images: %f' % accuracy)
    return accuracy


def traintest_mlp(mlp_model, batch_size_supervised, lr, get_loss, epochs, optimizer=optim.SGD, seed=None, dataseed=None):
    seed_init_fn(seed)
    train_loader, test_loader = make_data_loaders([batch_size_supervised, 10000], dataseed=dataseed, shuffle=True)
    also_setloss = get_loss
    loss_running_history_sgd_or_adam, loss_train_history_sgd_or_adam, loss_test_history_sgd_or_adam = train_sgd_model(
        mlp_model,
        epochs=epochs,
        lr=lr,
        train_loader=train_loader,
        optimizer=optimizer,
        also_setloss=also_setloss,
        set_loss_every=500,
        dataseed=dataseed)
    acc = test_multilayer_model(mlp_model, test_loader)
    return loss_running_history_sgd_or_adam, loss_train_history_sgd_or_adam, loss_test_history_sgd_or_adam, acc


def show_weights(weights, exp=False, base=e, norm_each_digit=False, rows=10, col=10):
    weights = weights.to('cpu')
    if weights.size()[0] < 100:
        rows = 2
        col = int(np.ceil(weights.size()[0] / rows))

    fig, axes = plt.subplots(rows, col, figsize=(col, rows))
    plt.pause(0.0001)
    w = weights.view(weights.size()[0], 28, 28)
    plt.pause(0.0001)
    if exp:
        w = torch.pow(base, w)

    if not norm_each_digit:
        vmin = w.min()
        vmax = w.max()
    for i in range(rows):
        for j in range(col):
            w_ij = w[i * col + j]
            if norm_each_digit:
                vmin = w_ij.min()
                vmax = w_ij.max()
            im = axes[i, j].imshow(w[i * col + j], cmap='coolwarm', vmin=vmin, vmax=vmax)
            axes[i, j].set_xticks([])
            axes[i, j].set_yticks([])
    plt.pause(0.0001)
    fig.tight_layout(pad=0.0)
    plt.show()
    plt.pause(0.0001)
    adjust_factor = 0.9
    fig.subplots_adjust(right=adjust_factor)
    cbar_ax = fig.add_axes([0.95, 0.15, 0.01, 0.7])
    fig.colorbar(im, cax=cbar_ax)
    fig.set_size_inches(fig.get_size_inches()[0] * (1+(1-adjust_factor)/adjust_factor), fig.get_size_inches()[1])


def get_set_loss_at_t(loss_train_loader, loss_test_loader, model, criterion):
    def get_loss_wholeset(dataloader, model, criterion):
        adapt_set = False
        loss_set_trainingbatch_sum = 0
        device = model.device
        for (data_set, target_set) in dataloader:
            data_set = data_set.to(device)
            target_set = target_set.to(device)
            # calculate outputs by running images through the network
            if model.isSoftHebb:
                output_set = model(data_set, adapt=adapt_set, wta_prop=True)
            else:
                output_set = model(data_set)
            loss_set_setbatch = criterion(output_set, target_set)
            loss_set_trainingbatch_sum += loss_set_setbatch.item()
        loss_set_trainingbatch_avg = loss_set_trainingbatch_sum / len(dataloader)
        return loss_set_trainingbatch_avg
    loss_train_trainingbatch_avg = get_loss_wholeset(loss_train_loader, model, criterion)
    loss_test_trainingbatch_avg = get_loss_wholeset(loss_test_loader, model, criterion)
    return loss_train_trainingbatch_avg, loss_test_trainingbatch_avg


def make_model(device, presets, epochs, subfolder=None, seed=None, dataseed=None, load_final_weights=True,
               plots_bool=False, replayfromdisk=True, save_atevery_batch=False, replace_presets_dict=None):
    if subfolder is None:
        subfolder = presets + '_seed' + str(seed)
    wta_prop = True  # propagates the output of the wta to the next layer
    save_model = True
    preset_dict = set_presets(presets)
    if replace_presets_dict is not None:
        for i in replace_presets_dict:
            preset_dict[i] = replace_presets_dict[i]
            subfolder += '_' + i + replace_presets_dict[i]
    show_weights_every = 1000
    batch_size = preset_dict['batch_size']
    batch_size_supervised = 100
    fig_counter = 0

    # making a dictionary with initialization parameters of the SoftHebb model
    hparams = preset_dict.copy()
    hparams['in_channels'] = 28 * 28
    hparams['learning_rate_decay'] = lr_scheduler_lin_perbatch
    hparams['device'] = device

    # these presets are not parameters to give to the model. removing them from the dictionary.
    hparams.pop('batch_size', None)
    hparams.pop('epochs', None)
    hparams.pop('preset', None)
    hparams.pop('seed', None)
    if seed is not None:
        seed_init_fn(seed)

    # ############### initialize SoftHebb model instance and save initial model
    model_adaptive = SoftHebbModel2Layers(hparams)
    if use_cuda:
        model_adaptive = model_adaptive.to(device)
    savedir = './models_saved/' + presets + '_seed' + str(seed)
    if not os.path.exists(savedir):
        os.mkdir(savedir)
    torch.save(model_adaptive.fc_wta, savedir + '/init.pt')

    def traintest_hebb_layer_wrapper(batch_idx, plots_bool, show_weights_every, epochs, fig_counter,
                                     save_atevery_batch=save_atevery_batch,
                                     replayfromdisk=replayfromdisk, subfolder=subfolder, dataseed=dataseed, batch_size=batch_size):
        train_loader, test_loader = make_data_loaders([batch_size, 10000], dataseed=dataseed, shuffle=True)
        batch_idx, fig_counter = train_hebb_layer(model_adaptive, batch_idx, 60000,
                                                  show_intermediate=plots_bool and show_weights_every,
                                                  epochs=epochs,
                                                  fig_counter=fig_counter,
                                                  get_loss=False, save_atevery_batch=save_atevery_batch,
                                                  replayfromdisk=replayfromdisk, subfolder=subfolder, train_loader=train_loader,
                                                  dataseed=dataseed)
        if save_model:
            torch.save((model_adaptive.fc_wta.weight, model_adaptive.fc_wta.bias),
                       './models_saved/' + subfolder + '/wtalayer.pt')
        train_loader, test_loader = make_data_loaders([60000, 10000], dataseed=dataseed)
        accuracy_testset = test_hebb_layer(model_adaptive.fc_wta, train_loader=train_loader,
                                           test_loader=test_loader, fig_counter=fig_counter, plots_bool=True)
        return accuracy_testset, fig_counter

    def train_hebbsgd_model_wrapper():
        print('Now training supervised classifier on top')
        train_loader, test_loader = make_data_loaders([batch_size_supervised, 10000], shuffle=True,
                                                      dataseed=dataseed)
        train_sgd_model(model_adaptive, epochs=60, lr=0.001, wta_prop=wta_prop, train_loader=train_loader)
        if save_model:
            torch.save((model_adaptive.fc.weight, model_adaptive.fc.bias),
                       './models_saved/' + subfolder + '/suplayer.pt')

    # ############### load or train 2-layer SoftHebb network
    batch_idx = 0
    if load_final_weights and not save_atevery_batch:
        try:  # load SoftHebb single layer if possible. otherwise train
            w, b = torch.load('./models_saved/' + subfolder + '/wtalayer.pt', map_location=torch.device(device))
            model_adaptive.fc_wta.weight = w
            model_adaptive.fc_wta.bias = b
            train_loader, test_loader = make_data_loaders([60000, 10000], dataseed=dataseed)
            accuracy_testset = test_hebb_layer(model_adaptive.fc_wta, train_loader=train_loader,
                                               test_loader=test_loader, fig_counter=fig_counter, plots_bool=plots_bool)
        except:
            print("Couldn't load wta parameters. Retraining now.")
            accuracy_testset, fig_counter = traintest_hebb_layer_wrapper(batch_idx, plots_bool, show_weights_every,
                                                                         epochs, fig_counter,
                                                                         save_atevery_batch=save_atevery_batch,
                                                                         replayfromdisk=replayfromdisk, subfolder=subfolder,
                                                                         dataseed=dataseed, batch_size=batch_size)

        try:  # load or train supervised classifier layer on top of SoftHebb layer
            w, b = torch.load('./models_saved/' + subfolder + '/suplayer.pt', map_location=torch.device(device))
            model_adaptive.fc.weight = w
            model_adaptive.fc.bias = b
        except:
            print("Couldn't load supervised parameters. Retraining now.")
            train_loader, test_loader = make_data_loaders([batch_size_supervised, 10000], dataseed=dataseed,
                                                          shuffle=True)
            train_sgd_model(model_adaptive, epochs=60, lr=0.001, wta_prop=wta_prop, train_loader=train_loader)
            torch.save((model_adaptive.fc.weight, model_adaptive.fc.bias),
                       './models_saved/' + subfolder + '/suplayer.pt')
    else:
        accuracy_testset, fig_counter = traintest_hebb_layer_wrapper(batch_idx, plots_bool, show_weights_every,
                                                                     epochs, fig_counter,
                                                                     save_atevery_batch=save_atevery_batch,
                                                                     replayfromdisk=replayfromdisk, subfolder=subfolder,
                                                                     dataseed=dataseed, batch_size=batch_size)
        train_hebbsgd_model_wrapper()
    return model_adaptive, wta_prop, fig_counter, accuracy_testset


def main(gpu_id, presets='softhebb2000normb0.05_base1e6_1epoch_seed0', get_loss=False, plots_bool=False,
                seed=None, dataseed=None, also_setloss=True, load_final_weights=True, replayfromdisk=True, replace_presets_dict=None):

    def get_loss_history(model_adaptive, presets, seed=None, fig_counter=0, epochs=1, also_setloss=True, dataseed=None):
        savedir = './models_saved/' + presets+'_seed'+str(seed)
        model_adaptive.fc_wta = torch.load(savedir + '/init.pt', map_location=torch.device(model_adaptive.device))
        if seed is not None:
            seed_init_fn(seed)
        if epochs == 1:
            batch_size = 1
        else:
            batch_size = 100
        train_loader, test_loader = make_data_loaders([batch_size, 10000], shuffle=True, dataseed=dataseed)
        batch_idx = 0
        set_loss_hebb_every = 500
        print('Will now get the post-hoc cross-entropy')
        batch_idx, loss_running_history, loss_train_history, loss_test_history = train_hebb_layer(model_adaptive, batch_idx,
                      60000, show_intermediate=False, fig_counter=fig_counter, epochs=epochs, get_loss=True,
                      replayfromdisk=True, subfolder=presets+'_seed'+str(seed), set_loss_every=set_loss_hebb_every,
                      train_loader=train_loader, also_setloss=also_setloss, dataseed=dataseed)
        save_dict = dict()
        save_dict['loss_running_history_wta'] = loss_running_history
        save_dict['loss_train_history_wta'] = loss_train_history
        save_dict['loss_test_history_wta'] = loss_test_history
        save_dict['acc_wta_1l'] = None
        save_dict['acc_wta_2l'] = None
        file = savedir+'/loss_history.pt'
        torch.save(save_dict, file)
        try:
            shutil.rmtree(savedir+'/every_batch/')
        except:
            pass
        return save_dict

    savedir = './models_saved/' + presets + '_seed' + str(seed)
    preset_dict = set_presets(presets)
    if replace_presets_dict is not None:
        for i in replace_presets_dict:
            preset_dict[i] = replace_presets_dict[i]
            savedir += '_'+i + replace_presets_dict[i]
    if not os.path.exists(savedir):
        os.mkdir(savedir)
    if seed is None:
        seed = preset_dict['seed']
    if dataseed is None:
        dataseed = seed
    epochs = preset_dict['epochs']
    device = init_script(gpu_id)
    file = savedir + '/loss_history.pt'
    if get_loss and replayfromdisk:
        save_atevery_batch = True
    else:
        save_atevery_batch = False

    # creates softhebb model, with added supervised classifier layer on top, trains the softhebb layer, tests the single layer's accuracy
    # then trains the supervised classifier, and returns the trained 2-layer model and the accuracy of the single layer
    model_adaptive, wta_prop, fig_counter, acc_wta_1l = make_model(device, epochs=epochs, presets=presets, seed=seed,
                                                                   load_final_weights=load_final_weights,
                                                                   save_atevery_batch=save_atevery_batch, replayfromdisk=replayfromdisk and not save_atevery_batch,  # do not replayfromdisk what you are only now saving
                                                                   replace_presets_dict=replace_presets_dict, plots_bool=plots_bool, 
                                                                   dataseed=dataseed)
    if os.path.exists(file):
        save_dict = torch.load(file, map_location=torch.device(device))
    else:
        save_dict = dict()
    save_dict['acc_wta_1l'] = acc_wta_1l
    torch.save(save_dict, file)

    if get_loss:
        model_adaptive.fc_wta.learning_rate = preset_dict['learning_rate']
        save_dict_ = get_loss_history(model_adaptive, presets, seed, fig_counter, also_setloss=also_setloss)
        save_dict['loss_running_history_wta'] = save_dict_['loss_running_history_wta']
        save_dict['loss_train_history_wta'] = save_dict_['loss_train_history_wta']
        save_dict['loss_test_history_wta'] = save_dict_['loss_test_history_wta']
    else:
        if not 'loss_running_history_wta' in save_dict:
            save_dict['loss_running_history_wta'] = None
        if not 'loss_train_history_wta' in save_dict:
            save_dict['loss_train_history_wta'] = None
        if not 'loss_test_history_wta' in save_dict:
            save_dict['loss_test_history_wta'] = None
    torch.save(save_dict, file)

    acc_wta_2l = test_multilayer_model(model_adaptive, wta_prop=wta_prop, noiseattack=False, dataseed=dataseed)
    save_dict['acc_wta_2l'] = acc_wta_2l
    torch.save(save_dict, file)

    return save_dict, model_adaptive


def update_loss_history(batch_idx_global, loss_running_thistrainingepoch_sum, loss_train_thistrainingepoch_sum,
                        loss_test_thistrainingepoch_sum, loss_running_trainingbatch_avg, loss_train_trainingbatch_avg,
                        loss_test_trainingbatch_avg, epochs, set_loss_every, loss_running_history, loss_train_history,
                        loss_test_history, batch_idx, train_loader):
    if epochs <= 3:
        loss_running_history.append(loss_running_trainingbatch_avg.item())
        if batch_idx_global % set_loss_every == 0:
            loss_train_history.append(loss_train_trainingbatch_avg)
            loss_test_history.append(loss_test_trainingbatch_avg)
        if batch_idx_global % 1000 == 0:
            print("loss:", loss_running_trainingbatch_avg)
    else:
        loss_running_thistrainingepoch_sum += loss_running_trainingbatch_avg.item()
        if batch_idx_global % set_loss_every == 0:
            loss_train_thistrainingepoch_sum += loss_train_trainingbatch_avg
            loss_test_thistrainingepoch_sum += loss_test_trainingbatch_avg
        if batch_idx == len(train_loader) - 1:
            loss_running_history.append(loss_running_thistrainingepoch_sum / len(train_loader))
            loss_train_history.append(loss_train_thistrainingepoch_sum / len(train_loader) * set_loss_every)
            loss_test_history.append(loss_test_thistrainingepoch_sum / len(train_loader) * set_loss_every)
    return loss_running_thistrainingepoch_sum, loss_train_thistrainingepoch_sum, loss_test_thistrainingepoch_sum, loss_train_history, loss_test_history


def compare_with_sup(presets, epochs, device, from_savedMLP=True, from_savedWTA=True, replayfromdisk_if_notfromsaved=True,
                     load_final_weights_if_notfromsavedandnotreplay=True, online=True, loss_single_layer=True,
                     running_train_test_options=None, seed=None, dataseed=None, get_loss=True, lr=None, n_neurons=2000):
    replayfromdisk = replayfromdisk_if_notfromsaved and not from_savedWTA
    load_final_weights = load_final_weights_if_notfromsavedandnotreplay and not from_savedWTA and not replayfromdisk
    if seed is None:
        seed = torch.initial_seed()
    if dataseed is None:
        dataseed = seed

    def moving_avg(x, rolling_window):
        cumsum, moving_aves = [0], []

        for i, x in enumerate(x, 1):
            cumsum.append(cumsum[i-1] + x)
            if i > rolling_window:
                moving_ave = (cumsum[i] - cumsum[i-rolling_window])/rolling_window
                moving_aves.append(moving_ave)
        return moving_aves

    def plot_onehistory(loss_histories, algo, axes, running_train_ortest='running', rolling_window=1500, loss_single_layer=False):
        max_length = 60000.
        if algo == 'SoftHebb':
            loss_running_history, loss_train_history, loss_test_history = loss_histories[0], loss_histories[1], loss_histories[2]
            color = 'C0'
        elif algo == 'SGD' or algo == 'Adam':
            loss_running_history, loss_train_history, loss_test_history = loss_histories[0], loss_histories[1], loss_histories[2]
            if loss_single_layer:
                color = 'C8'
            else:
                color = 'C1'

        linewidth = 0.8
        if running_train_ortest == 'running':
            sma = True
            sequence = loss_running_history
            label_pt2 = ''+loss_single_layer*('('+loss_single_layer*'pretrained 2nd layer'+')')
            linewidth = 0.5
        elif running_train_ortest == 'train':
            sma = False
            sequence = loss_train_history
            label_pt2 = ''+loss_single_layer*('('+running_train_ortest+'ing set'+loss_single_layer*', pretrained 2nd layer'+')')
        elif running_train_ortest == 'test':
            sma = False
            sequence = loss_test_history
            # label_pt2 = '('+running_train_test+'ing set'+loss_single_layer*', fixed 2nd layer'+')'
            label_pt2 = ''+loss_single_layer*(loss_single_layer*'(pretrained 2nd layer)')

        batch_size = max_length / len(sequence)
        if sma:
            rolling_window_ = int(rolling_window / batch_size)
            sequence = moving_avg(sequence, rolling_window_)
            examples = [j*batch_size+int(rolling_window/2) for j in range(len(sequence))]
        else:
            examples = [j*batch_size+int(batch_size/2) for j in range(len(sequence))]
        axes[0+(running_train_ortest != 'running')*1].plot(examples, sequence, label=algo + ' ' + label_pt2, color=color, linewidth=linewidth)

    def history_figure(algos, running_train_test_options, loss_single_layer):
        if loss_single_layer:
            fig, axes = plt.subplots(1, 2)
        else:
            fig, axes = plt.subplots()
        for algo in algos:
            for running_train_ortest in running_train_test_options:
                if algo == 'SoftHebb':
                    plot_onehistory(loss_histories_wta, algo=algo, axes=axes, running_train_ortest=running_train_ortest, loss_single_layer=False)
                else:
                    plot_onehistory(loss_histories, algo=algo, axes=axes, running_train_ortest=running_train_ortest, loss_single_layer=False)
                if loss_single_layer and algo != 'SoftHebb':
                    plot_onehistory(loss_histories_singlelayer, algo=algo, axes=axes, running_train_ortest=running_train_ortest, loss_single_layer=True)
        fontsize = 14
        axes[0].set_ylabel('Cross entropy loss', fontsize=fontsize)
        leg = axes[0].legend()
        axes[0].title.set_text('Running loss \n(moving average)')
        axes[1].title.set_text('Test set loss \n')
        axes[0].yaxis.tick_right()
        axes[0].yaxis.set_ticklabels([])
        for ax in axes:
            ax.set_xticks([0, 30000, 60000])
            ax.tick_params(axis='both', which='major', labelsize=fontsize)
            ax.set_xlabel('training examples', fontsize=fontsize)
            ax.set_ylim([1.466, 2.366])
        file_suffix = ''
        for i in running_train_test_options:
            file_suffix += i
        file_suffix += online*'_online'+loss_single_layer*'_loss_single_layer'+'_seed'+str(seed)
        plt.savefig("fastloss"+file_suffix+".png")
        print('Loss plot saved, check it out')

    if from_savedWTA:
        try:
            file = './models_saved/'+presets+'_seed'+str(seed)+'/loss_history.pt'
            save_dict_wta = torch.load(file, map_location=torch.device(device))
            print('Loaded SoftHebb loss history from disk')
        except:
            print('Could not load saved SoftHebb loss history. Try again with from_savedWTA=False.')
        model_adaptive = None
    else:
        print('Training SoftHebb+supervised classifier once, and then again the WTA layer only to get its post-hoc cross-entropy history.')
        save_dict_wta, model_adaptive = main(gpu_id=device.index, presets=presets, get_loss=get_loss, plots_bool=False,
                                             seed=seed, dataseed=dataseed, load_final_weights=load_final_weights,
                                             replayfromdisk=replayfromdisk)
    acc_wta_1l = save_dict_wta['acc_wta_1l']
    acc_wta_2l = save_dict_wta['acc_wta_2l']
    print('acc_wta_1l=' + str(acc_wta_1l))
    print('acc_wta_2l=' + str(acc_wta_2l))

    print('Moving on to get MLP\'s loss history')
    if from_savedMLP:
        print('Loading MLP\'s loss history from disk')
        file = './loss_history_sup/epochs_'+str(epochs)+'/loss_history_sgd'+online*'_online'+'_loss_single_layer'\
               +'_seed'+str(seed)+'.pt'
        try:
            if os.path.isfile(file):
                save_dict_mlp = torch.load(file, map_location=torch.device(device))
            else:
                file = './loss_history_sup/epochs_'+str(epochs)+'/loss_history_sgd'+online*'_online'+'_seed'+str(seed)+'.pt'
                save_dict_mlp = torch.load(file, map_location=torch.device(device))
        except:
            print('Could not load saved MLP loss history. Try again with from_savedMLP=False')
        loss_histories = save_dict_mlp['loss_histories']
        loss_histories_singlelayer = save_dict_mlp['loss_histories_singlelayer']
        acc_sgd = save_dict_mlp['acc_sgd']
        acc_sgd_singlelayer = save_dict_mlp['acc_sgd_singlelayer']
        if 'acc_sgd_singlelayer' not in locals() or ('acc_sgd_singlelayer' in locals() and acc_sgd_singlelayer is None):
            print('Single layer case is missing from disk. Re-run with from_savedMLP=False')
            loss_single_layer = False
    else:
        def traintest_mlp_wrapper(lr, epochs, device, loss_single_layer=False, optimizer=optim.Adam, seed=None, get_loss=False, dataseed=None, online=False, batch_size_supervised=None):
            def pretrain_output_layer(mlp_model, optimizer=optimizer):
                train_loader, test_loader = make_data_loaders([100, 10000], dataseed=dataseed)
                try:
                    train_sgd_model(mlp_model, epochs=60, lr=0.001, train_loader=train_loader, optimizer=optimizer,
                                    also_setloss=False, set_loss_every=500, dataseed=dataseed)
                except:
                    pass

                seed_init_fn(seed)
                mlp_model.fc1.reset_parameters()
                for param in mlp_model.fc2.parameters():
                    param.requires_grad = False

            if online:
                batch_size_supervised = 1
                if lr is None:
                    lr = 0.04
            else:
                if batch_size_supervised is None:
                    batch_size_supervised = 4
                    if lr is None:
                        lr = 0.2
            if seed is None:
                seed = torch.initial_seed()
            seed_init_fn(seed)
            mlp_model = MLP(device, n_neurons)
            seed_init_fn(seed)
            mlp_model.fc1.reset_parameters()
            if use_cuda:
                mlp_model = mlp_model.to(device)
            loss_running_history_sgd, loss_train_history_sgd, loss_test_history_sgd, acc_sgd = traintest_mlp(mlp_model, batch_size_supervised,
                                                                                                     lr, get_loss=get_loss, epochs=epochs,
                                                                                                     seed=seed,
                                                                                                     dataseed=dataseed)
            loss_histories = [loss_running_history_sgd, loss_train_history_sgd, loss_test_history_sgd]
            if loss_single_layer and get_loss:  # gets loss of first layer's training, after 2-layer MLP is trained ~exhaustively and second layer frozen
                pretrain_output_layer(mlp_model)  # trains model ~exhaustively, freezes second layer, resets first layer
                loss_running_history_sgd_singlelayer, loss_train_history_sgd_singlelayer, loss_test_history_sgd_singlelayer, \
                acc_sgd_singlelayer = traintest_mlp(mlp_model, batch_size_supervised, lr, get_loss, epochs=epochs, seed=seed, dataseed=dataseed)  # trains first layer alone, gets loss
                loss_histories_singlelayer = [loss_running_history_sgd_singlelayer, loss_train_history_sgd_singlelayer,
                                              loss_test_history_sgd_singlelayer]
            else:
                loss_histories_singlelayer = []
                acc_sgd_singlelayer = None
            return loss_histories, loss_histories_singlelayer, acc_sgd, acc_sgd_singlelayer, mlp_model
        loss_histories, loss_histories_singlelayer, acc_sgd, acc_sgd_singlelayer, mlp_model = traintest_mlp_wrapper(seed=seed,
                        get_loss=get_loss, lr=lr, epochs=epochs, dataseed=dataseed, online=online, loss_single_layer=loss_single_layer, optimizer=optim.SGD, device=device)
        save_dict_mlp = dict()
        save_dict_mlp['loss_histories'] = loss_histories
        save_dict_mlp['loss_histories_singlelayer'] = loss_histories_singlelayer
        save_dict_mlp['acc_sgd'] = acc_sgd
        save_dict_mlp['acc_sgd_singlelayer'] = acc_sgd_singlelayer
        torch.save(save_dict_mlp, './loss_history_sup/epochs_'+str(epochs) + '/loss_history_sgd' + online * '_online' +
                   loss_single_layer * '_loss_single_layer' + '_seed' + str(seed) + '.pt')
        print('saved histories')
    print('acc_wta_1l=' + str(acc_wta_1l))
    print('acc_wta_2l=' + str(acc_wta_2l))
    print('acc_sgd=' + str(acc_sgd))
    print('acc_sgd_singlelayer=' + str(acc_sgd_singlelayer))
    if get_loss:
        print('Making loss figure')
        plt.close('all')
        algos = ['SoftHebb', 'SGD']
        if running_train_test_options is None:
            running_train_test_options = ['test']
        loss_running_history_wta = save_dict_wta['loss_running_history_wta']
        loss_train_history_wta = save_dict_wta['loss_train_history_wta']
        loss_test_history_wta = save_dict_wta['loss_test_history_wta']
        loss_histories_wta = [loss_running_history_wta, loss_train_history_wta, loss_test_history_wta]
        history_figure(algos, running_train_test_options, loss_single_layer)

    return save_dict_wta, save_dict_mlp, model_adaptive


def get_adv_samples(attack, classifier, data):
    adversarial_samples = []
    benign_logits = []
    adv_logits = []
    images, labels = data
    benign_logits_batch = classifier.predict(images)
    adv_images = attack.generate(images, labels)
    adv_logits_batch = classifier.predict(adv_images)

    benign_logits.append(benign_logits_batch)
    adv_logits.append(adv_logits_batch)
    adversarial_samples.append(adv_images)

    benign_logits = np.concatenate(benign_logits, 0)
    adv_logits = np.concatenate(adv_logits, 0)
    logits = np.column_stack((benign_logits, adv_logits))
    adversarial_samples = np.concatenate(adversarial_samples, 0)
    return adversarial_samples, logits


def create_art_classifier(model_adaptive, input_shape, device):
    adv_toolbox_approriate_device_name = {'cpu': 'cpu', 'cuda': 'gpu'}[device]

    class FixedForwardParamsModel(nn.Module):
        def __init__(self, model, adapt, wta_prop):
            super().__init__()
            self.model = model
            self.adapt = adapt
            self.wta_prop = wta_prop

        def forward(self, images):
            return self.model(images, self.adapt, self.wta_prop)

    model_with_fixed_forw_params = FixedForwardParamsModel(
        model_adaptive, adapt=False, wta_prop=True)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model_adaptive.parameters(), lr=0.001)
    input_values_range = (0, 1)
    number_classes = 10

    art_classifier = PyTorchClassifier(
        # Here we change models
        model=model_with_fixed_forw_params,
        clip_values=input_values_range,
        loss=criterion,
        optimizer=optimizer,
        input_shape=input_shape,
        nb_classes=number_classes,
        device_type=adv_toolbox_approriate_device_name
    )
    return art_classifier


def main_adversarial(gpu_id, presets='softhebb2000normb0.03_base1000',
         seed=None, dataseed=None):
    preset_dict = set_presets(presets)

    if seed is None:
        seed = preset_dict['seed']
    if dataseed is None:
        dataseed = seed
    epochs = preset_dict['epochs']
    device = init_script(gpu_id)
    model_adaptive, _, _, _ = make_model(device, epochs=epochs, presets=presets, seed=seed,
                                load_final_weights=True, dataseed=dataseed)

    # run natural test
    accuracy_testset = test_multilayer_model(model_adaptive)
    print('Model accuracy (SoftHebb + Supervised Classifier) {:.2f} %'.format(accuracy_testset))

    # run adversarial test
    batch_size = 10
    print('Loading MNIST test dataset in minibatches')
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST(MNIST_PATH, train=False, transform=transforms.Compose([
            transforms.ToTensor(),
        ])),
        batch_size=batch_size, shuffle=True, worker_init_fn=seed_init_fn)
    print('Running PGD attack.')
    data_sample = next(iter(test_loader))
    input_shape = data_sample[0].shape
    art_classifier = create_art_classifier(
        model_adaptive, input_shape, device.type)


    num_iters = 200
    num_restarts = 5
    plt.figure()
    gs1 = gridspec.GridSpec(10, 10)
    gs1.update(wspace=0.0, hspace=0.01)  # set the spacing between axes.
    print('Making increasingly larger attack perturbations and drawing the adversarial examples...')
    col_counter = 0
    for j in np.concatenate((np.linspace(2, 5, 4), np.linspace(6, 8, 6))):  # increasingly larger perturbations
        eps0 = 2**j
        if eps0 == 256:
            eps0 = 255
        eps = eps0 / 255  # epsilon parameter of PGD attack
        attack = ProjectedGradientDescent(
            art_classifier, eps=eps, max_iter=num_iters, num_random_init=num_restarts)

        # logits is the list of 10 benign and 10 adversarial logits, output by the model at every input sample
        adv_samples, logits = get_adv_samples(
            attack, art_classifier, data_sample)
        for i in range(10):
            ax = plt.subplot(gs1[i*10+col_counter])
            if i == 0:
                if col_counter < 9:
                    ax.set_title(r'${:.0f}$'.format(2**j), fontsize=7)
                else:
                    ax.set_title(r'${:.0f}$'.format(2**j-1), fontsize=7)
            plt.imshow(adv_samples[i, 0], cmap='gray_r', vmin=0, vmax=1)
            plt.axis('off')
            plt.pause(0.0002)
        col_counter += 1
    print('Done.')
    return model_adaptive


MNIST_PATH = './opt/datasets'
if __name__ == '__main__':
    gpu_id = 0   # the script uses cpu automatically if gpu does not exist
    model_adaptive = main_adversarial(gpu_id=gpu_id, presets='softhebb2000normb0.03_base1000')
