import warnings
from typing import Optional

import torch
from torch import Tensor
try:
    from torch_scatter import scatter, segment_csr, gather_csr
    from torch_geometric.utils.num_nodes import maybe_num_nodes
except ModuleNotFoundError:
    warnings.warn("Please install torch-scatter and torch-geometric, if you would like to use the GNN")

from adl4cv.utils.utils import SerializableEnum


class AttentionScalingType(SerializableEnum):
    NO_SCALING = "no_scaling"
    MAX_SCALING = "max_scaling"
    VAR_SCALING = "var_scaling"


def softmax(src: Tensor, index: Optional[Tensor] = None,
            ptr: Optional[Tensor] = None, num_nodes: Optional[int] = None,
            dim: int = 0, scaling: Optional[AttentionScalingType] = None, threshold: Optional[float] = None) -> Tensor:
    r"""Computes a sparsely evaluated softmax.
    Given a value tensor :attr:`src`, this function first groups the values
    along the first dimension based on the indices specified in :attr:`index`,
    and then proceeds to compute the softmax individually for each group.

    Args:
        src (Tensor): The source tensor.
        index (LongTensor, optional): The indices of elements for applying the
            softmax. (default: :obj:`None`)
        ptr (LongTensor, optional): If given, computes the softmax based on
            sorted inputs in CSR representation. (default: :obj:`None`)
        num_nodes (int, optional): The number of nodes, *i.e.*
            :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`)
        dim (int, optional): The dimension in which to normalize.
            (default: :obj:`0`)

    :rtype: :class:`Tensor`
    """
    if ptr is not None:
        dim = dim + src.dim() if dim < 0 else dim
        size = ([1] * dim) + [-1]
        ptr = ptr.view(size)
        src_max = gather_csr(segment_csr(src, ptr, reduce='max'), ptr)
        out = (src - src_max).exp()
        out_sum = gather_csr(segment_csr(out, ptr, reduce='sum'), ptr)
    elif index is not None:
        N = maybe_num_nodes(index, num_nodes)
        if scaling == AttentionScalingType.MAX_SCALING:
            src_min = scatter(src, index, dim, dim_size=N, reduce='min')
            src_min = src_min.index_select(dim, index)
            src = src - src_min
            src_scale_max = scatter(src, index, dim, dim_size=N, reduce='max')
            src_scale_max = src_scale_max.index_select(dim, index)
            src = src / src_scale_max + 1e-16
            src = src * threshold
        elif scaling == AttentionScalingType.VAR_SCALING:
            src_mean = scatter(src, index, dim, dim_size=N, reduce='mean')
            src_mean = src_mean.index_select(dim, index)
            src = src - src_mean
            src_var = scatter(src**2, index, dim, dim_size=N, reduce='mean')
            src_var = src_var.index_select(dim, index)
            src = src / torch.sqrt(src_var) + 1e-16
            src = src * threshold
        src_max = scatter(src, index, dim, dim_size=N, reduce='max')
        src_max = src_max.index_select(dim, index)
        out = (src - src_max).exp()
        out_sum = scatter(out, index, dim, dim_size=N, reduce='sum')
        out_sum = out_sum.index_select(dim, index)
    else:
        raise NotImplementedError

    return out / (out_sum + 1e-16)
