#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
General useful functions for machine learning with Pytorch.
"""

# Python 2-3 compatible
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import

from typing import List, Union, Sequence, Optional

import numpy as np
import torch
from torch import Tensor
from torch.nn import Module, Parameter

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import seaborn
import io


def shuffle_in_unison(dataset, seed=None, in_place=False):
    """
    Shuffle two (or more) list in unison. It's important to shuffle the images
    and the labels maintaining their correspondence.

        Args:
            dataset: list of shuffle with the same order.
            seed (int): set of fixed Cifar parameters.
            in_place (bool): if we want to shuffle the same data or we want
                             to return a new shuffled dataset.
        Returns:
            list: train and test sets composed of images and labels, if in_place
                  is set to False.
    """

    order = np.array(range(10))
    np.random.shuffle(order)

    if seed:
        np.random.seed(seed)
    rng_state = np.random.get_state()

    order = np.array(range(len(dataset[0])))
    np.random.shuffle(order)
    np.random.set_state(rng_state)

    new_dataset = []
    for x in dataset:
        if in_place:
            x[np.array(range(len(x)))] = x[order]
        else:
            new_dataset.append(x[order])

    if not in_place:
        return new_dataset


def shuffle_in_unison_pytorch(dataset, seed=None):

    shuffled_dataset = []
    perm = torch.randperm(dataset[0].size(0))
    if seed:
        torch.manual_seed(seed)
    for x in dataset:
        shuffled_dataset.append(x[perm])

    return shuffled_dataset


def pad_data(dataset, mb_size):
    """
    Padding all the matrices contained in dataset to suit the mini-batch
    size. We assume they have the same shape.

        Args:
            dataset: sets to pad to reach a multile of mb_size.
            mb_size: mini-batch size.
        Returns:
            list: padded data sets
            int: number of iterations needed to cover the entire training set
                 with mb_size mini-batches.
    """

    num_set = len(dataset)
    x = dataset[0]
    # computing test_iters
    n_missing = x.shape[0] % mb_size
    if n_missing > 0:
        surplus = 1
    else:
        surplus = 0
    it = x.shape[0] // mb_size + surplus

    # padding data to fix batch dimentions
    if n_missing > 0:
        n_to_add = mb_size - n_missing
        for i, data in enumerate(dataset):
            if isinstance(data, Tensor):
                dataset[i] = torch.cat((data[:n_to_add], data))
            else:
                dataset[i] = np.concatenate((data[:n_to_add], data))

    if num_set == 1:
        dataset = dataset[0]

    return dataset, it


def get_accuracy(model, criterion, batch_size, test_x, test_y, use_cuda=True,
                 mask=None, preproc=None):
    """
    Test accuracy given a model and the test data.

        Args:
            model (nn.Module): the pytorch model to test.
            criterion (func): loss function.
            batch_size (int): mini-batch size.
            test_x (tensor): test data.
            test_y (tensor): test labels.
            use_cuda (bool): if we want to use gpu or cpu.
            mask (bool): if we want to maks out some classes from the results.
        Returns:
            ave_loss (float): average loss across the test set.
            acc (float): average accuracy.
            accs (list): average accuracy for class.
    """

    model.eval()

    correct_cnt, ave_loss = 0, 0
    model = maybe_cuda(model, use_cuda=use_cuda)

    num_class = int(np.max(test_y) + 1)
    hits_per_class = [0] * num_class
    pattern_per_class = [0] * num_class
    test_it = test_y.shape[0] // batch_size + 1

    test_x = torch.from_numpy(test_x).type(torch.FloatTensor)
    test_y = torch.from_numpy(test_y).type(torch.LongTensor)

    if preproc:
        test_x = preproc(test_x)

    for i in range(test_it):
        # indexing
        start = i * batch_size
        end = (i + 1) * batch_size

        x = maybe_cuda(test_x[start:end], use_cuda=use_cuda)
        y = maybe_cuda(test_y[start:end], use_cuda=use_cuda)

        logits = model(x)

        if mask is not None:
            # we put an high negative number so that after softmax that prob
            # will be zero and not contribute to the loss
            idx = (torch.FloatTensor(mask).cuda() == 0).nonzero()
            idx = idx.view(idx.size(0))
            logits[:, idx] = -10e10

        loss = criterion(logits, y)
        _, pred_label = torch.max(logits.data, 1)
        correct_cnt += (pred_label == y.data).sum()
        ave_loss += loss.item()

        for label in y.data:
            pattern_per_class[int(label)] += 1

        for i, pred in enumerate(pred_label):
            if pred == y.data[i]:
                hits_per_class[int(pred)] += 1

    accs = np.asarray(hits_per_class) / \
           np.asarray(pattern_per_class).astype(float)

    acc = correct_cnt.item() * 1.0 / test_y.size(0)

    ave_loss /= test_y.size(0)

    return ave_loss, acc, accs


# def get_accuracy2(model, batch_size, test_x, test_y, use_cuda=True,
#                  mask=None, preproc=None):
#     """
#     Test accuracy given a model and the test data.
#
#         Args:
#             model (nn.Module): the pytorch model to test.
#             criterion (func): loss function.
#             batch_size (int): mini-batch size.
#             test_x (tensor): test data.
#             test_y (tensor): test labels.
#             use_cuda (bool): if we want to use gpu or cpu.
#             mask (bool): if we want to maks out some classes from the results.
#         Returns:
#             ave_loss (float): average loss across the test set.
#             acc (float): average accuracy.
#             accs (list): average accuracy for class.
#     """
#
#     model.eval()
#
#     feature_extractor = MyMobilenetV1(pretrained=True, latent_layer_num=19)
#     replace_bn_with_brn(feature_extractor, momentum=0.01, r_d_max_inc_step=4.1e-05, r_max=1.25, d_max=0.5)
#     feature_extractor = maybe_cuda(feature_extractor, use_cuda=use_cuda)
#     feature_extractor.eval()
#
#     correct_cnt, ave_loss = 0, 0
#
#     num_class = int(np.max(test_y) + 1)
#     hits_per_class = [0] * num_class
#     pattern_per_class = [0] * num_class
#     test_it = test_y.shape[0] // batch_size + 1
#
#     test_x = torch.from_numpy(test_x).type(torch.FloatTensor)
#     test_y = torch.from_numpy(test_y).type(torch.LongTensor)
#
#     if preproc:
#         test_x = preproc(test_x)
#
#     for i in range(test_it):
#         # indexing
#         start = i * batch_size
#         end = (i + 1) * batch_size
#
#         x = maybe_cuda(test_x[start:end], use_cuda=use_cuda)
#         y = maybe_cuda(test_y[start:end], use_cuda=use_cuda)
#
#         with torch.no_grad():
#             _, x = feature_extractor(x, return_lat_acts=True)
#             x = x.detach()
#
#         logits = model(x)
#
#         _, pred_label = logits.max(1)
#         correct_cnt += (pred_label == y).sum()
#
#         for label in y.data:
#             pattern_per_class[int(label)] += 1
#
#         for i, pred in enumerate(pred_label):
#             if pred == y.data[i]:
#                 hits_per_class[int(pred)] += 1
#
#     accs = np.asarray(hits_per_class) / \
#            np.asarray(pattern_per_class).astype(float)
#
#     acc = correct_cnt.item() * 1.0 / test_y.size(0)
#
#     return 0, acc, accs


