import torch
from jaxtyping import Bool, Float, Int
from torch import Tensor, sparse, sparse_coo_tensor


def make_sparse_matrix(index: Int[Tensor, "2 d"], values: Float[Tensor, "d"]):
    """Construct sparse matrix.

    :param index: Index of sparse matrix.
    :param values: Values of sparse matrix.
    :return: Sparse matrix.
    """
    n = torch.max(index) + 1
    return sparse_coo_tensor(index, values, [n, n])


def make_normalized_sparse_matrix(
    index: Int[Tensor, "2 d"], values: Float[Tensor, "d"], eps: float = 1e-6
):
    """Construct normalized sparse matrix.

    :param index: Index of sparse matrix.
    :param values: Values of sparse matrix.
    :param eps: Epsilon value to avoid division by zero.
    :return: Normalized sparse matrix.
    """
    sparse_matrix = make_sparse_matrix(index, values)
    squared_values = torch.pow(values, 2)
    squared_sparse_matrix = make_sparse_matrix(index, squared_values)
    row_wise_norm = torch.sqrt(sparse.sum(squared_sparse_matrix, dim=1).values())
    return sparse_matrix * (1 / (eps + row_wise_norm.unsqueeze(1)))


def sparse_sum(
    index: Int[Tensor, "2 d"], values: Float[Tensor, "d"], axis: int
) -> Float[Tensor, "n"]:
    """Construct sparse matrix and compute sparse sum including zero elements. Assumes that the
    sparse matrix is a [n, n] matrix with n being the maximum index in `index`.

    :param index: Index of sparse matrix.
    :param values: Values of sparse matrix.
    :param axis: Axis to sum over.
    :return: Sum of sparse matrix over axis.
    """
    n = torch.max(index) + 1
    sparse_matrix = sparse_coo_tensor(index, values, [n, n])
    remaining_axis = 0 if axis == 1 else 1
    result = torch.zeros(n, device=index.device)
    mask = torch.unique(index[remaining_axis])
    result[mask] = sparse.sum(sparse_matrix, dim=axis).values()
    return result


def sparse_softmax(
    index: Int[Tensor, "2 d"], values: Float[Tensor, "d"], axis: int
) -> Float[Tensor, "d"]:
    """Construct sparse matrix and compute sparse softmax. Assumes that the sparse matrix is a [n,
    n] matrix with n being the maximum index in `index`.

    :param index: Index of sparse matrix.
    :param values: Values of sparse matrix.
    :param axis: Axis to softmax over.
    :return: Softmax of sparse matrix over axis.
    """
    sparse_matrix = make_sparse_matrix(index, values)
    return sparse.softmax(sparse_matrix, axis).values()


def sparse_random_choice(
    index: Int[Tensor, "2 d"], probability: Float[Tensor, "d"]
) -> Int[Tensor, "2 d"]:
    """Sample one index element for each row according to the given probability. One-hot encode the
    sampled indices according to the given index.

    :param index: Index of sparse matrix.
    :param probability: Probability of each index.
    :return: One-hot encoded sampled indices.
    """
    n = torch.max(index) + 1

    start = torch.unique(index[0])
    matrix = sparse_coo_tensor(index, probability, [n, n]).to_dense()
    end = torch.multinomial(matrix[start], num_samples=1).squeeze()
    sampled_index = torch.stack([start, end])

    zero_sparse_matrix = sparse_coo_tensor(index, torch.zeros_like(probability), [n, n])
    sample_sparse_matrix = sparse_coo_tensor(sampled_index, torch.ones_like(start), [n, n])
    return (zero_sparse_matrix + sample_sparse_matrix).coalesce().values()


def sparse_normalization(
    index: Int[Tensor, "2 d"], values: Float[Tensor, "d"]
) -> Float[Tensor, "d"]:
    """Normalize sparse matrix such that the sum of each row is 1.

    :param index: Index of sparse matrix.
    :param values: Values of sparse matrix.
    :return: Normalized values of sparse matrix.
    """
    sparse_matrix = make_sparse_matrix(index, values)

    normalizing_constant = sparse_sum(index, values, axis=1)
    normalizing_constant[normalizing_constant == 0] = 1
    normalizing_constant = 1 / normalizing_constant

    normalized_sparse_matrix = sparse_matrix * normalizing_constant.unsqueeze(1)
    return normalized_sparse_matrix.coalesce().values()


def sparse_mode(index: Int[Tensor, "2 d"], values: Float[Tensor, "d"]) -> Int[Tensor, "d"]:
    """Compute the mode of the sparse matrix. Assumes that the sparse matrix is a [n, n] matrix
    with n being the maximum index in `index`.

    :param index: Index of sparse matrix.
    :param values: Values of sparse matrix.
    :return: Mode of the sparse matrix.
    """
    matrix = make_sparse_matrix(index, values).to_dense()
    mode = torch.argmax(matrix, dim=1)
    return mode


def index_subset_mask(
    index: Int[Tensor, "2 d"], index_subset: Int[Tensor, "2 d_subset"]
) -> Bool[Tensor, "d"]:
    """Compute a mask to access the elements of a subset of the original index.

    :param index: Index of sparse matrix.
    :param index_subset: Subset of `index`.
    :return: Mask to access the elements of `index_subset` in `index`.
    """
    n = torch.max(index) + 1
    sparse_matrix = sparse_coo_tensor(index, torch.zeros_like(index[0]), [n, n])
    subset_matrix = sparse_coo_tensor(index_subset, torch.ones_like(index_subset[0]), [n, n])
    return (sparse_matrix + subset_matrix).coalesce().values().bool()


def pad_values_of_sparse_matrix(
    sparse_matrix,
    index: Int[Tensor, "2 d"],
) -> Float[Tensor, "d"]:
    """Pad the values of a sparse matrix with zeros to fit a larger index.

    :param sparse_matrix: Sparse matrix.
    :param index: Index of sparse matrix.
    :return: Padded values of sparse matrix.
    """
    n = torch.max(index) + 1
    d = index.shape[1]
    d_zeros = torch.zeros(d, device=index.device)

    zero_sparse_matrix = sparse_coo_tensor(index, d_zeros, [n, n])
    return (zero_sparse_matrix + sparse_matrix).coalesce().values()
