import torch
import numpy as np
import time
from tqdm import tqdm
from torch_sparse import spspmm, transpose
import numba
from torch_scatter import segment_coo
from .pernode_ppr_neighbor import construct_sparse


prams = {'arxiv': {'chunksize': 10000,
                   'alpha': 0.05,
                   'iters': 3,
                   'top_percent': [1, 0.01, 0.05]}, 
         'products': {'chunksize': 10000,
                   'alpha': 0.05,
                   'iters': 3,
                   'top_percent': [1, 0.01, 0.05]}, 
         'reddit': {'chunksize': 10000,
                   'alpha': 0.05,
                   'iters': 3,
                   'top_percent': [0.1, 0.01, 0.05]}, }


def get_topk_neighbors_mask(
    num_nodes, index, ppr_scores, topk
):
    """
    Get mask that filters the PPR scores so that each node has at most `topk` neighbors.
    Assumes that `index` is sorted.
    """
    device = index.device

    # Get number of neighbors
    # segment_coo assumes sorted index
    ones = index.new_ones(1).expand_as(index)
    num_neighbors = segment_coo(ones, index, dim_size=num_nodes)
    max_num_neighbors = num_neighbors.max()

    # Create a tensor of size [num_nodes, max_num_neighbors] to sort the PPR scores of the neighbors.
    # Use zeros so we can easily remove unused PPR scores later.
    ppr_sort = torch.zeros(
        [num_nodes * max_num_neighbors], device=device
    )

    # Create an index map to map PPR scores from ppr_scores to ppr_sort
    # index_sort_map assumes index to be sorted
    index_neighbor_offset = torch.cumsum(num_neighbors, dim=0) - num_neighbors
    index_neighbor_offset_expand = torch.repeat_interleave(
        index_neighbor_offset, num_neighbors
    )
    index_sort_map = (
        index * max_num_neighbors
        + torch.arange(len(index), device=device)
        - index_neighbor_offset_expand
    )
    ppr_sort.index_copy_(0, index_sort_map, ppr_scores)
    ppr_sort = ppr_sort.view(num_nodes, max_num_neighbors)

    # Sort neighboring nodes based on PPR score
    ppr_sort, index_sort = torch.sort(ppr_sort, dim=1, descending=True)
    # Select the topk neighbors that are closest
    ppr_sort = ppr_sort[:, :topk]
    index_sort = index_sort[:, :topk]

    # Offset index_sort so that it indexes into index
    index_sort = index_sort + index_neighbor_offset.view(-1, 1).expand(
        -1, topk
    )
    # Remove "unused pairs" with PPR scores 0
    mask_nonzero = ppr_sort > 0
    index_sort = torch.masked_select(index_sort, mask_nonzero)

    # At this point index_sort contains the index into index of the
    # closest topk neighbors per node
    # Create a mask to remove all pairs not in index_sort
    mask_topk = torch.zeros(len(index), device=device, dtype=bool)
    mask_topk.index_fill_(0, index_sort, True)

    return mask_topk


@numba.njit(cache=True, locals={'mask': numba.boolean[::1]})
def sort_coo(i, edge_index, val, topk, dim=0):
    mask = edge_index[1 - dim] == i
    sort_idx = np.argsort(val[mask])[-topk:]
    return np.sort(edge_index[dim][mask][sort_idx])


@numba.njit(cache=True, parallel=True)
def postprocess(edge_index, val, _size, topk):
    neighbors = [np.zeros(0, dtype=numba.int32)] * _size[1]
    
    for i in numba.prange(_size[1]):
        neighbors[i] = sort_coo(i, edge_index, val, topk)
    
    return neighbors


# slicing row-wise then transpose is faster than slicing col-wise
def chunk_csr_col(mat, chunksize, device='cuda'):
    parts = []
    
    num_chunks = mat.size(1) // chunksize + int(mat.size(1) % chunksize > 0)
    pbar = tqdm(range(num_chunks))
    pbar.set_description("Chunking")
    
    for i in pbar:
        part_mat = mat[i * chunksize : min((i + 1) * chunksize, mat.size(0)), :]
        idx = torch.stack((part_mat.storage._row, part_mat.storage._col), 0).to('cuda')
        val = part_mat.storage._value.to('cuda')
        
        idx, val = transpose(idx, val, *tuple(part_mat.sizes()))
        
        parts.append((idx.to(device), val.to(device), part_mat.size(0)))
    
    return parts


def ppr_power_iter(adj, dataset_name, topk, device='cuda'):
    
    chunksize, alpha, iters, top_percent = prams[dataset_name].values()
    
    neighbor_list = []
    weights_list = []

    index1 = torch.stack((adj.storage._row, adj.storage._col), 0).to(device)
    value1 = adj.storage._value.to(device)
    parts = chunk_csr_col(adj, chunksize=chunksize, device='cpu')
    
    pbar = tqdm(range(len(parts)))
    for i in pbar:
        index2, value2, size1 = parts[i]

        with torch.no_grad():
            index2, value2 = index2.to(device), value2.to(device)
            for it in range(iters):
                if it > 0:
                    index2, value2 = spspmm(index1, value1, index2, value2, adj.size(0), adj.size(1), size1)

                # r = (1 - alpha) * A * r + alpha * t
                value2 *= 1 - alpha

                # identity matrix indices
                mask1 = torch.where(index2[0] == (index2[1] + i * chunksize))[0]
                value2[mask1] += alpha

                # pick the globally highest scores, but unfair for some nodes!
                mask2 = torch.topk(value2, k=int(len(value2) * top_percent[it]), sorted=False).indices
                mask = torch.cat((mask1, mask2), dim=0).unique(sorted=True)  # force seed node in the picked ones
                index2, value2 = index2[:, mask], value2[mask]

        index_ascending, sorting_indices = torch.sort(index2[1])  # col should be sorted
        ppr_scores_ascending = value2[sorting_indices]
        mask = get_topk_neighbors_mask(size1, index_ascending, ppr_scores_ascending, topk)

        (row, col), val = transpose(index2, value2, adj.size(0), size1)
        split_idx = ((row[1:] > row[:-1]).nonzero().squeeze() + 1).cpu().numpy()
        mask_splits = np.array_split(mask.cpu().numpy(), split_idx)
        col_splits = np.array_split(col.cpu().numpy(), split_idx)
        val_splits = np.array_split(val.cpu().numpy(), split_idx)
    #     row_splits = np.array_split(row.numpy(), split_idx)

        neighbor_list += [c[m] for c, m in zip(col_splits, mask_splits)]
        weights_list += [v[m] for v, m in zip(val_splits, mask_splits)]
        
        index2, value2 = None, None
        parts[i] = None
        torch.cuda.empty_cache()

        pbar.set_postfix(max_memory = torch.cuda.max_memory_allocated(), 
                         memory = torch.cuda.memory_allocated(), 
                         max_memory_reserve = torch.cuda.max_memory_reserved(), 
                         memory_reserve = torch.cuda.memory_reserved())

    index1 = index1.to('cpu')
    value1 = value1.to('cpu')
    
    torch.cuda.reset_peak_memory_stats()
    
    pprmat = construct_sparse(neighbor_list, weights_list, adj.sizes()).tocsr()
    
    return np.array(neighbor_list), pprmat