def get_accuracy_conf_matrix(model, criterion, batch_size, test_x, test_y, use_cuda=True,
                 mask=None, preproc=None, classes_n=1000):
    """
    Test accuracy given a model and the test data.

        Args:
            model (nn.Module): the pytorch model to test.
            criterion (func): loss function.
            batch_size (int): mini-batch size.
            test_x (tensor): test data.
            test_y (tensor): test labels.
            use_cuda (bool): if we want to use gpu or cpu.
            mask (bool): if we want to maks out some classes from the results.
        Returns:
            ave_loss (float): average loss across the test set.
            acc (float): average accuracy.
            accs (list): average accuracy for class.
    """

    model.eval()

    confusion_matrix = np.zeros((classes_n, classes_n), dtype=int)

    correct_cnt, ave_loss = 0, 0
    model = maybe_cuda(model, use_cuda=use_cuda)

    num_class = int(np.max(test_y) + 1)
    hits_per_class = [0] * num_class
    pattern_per_class = [0] * num_class
    test_it = test_y.shape[0] // batch_size + 1

    if not isinstance(test_x, Tensor):
        test_x = torch.from_numpy(test_x).type(torch.FloatTensor)
    if not isinstance(test_y, Tensor):
        test_y = torch.from_numpy(test_y).type(torch.LongTensor)

    if preproc:
        test_x = preproc(test_x)

    for i in range(test_it):
        # indexing
        start = i * batch_size
        end = (i + 1) * batch_size

        x = maybe_cuda(test_x[start:end], use_cuda=use_cuda)
        y = maybe_cuda(test_y[start:end], use_cuda=use_cuda)

        logits = model(x)

        if mask is not None:
            # we put an high negative number so that after softmax that prob
            # will be zero and not contribute to the loss
            idx = (torch.FloatTensor(mask).cuda() == 0).nonzero()
            idx = idx.view(idx.size(0))
            logits[:, idx] = -10e10

        loss = criterion(logits, y)
        _, pred_label = torch.max(logits.data, 1)
        correct_cnt += (pred_label == y.data).sum()
        ave_loss += loss.item()

        for label in y.data:
            pattern_per_class[int(label)] += 1

        for k, pred in enumerate(pred_label):
            if pred == y.data[k]:
                hits_per_class[int(pred)] += 1

        for j in range(x.shape[0]):
            if model.cur_j[y[j].item()] > 0 or model.past_j[y[j].item()] > 0:
                confusion_matrix[pred_label[j]][y[j].item()] += 1

    accs = np.asarray(hits_per_class) / \
           np.asarray(pattern_per_class).astype(float)

    acc = correct_cnt.item() * 1.0 / test_y.size(0)

    ave_loss /= test_y.size(0)

    return ave_loss, acc, accs, confusion_matrix


