#! /usr/bin/env python
# -*- coding: utf-8 -*-
# vim:fenc=utf-8

import torch
from torch import Tensor
from torch_scatter import scatter, segment_csr, gather_csr



def index_to_mask(index, size):
    mask = torch.zeros(size, dtype=torch.bool, device=index.device)
    mask[index] = 1
    return mask


def random_planetoid_splits(data, num_classes, percls_trn=20, val_lb=500, Flag=0):
    # Set new random planetoid splits:
    # * round(train_rate*len(data)/num_classes) * num_classes labels for training
    # * val_rate*len(data) labels for validation
    # * rest labels for testing
    # print(percls_trn, val_lb)

    indices = []
    for i in range(num_classes):
        index = (data.y == i).nonzero(as_tuple=False).view(-1)
        index = index[torch.randperm(index.size(0))]
        indices.append(index)

    train_index = torch.cat([i[:percls_trn] for i in indices], dim=0)

    if Flag == 0:
        rest_index = torch.cat([i[percls_trn:] for i in indices], dim=0)
        rest_index = rest_index[torch.randperm(rest_index.size(0))]

        data.train_mask = index_to_mask(train_index, size=data.num_nodes)
        data.val_mask = index_to_mask(rest_index[:val_lb], size=data.num_nodes)
        data.test_mask = index_to_mask(
            rest_index[val_lb:], size=data.num_nodes)
    else:
        val_index = torch.cat([i[percls_trn:percls_trn+val_lb]
                               for i in indices], dim=0)
        rest_index = torch.cat([i[percls_trn+val_lb:] for i in indices], dim=0)
        rest_index = rest_index[torch.randperm(rest_index.size(0))]

        data.train_mask = index_to_mask(train_index, size=data.num_nodes)
        data.val_mask = index_to_mask(val_index, size=data.num_nodes)
        data.test_mask = index_to_mask(rest_index, size=data.num_nodes)
    return data


def random_class_balance_splits(data, num_classes, train_rate, val_rate):
    # Set new random planetoid splits:
    # * round(train_rate*len(data)/num_classes) * num_classes labels for training
    # * val_rate*len(data) labels for validation
    # * rest labels for testing
    # print(percls_trn, val_lb)
    # Set new random splits:
    # round(train_)
    indices = []
    for i in range(num_classes):
        index = (data.y == i).nonzero(as_tuple=False).view(-1)
        index = index[torch.randperm(index.size(0))]
        indices.append(index)

    train_index = torch.cat([ind[:round(train_rate * ind.size(0))] for ind in indices], dim=0)
    val_index = torch.cat([ind[round(train_rate * ind.size(0)):round((train_rate + val_rate) * ind.size(0))] for ind in indices], dim=0)
    test_index = torch.cat([ind[round((train_rate + val_rate) * ind.size(0)):] for ind in indices], dim=0)

    data.train_mask = index_to_mask(train_index, size=data.num_nodes)
    data.val_mask = index_to_mask(val_index, size=data.num_nodes)
    data.test_mask = index_to_mask(test_index, size=data.num_nodes)
    return data



def maybe_num_nodes(edge_index, num_nodes=None):
    if num_nodes is not None:
        return num_nodes
    elif isinstance(edge_index, Tensor):
        return int(edge_index.max()) + 1
    else:
        return max(edge_index.size(0), edge_index.size(1))


