import torch
import math
import random
import numpy as np
import ShiftingWindowSetting as sw
from torch.utils.data import DataLoader
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from matplotlib import cm
from torch import nn


# this is a temp solution for compute results of DeepCCG with a standard head (should really inherit from DeepCCG...)
class DeepCCG_SH(sw.CLLearningAlgo):

    per_class_mem = {}

    def __init__(self, args, use_task_inc_loss=True, mem_size=1000, mem_batch_size=10, min_num_mem_points=1,  # 1
                 num_of_iter_new_mem=30, mem_lr=0.1, fit_mem_batch_size=32, diag_approx=True,
                 new_mem_diag_approx=False, s_coeff=1, fixed_var=True, calc_stat_batch_size=128,
                 mem_select_KLD=True, store_old_means=False):
        super().__init__(args=args)
        self.mem_size = mem_size
        self.mem_batch_size = mem_batch_size
        self.seen_classes = []
        self.min_num_mem_points = min_num_mem_points
        self.use_task_inc_loss = use_task_inc_loss
        self.fixed_var = fixed_var
        self.mem_select_KLD = mem_select_KLD if fixed_var else False
        self.calc_stat_batch_size = calc_stat_batch_size
        for name, module in self.model.named_modules():
            if "linear" in name:
                self.rep_size = module.out_features

        self.rep_size = self.model_width * 8  # * 16  # this is a temp test

        # these are used in evauation to store the memory buffer's statistics
        self.mem_means = None
        self.mem_precisions = None
        self.mem_log_sqrt_dets = None

        # used for fitting new memory
        self.num_of_iter_new_mem = num_of_iter_new_mem  # can perform random sample for new mem by having this as 0
        self.mem_lr = mem_lr
        self.fit_mem_batch_size = fit_mem_batch_size
        self.class_sorted = {}

        # these ares used to state whether to use a diagonal approximation of the covariance firstly in the method
        # as a whole and secondly whether to do so in the memory selection part
        self.diag_approx = diag_approx
        # not used at the moment perhaps try out when rest of code is good
        self.new_mem_diag_approx = new_mem_diag_approx
        self.det_bias = 0.0  # this is to make sure that the det dose not underflow

        # these are the priors to our model params
        self.s_coeff = s_coeff
        self.S_init = s_coeff*torch.ones(self.rep_size) if (self.diag_approx or self.fixed_var) \
            else s_coeff*torch.eye(self.rep_size)
        self.S_init = self.S_init.to(self.device)  # perhaps for lagers datasets must only send to device when needed?

        # setting this to true makes the model store the means (using the embbing function at the time)
        # and number of thrown away data points for each class to be used as the mean prior
        self.store_old_means = store_old_means
        self.old_data_sum = None
        self.number_of_old_data_points = None
        self.old_data_means = None
        self.old_data_var_reg = None

    # calc mean and "variance" (not technically) using memory data (which is a per class dict)
    def _calc_mean_and_var(self, memory, classes):

        mem_means = {c: torch.zeros(self.rep_size, device=self.device) for c in classes}
        mem_vars = {c: torch.zeros(self.rep_size if self.diag_approx else (self.rep_size, self.rep_size),
                                   device=self.device) for c in classes}

        for c in classes:
            n = len(memory[c])
            mem_dataloader = DataLoader(memory[c], batch_size=self.calc_stat_batch_size)
            for (X, _) in mem_dataloader:
                X = X.to(self.device)
                #r = torch.sigmoid(self.model(X))
                r = self.model.calc_representation(X)
                mem_means[c] += r.sum(0)
                mem_vars[c] += (r**2).sum(0) if self.diag_approx else torch.matmul(torch.t(r), r)
            mem_means[c] /= n
            if self.diag_approx:
                #mem_vars[c] = 1/n * (self.S_init + mem_vars[c]) - mem_means[c]**2
                mem_vars[c] = self.S_init  # this is temporary remove after testing
            else:
                mem_vars[c] = 1/n * (self.S_init + mem_vars[c]) - \
                              torch.mm(mem_means[c].reshape(-1, 1), mem_means[c].reshape(1, -1))

        #print([mem_means[c][0].item() for c in mem_means])
        return mem_means, mem_vars

    def _calc_means(self, memory, classes):

        if len(classes) == 0:
            return {}

        mem_means = {c: torch.zeros(self.rep_size, device=self.device) for c in classes}

        mem_dataloader = DataLoader([(data[0], c) for c in classes for data in memory[c]],
                                    shuffle=True, batch_size=self.calc_stat_batch_size)
        for (X, Y) in mem_dataloader:
            X = X.to(self.device)
            # r = torch.sigmoid(self.model(X))
            r = self.model.calc_representation(X)
            for c in classes:
                mem_means[c] += r[Y==c].sum(0)
        for c in classes:
            mem_means[c] /= len(memory[c])

        return mem_means

    def _calc_mem_sum(self, memory, classes):

        if len(classes) == 0:
            return {}

        mem_sum = {c: torch.zeros(self.rep_size, device=self.device) for c in classes}

        mem_dataloader = DataLoader([(data[0], c) for c in classes for data in memory[c]],
                                    shuffle=True, batch_size=self.calc_stat_batch_size)
        for (X, Y) in mem_dataloader:
            X = X.to(self.device)
            # r = torch.sigmoid(self.model(X))
            r = self.model.calc_representation(X)
            for c in classes:
                mem_sum[c] += r[Y==c].sum(0)
        return mem_sum

    # invert and find det of "variance"
    def _invs_and_get_dets(self, mem_vars):
        precisions = {}
        log_sqrt_dets = {}
        var_dataloader = DataLoader(list(mem_vars.items()), batch_size=self.calc_stat_batch_size)
        for (C, V) in var_dataloader:
            V = V.to(self.device)
            invV = 1/V if self.diag_approx else torch.inverse(V)
            log_sqrt_detV = 1/2*torch.log(V).sum(1) if self.diag_approx else torch.det(V)
            for i in range(C.shape[0]):
                precisions[C[i].item()] = invV[i]
                log_sqrt_dets[C[i].item()] = log_sqrt_detV[i].item()

        # transpose precisions so they broadcast along rows
        for y in precisions:
            precisions[y] = torch.unsqueeze(precisions[y], dim=0)

        return precisions, log_sqrt_dets

    # returns a tensor of log probs for a batch of data
    # where the rows are the data points and the columns are the classes
    def _batch_log_probs(self, X, means, precisions, log_sqrt_dets, classes, memory):
        X = X.to(self.device)
        #r = torch.sigmoid(self.model(X))
        r = self.model.calc_representation(X)
        log_probs = {}
        for c in classes:
            if self.fixed_var:
                if self.store_old_means:
                    var = self.number_of_old_data_points[c]/self.old_data_var_reg[c]
                    inv_n = 1/(len(memory[c])+var) + 1
                else:
                    inv_n = 1 / len(memory[c]) + 1
                k = -1 / (2 * self.s_coeff * inv_n)
                log_probs[c] = ((r - means[c]) ** 2).sum(1) * k - r.shape[1] / 2 * math.log(inv_n)
            else:
                centered_r = r - means[c]
                if self.diag_approx:
                    log_probs[c] = 1+1/(len(memory[c]) + self.rep_size)*torch.log((precisions[c]*(centered_r**2)).sum(1))
                else:
                    log_probs[c] = torch.log(torch.matmul(torch.matmul(torch.t(centered_r), precisions[c]), centered_r))
                log_probs[c] = log_sqrt_dets[c] - (1 + len(memory[c]) + self.rep_size) / 2 * log_probs[c]
        # check to make sure this calcs correct thing
        return torch.stack([log_probs[c] for c in classes], dim=1)

    # calculates the necessary stats of the given memory to compute the posterior predictive
    def _calc_stats(self, memory, classes):
        if self.fixed_var:
            return self._calc_means(memory, classes), {c: 1/((1/len(memory[c])+1)*self.S_init) for c in classes}, None

        # calc mean and "variance" (not technically) using memory data
        mem_means, mem_vars = self._calc_mean_and_var(memory, classes)

        # invert and find det of "variance"
        precisions, log_sqrt_dets = self._invs_and_get_dets(mem_vars)

        return mem_means, precisions, log_sqrt_dets

    def _current_batch_loss(self, X, Y, means, precisions, log_sqrt_dets, classes, memory):
        loss = torch.zeros(1, device=self.device)
        if len(classes) == 0:
            return loss
        data_log_probs = self._batch_log_probs(X, means, precisions, log_sqrt_dets, classes, memory)
        z = torch.logsumexp(data_log_probs, 1)
        for i in range(Y.shape[0]):
            try:
                j = classes.index(Y[i].item())
                loss -= data_log_probs[i, j] - z[i]
            except ValueError:
                continue
        return 1/Y.shape[0]*loss

    # calculates by rejection sampling the replay data, in-place in the inputs
    def _internal_replay_sample_calc(self, per_task_replay_data, memory, probs, num_of_sample_points):
        ys = np.random.multinomial(num_of_sample_points, probs)
        num_of_under_sampled_classes = 0
        delta = 0
        for i, n in enumerate(ys):
            if n == 0:
                continue
            y = list(self.per_class_mem.keys())[i]
            full_sample = True
            class_delta = len(memory[y]) - self.min_num_mem_points
            if class_delta <= 0:
                probs[i] = 0.0
                num_of_under_sampled_classes += 1
                delta += n
                continue
            if class_delta < n:
                indexes = random.sample(range(len(memory[y])), class_delta)
                full_sample = False
            else:
                indexes = random.sample(range(len(memory[y])), n)
            indexes = sorted(indexes, reverse=True)
            for index in indexes:
                per_task_replay_data[memory[y][index][1]].append((memory[y][index][0], y))
                del memory[y][index]  # this deletes from per_class_mem
            if not full_sample:
                probs[i] = 0.0
                num_of_under_sampled_classes += 1
                delta += n - class_delta

        # if we could not sample some points due to classes not having enough points sample the rest from other classes
        if num_of_under_sampled_classes > 0:
            all_zero = True
            for j in range(len(probs)):
                if probs[j] > 0:
                    all_zero = False
                    probs[j] = 1 / (1 / probs[j] - num_of_under_sampled_classes)
            if not all_zero:
                self._internal_replay_sample_calc(per_task_replay_data, memory, probs, delta)

    def _calc_random_replay_sample(self, per_task_replay_data, memory, num_of_samples):
        if num_of_samples == 0:
            return
        lengths = {}
        n = 0
        for y in memory:
            temp = len(self.per_class_mem[y]) - self.min_num_mem_points
            if temp > 0:
                lengths[y] = temp
                n += temp
        if n <= num_of_samples:
            indexes = list(range(n))
        else:
            indexes = random.sample(range(n), k=num_of_samples)
            indexes = sorted(indexes)
        class_grouped_indexes = {y: [] for y in lengths}
        acc = 0
        last_acc = 0
        for y in lengths:
            acc += lengths[y]
            for i, index in enumerate(indexes):
                if index < acc:
                    class_grouped_indexes[y].append(index-last_acc+self.min_num_mem_points)
                else:
                    i -= 1
                    break
            del indexes[:i+1]
            last_acc = acc
        for y in class_grouped_indexes:
            # random.shuffle(memory[y])  # this is a temp test
            for index in reversed(class_grouped_indexes[y]):
                per_task_replay_data[memory[y][index][1]].append((memory[y][index][0], y))
                del memory[y][index]

    # sample points from memory
    # currently do this in a basic manner might need to/want to change this to improve performance
    def _split_memory_for_per_task_replay(self):
        memory = {y: list(self.per_class_mem[y]) for y in self.per_class_mem}
        per_task_replay_data = {t: [] for t in range(self.task_id+1)}
        # self._internal_replay_sample_calc(per_task_replay_data, memory,
        #                                  [1/len(self.per_class_mem)]*len(self.per_class_mem), self.mem_batch_size)
        self._calc_random_replay_sample(per_task_replay_data, memory, self.mem_batch_size)

        # calculate all the tasks and associated classes in replay buffer
        memory_classes = {}
        for t in range(self.task_id+1):
            if per_task_replay_data[t] == []:
                del per_task_replay_data[t]
            else:
                memory_classes[t] = self.calc_classes(t, self.task_stream.classes)

        return per_task_replay_data, memory, memory_classes

    def _calc_class_inc_replay_loss(self, per_task_replay_data, means, precisions,
                                    log_sqrt_dets, classes, memory):
        replay_loss = torch.zeros(1, device=self.device)
        replay_data = []
        for t in per_task_replay_data:
            replay_data += per_task_replay_data[t]
        data_loader = DataLoader(replay_data, batch_size=self.batch_size)
        for (X, Y) in data_loader:
            log_probs = self._batch_log_probs(X, means, precisions, log_sqrt_dets, classes, memory)
            z = torch.logsumexp(log_probs, 1)
            for i in range(Y.shape[0]):
                j = classes.index(Y[i].item())
                replay_loss -= log_probs[i, j] - z[i]
        return 1/len(replay_data)*replay_loss if len(replay_data) != 0 else replay_loss

    def _calc_task_inc_replay_loss(self, per_task_replay_data, means, precisions,
                                    log_sqrt_dets, memory_classes, memory):
        replay_loss = torch.zeros(1, device=self.device)
        n = 0
        for t in per_task_replay_data:
            data_loader = DataLoader(per_task_replay_data[t], batch_size=self.batch_size)
            n += len(per_task_replay_data[t])
            for (X, Y) in data_loader:
                log_probs = self._batch_log_probs(X, means, precisions, log_sqrt_dets, memory_classes[t], memory)
                z = torch.logsumexp(log_probs, 1)
                for i in range(Y.shape[0]):
                    j = memory_classes[t].index(Y[i].item())
                    replay_loss -= log_probs[i, j] - z[i]
        return 1/n*replay_loss if n != 0 else replay_loss

    def loss_fn(self, X, Y):
        # in task inc setting per_class_mem also contains task indexes with every x, in a pair (x, t)

        # if memory is empty cannot compute loss
        if len(self.per_class_mem) == 0:
            # at start of experiment init our old mean values (used in mean prior) if selected to do so
            if self.store_old_means:
                self.old_data_means = {y: torch.zeros(self.rep_size, device=self.device) for y in self.task_stream.classes}
                self.old_data_sum = {y: torch.zeros(self.rep_size, device=self.device) for y in self.task_stream.classes}
                self.number_of_old_data_points = {y: 0 for y in self.task_stream.classes}
                self.old_data_var_reg = {y: 1.0 for y in self.task_stream.classes}
            return torch.zeros(1, device=self.device, requires_grad=True)

        # sample points from memory
        # currently do this in a basic manner might need to/want to change this to improve performance
        per_task_replay_data, memory, memory_classes = self._split_memory_for_per_task_replay()

        # calculate the current task classes
        current_task_classes = self.calc_classes(self.task_id, self.task_stream.classes)

        # remove classes which do not have enough data in memory yet
        temp = []
        for c in current_task_classes:
            if c in memory.keys() and \
                    len(memory[c]) >= self.min_num_mem_points:
                temp.append(c)
        current_task_classes = temp

        if self.use_task_inc_loss:
            classes = set()
            for t in memory_classes:
                if t == self.task_id:
                    memory_classes[t] = current_task_classes
                classes.update(memory_classes[t])
            classes.update(current_task_classes)
            classes = list(classes)
        else:
            classes = []
            for y in memory:
                if len(memory[y]) >= self.min_num_mem_points:
                    classes.append(y)

        if len(classes) == 0:
            return torch.zeros(1, device=self.device, requires_grad=True)

        # calculate the losses
        #if not self.store_old_means:
        #    means, precisions, log_sqrt_dets = self._calc_stats(memory, classes)
        #else:
        #    means = {}
        #    mem_sums = self._calc_mem_sum(memory, classes)
        #    precisions, log_sqrt_dets = None, None
        #    for y in mem_sums:
                # if self.number_of_old_data_points[y] > 0.01:
                # means[y] = len(memory[y])*means[y]+self.old_data_sum[y]
        #        inv_var = self.number_of_old_data_points[y]/self.old_data_var_reg[y]
                # means[y] = len(memory[y]) * means[y] + self.number_of_old_data_points[y]*self.old_data_means[y]
                # means[y] /= len(memory[y])+self.number_of_old_data_points[y]
        #        means[y] = (mem_sums[y] + self.old_data_sum[y]/self.old_data_var_reg[y])/(len(memory[y])+inv_var)

        #if self.use_task_inc_loss:
        #    replay_loss = self._calc_task_inc_replay_loss(per_task_replay_data, means, precisions,
        #                            log_sqrt_dets, memory_classes, memory)
        #else:
        #    replay_loss = self._calc_class_inc_replay_loss(per_task_replay_data, means, precisions,
        #                            log_sqrt_dets, classes, memory)

        #replay_loss = torch.zeros(1, device=self.device, requires_grad=True) if self.mem_batch_size == 0 else replay_loss

        #batch_loss = self._current_batch_loss(X, Y, means, precisions, log_sqrt_dets, current_task_classes, memory)

        if self.store_old_means:
            for y in self.number_of_old_data_points:
                if self.number_of_old_data_points[y] > 0.01:
                    #self.number_of_old_data_points[y] *= 0.5
                    self.old_data_var_reg[y] += 0.5


        #return batch_loss + replay_loss

        loss = torch.zeros(1, device=self.device)
        for t in per_task_replay_data:
            nullClasses = self.calc_null_classes(t, self.task_stream.classes)
            mem_X, mem_Y = list(zip(*per_task_replay_data[t]))
            mem_X = torch.stack(mem_X).to(self.device)
            mem_Y = torch.tensor(mem_Y, device=self.device)
            loss += nn.functional.cross_entropy(sw.calc_model_output(self.model, mem_X, nullClasses), mem_Y)

        return nn.functional.cross_entropy(super(DeepCCG_SH, self).predict(X), Y) + loss

    def eval(self):
        super().eval()

        # calculate memory statistics
        self.mem_means, self.mem_precisions, self.mem_log_sqrt_dets = self._calc_stats(self.per_class_mem,
                                                                                       self.per_class_mem.keys())
        for y in self.per_class_mem:
            self.mem_means[y] = self.mem_means[y].detach()
            self.mem_precisions[y] = self.mem_precisions[y].detach()
            self.mem_means[y].requires_grad = False
            self.mem_precisions[y].requires_grad = False

    def _calc_predictions(self, X, classes):
        log_probs = self._batch_log_probs(X, self.mem_means, self.mem_precisions,
                                           self.mem_log_sqrt_dets, classes, self.per_class_mem)
        classes = torch.tensor(classes, device=self.device)
        return classes[log_probs.argmax(1)]

    def calc_class_inc_model_predictions(self, X):
        # use all seen classes in prediction as test time according to one version of class inc setup
        classes = list(self.per_class_mem.keys())
        return self._calc_predictions(X, classes)

    def calc_task_inc_model_predictions(self, X):
        # calculate the current task classes
        classes = self.calc_classes(self.task_id, self.task_stream.classes)
        return self._calc_predictions(X, classes)

    def predict(self, X):
        if self.training:
            return X
        else:
            return self.calc_task_inc_model_predictions(X) if self.use_task_inc_loss \
                else self.calc_class_inc_model_predictions(X)

    def _get_mem_representations(self, class_memory):
        mem_dataloader = DataLoader(class_memory, batch_size=self.calc_stat_batch_size)  # should I use stat_batch_size here?
        r_class_memory = []
        for (X, _) in mem_dataloader:
            X = X.to(self.device)
            #r = torch.sigmoid(self.model(X))
            r = self.model.calc_representation(X)
            r = r.detach()  # don't want to pass grad through embedding
            r.requires_grad = False
            for i in range(X.shape[0]):
                r_class_memory.append(r[i])

        return r_class_memory

    def _mem_select_loss(self, class_memory_r, mem_selector, mem_mean, mem_var, new_mem_size):
        n = len(class_memory_r)
        # calculate the weighted mean and variance
        weighted_mean = torch.zeros(self.rep_size, device=self.device)
        if not self.mem_select_KLD:
            weighted_var = torch.zeros((self.rep_size, self.rep_size), device=self.device)
        mem_dataloader = DataLoader(class_memory_r, batch_size=self.fit_mem_batch_size)

        start, stop = 0, 0  # used to select what weights to use
        for r in mem_dataloader:
            stop = start + r.shape[0]
            weighted_r = mem_selector(r, start, stop)  # just calculates weights * r
            weighted_mean += weighted_r.sum(0)
            if not self.mem_select_KLD:
                weighted_var += (r*weighted_r).sum(0) if self.diag_approx else torch.matmul(torch.t(r), weighted_r)
            start = stop
        mod_n = mem_selector.calc_weights(0, n).sum() if self.fixed_var else n
        weighted_mean /= mod_n
        if not self.mem_select_KLD:
            if self.diag_approx:
                weighted_var = 1 / mod_n * weighted_var  # - weighted_mean**2
            else:
                weighted_var = 1 / mod_n * weighted_var - torch.mm(weighted_mean.reshape(-1, 1), weighted_mean.reshape(1, -1))

        # calc squared error loss with l_1 reg (where number of sample is used as the reg coefficient)
        # is this a good value of reg coefficient?
        if self.mem_select_KLD:
            return ((mem_mean-weighted_mean)**2).sum() + n/new_mem_size*mem_selector.calc_weights(0, n).abs().sum()
        else:
            return ((mem_mean-weighted_mean)**2).sum()+((mem_var-weighted_var)**2).sum() \
               + n/new_mem_size*mem_selector.calc_weights(0, n).sum()

    # calculates the new memory after training on a batch
    def calc_new_mem(self, X, Y):

        self.model.eval()

        # add batch to memory
        for i in range(Y.shape[0]):
            y = Y[i].item()
            if y in self.per_class_mem:
                self.per_class_mem[y].append((X[i].to("cpu"), self.task_id))
            else:
                self.per_class_mem[y] = [(X[i].to("cpu"), self.task_id)]

        # calculate the number of memory locations which are going to be used by the classes,
        # keeping it as balanced as possible but also using up all the memory if possible
        per_class_mem_size = np.array([0]*len(self.task_stream.classes))
        classes = np.array(list(self.per_class_mem.keys()))
        start = 0
        end = 0
        remaining_mem = self.mem_size
        ns = [len(self.per_class_mem[y]) for y in self.per_class_mem]
        sorted_ns_indices = sorted(range(len(ns)), key=lambda j: ns[j])
        while start < len(ns):
            min_n = ns[sorted_ns_indices[start]]
            mem_n = remaining_mem//(len(ns)-start)
            if min_n > mem_n:
                per_class_mem_size[classes[sorted_ns_indices[start:]]] = mem_n
                break
            more_left = False
            for end, sorted_index in enumerate(sorted_ns_indices[start:]):
                if ns[sorted_index] != min_n:
                    more_left = True
                    break
            if not more_left:
                per_class_mem_size[classes[sorted_ns_indices[start:]]] = min_n
                break
            end += start
            per_class_mem_size[classes[sorted_ns_indices[start:end]]] = min_n
            remaining_mem -= min_n*(end-start)
            start = end

        batch_classes = []
        for i in range(Y.shape[0]):
            batch_classes.append(Y[i].item())

        # calculate the class whose memory needs to be updated using computationally heavy fitting mechanism
        fit_classes = []
        for y in self.per_class_mem:
            n = len(self.per_class_mem[y])

            # if statements used to simplify computation where possible
            # class_sorted currently dict of class to bool could make more mem efficient
            if n <= per_class_mem_size[y]:
                if y in batch_classes:
                    self.class_sorted[y] = False
                continue
            if y not in batch_classes:
                if y in self.class_sorted and self.class_sorted[y]:
                    self.per_class_mem[y] = [self.per_class_mem[y][i] for i in range(per_class_mem_size[y])]
                    continue
            fit_classes.append(y)

        # calculate memory statistics
        if self.fixed_var:
            mem_means = self._calc_means(self.per_class_mem, fit_classes)
            # mem_vars = {y: (1+1/len(self.per_class_mem[y]))*self.S_init for y in fit_classes}
            mem_vars = {y: (1 + 1 / len(self.per_class_mem[y])) * self.S_init + mem_means[y]**2 for y in fit_classes}
        else:
            mem_means, mem_vars = self._calc_mean_and_var(self.per_class_mem, fit_classes)

        # don't backprop through memory statistics
        # check if this work correctly
        for y in fit_classes:
            mem_means[y] = mem_means[y].detach()
            mem_vars[y] = mem_vars[y].detach()
            mem_means[y].requires_grad, mem_vars[y].requires_grad = False, False

        # for each class calculate new memory
        for y in fit_classes:
            n = len(self.per_class_mem[y])

            # this computes the memory items to store by matching parameters
            mem_selector = MemSelectModel(n, device=self.device)
            optimiser = torch.optim.SGD(mem_selector.parameters(), lr=self.mem_lr)
            class_mem_r = self._get_mem_representations(self.per_class_mem[y])
            for i in range(self.num_of_iter_new_mem):
                optimiser.zero_grad()
                loss = self._mem_select_loss(class_mem_r, mem_selector, mem_means[y], mem_vars[y], per_class_mem_size[y])
                loss.backward()
                optimiser.step()
            _, indices = torch.topk(mem_selector.weights, per_class_mem_size[y], sorted=True)

            if self.store_old_means:
                dropped_rs = []
                for i in range(n):
                    if i not in indices:
                        dropped_rs.append(class_mem_r[i])

                r_sum = torch.stack(dropped_rs).sum(dim=0)
                #self.old_data_means[y] = self.number_of_old_data_points[y]*self.old_data_means[y]+r_sum
                #self.old_data_means[y] /= self.number_of_old_data_points[y] + len(dropped_rs)
                self.old_data_sum[y] += r_sum
                self.number_of_old_data_points[y] += len(dropped_rs)
                self.old_data_means[y] = self.old_data_sum[y]/self.number_of_old_data_points[y]

            self.per_class_mem[y] = [self.per_class_mem[y][i] for i in indices]
            #self.per_class_mem[y] = random.sample(self.per_class_mem[y], k=per_class_mem_size[y])

        self.model.train()

    def after_optimiser_step(self):
        self.calc_new_mem(*self.batch)