def get_accuracy_conf_matrix_from_dataloader(model, criterion, test_loader, num_class, use_cuda=True,
                                             mask=None, classes_n=1000):
    """
    Test accuracy given a model and the test data.

        Args:
            model (nn.Module): the pytorch model to test.
            criterion (func): loss function.
            batch_size (int): mini-batch size.
            test_loader: test data.
            num_class (int): overall number of classes.
            use_cuda (bool): if we want to use gpu or cpu.
            mask (bool): if we want to maks out some classes from the results.
        Returns:
            ave_loss (float): average loss across the test set.
            acc (float): average accuracy.
            accs (list): average accuracy for class.
    """

    model.eval()

    confusion_matrix = np.zeros((classes_n, classes_n), dtype=int)

    correct_cnt, ave_loss = 0, 0
    model = maybe_cuda(model, use_cuda=use_cuda)

    hits_per_class = [0] * num_class
    pattern_per_class = [0] * num_class

    n_test_patterns = 0

    for x, y in test_loader:
        n_test_patterns += len(x)
        x = maybe_cuda(x, use_cuda=use_cuda)
        y = maybe_cuda(y, use_cuda=use_cuda)

        logits = model(x)

        if mask is not None:
            # we put an high negative number so that after softmax that prob
            # will be zero and not contribute to the loss
            idx = (torch.FloatTensor(mask).cuda() == 0).nonzero()
            idx = idx.view(idx.size(0))
            logits[:, idx] = -10e10

        loss = criterion(logits, y)
        _, pred_label = torch.max(logits.data, 1)
        correct_cnt += (pred_label == y.data).sum()
        ave_loss += loss.item()

        for label in y.data:
            pattern_per_class[int(label)] += 1

        for k, pred in enumerate(pred_label):
            if pred == y.data[k]:
                hits_per_class[int(pred)] += 1

        for j in range(x.shape[0]):
            # if model.cur_j[y[j].item()] > 0 or model.past_j[y[j].item()] > 0:
            confusion_matrix[pred_label[j]][y[j].item()] += 1

    accs = (np.asarray(hits_per_class), np.asarray(pattern_per_class))

    acc = correct_cnt.item() * 1.0 / n_test_patterns

    ave_loss /= n_test_patterns

    return ave_loss, acc, accs, confusion_matrix