def add_self_loops(edge_index, edge_weight=None,
                   fill_value: float = 1., num_nodes=None):
    r"""Adds a self-loop :math:`(i,i) \in \mathcal{E}` to every node
    :math:`i \in \mathcal{V}` in the graph given by :attr:`edge_index`.
    In case the graph is weighted, self-loops will be added with edge weights
    denoted by :obj:`fill_value`.

    Args:
        edge_index (LongTensor): The edge indices.
        edge_weight (Tensor, optional): One-dimensional edge weights.
            (default: :obj:`None`)
        fill_value (float, optional): If :obj:`edge_weight` is not :obj:`None`,
            will add self-loops with edge weights of :obj:`fill_value` to the
            graph. (default: :obj:`1.`)
        num_nodes (int, optional): The number of nodes, *i.e.*
            :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)

    :rtype: (:class:`LongTensor`, :class:`Tensor`)
    """
    N = maybe_num_nodes(edge_index, num_nodes)

    loop_index = torch.arange(0, N, dtype=torch.long, device=edge_index.device)
    loop_index = loop_index.unsqueeze(0).repeat(2, 1)

    if edge_weight is not None:
        # assert edge_weight.numel() == edge_index.size(1)
        # loop_weight = edge_weight.new_full((N, ), fill_value)
        # edge_weight = torch.cat([edge_weight, loop_weight], dim=0)
        # edge_weight has shape of [E,d]
        E, d = edge_weight.shape
        loop_weight = edge_weight.new_full((N,d), fill_value)
        edge_weight = torch.cat([edge_weight, loop_weight], dim=0)

    edge_index = torch.cat([edge_index, loop_index], dim=1)

    return edge_index, edge_weight


# @torch.jit.script
def softmax(src: Tensor, index, ptr=None,
            num_nodes=None, temperature=1.0):
    r"""Computes a sparsely evaluated softmax.
    Given a value tensor :attr:`src`, this function first groups the values
    along the first dimension based on the indices specified in :attr:`index`,
    and then proceeds to compute the softmax individually for each group.

    Args:
        src (Tensor): The source tensor.
        index (LongTensor): The indices of elements for applying the softmax.
        ptr (LongTensor, optional): If given, computes the softmax based on
            sorted inputs in CSR representation. (default: :obj:`None`)
        num_nodes (int, optional): The number of nodes, *i.e.*
            :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`)

    :rtype: :class:`Tensor`
    """
    src /= temperature
    out = src - src.max()
    out = out.exp()

    if ptr is not None:
        out_sum = gather_csr(segment_csr(out, ptr, reduce='sum'), ptr)
    elif index is not None:
        N = maybe_num_nodes(index, num_nodes)
        out_sum = scatter(out, index, dim=0, dim_size=N, reduce='sum')[index]
    else:
        raise NotImplementedError

    return out / (out_sum + 1e-16)


def weighted_degree(index, edge_weight = None, num_nodes = None,
           dtype = None):

    N = maybe_num_nodes(index, num_nodes)
    out = torch.zeros((N, ), dtype=dtype, device=index.device)
    if edge_weight is None:
        edge_weight = torch.ones((index.size(0), ), dtype=out.dtype, device=out.device)
    return out.scatter_add_(0, index, edge_weight)


def multi_weighted_degree(index, edge_weight = None, num_nodes = None,
           dtype = None):
    # edge_weight: [E, C]

    N = maybe_num_nodes(index, num_nodes)
    C = edge_weight.size(1)
    out = torch.zeros((N, C), dtype=dtype, device=index.device)
    if edge_weight is None:
        edge_weight = torch.ones((index.size(0), ), dtype=out.dtype, device=out.device)
    index = index.view(-1, 1).expand(index.size(0), C)
    return out.scatter_add_(0, index, edge_weight)


def how_sharp_a_distribution(arr, p=2):
    out = torch.norm(arr, p=p, dim=-1)
    min_val = torch.pow(torch.tensor(1.0*arr.size(-1)), (1.0/p)-1)
    return (out - min_val) / (1 - min_val)


def weighted_confusion_matrix(y, score):
    H = y.max()+1
    W = score.size(1)
    out = torch.zeros(H, W)
    index = y.view(-1, 1).expand(y.size(0), W)
    out = out.scatter_add_(0, index, score)
    return out

def sparse_tensor_to_edge_index(adj_t):
    row, col, edge_attr = adj_t.t().coo()
    edge_index = torch.stack([row, col], dim=0)
    return edge_index


def cal_heterophilous_ratio_torchversion(M, c_degs, l):
    M_power_l = torch.matrix_power(M, l)
    diff = M_power_l.unsqueeze(1) - M_power_l
    dist_matrix = torch.sqrt(torch.sum(diff**2, dim=-1))
    D_power_l = c_degs ** l
    D_item = torch.sqrt(D_power_l / 2)
    return dist_matrix * D_item