# this class is a wrapper to weighting our representations such that we can use a torch optimiser to fit the weights
class MemSelectModel(torch.nn.Module):
    def __init__(self, num_of_data_points, device):
        super().__init__()
        weights = torch.ones(num_of_data_points, device=device)
        torch.nn.init.normal_(weights, mean=1.25, std=0.1)
        self.weights = torch.nn.Parameter(weights)

    def calc_weights(self, start, stop):
        return torch.sigmoid(2*self.weights[start:stop, None])

    def forward(self, r, start, stop):
        return self.calc_weights(start, stop)*r  # makes the weights broadcastable across rows


def t_sne_plot(means, per_class_r):
    tsne = TSNE(2, verbose=1)
    data = [means[y] for y in means]
    classes = list(means.keys())
    for y in classes:
        data += per_class_r[y]
    tsne_proj = tsne.fit_transform(data)
    # Plot those points as a scatter plot and label them based on the pred labels
    cmap = cm.get_cmap('tab20')
    fig, ax = plt.subplots(figsize=(8, 8))
    offset = len(classes)
    for i, y in enumerate(classes):
        indices = [i] + list(range(offset, offset+len(per_class_r[y])))
        offset += len(per_class_r[y])
        ax.scatter(tsne_proj[indices, 0], tsne_proj[indices, 1], c=np.array(cmap(y)).reshape(1, 4), label=y,
                   alpha=0.5)
    ax.legend(fontsize='large', markerscale=2)
    plt.show()