def get_accuracy_conf_matrix_nic(model, criterion, batch_size, test_x, test_y, use_cuda=True,
                                 mask=None, preproc=None, cf_labels_inc=None, classes_n=1000):
    if cf_labels_inc is None:
        cf_labels_inc = [_ for _ in range(classes_n)]
    model.eval()

    confusion_matrix = np.zeros((classes_n, classes_n), dtype=int)

    correct_cnt, ave_loss = 0, 0
    model = maybe_cuda(model, use_cuda=use_cuda)

    num_class = int(np.max(test_y) + 1)
    hits_per_class = [0] * num_class
    pattern_per_class = [0] * num_class
    test_it = test_y.shape[0] // batch_size + 1

    test_x = torch.from_numpy(test_x).type(torch.FloatTensor)
    test_y = torch.from_numpy(test_y).type(torch.LongTensor)

    if preproc:
        test_x = preproc(test_x)

    for i in range(test_it):
        # indexing
        start = i * batch_size
        end = (i + 1) * batch_size

        x = maybe_cuda(test_x[start:end], use_cuda=use_cuda)
        y = maybe_cuda(test_y[start:end], use_cuda=use_cuda)

        logits = model(x)

        if mask is not None:
            # we put an high negative number so that after softmax that prob
            # will be zero and not contribute to the loss
            idx = (torch.FloatTensor(mask).cuda() == 0).nonzero()
            idx = idx.view(idx.size(0))
            logits[:, idx] = -10e10

        loss = criterion(logits, y)
        _, pred_label = torch.max(logits.data, 1)
        correct_cnt += (pred_label == y.data).sum()
        ave_loss += loss.item()

        for label in y.data:
            pattern_per_class[int(label)] += 1

        for k, pred in enumerate(pred_label):
            if pred == y.data[k]:
                hits_per_class[int(pred)] += 1

        for j in range(x.shape[0]):
            if model.cur_j[y[j].item()] > 0 or model.past_j[y[j].item()] > 0:
                confusion_matrix[cf_labels_inc[pred_label[j].item()]][cf_labels_inc[y[j].item()]] += 1

    accs = np.asarray(hits_per_class) / \
           np.asarray(pattern_per_class).astype(float)

    acc = correct_cnt.item() * 1.0 / test_y.size(0)

    ave_loss /= test_y.size(0)

    return ave_loss, acc, accs, confusion_matrix


def preprocess_imgs(img_batch, scale=True, norm=True, channel_first=True):
    """
    Here we get a batch of PIL imgs and we return them normalized as for
    the pytorch pre-trained models.

        Args:
            img_batch (tensor): batch of images.
            scale (bool): if we want to scale the images between 0 an 1.
            channel_first (bool): if the channel dimension is before of after
                                  the other dimensions (width and height).
            norm (bool): if we want to normalize them.
        Returns:
            tensor: pre-processed batch.

    """

    if scale:
        # convert to float in [0, 1]
        img_batch = img_batch / 255

    if norm:
        # normalize
        img_batch[:, :, :, 0] = ((img_batch[:, :, :, 0] - 0.485) / 0.229)
        img_batch[:, :, :, 1] = ((img_batch[:, :, :, 1] - 0.456) / 0.224)
        img_batch[:, :, :, 2] = ((img_batch[:, :, :, 2] - 0.406) / 0.225)

    if channel_first:
        # Swap channel dimension to fit the caffe format (c, w, h)
        img_batch = np.transpose(img_batch, (0, 3, 1, 2))

    return img_batch


