import hashlib

import scipy.sparse as sp
import torch


def hash_tensor(tensor):
    """
    Compute the SHA256 hash of a tensor.

    Args:
        tensor (torch.Tensor): The input tensor.

    Returns:
        str: The SHA256 hash of the tensor.
    """
    return hashlib.sha256(tensor.byte().numpy()).hexdigest()


def get_margin_incorrect_vectorized(logits, y):
    # shape n_subsets, n_models, n_nodes, n_classes
    n_subsets, n_models, n_nodes, _ = logits.shape

    if len(logits.shape) - 2 != len(y.shape):
        y = y.unsqueeze(0).expand(n_subsets, -1)
    y = y.unsqueeze(1).repeat_interleave(n_models, dim=1)

    mask = y >= 0
    # set the invalid entries to 0 (the respective margins will be masked out at the end)
    valid_y = torch.where(mask, y, torch.zeros_like(y))
    correct_class_logits = torch.gather(logits, -1, valid_y.unsqueeze(-1)).squeeze(-1)

    other_logits = logits.clone()
    other_logits.scatter_(-1, valid_y.unsqueeze(-1), float("-inf"))
    highest_incorrect_logit = other_logits.max(dim=-1).values

    margins = correct_class_logits - highest_incorrect_logit
    margins[~mask] = float("nan")

    # # Uncomment if you want to double check the margins are correct
    # sub_margins = []
    # for sub_logits, sub_y in zip(logits, y):
    #     sub_margins_nan = get_margin_incorrect(sub_logits, sub_y)
    #     sub_margins_fix = get_margin_incorrect(
    #         sub_logits[sub_y >= 0], sub_y[sub_y >= 0]
    #     )

    #     assert (sub_margins_fix == sub_margins_nan[sub_y >= 0]).all()
    #     sub_margins.append(sub_margins_nan)

    # sub_margins = torch.stack(sub_margins)
    # # Mask because `float("NaN") == float("NaN")` is False
    # assert (margins[mask] == sub_margins[mask]).all()

    return margins


def get_margin_incorrect(logits, y):
    """
    Return the margin of the correct class with the highest incorrect class.

    Args:
        logits (torch.Tensor): The logits tensor of shape (n_nodes, n_classes).
        y (torch.Tensor): The true class labels tensor of shape (n_nodes,).

    Returns:
        torch.Tensor: The margins tensor of shape (n_nodes,).
    """
    n_nodes, n_classes = logits.shape

    if n_classes == 1:
        probs = logits.sigmoid().squeeze()
        margins = torch.where(y == 1, probs, 1 - probs)
    else:
        logits_true = logits[torch.arange(n_nodes), y]
        mask = torch.arange(n_classes).expand(n_nodes, n_classes)
        mask = mask == y.unsqueeze(1)

        # Apply the mask by setting correct logits to -infinity
        logits_masked = logits.masked_fill(mask, -float("Inf"))

        # Find the highest incorrect logit
        highest_incorrect_logit = logits_masked.max(dim=1).values
        margins = logits_true - highest_incorrect_logit

    return margins


def get_margin_second(logits):
    """
    Return the margin between the highest and second highest logit.

    Args:
        logits (torch.Tensor): The input tensor of shape (batch_size, num_classes) containing the logits.

    Returns:
        torch.Tensor: The tensor of shape (batch_size,) containing the margins between the highest and second highest logit.
    """
    top2_logits = logits.topk(2, dim=1).values
    margins = top2_logits[:, 0] - top2_logits[:, 1]

    return margins


def sparse_tensor_to_scipy(sparse_tensor):
    """
    Convert a sparse tensor to a SciPy sparse matrix.

    Parameters:
        sparse_tensor (torch.sparse.Tensor): The input sparse tensor.

    Returns:
        scipy.sparse.csr_matrix: The converted SciPy sparse matrix.
    """
    indices = sparse_tensor.indices().numpy()
    values = sparse_tensor.values().numpy()
    shape = sparse_tensor.size()
    rows, cols = indices

    return sp.csr_matrix((values, (rows, cols)), shape=shape)


def flatten_input(x):
    if isinstance(x, sp.csr_matrix):
        return torch.from_numpy(x.todense()).flatten()
    else:
        return torch.from_numpy(x).flatten()
