from typing import Callable, Tuple, Iterable, List
from copy import deepcopy

import json
import torch
import torch.nn.functional as F
from sklearn import random_projection
import numpy as np

#  ------- HELPER METHODS --------


# def homogeniety_normalization(grad_dict_or_list: [dict, list], model: torch.nn.Module, train: bool=False) -> torch.tensor:
#     """
#     Compute normalization factor to remove scaling effect group
#     :param model:
#     :param train:
#     :return:
#     """
#     norm_sum: float = 0.0
#     if isinstance(grad_dict_or_list, dict):
#         for name, param in model.named_parameters():
#             norm_sum += torch.dot(grad_dict_or_list[name].view(-1).detach(),
#                                   grad_dict_or_list[name].view(-1).detach()) \
#                         / (1e-5 + torch.norm(param, p='fro').detach() ** 2)
#     elif isinstance(grad_dict_or_list, list):
#         count: int = 0
#         for name, param in model.named_parameters():
#             if count < len(grad_dict_or_list):
#                 norm_sum += torch.dot(grad_dict_or_list[count].view(-1).detach(),
#                                       grad_dict_or_list[count].view(-1).detach()) \
#                             / (1e-5 + torch.norm(param, p='fro').detach() ** 2)
#                 count += 1
#     return 0

def riemann_norm(grad_dict_or_list: [dict, list], model: torch.nn.Module, train: bool=False) -> float:
    """
    Metric from "A Scale Invariant Flatness Measure for Deep Network Minima"
    The metric helps normalize away a scale invariant "nuisance" of neural networks
    such as where ||a W_1|| + || 1/a W_2|| != ||W_1|| + ||W_2|| even though
    a W_1 \times 1/a W_2 = W_1 \times W_2 (if a!=1 and a!=0)
    URL: https://arxiv.org/pdf/1902.02434.pdf

    :param grad_dict: Dictionary of gradients with layer names as keys
    :param model: Pytorch model
    :return: Float of desired norm
    """

    if not train:
        norm_sum: float = 0.0
        if isinstance(grad_dict_or_list, dict):
            for name, param in model.named_parameters():
                norm_sum += torch.dot(grad_dict_or_list[name].view(-1).detach(),
                                      grad_dict_or_list[name].view(-1).detach()) \
                            / (1e-5 + torch.norm(param, p='fro').detach()**2)
        elif isinstance(grad_dict_or_list, list):
            count: int = 0
            for name, param in model.named_parameters():
                if count < len(grad_dict_or_list):
                    norm_sum += torch.dot(grad_dict_or_list[count].view(-1).detach(),
                                          grad_dict_or_list[count].view(-1).detach()) \
                                / (1e-5 + torch.norm(param, p='fro').detach()**2)
                    count += 1
        return float(torch.sqrt(norm_sum))
    else:

        norm_sum: torch.tensor = 0.0
        if isinstance(grad_dict_or_list, dict):
            for name, param in model.named_parameters():
                norm_sum += torch.dot(grad_dict_or_list[name].view(-1),
                                      grad_dict_or_list[name].view(-1)) \
                            / (1e-5 + torch.norm(param, p='fro') ** 2)
        elif isinstance(grad_dict_or_list, list):
            count: int = 0
            for name, param in model.named_parameters():
                if count < len(grad_dict_or_list):
                    norm_sum += torch.dot(grad_dict_or_list[count].view(-1),
                                          grad_dict_or_list[count].view(-1)) \
                                / (1e-5 + torch.norm(param, p='fro') ** 2)
                    count += 1
        return torch.sqrt(norm_sum)


def grad_batch_correlation(x: torch.Tensor, model: torch.nn.Module, loss, dataloader: Iterable, device,
                           batch_select_freq: int) -> float:
    """ Calculating gradient correlation to passed gradient "x"
        Inspired by: "Stiffness: A New Perspective on Generalization in Neural Networks"
        (https://arxiv.org/abs/1901.09491)

    :param x: specific vector to calculate correlation for
    :param model:
    :param loss:
    :param dataloader:
    :param device:
    :param batch_select_freq:
    :return:
    """
    count_x = 0
    x /= torch.norm(x)  # normalizing x
    x_corr = 0.0

    with torch.enable_grad():
        for batch_idx, (data, target) in enumerate(dataloader):
            if batch_idx % batch_select_freq == 0:
                data, target = data.to(device), target.to(device)
                pred_target = model(data)
                task_loss = loss(pred_target, target, reduction="mean")
                model.zero_grad()
                # Get local gradient for batch
                grad_tuple = torch.autograd.grad(
                    task_loss, model.parameters(), create_graph=True,
                    retain_graph=True)
                grad_vec = torch.cat([g.contiguous().view(-1) for g in grad_tuple], dim=0).detach()
                x_corr += torch.dot(x, grad_vec / torch.norm(grad_vec))
                count_x += 1
    return float(x_corr) / count_x


def get_batch_grad_vecs(model: torch.nn.Module, loss: Callable, dataloader: Iterable, device, batch_select_freq: int,
                        normalize: bool=True) -> List[torch.tensor]:
    """
    Collects a list of batch gradient vectors and moves them to CPU to avoid GPU memory bottlenecks
    :param model:
    :param loss:
    :param dataloader:
    :param device:
    :param batch_select_freq:
    :param normalize:
    :return:
    """
    grad_vec_list = []
    for batch_idx, (data, target) in enumerate(dataloader):
        if batch_idx % batch_select_freq == 0:
            model.zero_grad()
            data, target = data.to(device), target.to(device)
            # Calculate loss
            # Get Jacobian of model prediction "dy/dtheta" (gradient) for selected batch
            pred_target = model(data)
            task_loss = loss(pred_target, target, reduction="mean")

            # Get Jacobian (gradient) for selected batch
            grad_tuple = torch.autograd.grad(task_loss, model.parameters())
            grad_vec = torch.cat([g.contiguous().view(-1) for g in grad_tuple], dim=0).detach()

            if normalize:
                # Normalize batch gradient
                grad_vec /= torch.norm(grad_vec)
            grad_vec_list.append(grad_vec.cpu())
    return grad_vec_list


def tk_loop(model: torch.nn.Module, transformer: torch.Tensor, data: torch.Tensor, i: int)\
        -> Tuple[torch.tensor, float]:
    """
    Inner loop of tangent kernel computation. Computes the gradient of the model prediction w.r.t. parameters
    and most importantly projects the gradients to a lower dimensional space via Sparse projection
    :param model:
    :param transformer: Sparse projection matrix
    :param data:
    :param i:
    :return:
    """
    with torch.enable_grad():
        model.zero_grad()
        data_i = data.reshape((1,) + data.size())[:, i]
        pred_target_i = model(data_i)

        # Get Jacobian of loss w.r.t. model prediction
        if transformer[0] is not None:
            jacobi = torch.zeros((pred_target_i.size(1), transformer.size(0)))
        else:
            jacobi = torch.zeros((pred_target_i.size(1), transformer[1]))
        grad_y = torch.zeros_like(pred_target_i)
        for j in range(pred_target_i.size(1)):
            grad_y[0, j] = 1.
            pred_grad_tuple = torch.autograd.grad(pred_target_i, model.parameters(), grad_outputs=grad_y,
                                                  create_graph=True, retain_graph=True)
            pred_grad_vec = torch.cat([g.contiguous().view(-1) for g in pred_grad_tuple], dim=0).detach()
            if transformer[0] is not None:
                jacobi[j] = torch.squeeze(torch.sparse.mm(transformer, pred_grad_vec[:, None]))
            else:
                jacobi[j] = pred_grad_vec
            grad_y[0, j] = 0.
    return jacobi.cpu(), float(torch.norm(pred_target_i.detach()))