def maybe_cuda(what, use_cuda=True, **kw):
    """
    Moves `what` to CUDA and returns it, if `use_cuda` and it's available.

        Args:
            what (object): any object to move to eventually gpu
            use_cuda (bool): if we want to use gpu or cpu.
        Returns
            object: the same object but eventually moved to gpu.
    """
    if use_cuda is not False and torch.cuda.is_available():
        what = what.cuda()
    return what


def consolidate_weights(model, cur_clas, cur_clas_batch, classes_n=1000):
    """
    Mean-shift for the target layer weights
    :param model: the classifier model
    :param cur_clas: the class in the current batch (classes in the actual batch + replay)
    :param cur_clas_batch: the class actually in the current batch (no classes in replay memory)
    :param classes_n: the total number of classes
    :return: the changes in the weights after the consolidation step
    """
    with torch.no_grad():
        w_changes = [0.0 for _ in range(classes_n)]
        # calculate the global average of the weight of the classes in the current batch
        # globavg = np.average(model.fc.weight.detach()
        #                      .cpu().numpy()[cur_clas_batch])
        for c in cur_clas:
            # take the weights of the class in exam
            classavg = np.average(model.fc.weight.detach()
                                  .cpu().numpy()[c])
            w = model.fc.weight.detach().cpu().numpy()[c]
            # subtract the mean calculated above
            new_w = w - classavg
            # if the classs in exam is not new (the model already seen the class before)
            if c in model.saved_weights.keys():
                # if the class is in the current batch normal cwr
                if c in cur_clas_batch:
                    wpast_j = np.sqrt(model.past_j[c] / model.cur_j[c])  # sqrt wpastj
                    # wpast_j = (model.past_j[c] / model.cur_j[c])  # no sqrt wpastj
                    new_weight = (model.saved_weights[c] * wpast_j
                                  + new_w) / (wpast_j + 1)
                # if the class in exam is not in the current batch (so its a replay pattern)
                # the weights are not changed at all.
                else:
                    new_weight = model.saved_weights[c]
                # calculate the mean of the difference between old and new weights (abs)
                diff = np.abs(model.saved_weights[c] - new_weight).mean()
                model.saved_weights[c] = new_weight
                w_changes[c] = diff
            # If the model have not seen the class before (new class)
            else:
                # the new weight is the weight in the temp. head (mean already subtracted)
                model.saved_weights[c] = new_w
                w_changes[c] = np.abs(new_w).mean()
    return w_changes


def consolidate_weights_no_mean_sub(model, cur_clas, cur_clas_batch, classes_n=1000):
    """
    Mean-shift for the target layer weights
    :param model: the classifier model
    :param cur_clas: the class in the current batch (classes in the actual batch + replay)
    :param cur_clas_batch: the class actually in the current batch (no classes in replay memory)
    :param classes_n: the total number of classes
    :return: the changes in the weights after the consolidation step
    """
    with torch.no_grad():
        w_changes = [0.0 for _ in range(classes_n)]
        # calculate the global average of the weight of the classes in the current batch
        # globavg = np.average(model.fc.weight.detach()
        #                      .cpu().numpy()[cur_clas_batch])
        for c in cur_clas:
            # take the weights of the class in exam
            new_w = model.fc.weight.detach().cpu().numpy()[c]
            # if the classs in exam is not new (the model already seen the class before)
            if c in model.saved_weights.keys():
                # if the class is in the current batch normal cwr
                if c in cur_clas_batch:
                    wpast_j = np.sqrt(model.past_j[c] / model.cur_j[c])  # sqrt wpastj
                    # wpast_j = (model.past_j[c] / model.cur_j[c])  # no sqrt wpastj
                    new_weight = (model.saved_weights[c] * wpast_j
                                  + new_w) / (wpast_j + 1)
                # if the class in exam is not in the current batch (so its a replay pattern)
                # the weights are not changed at all.
                else:
                    new_weight = model.saved_weights[c]
                # calculate the mean of the difference between old and new weights (abs)
                diff = np.abs(model.saved_weights[c] - new_weight).mean()
                model.saved_weights[c] = new_weight
                w_changes[c] = diff
            # If the model have not seen the class before (new class)
            else:
                # the new weight is the weight in the temp. head (mean already subtracted)
                model.saved_weights[c] = new_w
                w_changes[c] = np.abs(new_w).mean()
    return w_changes



