import sys

import torch

sys.path.append(".")
from src.tools.sharpness_tools.math_utils import hvp
from torch.cuda.amp import GradScaler


def fishr(model, data_loader):
    """
    :param model:
    :param data_loader:
    :return: Fisher information matrix
    """
    scalar = GradScaler()

    theta = torch.cat([p.view(-1) for p in model.parameters()])

    return (theta.T @ hvp(model, data_loader, theta, scalar)).item()