def tk_riemann_norm_loop(model: torch.nn.Module, transformer: torch.Tensor, data: torch.Tensor, i: int)\
        -> Tuple[torch.tensor, float]:
    """
    Inner loop of tangent kernel computation taking into account Riemann norm invariance
    from:
     Method: "A Scale Invariant Flatness Measure for Deep Network Minima"
        URL: https://arxiv.org/abs/1902.02434
    Computes the gradient of the model prediction w.r.t. parameters
    and most importantly projects the gradients to a lower dimensional space via Sparse projection
    :param model:
    :param transformer: Sparse projection matrix
    :param data:
    :param i:
    :return:
    """
    with torch.enable_grad():
        model.zero_grad()
        data_i = data.reshape((1,) + data.size())[:, i]
        pred_target_i = model(data_i)

        # Get Jacobian of loss w.r.t. model prediction
        jacobi = torch.zeros((pred_target_i.size(1), transformer.size(0)))
        grad_y = torch.zeros_like(pred_target_i)
        for j in range(pred_target_i.size(1)):
            grad_y[0, j] = 1.
            pred_grad_list = list(torch.autograd.grad(pred_target_i, model.parameters(), grad_outputs=grad_y,
                                                      create_graph=True, retain_graph=True))
            pred_grad_list = [pred_grad_list[i].detach() * torch.norm(param).detach() **2
                              for (i, param) in enumerate(model.parameters())]
            # Rescaling of Euclidean tangent vector to Riemannian tangent vector

            rho = riemann_norm(pred_grad_list, model)
            pred_grad_vec = torch.cat([g.contiguous().view(-1) for g in pred_grad_list], dim=0).detach()
            pred_grad_vec /= rho

            jacobi[j] = torch.squeeze(torch.sparse.mm(transformer, pred_grad_vec[:, None]))

            grad_y[0, j] = 0.
    return jacobi.cpu(), float(torch.norm(pred_target_i.detach()))


def Gvp_loop(model: torch.nn.Module, loss: Callable, grad_vec: dict,
             data: torch.Tensor, target: torch.Tensor) -> dict:
    """

    :param model:
    :param loss:
    :param grad_vec:
    :param data:
    :param target:
    :return:
    """
    with torch.enable_grad():
        pred_target = model(data)
        task_loss = loss(pred_target, target, reduction="mean")

        # Get Jacobian of loss w.r.t. model prediction
        loss_grad_tuple = torch.autograd.grad(task_loss, pred_target, create_graph=True)
        loss_grad_vec = torch.cat([g.contiguous().view(-1) for g in loss_grad_tuple], dim=0)
        loss_grad_tuple = None

        # Get Hessian of loss w.r.t. model prediction but right multiply by "dy/dtheta"
        L_hess = []
        grad_y = torch.zeros_like(loss_grad_vec)
        for j in range(len(loss_grad_vec)):
            grad_y[j] = 1.
            loss_pred_tuple = torch.autograd.grad(loss_grad_vec, pred_target, grad_outputs=grad_y, retain_graph=True)
            loss_pred_vec = torch.cat([g.contiguous().view(-1) for g in loss_pred_tuple], dim=0).detach()
            loss_pred_tuple = None
            grad_y[j] = 0.

            L_hess.append(loss_pred_vec)
            loss_pred_vec = None
        L_hess = torch.stack(L_hess, dim=0)

        jacobi = dict()
        for name, param in model.named_parameters():
            jacobi[name] = []
        out_vec_dot = []
        out_vec_temp = 0
        grad_y = torch.zeros_like(pred_target)
        # TODO: this gets too large
        for j in range(pred_target.size(1)):
            grad_y[0, j] = 1.
            pred_grad_tuple = torch.autograd.grad(pred_target, model.parameters(), grad_outputs=grad_y,
                                                  retain_graph=True)
            itr2 = 0
            for name, param in model.named_parameters():
                # Special dot product - taking to account scaling invariance
                out_vec_temp += torch.dot(grad_vec[name].view(-1).detach(), pred_grad_tuple[itr2].view(-1)) * torch.norm(param).detach() ** 2
                jacobi[name].append(pred_grad_tuple[itr2])
                itr2 += 1
            out_vec_dot.append(out_vec_temp)
            # pred_grad_vec = torch.cat([g.contiguous().view(-1) for g in pred_grad_tuple], dim=0).detach()
            # jacobi.append(pred_grad_vec)
        pred_grad_tuple = None
        for name, param in model.named_parameters():
            jacobi[name] = torch.stack(jacobi[name], dim=0)
        out_vec_dot = torch.stack(out_vec_dot, dim=0)

        # out = torch.mm(jacobi, grad_vec[:, None])
        out = torch.squeeze(torch.mm(L_hess, out_vec_dot[:, None]))
        L_hess = None
        y = dict()
        for name, param in model.named_parameters():
            n_dims = len(jacobi[name].size())
            permutation_dims = tuple(range(1, n_dims)) + (0,)
            y[name] = torch.matmul(jacobi[name].permute(permutation_dims), out).detach()
            assert y[name].size() == param.size()
        jacobi = None
    return y


def inner_sensitivity_update_loop(model: torch.nn.Module, transformer: torch.tensor, data: torch.Tensor, i: int)\
        -> torch.Tensor:
    """
    Inner-loop of parameter-gradient computation of model output:
        Projects gradients to lower dimensional subspace via Sparse projection matrix
        Then accumulates parameter-gradients of model output
    :param model:
    :param transformer: Sparse projection matrix tensor
    :param data:
    :param i:
    :return: Torch tensor of parameter gradients of model output
    (projected to lower dimensional parameter subspace)
    """
    with torch.enable_grad():
        data_i = data.reshape((1,) + data.size())[:, i]
        pred_target_i = model(data_i)
        # pred_target_i = F.log_softmax(model(data_i), dim=1)

        # Get Jacobian of loss w.r.t. model prediction
        jacobi = torch.zeros((pred_target_i.size(1), transformer.size(0)))
        grad_y = torch.zeros_like(pred_target_i)
        for j in range(pred_target_i.size(1)):
            grad_y[0, j] = 1.
            pred_grad_tuple = torch.autograd.grad(pred_target_i, model.parameters(), grad_outputs=grad_y,
                                                  create_graph=True, retain_graph=True)
            pred_grad_vec = torch.cat([g.contiguous().view(-1) for g in pred_grad_tuple], dim=0).detach()
            jacobi[j] = torch.squeeze(torch.sparse.mm(transformer, pred_grad_vec[:, None]))
            grad_y[0, j] = 0.
    return jacobi


def hess_vp_loop(model: torch.nn.Module, loss: Callable, grad_vec: torch.Tensor, data: torch.Tensor,
                 target: torch.Tensor, i: int) -> float:
    """
    Inner loop for one data point with index 'i' of Hessian-vector product computation
    :param model:
    :param loss:
    :param grad_vec:
    :param data:
    :param target:
    :param i:
    :return:
    """
    with torch.enable_grad():
        data_i = data.reshape((1,) + data.size())[:, i]
        pred_target_i = model(data_i)
        target_i = target[None, i]
        task_loss_i = loss(pred_target_i, target_i, reduction="none")

        # Get local gradient for individual datapoint
        grad_tuple_i = torch.autograd.grad(
            task_loss_i, model.parameters(), create_graph=True
        )
        grad_vec_i = torch.cat([g.contiguous().view(-1) for g in grad_tuple_i], dim=0)

        # Get local Hessian for selected batch, right and left multiply by grad_vec gradient
        hess_v_tuple = torch.autograd.grad(grad_vec_i, model.parameters(), grad_outputs=grad_vec, only_inputs=True)
        hess_v_vec = torch.cat([g.contiguous().view(-1) for g in hess_v_tuple], dim=0).detach()

    # Measuring curvature along gradient direction
    return float(torch.dot(grad_vec, hess_v_vec))


def get_avg_grad(model: torch.nn.Module, loss: Callable, dataloader: Iterable, device, batch_select_freq: int=50,
                 riemann_bool=False) -> Tuple[torch.tensor, float]:
    """
    Computes average gradient of loss w.r.t. parameters across several batches
    Takes every 'batch_select_freq'-th batch of dataloader
    :param model:
    :param loss:
    :param dataloader:
    :param device:
    :param batch_select_freq: Batch selection frequency
    :param normed: Bool whether to normalize average gradient to unit norm (L2)
    :return: Tuple of Tensor of average gradient and float of norm of this average gradient
    """
    count = 0
    grad_vec = torch.cat([torch.zeros_like(g).contiguous().view(-1) for g in model.parameters()]).detach()
    # grad_vec = dict()
    # for name, g in model.named_parameters():
    #     grad_vec[name] = torch.zeros_like(g)

    with torch.enable_grad():
        for batch_idx, (data, target) in enumerate(dataloader):
            if batch_idx % batch_select_freq == 0:
                grad_vec_temp = dict()
                for name, g in model.named_parameters():
                    grad_vec_temp[name] = torch.zeros_like(g)

                data, target = data.to(device), target.to(device)

                model.zero_grad()
                model.train()
                pred_target = model(data)
                task_loss = loss(pred_target, target, reduction="mean")
                task_loss.backward()
                if riemann_bool:
                    for name, g in model.named_parameters():
                        grad_vec_temp[name] += g.grad.detach() * torch.norm(g).detach()
                else:
                    for name, g in model.named_parameters():
                        grad_vec_temp[name] += g.grad.detach()

                # if riemann_bool:
                    # rho = riemann_norm(grad_vec_temp, model)
                    # grad_vec += torch.cat([grad_vec_temp[name].view(-1) for name, g in model.named_parameters()],
                    #                       dim=0).detach() #/ rho
                # else:
                grad_vec += torch.cat([grad_vec_temp[name].view(-1) for name, g in model.named_parameters()],
                                      dim=0).detach()
                # torch.cat([g.grad.contiguous().view(-1) for g in model.parameters()], dim=0).detach()
                count += 1
    grad_vec /= count
    grad_vec_norm_sum = torch.norm(grad_vec).detach()

    return grad_vec, float(grad_vec_norm_sum)