def set_consolidate_weights(model):
    """ set trained weights """

    with torch.no_grad():
        model.fc.weight.fill_(0.0)
        for c, w in model.saved_weights.items():
            model.fc.weight[c].copy_(
                torch.from_numpy(model.saved_weights[c])
            )


def reset_weights(model, cur_clas):
    """ reset weights"""

    with torch.no_grad():
        model.fc.weight.fill_(0.0)
        # model.fc.weight.copy_(
        #     torch.zeros(model.fc.weight.size())
        # )
        for c, w in model.saved_weights.items():
            if c in cur_clas:
                model.fc.weight[c].copy_(
                    torch.from_numpy(model.saved_weights[c])
                )


def examples_per_class(train_y, classes_n=1000):
    count = {i:0 for i in range(classes_n)}
    for y in train_y:
        count[int(y)] +=1

    return count


def freeze_up_to_legacy(model, freeze_below_layer, only_conv=False):
    for name, param in model.named_parameters():
        # tells whether we want to use gradients for a given parameter
        if only_conv:
            if "conv" in name:
                param.requires_grad = False
                # print("Freezing parameter " + name)
        else:
            param.requires_grad = False
            # print("Freezing parameter " + name)

        if name == freeze_below_layer:
            break


def freeze_up_to(model: Module, freeze_below_layer: str, only_conv: bool = False,
                 print_frozen_parameters: bool = False, set_eval_mode: bool = True,
                 module_prefix: str = '') -> bool:
    if not freeze_below_layer:
        return True

    result = False

    for layer_name, layer in model.named_children():
        layer_prefix = module_prefix + ('.' if module_prefix else '') + layer_name
        if freeze_below_layer == layer_prefix:
            result = True

        if result:
            break

        freeze = (not only_conv) or ('conv' in layer_name)
        if freeze:
            has_been_frozen = False

            if set_eval_mode:
                if len(list(layer.children())) == 0:
                    layer.eval()
                    has_been_frozen = True

            for name, param in layer.named_parameters(recurse=False):
                param.requires_grad = False
                has_been_frozen = True

            if print_frozen_parameters and has_been_frozen:
                print('Freezing', layer_prefix)

            if freeze_up_to(layer, freeze_below_layer, only_conv=only_conv, set_eval_mode=set_eval_mode,
                            print_frozen_parameters=print_frozen_parameters, module_prefix=layer_prefix):
                result = True
    return result


def create_syn_data(model, bn_prefix='bn', classification_layer='output'):
    size = 0
    print('Creating Syn data for Optimal params and their Fisher info')

    for name, param in model.named_parameters():
        if bn_prefix not in name and classification_layer not in name:
            print(name, param.flatten().size(0))
            size += param.flatten().size(0)

    # The first array returned is a 2D array: the first component contains
    # the params at loss minimum, the second the parameter importance
    # The second array is a dictionary with the synData
    synData = {}
    synData['old_theta'] = torch.zeros(size, dtype=torch.float32)
    synData['new_theta'] = torch.zeros(size, dtype=torch.float32)
    synData['grad'] = torch.zeros(size, dtype=torch.float32)
    synData['trajectory'] = torch.zeros(size, dtype=torch.float32)
    synData['cum_trajectory'] = torch.zeros(size, dtype=torch.float32)

    return torch.zeros((2, size), dtype=torch.float32), synData


