import torch
from torch_geometric.utils import scatter, add_remaining_self_loops, spmm
from torch import Tensor
from greatx.nn.models import Propagate


def project(num_budgets: int, values: Tensor, eps: float = 1e-7) -> Tensor:
    """Project :obj:`values`:
    :math:`num_budgets \ge \sum \Pi_{[0, 1]}(\text{values})`."""
    if torch.clamp(values, 0, 1).sum() > num_budgets:
        left = (values - 1).min()
        right = values.max()
        miu = bisection(values, left, right, num_budgets)
        values = values - miu
    return torch.clamp(values, min=eps, max=1 - eps)


def bisection(edge_weight: Tensor, a: float, b: float, n_pert: int, eps=1e-5,
              max_iter=1e3) -> Tensor:
    """Bisection search for projection."""
    def shift(offset: float):
        return (torch.clamp(edge_weight - offset, 0, 1).sum() - n_pert)

    miu = a
    for _ in range(int(max_iter)):
        miu = (a + b) / 2
        # Check if middle point is root
        if (shift(miu) == 0.0):
            break
        # Decide the side to repeat the steps
        if (shift(miu) * shift(a) < 0):
            b = miu
        else:
            a = miu
        if ((b - a) <= eps):
            break
    return miu



def propagate(x, edge_index, edge_weight=None, norm=True, nlayers=2):
    """
    GCN-like Propagation, similar to 
    https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/conv/gcn_conv.html#GCNConv
    """

    """
    num_nodes = x.shape[0]
    row, col = edge_index[0], edge_index[1]
    edge_index, edge_weight = gcn_norm( 
                    edge_index, edge_weight, num_nodes)
    adj = torch.sparse_coo_tensor(edge_index, edge_weight, (num_nodes, num_nodes))
    out = spmm(adj, x)
    """
    model = Propagate(nlayers=nlayers, norm=norm)
    out = model(x, edge_index, edge_weight)

    return out



def num_possible_edges(n: int, is_undirected_graph: bool) -> int:
    """Determine number of possible edges for graph."""
    if is_undirected_graph:
        return n * (n - 1) // 2
    else:
        return int(n**2)  # We filter self-loops later


def linear_to_triu_idx(n: int, lin_idx: Tensor) -> Tensor:
    """Linear index to upper triangular matrix without diagonal.
    This is similar to
    https://stackoverflow.com/questions/242711/algorithm-for-index-numbers-of-triangular-matrix-coefficients/28116498#28116498
    with number nodes decremented and col index incremented by one."""
    nn = n * (n - 1)
    row_idx = n - 2 - torch.floor(
        torch.sqrt(-8 * lin_idx.double() + 4 * nn - 7) / 2.0 - 0.5).long()
    col_idx = 1 + lin_idx + row_idx - nn // 2 + torch.div(
        (n - row_idx) * (n - row_idx - 1), 2, rounding_mode='floor')
    return torch.stack((row_idx, col_idx))


def src_tar_to_triu_idx(n: int, candidate_nodes: Tensor, lin_idx: Tensor) -> Tensor:
    """Similar to linear_to_triu_idx""" 

    nn = n * (n - 1)
    row = n - 2 - torch.floor(
        torch.sqrt(-8 * lin_idx.double() + 4 * nn - 7) / 2.0 - 0.5).long()
    col = 1 + lin_idx + row - nn // 2 + torch.div(
        (n - row) * (n - row - 1), 2, rounding_mode='floor')
    
    row_idx = candidate_nodes[row]
    col_idx = candidate_nodes[col]
    return torch.stack((row_idx, col_idx))


def linear_to_full_idx(n: int, lin_idx: Tensor) -> Tensor:
    """Linear index to dense matrix including diagonal."""
    row_idx = torch.div(lin_idx, n, rounding_mode='floor')
    col_idx = lin_idx % n
    return torch.stack((row_idx, col_idx))