#  ------- HELPER METHODS END --------


#  ------- ANALYSIS METHODS MAPPING TO DICT OUTPUT --------

def model_weight_norm(model: torch.nn.Module, loss, dataloader: Iterable, device, batch_select_freq: int, **kwargs) \
        -> dict:
    """
    Registers the weight norm of each layer and the overall sum of weight norms in a dictionary
    https://arxiv.org/pdf/2001.00939.pdf
    :param model: Pytorch model
    :return: Dictionary of layer name keys and weight norm values
    """
    weight_norm_dict = dict()
    for name, param in model.named_parameters():
        weight_norm_dict[name] = float(torch.norm(param.detach(), p="fro")**2)
    weight_norm_dict["all_weight_sum"] = sum([weight_norm_dict[name] for name in weight_norm_dict])

    # By default we do not save weights, but this can be changed in the config file
    save_weights_bool = kwargs.get("save_weights", False)
    if save_weights_bool:
        weights = torch.cat([param.contiguous().view(-1) for g in model.parameters()], dim=0).detach()
        weight_norm_dict["weights_tensor"] = weights

    return weight_norm_dict


def elasticity(model: torch.nn.Module, loss: Callable, dataloader: Iterable, device, batch_select_freq: int, **kwargs) \
        -> dict:
    """
    'Elasticity' is our own concept that is meant to encompass how much the output of a model changes at a given
    time step.
    :param model:
    :param loss:
    :param dataloader:
    :param device:
    :param batch_select_freq:
    :return: Dict of elasticity statistics and the norm of average gradients across several batches
    """
    model_copy = deepcopy(model)
    elasticity_sum: float = 0.0
    elasticity_norm_sum: float = 0.0
    elasticity_normed_grad: float = 0.0
    n_grads = kwargs.get("n_grads", 50)
    n_classes = kwargs.get("n_classes", 50)
    elasticity_classes = torch.zeros((n_grads * n_classes,))
    elasticity_individual = torch.zeros((n_grads * n_classes,))
    count: int = 0
    count_ij: int = 0
    grad_vec, grad_vec_norm_sum = get_avg_grad(model_copy, loss, dataloader, device)

    for batch_idx, (data, target) in enumerate(dataloader):
        if count_ij >= n_grads:
            break
        data, target = data.to(device), target.to(device)

        model.zero_grad()
        model.eval()
        pred_target = model(data)
        pred_target = pred_target.view(-1)

        grad_y = torch.zeros_like(pred_target.view(-1))
        for j in range(len(grad_y)):
            grad_y[j] = 1.
            elasticity_classes[count_ij] = target[j // n_classes]
            pred_grad_tuple = torch.autograd.grad(pred_target.view(-1), model.parameters(), grad_outputs=grad_y,
                                                  retain_graph=True)
            pred_grad_vec = torch.cat([g.contiguous().view(-1) for g in pred_grad_tuple], dim=0).detach()
            elasticity_sum += torch.dot(pred_grad_vec, grad_vec) / data.size(0)
            elasticity_norm_sum += torch.norm(torch.dot(pred_grad_vec, grad_vec)) / data.size(0)
            elasticity_normed_grad += torch.norm(torch.dot(pred_grad_vec, grad_vec) / grad_vec_norm_sum) / data.size(0)
            elasticity_individual[count_ij] = torch.norm(torch.dot(pred_grad_vec, grad_vec))
            count_ij += 1
            grad_y[j] = 0.
            if count_ij >= n_grads:
                break

        count += 1
    elasticity_sum = torch.sum(elasticity_sum / count)
    elasticity_norm_sum /= count
    elasticity_normed_grad /= count
    return {"elasticity_sum": float(elasticity_sum),
            "elasticity_normed_grad_sum": float(elasticity_sum) / float(grad_vec_norm_sum),
            "elasticity_norm_sum": float(elasticity_norm_sum),
            "elasticity_norm_sum_normed_grad": float(elasticity_normed_grad),
            "avg_grad_norm_sum": float(grad_vec_norm_sum),
            "elasticity_classes": elasticity_classes,
            "elasticity_individual": elasticity_individual,
            }


def elasticity_riemann_norm(model: torch.nn.Module, loss: Callable, dataloader: Iterable, device, batch_select_freq: int, **kwargs) \
        -> dict:
    """
    'Elasticity' is our own concept that is meant to encompass how much the output of a model changes at a given
    time step.
    In this function we also take into account insights from:
    Method: "A Scale Invariant Flatness Measure for Deep Network Minima"
    URL: https://arxiv.org/abs/1902.02434
    :param model:
    :param loss:
    :param dataloader:
    :param device:
    :param batch_select_freq:
    :return: Dict of elasticity statistics and the norm of average gradients across several batches
    """
    model_copy = deepcopy(model)
    elasticity_sum: float = 0.0
    elasticity_norm_sum: float = 0.0
    elasticity_normed_grad: float = 0.0
    n_grads = kwargs.get("n_grads", 50)
    n_classes = kwargs.get("n_classes", 50)
    elasticity_classes = torch.zeros((n_grads * n_classes,))
    elasticity_individual = torch.zeros((n_grads * n_classes,))
    count: int = 0
    count_ij: int = 0
    grad_vec, grad_vec_norm_sum = get_avg_grad(model_copy, loss, dataloader, device, riemann_bool=True)

    for batch_idx, (data, target) in enumerate(dataloader):
        if count_ij >= n_grads:
            break
        data, target = data.to(device), target.to(device)

        model.zero_grad()
        model.eval()
        pred_target = model(data)
        pred_target = pred_target.view(-1)

        grad_y = torch.zeros_like(pred_target.view(-1))
        for j in range(len(grad_y)):
            grad_y[j] = 1.
            elasticity_classes[count_ij] = target[j // n_classes]
            # pred_grad_tuple = torch.autograd.grad(pred_target.view(-1), model.parameters(), grad_outputs=grad_y,
            #                                       retain_graph=True)
            pred_grad_list = list(torch.autograd.grad(pred_target, model.parameters(), grad_outputs=grad_y,
                                                      create_graph=True, retain_graph=True))

            pred_grad_list = [pred_grad_list[i] * torch.norm(param).detach() ** 2
                              for i, param in enumerate(model.parameters())]
            # print(f"pred grad list: {pred_grad_list[0]}")
            rho = riemann_norm(pred_grad_list, model)
            pred_grad_vec = torch.cat([g.contiguous().view(-1) for g in pred_grad_list], dim=0).detach()
            pred_grad_vec /= rho
            # print(f"pred_grad_vec {pred_grad_vec}")
            # print(f"grad vec: {grad_vec}")
            # pred_grad_vec = torch.cat([g.contiguous().view(-1) for g in pred_grad_tuple], dim=0).detach()
            elasticity_sum += torch.dot(pred_grad_vec, grad_vec) / data.size(0)
            elasticity_norm_sum += torch.norm(torch.dot(pred_grad_vec, grad_vec)) / data.size(0)
            elasticity_normed_grad += torch.norm(torch.dot(pred_grad_vec, grad_vec) / grad_vec_norm_sum) / data.size(0)
            elasticity_individual[count_ij] = torch.norm(torch.dot(pred_grad_vec, grad_vec))
            count_ij += 1
            grad_y[j] = 0.
            if count_ij >= n_grads:
                break

        count += 1
    elasticity_sum = torch.sum(elasticity_sum / count)
    elasticity_norm_sum /= count
    elasticity_normed_grad /= count
    return {"elasticity_sum": float(elasticity_sum),
            "elasticity_normed_grad_sum": float(elasticity_sum) / float(grad_vec_norm_sum),
            "elasticity_norm_sum": float(elasticity_norm_sum),
            "elasticity_norm_sum_normed_grad": float(elasticity_normed_grad),
            "avg_grad_norm_sum": float(grad_vec_norm_sum),
            "elasticity_classes": elasticity_classes,
            "elasticity_individual": elasticity_individual,
            }


def plasticity(model: torch.nn.Module, loss, dataloader: Iterable, device, batch_select_freq: int, **kwargs) -> dict:
    """
        'Plasticity' is our own concept that is meant to encompass how much the paramter-derivative
        of a model changes at a given time step.
        :param model:
        :param loss:
        :param dataloader:
        :param device:
        :param batch_select_freq:
        :return: Dict of plasticity statistics
        """
    n_grads = kwargs.get("n_grads", 50)
    n_classes = kwargs.get("n_classes", 50)
    plasticity_classes = torch.zeros((n_grads * n_classes,))
    plasticity_individual = torch.zeros((n_grads * n_classes,))

    model_copy = deepcopy(model)
    params = torch.cat([g.contiguous().view(-1) for g in model.parameters()][:-1], dim=0).detach()
    plasticity_sum = torch.zeros_like(params)
    plasticity_norm_sum: float = 0.0
    plasticity_dot_sum: float = 0.0
    plasticity_dot_sum_norm: float = 0.0
    plasticity_dot_sum_cos_dist: float = 0.0
    count: int = 0
    grad_vec, grad_vec_norm_sum = get_avg_grad(model_copy, loss, dataloader, device)
    count_ij: int = 0
    for batch_idx, (data, target) in enumerate(dataloader):

        if count_ij >= n_grads:
            break
        data, target = data.to(device), target.to(device)

        model.zero_grad()
        model.eval()
        pred_target = model(data)
        pred_target = pred_target.view(-1)

        grad_y = torch.zeros_like(pred_target)
        for j in range(len(grad_y)):
            grad_y[j] = 1.
            plasticity_classes[count_ij] = target[j // n_classes]
            pred_grad_tuple = torch.autograd.grad(pred_target, model.parameters(), grad_outputs=grad_y,
                                                  create_graph=True, retain_graph=True)
            # The last parameters are not used for second-derivative calculation, since some of the parameters
            # are not present in graph, skip these instead
            pred_grad_vec = torch.cat([g.contiguous().view(-1) for g in pred_grad_tuple], dim=0)
            # get second derivative:
            pred_grad_grad_tuple = torch.autograd.grad(pred_grad_vec, model.parameters(), grad_outputs=grad_vec,
                                                       retain_graph=True, allow_unused=True)
            # Since the last layer is not used by all parameters, its derivative is "None" or "unused"
            pred_grad_grad_vec = torch.cat([g.contiguous().view(-1) for g in pred_grad_grad_tuple[:-1]],
                                           dim=0).detach()
            gg_size = pred_grad_grad_vec.size(0)
            plasticity_sum += pred_grad_grad_vec / data.size(0)
            plasticity_norm_sum += torch.norm(pred_grad_grad_vec) / data.size(0)
            plasticity_dot_sum += torch.dot(pred_grad_grad_vec, pred_grad_vec[:gg_size].detach()) / data.size(0)
            plasticity_dot_sum_norm += torch.norm(torch.dot(pred_grad_grad_vec,
                                                            pred_grad_vec[:gg_size].detach())) / data.size(0)
            plasticity_dot_sum_cos_dist += torch.dot(pred_grad_grad_vec, pred_grad_vec[:gg_size].detach()) \
                                           / torch.norm(pred_grad_grad_vec) \
                                           / torch.norm(pred_grad_vec[:gg_size].detach()) / data.size(0) / 10.0
            plasticity_individual[count_ij] = torch.norm(pred_grad_grad_vec)
            count_ij += 1
            if count_ij >= n_grads:
                break
            grad_y[j] = 0.
        count += 1
    plasticity_sum = torch.sum(plasticity_sum / count)
    plasticity_norm_sum /= count
    plasticity_dot_sum /= count
    plasticity_dot_sum_norm /= count
    plasticity_dot_sum_cos_dist /= count
    return {"plasticity_sum": float(plasticity_sum), "plasticity_norm_sum": float(plasticity_norm_sum),
            "plasticity_dot_sum": float(plasticity_dot_sum), "plasticity_dot_sum_norm": float(plasticity_dot_sum_norm),
            "plasticity_dot_sum_cos_dist": float(plasticity_dot_sum_cos_dist),
            "plasticity_sum_normed_grad": float(plasticity_sum) / float(grad_vec_norm_sum),
            "plasticity_norm_sum_normed_grad": float(plasticity_norm_sum) / float(grad_vec_norm_sum),
            "plasticity_dot_sum_normed_grad": float(plasticity_dot_sum) / float(grad_vec_norm_sum),
            "plasticity_dot_sum_norm_normed_grad": float(plasticity_dot_sum_norm) / float(grad_vec_norm_sum),
            "plasticity_classes": plasticity_classes,
            "plasticity_individual": plasticity_individual,
            }


def plasticity_riemann_norm(model: torch.nn.Module, loss, dataloader: Iterable, device, batch_select_freq: int, **kwargs) -> dict:
    """
        'Plasticity' is our own concept that is meant to encompass how much the paramter-derivative
        of a model changes at a given time step.
        In this function we also take into account insights from:
        Method: "A Scale Invariant Flatness Measure for Deep Network Minima"
        URL: https://arxiv.org/abs/1902.02434
        :param model:
        :param loss:
        :param dataloader:
        :param device:
        :param batch_select_freq:
        :return: Dict of plasticity statistics
        """
    n_grads = kwargs.get("n_grads", 50)
    n_classes = kwargs.get("n_classes", 50)
    plasticity_classes = torch.zeros((n_grads * n_classes,))
    plasticity_individual = torch.zeros((n_grads * n_classes,))

    model_copy = deepcopy(model)
    params = torch.cat([g.contiguous().view(-1) for g in model.parameters()][:-1], dim=0).detach()
    plasticity_sum = torch.zeros_like(params)
    plasticity_norm_sum: float = 0.0
    plasticity_dot_sum: float = 0.0
    plasticity_dot_sum_norm: float = 0.0
    plasticity_dot_sum_cos_dist: float = 0.0
    count: int = 0
    grad_vec, grad_vec_norm_sum = get_avg_grad(model_copy, loss, dataloader, device, riemann_bool=True)
    count_ij: int = 0
    for batch_idx, (data, target) in enumerate(dataloader):

        if count_ij >= n_grads:
            break
        data, target = data.to(device), target.to(device)

        model.zero_grad()
        model.eval()
        pred_target = model(data)
        pred_target = pred_target.view(-1)

        grad_y = torch.zeros_like(pred_target)
        for j in range(len(grad_y)):
            grad_y[j] = 1.
            plasticity_classes[count_ij] = target[j // n_classes]
            # pred_grad_tuple = torch.autograd.grad(pred_target, model.parameters(), grad_outputs=grad_y,
            #                                       create_graph=True, retain_graph=True)
            # The last parameters are not used for second-derivative calculation, since some of the parameters
            # are not present in graph, skip these instead

            pred_grad_list = list(torch.autograd.grad(pred_target.view(-1), model.parameters(), grad_outputs=grad_y,
                                                      create_graph=True, retain_graph=True))
            pred_grad_list = [pred_grad_list[i] * torch.norm(param).detach() ** 2 for i, param in enumerate(model.parameters())]
            rho = riemann_norm(pred_grad_list, model)
            pred_grad_vec = torch.cat([g.contiguous().view(-1) for g in pred_grad_list], dim=0) #.detach()
            pred_grad_vec /= rho

            # pred_grad_vec = torch.cat([g.contiguous().view(-1) for g in pred_grad_tuple], dim=0)
            # get second derivative:
            pred_grad_grad_list = list(torch.autograd.grad(pred_grad_vec, model.parameters(), grad_outputs=grad_vec,
                                                       retain_graph=True, allow_unused=True))[:-1]

            # pred_grad_grad_list = list(torch.autograd.grad(pred_target.view(-1), model.parameters(), grad_outputs=grad_y,
            #                                           create_graph=True, retain_graph=True))
            # pred_grad_grad_list = [pred_grad_grad_list[i].contiguous() * torch.norm(param).detach()
            #                        for i, param in enumerate(model.parameters()) if i < (len(pred_grad_grad_list))]
            # rho = riemann_norm(pred_grad_grad_list, model)
            pred_grad_grad_vec = torch.cat([g.contiguous().view(-1) for g in pred_grad_grad_list], dim=0).detach()
            # pred_grad_grad_vec /= rho
            # Since the last layer is not used by all parameters, its derivative is "None" or "unused"
            # pred_grad_grad_vec = torch.cat([g.contiguous().view(-1) for g in pred_grad_grad_tuple[:-1]],
            #                                dim=0).detach()
            gg_size = pred_grad_grad_vec.size(0)
            plasticity_sum += pred_grad_grad_vec / data.size(0)
            plasticity_norm_sum += torch.norm(pred_grad_grad_vec) / data.size(0)
            plasticity_dot_sum += torch.dot(pred_grad_grad_vec, pred_grad_vec[:gg_size].detach()) / data.size(0)
            plasticity_dot_sum_norm += torch.norm(torch.dot(pred_grad_grad_vec,
                                                            pred_grad_vec[:gg_size].detach())) / data.size(0)
            plasticity_dot_sum_cos_dist += torch.dot(pred_grad_grad_vec, pred_grad_vec[:gg_size].detach()) \
                                           / torch.norm(pred_grad_grad_vec) \
                                           / torch.norm(pred_grad_vec[:gg_size].detach()) / data.size(0) / 10.0
            plasticity_individual[count_ij] = torch.norm(pred_grad_grad_vec)
            count_ij += 1
            if count_ij >= n_grads:
                break
            grad_y[j] = 0.
        count += 1
    plasticity_sum = torch.sum(plasticity_sum / count)
    plasticity_norm_sum /= count
    plasticity_dot_sum /= count
    plasticity_dot_sum_norm /= count
    plasticity_dot_sum_cos_dist /= count
    return {"plasticity_sum": float(plasticity_sum), "plasticity_norm_sum": float(plasticity_norm_sum),
            "plasticity_dot_sum": float(plasticity_dot_sum), "plasticity_dot_sum_norm": float(plasticity_dot_sum_norm),
            "plasticity_dot_sum_cos_dist": float(plasticity_dot_sum_cos_dist),
            "plasticity_sum_normed_grad": float(plasticity_sum) / float(grad_vec_norm_sum),
            "plasticity_norm_sum_normed_grad": float(plasticity_norm_sum) / float(grad_vec_norm_sum),
            "plasticity_dot_sum_normed_grad": float(plasticity_dot_sum) / float(grad_vec_norm_sum),
            "plasticity_dot_sum_norm_normed_grad": float(plasticity_dot_sum_norm) / float(grad_vec_norm_sum),
            "plasticity_classes": plasticity_classes,
            "plasticity_individual": plasticity_individual,
            }


def power_method_rho_hessian(model: torch.nn.Module, loss: Callable, dataloader: Iterable, device,
                             batch_select_freq: int, **kwargs) -> dict:
    """ Method: "A Scale Invariant Flatness Measure for Deep Network Minima"
        URL: https://arxiv.org/abs/1902.02434
        Matrix power method iteration to get largest eigenvalue and associated eigenvector
        Source: https://people.inf.ethz.ch/arbenz/ewp/Lnotes/lsevp.pdf , Algorithm 7.2
        combined with above scaling-invariant metric

        - We start with an initial random vector "x"
        - Then we iterate over batches to calculate y = 1/n \sum_i^n H_i x
        - We normalize y with scale invariant norm
        - We compare y to x
        - For the next iteration, we set x <- y and start again

    :param model:
    :param loss:
    :param dataloader:
    :param device:
    :param batch_select_freq:
    :return:
    """
    max_iter: int = 30
    # Initialize power iteration vector randomly
    x = torch.cat([torch.rand(param.size()).contiguous().view(-1) for param in model.parameters()],
                  dim=0).detach().to(device)

    rho_out = 0
    for itr in range(max_iter):
        # y to accumulate outputs across batches
        y = torch.cat([torch.zeros(param.size()).contiguous().view(-1) for param in model.parameters()],
                      dim=0).detach().to(device)

        count = 0
        for batch_idx, (data, target) in enumerate(dataloader):

            if batch_idx % 50 == 0:
                data, target = data.to(device), target.to(device)
                with torch.enable_grad():
                    model.zero_grad()
                    pred_target = model(data)
                    task_loss = loss(pred_target, target, reduction="mean")

                    # Get local gradient for individual datapoint
                    grad_tuple_i = torch.autograd.grad(task_loss, model.parameters(), create_graph=True)
                    grad_vec_i = torch.cat([g.contiguous().view(-1) for g in grad_tuple_i], dim=0)

                    out = torch.dot(x, grad_vec_i)
                    # Get local Hessian for selected batch, right and left multiply by grad_vec gradient
                    y_tuple = torch.autograd.grad(out, model.parameters(), only_inputs=True)
                    y_temp = torch.cat([g.contiguous().view(-1) for g in y_tuple], dim=0)
                    y += y_temp
                    count += 1
        # Normalize y
        y /= count

        rho_out = torch.norm(y)
        y /= rho_out

        # Error computation
        error = torch.norm(y - x).detach()**2
        vec_norm = torch.norm(y).detach()
        x = y.data.clone()
        error = torch.sqrt(error) / vec_norm
        print(f"Error {itr}: {error}, vec_norm: {vec_norm}, rho_out: {rho_out}")
        if error < 1e-3:
            break
    # Also check how correlated eigenvector of largest eigenvalue is to batch gradients
    x_corr = grad_batch_correlation(x, model, loss, dataloader, device, batch_select_freq)

    return {"rho_hessian": float(rho_out), "hessian_eig_vec_grad_corr": float(x_corr)}


def power_method_rho_hessian_normed(model: torch.nn.Module, loss: Callable, dataloader: Iterable, device,
                             batch_select_freq: int, **kwargs) -> dict:
    """ Method: "A Scale Invariant Flatness Measure for Deep Network Minima"
        URL: https://arxiv.org/abs/1902.02434
        Matrix power method iteration to get largest eigenvalue and associated eigenvector
        Source: https://people.inf.ethz.ch/arbenz/ewp/Lnotes/lsevp.pdf , Algorithm 7.2
        combined with above scaling-invariant metric

        - We start with an initial random vector "x"
        - Then we iterate over batches to calculate y = 1/n \sum_i^n H_i x
        - We normalize y with scale invariant norm
        - We compare y to x
        - For the next iteration, we set x <- y and start again

    :param model:
    :param loss:
    :param dataloader:
    :param device:
    :param batch_select_freq:
    :return:
    """
    max_iter: int = 30
    x = dict()
    for name, param in model.named_parameters():
        x[name] = torch.rand(param.size()).to(device).detach()

    rho_out = 0
    for itr in range(max_iter):
        y = dict()
        for name, param in model.named_parameters():
            y[name] = torch.zeros_like(param).detach()

        count = 0
        for batch_idx, (data, target) in enumerate(dataloader):

            if batch_idx % 50 == 0:
                data, target = data.to(device), target.to(device)
                with torch.enable_grad():
                    model.zero_grad()
                    pred_target = model(data)
                    task_loss = loss(pred_target, target, reduction="mean")

                    # Get local gradient for individual datapoint
                    grad_tuple_i = torch.autograd.grad(task_loss, model.parameters(), create_graph=True)
                    grad_vec_i = torch.cat([g.contiguous().view(-1) for g in grad_tuple_i], dim=0)

                    grad_out = torch.cat([x[key].contiguous().view(-1) for key in x], dim=0).detach()
                    out = torch.dot(grad_out, grad_vec_i)
                    # Get local Hessian for selected batch, right and left multiply by grad_vec gradient
                    y_tuple = torch.autograd.grad(out, model.parameters(), only_inputs=True)
                    itr2 = 0
                    for name, param in model.named_parameters():
                        y[name] += y_tuple[itr2].detach() * torch.norm(param).detach()**2
                        itr2 += 1
                    count += 1
        # Normalize y with scaling invariant metric
        rho_out = 0
        for name, param in model.named_parameters():
            rho_out += torch.dot(y[name].view(-1).detach(), x[name].view(-1).detach()) / torch.norm(param).detach() ** 2

        rho = riemann_norm(y, model)
        for name, param in model.named_parameters():
            y[name] /= rho

        error = 0
        vec_norm = 0
        for name, param in model.named_parameters():
            error += torch.norm(y[name] - x[name]).detach()**2
            vec_norm += torch.norm(y[name]).detach()
            x[name] = y[name].data.clone()
        error = torch.sqrt(error) / vec_norm
        print(f"Error {itr}: {error}, vec_norm: {vec_norm}, rho_out: {rho_out}")
        if error < 1e-3:
            break
    # Also check how correlated eigenvector of largest eigenvalue is to batch gradients
    x = torch.cat([x[key].contiguous().view(-1) for key in x], dim=0)
    x_corr = grad_batch_correlation(x, model, loss, dataloader, device, batch_select_freq)

    return {"rho_hessian": float(rho_out), "hessian_eig_vec_grad_corr": float(x_corr)}


def individial_grad_correlation(model: torch.nn.Module, loss: Callable, dataloader: Iterable, device,
                                batch_select_freq: int, **kwargs) -> dict:
    """ Calculating gradient correlation of a low number of individual gradients "g_i"
            Inspired by: "Stiffness: A New Perspective on Generalization in Neural Networks"
            (https://arxiv.org/abs/1901.09491)

        :param model:
        :param loss:
        :param dataloader:
        :param device:
        :param batch_select_freq:
        :param n_grads: number of individual gradients to correlate against batch gradients
        :return:
        """
    count_x = 0
    ind_grads = []
    n_grads = kwargs.get("n_grads", 1)

    with torch.enable_grad():
        for batch_idx, (data, target) in enumerate(dataloader):
            if batch_idx % batch_select_freq == 0:
                data, target = data.to(device), target.to(device)
                pred_target = model(data)
                task_loss = loss(pred_target, target, reduction="none")
                model.zero_grad()
                # Get local gradient for batch
                for i in range(data.size(0)):
                    grad_tuple_i = torch.autograd.grad(
                        task_loss[i], model.parameters(), create_graph=True,
                        retain_graph=True)
                    ind_grads.append(torch.cat([g.contiguous().view(-1) for g in grad_tuple_i], dim=0).detach().cpu())
                    count_x += 1
                if count_x >= n_grads:
                    break

    ind_corr = 0
    for grad_i in ind_grads:
        # Reusing correlation function "grad_batch_correlation"
        ind_corr += grad_batch_correlation(grad_i.to(device), model, loss, dataloader, device, batch_select_freq)
    ind_corr /= len(ind_grads)
    return {"ind_corr": float(ind_corr)}


def loss_and_grad_mag(model: torch.nn.Module, loss: Callable, data: torch.Tensor, target: torch.Tensor, **kwargs) \
        -> dict:
    """Computes the magnitude of current gradient
    Arguments:

    """
    model.zero_grad()
    pred_target = model(data)
    task_loss = loss(pred_target, target, reduction="mean")

    # Get Jacobian (gradient) for selected batch
    grad_tuple = torch.autograd.grad(task_loss, model.parameters(), retain_graph=True)
    grad_vec = torch.cat([g.contiguous().view(-1) for g in grad_tuple], dim=0).detach()

    # Get Jacobian of loss w.r.t. predictions
    dldf_tuple = torch.autograd.grad(task_loss, pred_target)
    dldf_vec = torch.cat([g.contiguous().view(-1) for g in dldf_tuple], dim=0).detach()

    return {"local_task_loss": float(task_loss), "local_grad_norm": float(torch.norm(grad_vec)),
            "local_dldf_norm": float(torch.norm(dldf_vec))}


def compute_theta_jacobian_change(model: torch.nn.Module, loss: Callable, dataloader: Iterable, device,
                                  batch_select_freq: int, **kwargs) -> dict:
    """
    Computes ||df/dtheta||^2, where f is the model output and theta are the model parameters
    :param model:
    :param dataloader:
    :param device:
    :param batch_select_freq:
    :param train:
    :return:
    """
    dfdtheta_sum: float = 0.0
    with torch.enable_grad():
        for batch_idx, (data, target) in enumerate(dataloader):
            if batch_idx % batch_select_freq == 0:
                data, target = data.to(device), target.to(device)

                data.requires_grad = True
                idx, batch_size = 0, data.size(0)
                while idx < batch_size:
                    model.zero_grad()
                    data_i = data[idx].to(device)
                    pred_target0_i = F.softmax(model(data_i[None, :]), dim=1)

                    grad_y = torch.zeros_like(pred_target0_i)
                    for j in range(len(pred_target0_i)):
                        grad_y[0, j] = 1.
                        # Get local gradient for individual datapoint
                        grad_tuple_i = torch.autograd.grad(
                            pred_target0_i, model.parameters(), grad_outputs=grad_y, create_graph=True,
                            retain_graph=True)
                        grad_vec_i = torch.cat([g.contiguous().view(-1) for g in grad_tuple_i], dim=0).detach()
                        grad_tuple_i = None
                        dfdtheta_sum += torch.norm(grad_vec_i)**2
                        grad_vec_i = None

                        grad_y[0, j] = 0.
                        idx += 1
    return {"dfdtheta_sum": dfdtheta_sum}


def compute_input_jacobian_change(model: torch.nn.Module, loss: Callable, dataloader: Iterable, device,
                                  batch_select_freq: int, **kwargs) -> dict:
    """
        Computes ||df/dx||^2 and ||d^2f/dxdtheta||^2, where f is the model output and x are the input data points and
        theta are the model parameters
        :param model:
        :param dataloader:
        :param device:
        :param batch_select_freq:
        :return:
        """

    # grad_vec_list = get_batch_grad_vecs(model, loss, dataloader, device, batch_select_freq, normalize=False)
    avg_grad, avg_grad_norm = get_avg_grad(model, loss, dataloader, device)

    dJdtheta: float = 0.0
    dfdx_sum: float = 0.0
    count: int = 0

    with torch.enable_grad():
        for batch_idx, (data, target) in enumerate(dataloader):
            if batch_idx % batch_select_freq == 0:
                data, target = data.to(device), target.to(device)

                data.requires_grad = True
                for idx in range(data.size(0)):
                    data_i = data[idx].to(device)
                    pred_target0_i = F.softmax(model(data_i[None, :]), dim=1)
                    model.zero_grad()
                    grad_y = torch.zeros_like(pred_target0_i)
                    for j in range(len(pred_target0_i)):
                        grad_y[0, j] = 1.
                        dfdx_tuple_i = torch.autograd.grad(
                            pred_target0_i, data_i, grad_outputs=grad_y, create_graph=True, retain_graph=True)
                        dfdx_vec_i = torch.cat([g.contiguous().view(-1) for g in dfdx_tuple_i], dim=0).detach()
                        dfdx_sum += torch.norm(dfdx_vec_i)**2 / torch.norm((pred_target0_i + 1))**2
                        dfdx_vec_i = None
                        dfdx_tuple_i = None
                        # Get local gradient for individual datapoint
                        grad_tuple_i = torch.autograd.grad(
                            pred_target0_i, model.parameters(), grad_outputs=grad_y, create_graph=True,
                            retain_graph=True)
                        grad_vec_i = torch.cat([g.contiguous().view(-1) for g in grad_tuple_i], dim=0)

                        # count_b = 0
                        # for grad_batch_j in grad_vec_list:
                        #     grad_batch_j = grad_batch_j.to(device)
                        #     grad_batch_j.requires_grad = False
                        grad_dxdtheta_tuple_i = torch.autograd.grad(
                            grad_vec_i, data_i, grad_outputs=avg_grad, retain_graph=True, allow_unused=True)
                        grad_dxdtheta_vec_i = torch.cat([g.contiguous().view(-1) for g in grad_dxdtheta_tuple_i if g is not None],
                                                        dim=0).detach()
                            # count_b += 1
                        dJdtheta += torch.norm(grad_dxdtheta_vec_i)**2
                        grad_y[0, j] = 0.

                        count += 1
    return {"dfdx_sum": dfdx_sum / count, "dJdtheta": dJdtheta / count}


def compute_fisher_information(model: torch.nn.Module, loss: Callable, dataloader: Iterable, device,
                               batch_select_freq: int, **kwargs) -> dict:
    """
    Trace of Fisher information matrix (FIM):
        Sample from output distribution of model and compute the trace of the FFIM
    Code inspired by: https://github.com/tudor-berariu/fisher-information-matrix/blob/master/fim.py
    :param model:
    :param loss:
    :param dataloader:
    :param device:
    :param batch_select_freq:
    :param train:
    :return: Float of trace of FIM
    """
    train = kwargs.get("train_bool", False)

    with torch.enable_grad():
        fim = {name: 0 for name, param in model.named_parameters() if param.requires_grad}
        count: int = 0
        for batch_idx, (data, target) in enumerate(dataloader):
            data, target = data.to(device), target.to(device)
            pred_target0 = F.log_softmax(model(data), dim=1)
            outdx = torch.distributions.Categorical(logits=pred_target0).sample().unsqueeze(1).detach()
            samples = pred_target0.gather(1, outdx)
            idx, batch_size = 0, data.size(0)
            while idx < batch_size:
                model.zero_grad()
                torch.autograd.backward(samples[idx], retain_graph=True)
                for name, param in model.named_parameters():
                    if param.requires_grad:
                        if train:
                            fim[name] += torch.sum(param.grad * param.grad)
                        else:
                            fim[name] += torch.sum(param.grad * param.grad).detach()
                count += 1
                idx += 1
    fim = sum([fim[name] for name in fim]) / count
    return {"fim": float(fim)}


def hess_vp(model: torch.nn.Module, loss: Callable, dataloader: Iterable, device: list,
            batch_select_freq: int, **kwargs) -> dict:
    """Computes an Hessian gradient product.
    g^T H g by using the Pearlmutter trick for Hessian-vector products.
    The idea of looking at this direction is from:
    "On the Convex Behavior of Deep Neural Networks in Relation to the Layers' Width"
    URL: https://arxiv.org/abs/2001.04878
    Note: Hessian is averaged across several batch gradient directions.
    Pearlmutter, Barak A. "Fast exact multiplication by the Hessian." Neural computation 6.1 (1994): 147-160.
    Arguments:

    """
    grad_vec_list = get_batch_grad_vecs(model, loss, dataloader, device, batch_select_freq, normalize=False)

    hessian_sum: float = 0.0
    count: int = 0
    for batch_idx, (data, target) in enumerate(dataloader):
        if batch_idx % batch_select_freq == 0:
            model.zero_grad()
            data, target = data.to(device), target.to(device)

            for grad_vec in grad_vec_list:
                count += 1
                for i in range(data.size(0)):
                    # Summing across selected batches
                    hessian_sum += hess_vp_loop(model, loss, grad_vec.to(device), data, target, i) / data.size(0)
    return {"hessian_sum": float(hessian_sum / count)}


def tk_kernel_riemann_norm(model: torch.nn.Module, loss: Callable, dataloader: Iterable, device,
                           batch_select_freq: int=100, **kwargs) -> dict:
    """
    Compute a projected version of the Neural Tangent Kernel (NTK)
    See for example:
    https://papers.nips.cc/paper/9063-wide-neural-networks-of-any-depth-evolve-as-linear-models-under-gradient-descent.pdf
    for a definition
    """
    n_grads = kwargs.get("n_grads", 50)
    n_classes = kwargs.get("n_classes", 50)
    n_components = kwargs.get("n_random_components", 10000)

    tk_classes = torch.zeros((n_grads * n_classes,))
    n_params = sum([np.prod(param.size()) for param in model.parameters()])
    y_grad = torch.zeros(n_grads, n_classes, n_components)

    # Project high-dimensional gradients into 'n_components'-dimensional subspace:
    #   https://scikit-learn.org/stable/modules/generated/sklearn.random_projection.SparseRandomProjection.html
    #   #sklearn.random_projection.SparseRandomProjection
    transformer = random_projection.SparseRandomProjection(n_components=n_components, density='auto')
    transformer.fit(np.zeros((1, n_params)))
    row = torch.from_numpy(transformer.components_.tocoo().row.astype(np.int64)).to(torch.long)
    col = torch.from_numpy(transformer.components_.tocoo().col.astype(np.int64)).to(torch.long)
    edge_index = torch.stack([row, col], dim=0)
    # Presuming values are floats, can use np.int64 for dtype=int8
    val = torch.from_numpy(transformer.components_.data.astype(np.float64)).to(torch.float)
    transf = torch.sparse.FloatTensor(edge_index, val, torch.Size(transformer.components_.shape)).to(device)
    transformer = None
    # Assemble individual projected gradients
    count = 0
    pred_norm_sum = 1
    for batch_idx, (data, target) in enumerate(dataloader):
        model.zero_grad()
        data, target = data.to(device), target.to(device)
        # Loop over data points in batch
        for i in range(data.size(0)):
            if count >= n_grads:
                break
            # Summing across selected batches
            y_grad[count], pred_norm = tk_riemann_norm_loop(model, transf, data, i)
            tk_classes[count*n_classes:(count+1)*n_classes] = target[i]
            pred_norm_sum += pred_norm
            count += 1
        if count >= n_grads:
            break
    y_grad = y_grad.view(-1, n_components)

    # Compute NTK
    tk = torch.mm(y_grad, y_grad.T)

    lambdas, eigvecs = torch.lobpcg(tk, k=20)
    condition_num = float(lambdas[0]) / (float(lambdas[-1]) + 1e-6)

    return {"tangent_kernel_rnormed": tk, "tk_norm_rnormed": float(torch.norm(tk)),
            "tk_trace_rnormed": float(torch.trace(tk)), "tk_labels": tk_classes,
            "rho_tk_rnormed": float(lambdas[0]), "condition_num_rnormed": condition_num,
            "pred_norm_sum": float(pred_norm_sum), "y_grad": y_grad}


def tk_kernel(model: torch.nn.Module, loss: Callable, dataloader: Iterable, device, batch_select_freq: int=100,
              **kwargs) -> dict:
    """
    Compute a projected version of the Neural Tangent Kernel (NTK)
    See for example:
    https://papers.nips.cc/paper/9063-wide-neural-networks-of-any-depth-evolve-as-linear-models-under-gradient-descent.pdf
    for a definition
    """
    random_proj = kwargs.get("random_projection", True)
    random_proj = True
    n_grads = kwargs.get("n_grads", 50)
    n_classes = kwargs.get("n_classes", 50)
    n_components = kwargs.get("n_random_components", 10000)
    n_params = sum([np.prod(param.size()) for param in model.parameters()])

    y_grad = torch.zeros(n_grads, n_classes, n_components)

    # Project high-dimensional gradients into 'n_components'-dimensional subspace:
    #   https://scikit-learn.org/stable/modules/generated/sklearn.random_projection.SparseRandomProjection.html
    #   #sklearn.random_projection.SparseRandomProjection
    transformer = random_projection.SparseRandomProjection(n_components=n_components, density='auto')
    transformer.fit(np.zeros((1, n_params)))
    row = torch.from_numpy(transformer.components_.tocoo().row.astype(np.int64)).to(torch.long)
    col = torch.from_numpy(transformer.components_.tocoo().col.astype(np.int64)).to(torch.long)
    edge_index = torch.stack([row, col], dim=0)
    # Presuming values are floats, can use np.int64 for dtype=int8
    val = torch.from_numpy(transformer.components_.data.astype(np.float64)).to(torch.float)
    transf = torch.sparse.FloatTensor(edge_index, val, torch.Size(transformer.components_.shape)).to(device)
    transformer = None
    if not random_proj:
        transf = (None, transf.size(0))
    # Assemble individual projected gradients
    count = 0
    pred_norm_sum = 1
    for batch_idx, (data, target) in enumerate(dataloader):
        model.zero_grad()
        data, target = data.to(device), target.to(device)
        # Loop over data points in batch
        for i in range(data.size(0)):
            if count >= n_grads:
                break
            # Summing across selected batches
            y_grad[count], pred_norm = tk_loop(model, transf, data, i)
            pred_norm_sum += pred_norm
            count += 1
        if count >= n_grads:
            break
    y_grad = y_grad.view(-1, n_components)

    # Compute NTK
    tk = torch.mm(y_grad, y_grad.T)

    lambdas, eigvecs = torch.lobpcg(tk, k=20)
    condition_num = float(lambdas[0]) / (float(lambdas[-1]) + 1e-6)

    return {"tangent_kernel": tk, "tk_norm": float(torch.norm(tk)), "tk_trace": float(torch.trace(tk)),
            "rho_tk": float(lambdas[0]), "condition_num": condition_num, "pred_norm_sum": float(pred_norm_sum),
            "y_grad": y_grad}


def sensitivity_update(model: torch.nn.Module, loss: Callable, dataloader: Iterable, device,
                       batch_select_freq: int=100, **kwargs) -> dict:
    """
    Compute a projected version of parameter derivative of the model output $\nabla_\theta f(x, \theta)$
    Outputs its largest singular value and the sum of its singular values
    """
    n_grads = kwargs.get("n_grads", 50)
    n_classes = kwargs.get("n_classes", 50)
    n_components = kwargs.get("n_random_components", 10000)

    n_params = sum([np.prod(param.size()) for param in model.parameters()])
    y_grad = torch.zeros(n_grads, n_classes, n_components)

    # Project high-dimensional gradients into 'n_components'-dimensional subspace:
    #   https://scikit-learn.org/stable/modules/generated/sklearn.random_projection.SparseRandomProjection.html
    #   #sklearn.random_projection.SparseRandomProjection
    transformer = random_projection.SparseRandomProjection(n_components=n_components, density='auto')
    transformer.fit(np.zeros((1, n_params)))
    row = torch.from_numpy(transformer.components_.tocoo().row.astype(np.int64)).to(torch.long)
    col = torch.from_numpy(transformer.components_.tocoo().col.astype(np.int64)).to(torch.long)
    edge_index = torch.stack([row, col], dim=0)
    # Presuming values are floats, can use np.int64 for dtype=int8
    val = torch.from_numpy(transformer.components_.data.astype(np.float64)).to(torch.float)
    transf = torch.sparse.FloatTensor(edge_index, val, torch.Size(transformer.components_.shape)).to(device)
    transformer = None
    # Assemble individual projected gradients
    count = 0
    pred_norm_sum = 1
    for batch_idx, (data, target) in enumerate(dataloader):
        model.zero_grad()
        data, target = data.to(device), target.to(device)
        # Loop over data points in batch
        for i in range(data.size(0)):
            if count >= n_grads:
                break
            # Summing across selected batches
            y_grad[count] = inner_sensitivity_update_loop(model, transf, data, i)
            count += 1
        if count >= n_grads:
            break
    y_grad = y_grad.view(-1, n_components)

    u, s, v = torch.svd(y_grad, compute_uv=False)

    return {"largest_eig_dfdtheta": float(s[0]), "trace_dfdtheta": float(torch.sum(s)), "y_grad": y_grad}


def compute_loss(model: torch.nn.Module, loss: Callable, dataloader, device, batch_select_freq, **kwargs) -> dict:
    """
    Compute average loss and accuracy of a model on a dataset (dataloader) and a given loss
    :param model:
    :param loss:
    :param dataloader:
    :param device:
    :param batch_select_freq:
    :return:
    """
    count: int = 0
    loss_sum: float = 0.0
    correct_sum: float = 0.0
    n_points: int = 0
    predictions = []

    for batch_idx, (data, target) in enumerate(dataloader):
        model.zero_grad()
        data, target = data.to(device), target.to(device)

        # Calculate loss
        pred_target = model(data)

        if batch_idx % batch_select_freq == 0:
            predictions.append(pred_target.detach())
        max_pred = pred_target.argmax(dim=1, keepdim=True)
        loss_sum += loss(pred_target, target, reduction="mean").detach() * data.size(0)
        correct_batch = max_pred.eq(target.view_as(max_pred))
        correct_sum += correct_batch.sum().item()
        n_points += data.size(0)
        count += 1

    predictions = torch.cat([g.contiguous() for g in predictions], dim=0)
    return {"loss_mean": float(loss_sum) / n_points, "mean_accuracy": float(correct_sum) / n_points,
            "predictions": predictions}


class OnlineDigestAnalysis:
    """
    Class object for analysis requiring memory of previous gradient update
    """
    def __init__(self, batch_selection_frequency):
        self.params_before = None
        self.pred_before = []
        self.pred_grad_before = []
        self.batch_selection_frequency = batch_selection_frequency
        self.loss_before = None
        self.grad_before = None
        self.loss_pred_grad = None

    def approx_2nd_order_model(self, model: torch.nn.Module, dataloader: Iterable, device, kind: str) -> dict:
        """
        Computes change in model gradient times gradient of loss w.r.t. model at time t for the same datapoints after a gradient update
        1/N \sum_i \nabla_{model} Loss(\model(\theta_t) \dotproduct
        (\nabla_\theta Model(\theta_{t+1}) - \nabla_\theta Model(\theta_{t}))
        Approximation from: "On the Convex Behavior of Deep Neural Networks in Relation to the Layers' Width"
        URL: https://arxiv.org/abs/2001.04878
        Note: A factor "2" is missing in the linked paper
        :param model:
        :param dataloader:
        :param kind:
        :return:
        """
        assert kind == "before" or kind == "after", "Choose a valid 'kind' option."
        output = None
        params = torch.cat([g.contiguous().view(-1) for g in model.parameters()],
                           dim=0).detach()
        if kind == "before":
            self.params_before = params
            self.pred_before = []
            self.pred_grad_before = []

        count = 0
        for batch_idx, (data, target) in enumerate(dataloader):
            if batch_idx % self.batch_selection_frequency == 0:
                model.zero_grad()
                data = data.to(device)
                pred_target = model(data)

                samples, indices = torch.max(pred_target, dim=1)

                jacobi = []
                for i in range(data.size(0)):
                    pred_grad_tuple_i = torch.autograd.grad(samples[i], model.parameters(), retain_graph=True)
                    jacobi.append(torch.cat([g.contiguous().view(-1) for g in pred_grad_tuple_i], dim=0).detach().cpu())
                jacobi = torch.stack(jacobi, dim=0)

                if kind == "before":
                    self.pred_before.append(samples)
                    self.pred_grad_before.append(jacobi.cpu())

                if kind == "after":
                    diff = (params - self.params_before).cpu()
                    diff_norm = torch.norm(diff) + 1e-4
                    output_temp = 2 * (samples.cpu() - self.pred_before[count].cpu() - \
                                       torch.squeeze(torch.mm(self.pred_grad_before[count], diff[:, None]))) / diff_norm

                    if output is None:
                        output = torch.mean(output_temp)
                    else:
                        output += torch.mean(output_temp)
                count += 1
        if kind == "before":
            return {"Delta_model_grad_before": 0}
        else:
            return {"Delta_model_grad": float(output) / count}

    def approx_2nd_order_loss(self, model: torch.nn.Module, loss: Callable,
                              dataloader: Iterable, device, kind: str) -> dict:
        """
        Computes change in gradient for the same datapoints after a gradient update
        1/N \sum_i \nabla_\theta Loss(\theta_{t+1}) - \nabla_\theta Loss(\theta_{t})
        Approximation from: "On the Convex Behavior of Deep Neural Networks in Relation to the Layers' Width"
        URL: https://arxiv.org/abs/2001.04878
        Note: A factor "2" is missing in the linked paper
        :param model:
        :param loss:
        :param dataloader:
        :param kind:
        :return:
        """
        assert kind == "before" or kind == "after", "Choose a valid 'kind' option."
        output = None
        params = torch.cat([g.contiguous().view(-1) for g in model.parameters()],
                           dim=0).detach()
        if kind == "before":
            self.params_before = params
            self.loss_before = []
            self.grad_before = []
        count = 0
        for batch_idx, (data, target) in enumerate(dataloader):
            if batch_idx % self.batch_selection_frequency == 0:
                model.zero_grad()
                data, target = data.to(device), target.to(device)
                pred_target = model(data)
                task_loss = loss(pred_target, target)
                grad_tuple = torch.autograd.grad(task_loss, model.parameters(), retain_graph=True)
                grad_vec = torch.cat([g.contiguous().view(-1) for g in grad_tuple], dim=0).detach()

                if kind == "before":
                    self.grad_before.append(grad_vec.cpu())
                    self.loss_before.append(task_loss)

                if kind == "after":
                    diff = (params - self.params_before).cpu()
                    diff_norm = torch.norm(diff) + 1e-4
                    output_temp = 2 * (task_loss.cpu() - self.loss_before[count].cpu() - \
                                  torch.squeeze(torch.dot(self.grad_before[count], diff))) / diff_norm

                    if output is None:
                        output = torch.mean(output_temp)
                    else:
                        output += torch.mean(output_temp)

                count += 1
        if kind == "before":
            return {"Delta_grad_before": 0}
        else:
            return {"Delta_grad": float(output) / count}