def extract_weights(model, target, bn_prefix='bn', classification_layer='output'):

    with torch.no_grad():
        weights_vector= None
        for name, param in model.named_parameters():
            if bn_prefix not in name and classification_layer not in name:
                # print(name, param.flatten())
                if weights_vector is None:
                    weights_vector = param.flatten()
                else:
                    weights_vector = torch.cat(
                        (weights_vector, param.flatten()), 0)

        target[...] = weights_vector.cpu()


def extract_grad(model, target, bn_prefix='bn', classification_layer='output'):
    # Store the gradients into target
    with torch.no_grad():
        grad_vector = None
        for name, param in model.named_parameters():
            if bn_prefix not in name and classification_layer not in name:
                # print(name, param.flatten())
                if grad_vector is None:
                    grad_vector = param.grad.flatten()
                else:
                    grad_vector = torch.cat(
                        (grad_vector, param.grad.flatten()), 0)

        target[...] = grad_vector.cpu()


def weight_stats(net, ewcData, clip_to, bn_prefix='bn', classification_layer='output'):
    print('Average F saturation = %.3f%%' % (100*torch.sum(ewcData[1])/(ewcData[1].shape[0]*clip_to)),
          ' Max = ', torch.max(ewcData[1]), ' Size = ', ewcData[1].shape[0])
    offset = 0
    checksum = 0
    level_sum = {}
    for levname, param in net.named_parameters():
        if bn_prefix not in levname and classification_layer not in levname:
            sizew = param.flatten().size(0)
            level_sum[levname, 0] = torch.sum(ewcData[1][offset:offset+sizew])
            level_sum[levname, 1] = 0
            print(levname, ' W = %.3f%%  B = %.3f%%' % (100*level_sum[levname, 0] / (sizew*clip_to),
                                                        0.0))
            # checksum += level_sum[levname, 0] + level_sum[levname, 1]
            offset += sizew
    # print('CheckSum Weights: ', checksum, ' Size = ', offset)


def init_batch(net, ewcData, synData, bn_prefix='bn', classification_layer='output'):
    extract_weights(net, ewcData[0], bn_prefix=bn_prefix, classification_layer=classification_layer)  # Keep initial weights
    synData['trajectory'] = 0


def pre_update(net, synData, bn_prefix='bn', classification_layer='output'):
    extract_weights(net, synData['old_theta'], bn_prefix=bn_prefix, classification_layer=classification_layer)


def post_update(net, synData, bn_prefix='bn', classification_layer='output'):
    extract_weights(net, synData['new_theta'], bn_prefix=bn_prefix, classification_layer=classification_layer)
    extract_grad(net, synData['grad'], bn_prefix=bn_prefix, classification_layer=classification_layer)

    synData['trajectory'] += synData['grad'] * (
                    synData['new_theta'] - synData['old_theta'])


def update_ewc_data(net, ewcData, synData, clip_to, c=0.0015, bn_prefix='bn', classification_layer='output'):
    extract_weights(net, synData['new_theta'], bn_prefix=bn_prefix, classification_layer=classification_layer)
    eps = 0.0000001  # 0.001 in few task - 0.1 used in a more complex setup

    synData['cum_trajectory'] += c * synData['trajectory'] / (
                    np.square(synData['new_theta'] - ewcData[0]) + eps)

    ewcData[1] = torch.empty_like(synData['cum_trajectory'])\
        .copy_(-synData['cum_trajectory'])
    # change sign here because the Ewc regularization
    # in Caffe (theta - thetaold) is inverted w.r.t. syn equation [4]
    # (thetaold - theta)
    ewcData[1] = torch.clamp(ewcData[1], max=clip_to)
    # (except CWR)
    ewcData[0] = synData['new_theta'].clone().detach()


