import warnings
import torch
from torch import Tensor
from torch_scatter import scatter


def cheby(i,x):
    if i==0:
        return 1
    elif i==1:
        return x
    else:
        T0=1
        T1=x
        for ii in range(2,i+1):
            T2=2*x*T1-T0
            T0,T1=T1,T2
        return T2


def initialize_edge_weight(data):
	data.edge_weight = torch.ones(data.edge_index.shape[1], dtype=torch.float)
	return data


def initialize_node_features(data):
	num_nodes = int(data.edge_index.max()) + 1
	data.x = torch.ones((num_nodes, 1))
	return data


def set_tu_dataset_y_shape(data):
	num_tasks = 1
	data.y = data.y.unsqueeze(num_tasks)
	return data


def spmm(src, other, reduce: str = "sum") -> Tensor:
    reduce = 'sum' if reduce == 'add' else reduce

    if reduce not in ['sum', 'mean', 'min', 'max']:
        raise ValueError(f"`reduce` argument '{reduce}' not supported")

    if reduce == 'sum':
        return torch.sparse.mm(src, other)
    elif reduce == 'mean':
        if src.layout == torch.sparse_csr:
            ptr = src.crow_indices()
            deg = ptr[1:] - ptr[:-1]
        elif src.layout == torch.sparse_csc:
            assert src.layout == torch.sparse_csc
            deg = scatter(torch.ones_like(src.values()), src.row_indices(),
                          dim=0, dim_size=src.size(0), reduce='sum')
        else:
            assert src.layout == torch.sparse_coo
            src = src.coalesce()
            deg = scatter(torch.ones_like(src.values()), src.indices(), dim=0,
                          dim_size=src.size(0), reduce='sum')

        return torch.sparse.mm(src, other) / deg.view(-1, 1).clamp_(min=1)

    raise ValueError(f"`{reduce}` reduction is not supported for "
                     f"'torch.sparse.Tensor' on device '{src.device}'")