def compute_ewc_loss(model, ewcData, lambd=0, bn_prefix='bn', classification_layer='output'):

    weights_vector = None
    for name, param in model.named_parameters():
        if bn_prefix not in name and classification_layer not in name:
            # print(name, param.flatten())
            if weights_vector is None:
                weights_vector = param.flatten()
            else:
                weights_vector = torch.cat(
                    (weights_vector, param.flatten()), 0)

    ewcData = maybe_cuda(ewcData, use_cuda=True)
    loss = (lambd / 2) * torch.dot(ewcData[1], (weights_vector - ewcData[0])**2)
    return loss


def gen_plot(df, max_value=None):
    """Create a pyplot plot and save to buffer."""
    plt.figure(figsize=(25, 25))
    seaborn.set(font_scale=1)
    seaborn.heatmap(df, annot=False, annot_kws={"size": 12}, fmt="d", cmap="YlGnBu_r", cbar=False,
                    xticklabels=False, yticklabels=False, vmin=0, vmax=max_value)
    buf = io.BytesIO()
    plt.savefig(buf, format='jpeg')
    buf.seek(0)
    return buf


def ar1f_get_params_lr(model: Module, cwr_layers_names: Union[str, Sequence[str]],
                       freeze_below_layer: Optional[str] = None, print_frozen_parameters: bool = False) -> \
        (List[Parameter], List[Parameter], List[Parameter]):

    if isinstance(cwr_layers_names, str):
        cwr_layers_names = [cwr_layers_names]

    cwr_layer_weights_names = [weight_name + '.weight' for weight_name in cwr_layers_names]
    cwr_layer_bias_names = [weight_name + '.bias' for weight_name in cwr_layers_names]

    if freeze_below_layer is None:
        freeze_below_layer = ''

    freeze_below_layer_params = [freeze_below_layer + '.weight', freeze_below_layer + '.bias']

    freeze_layer_found = False
    standard_parameter_names = set()
    frozen_parameter_names = set()
    frozen_parameter_names.update(cwr_layer_bias_names)

    freeze_next_layers = True
    if not freeze_below_layer:
        # If freeze_below_layer is empty, don't freeze!
        freeze_next_layers = False
    for param_name, param in model.named_parameters():
        if param_name in freeze_below_layer_params or not freeze_next_layers:
            freeze_next_layers = False
            freeze_layer_found = True

            if param_name not in frozen_parameter_names:
                if param_name not in cwr_layer_weights_names:
                    standard_parameter_names.add(param_name)
        else:  # Frozen layer
            if param_name not in cwr_layer_weights_names:
                frozen_parameter_names.add(param_name)
                if print_frozen_parameters:
                    print("Freezing parameter " + param_name)

    if freeze_below_layer and not freeze_layer_found:
        raise ValueError('Freeze layer not found: ' + str(freeze_below_layer))

    frozen_params = list(filter(lambda kv: kv[0] in frozen_parameter_names, model.named_parameters()))
    standard_params = list(filter(lambda kv: kv[0] in standard_parameter_names, model.named_parameters()))
    cwr_mul_params = list(filter(lambda kv: kv[0] in cwr_layer_weights_names, model.named_parameters()))

    frozen_params = [p[1] for p in frozen_params]
    standard_params = [p[1] for p in standard_params]
    cwr_mul_params = [p[1] for p in cwr_mul_params]

    return standard_params, cwr_mul_params, frozen_params


__all__ = ['shuffle_in_unison', 'shuffle_in_unison_pytorch', 'pad_data', 'get_accuracy',
           'get_accuracy_conf_matrix', 'get_accuracy_conf_matrix_from_dataloader', 'get_accuracy_conf_matrix_nic',
           'preprocess_imgs', 'maybe_cuda', 'consolidate_weights', 'set_consolidate_weights', 'reset_weights',
           'examples_per_class', 'freeze_up_to_legacy', 'freeze_up_to', 'create_syn_data', 'extract_weights',
           'extract_grad', 'weight_stats', 'init_batch', 'pre_update', 'post_update', 'update_ewc_data',
           'compute_ewc_loss', 'gen_plot', 'ar1f_get_params_lr']
