import torch
from torch_sparse import SparseTensor
from torch import Tensor
import torch_sparse
from typing import List, Tuple


class PermIterator:
    
    def __init__(self, device, size, bs, training=True) -> None:
        
        self.bs = bs  
        self.training = training  
        self.idx = torch.randperm(
            size, device=device) if training else torch.arange(size,
                                                               device=device)

    def __len__(self):
        
        return (self.idx.shape[0] + (self.bs - 1) *
                (not self.training)) // self.bs

    def __iter__(self):
        
        self.ptr = 0
        return self

    def __next__(self):
        
        if self.ptr + self.bs * self.training > self.idx.shape[0]:
            raise StopIteration
        
        ret = self.idx[self.ptr:self.ptr + self.bs]
        
        self.ptr += self.bs
        return ret



def sparsesample(adj: SparseTensor, deg: int) -> SparseTensor:
    
    rowptr, col, _ = adj.csr()
    
    rowcount = adj.storage.rowcount()
    
    mask = rowcount > 0
    
    rowcount = rowcount[mask]
    
    rowptr = rowptr[:-1][mask]

    rand = torch.rand((rowcount.size(0), deg), device=col.device)
    
    rand.mul_(rowcount.to(rand.dtype).reshape(-1, 1))
    
    rand = rand.to(torch.long)
    
    rand.add_(rowptr.reshape(-1, 1))

    samplecol = col[rand]

    
    samplerow = torch.arange(adj.size(0), device=adj.device())[mask]

    
    ret = SparseTensor(row=samplerow.reshape(-1, 1).expand(-1, deg).flatten(),
                       col=samplecol.flatten(),
                       sparse_sizes=adj.sparse_sizes()).to_device(
                           adj.device()).coalesce().fill_value_(1.0)
    
    return ret



def sparsesample2(adj: SparseTensor, deg: int) -> SparseTensor:
    '''
    another implementation for sampling elements from a adjacency matrix
    '''
    rowptr, col, _ = adj.csr()
    rowcount = adj.storage.rowcount()
    mask = rowcount > deg

    rowcount = rowcount[mask]
    rowptr = rowptr[:-1][mask]

    rand = torch.rand((rowcount.size(0), deg), device=col.device)
    rand.mul_(rowcount.to(rand.dtype).reshape(-1, 1))
    rand = rand.to(torch.long)
    rand.add_(rowptr.reshape(-1, 1))

    samplecol = col[rand].flatten()

    samplerow = torch.arange(adj.size(0), device=adj.device())[mask].reshape(
        -1, 1).expand(-1, deg).flatten()

    mask = torch.logical_not(mask)
    nosamplerow, nosamplecol = adj[mask].coo()[:2]
    nosamplerow = torch.arange(adj.size(0),
                               device=adj.device())[mask][nosamplerow]

    ret = SparseTensor(
        row=torch.cat((samplerow, nosamplerow)),
        col=torch.cat((samplecol, nosamplecol)),
        sparse_sizes=adj.sparse_sizes()).to_device(
            adj.device()).fill_value_(1.0).coalesce()  
    return ret


def sparsesample_reweight(adj: SparseTensor, deg: int) -> SparseTensor:
    '''
    another implementation for sampling elements from a adjacency matrix. It will also scale the sampled elements.
    
    '''
    rowptr, col, _ = adj.csr()
    rowcount = adj.storage.rowcount()
    mask = rowcount > deg

    rowcount = rowcount[mask]
    rowptr = rowptr[:-1][mask]

    rand = torch.rand((rowcount.size(0), deg), device=col.device)
    rand.mul_(rowcount.to(rand.dtype).reshape(-1, 1))
    rand = rand.to(torch.long)
    rand.add_(rowptr.reshape(-1, 1))

    samplecol = col[rand].flatten()

    samplerow = torch.arange(adj.size(0), device=adj.device())[mask].reshape(
        -1, 1).expand(-1, deg).flatten()
    samplevalue = (rowcount * (1/deg)).reshape(-1, 1).expand(-1, deg).flatten()

    mask = torch.logical_not(mask)
    nosamplerow, nosamplecol = adj[mask].coo()[:2]
    nosamplerow = torch.arange(adj.size(0),
                               device=adj.device())[mask][nosamplerow]

    ret = SparseTensor(row=torch.cat((samplerow, nosamplerow)),
                       col=torch.cat((samplecol, nosamplecol)),
                       value=torch.cat((samplevalue,
                                        torch.ones_like(nosamplerow))),
                       sparse_sizes=adj.sparse_sizes()).to_device(
                           adj.device()).coalesce() 
    return ret


def elem2spm(element: Tensor, sizes: List[int]) -> SparseTensor:
    
    col = torch.bitwise_and(element, 0xffffffff)
    row = torch.bitwise_right_shift(element, 32)
    return SparseTensor(row=row, col=col, sparse_sizes=sizes).to_device(
        element.device).fill_value_(1.0)


def spm2elem(spm: SparseTensor) -> Tensor:
   
    sizes = spm.sizes()
    elem = torch.bitwise_left_shift(spm.storage.row(),
                                    32).add_(spm.storage.col())
    
    return elem


def spmoverlap_(adj1: SparseTensor, adj2: SparseTensor) -> SparseTensor:
    
    assert adj1.sizes() == adj2.sizes()
    
    
    element1 = spm2elem(adj1)
    element2 = spm2elem(adj2)

    
    if element2.shape[0] > element1.shape[0]:
        element1, element2 = element2, element1

    
    idx = torch.searchsorted(element1[:-1], element2)
   
    mask = (element1[idx] == element2)
    
    retelem = element2[mask]

    
    return elem2spm(retelem, adj1.sizes())


def spmnotoverlap_(adj1: SparseTensor,
                   adj2: SparseTensor) -> Tuple[SparseTensor, SparseTensor]:
    
    element1 = spm2elem(adj1)  
    element2 = spm2elem(adj2)  

    
    idx = torch.searchsorted(element1[:-1], element2)
   
    matchedmask = (element1[idx] == element2)

    
    maskelem1 = torch.ones_like(element1, dtype=torch.bool)  
    maskelem1[idx[matchedmask]] = 0  
    retelem1 = element1[maskelem1]

    
    retelem2 = element2[torch.logical_not(matchedmask)]

    
    return elem2spm(retelem1, adj1.sizes()), elem2spm(retelem2, adj2.sizes())



def spmoverlap_notoverlap_(
        adj1: SparseTensor,
        adj2: SparseTensor) -> Tuple[SparseTensor, SparseTensor, SparseTensor]:
    
    element1 = spm2elem(adj1)  
    element2 = spm2elem(adj2)  

   
    if element1.shape[0] == 0:
        retoverlap = element1  
        retelem1 = element1  
        retelem2 = element2  
    else:
        
        idx = torch.searchsorted(element1[:-1], element2)
        
        matchedmask = (element1[idx] == element2)

        
        maskelem1 = torch.ones_like(element1, dtype=torch.bool)  
        maskelem1[idx[matchedmask]] = 0  
        
        retelem1 = element1[maskelem1]

        
        retoverlap = element2[matchedmask]
        
        retelem2 = element2[torch.logical_not(matchedmask)]

    
    sizes = adj1.sizes()
    
    return elem2spm(retoverlap,
                    sizes), elem2spm(retelem1,
                                     sizes), elem2spm(retelem2, sizes)



def adjoverlap(adj1: SparseTensor,
               adj2: SparseTensor,
               tarei: Tensor,
               filled1: bool = False,
               calresadj: bool = False,
               cnsampledeg: int = -1,
               ressampledeg: int = -1):
   
    adj1 = adj1[tarei[0]]  
    adj2 = adj2[tarei[1]]  

    
    if calresadj:
        
        adjoverlap, adjres1, adjres2 = spmoverlap_notoverlap_(adj1, adj2)

        
        if cnsampledeg > 0:
            adjoverlap = sparsesample_reweight(adjoverlap, cnsampledeg)

        
        if ressampledeg > 0:
            adjres1 = sparsesample_reweight(adjres1, ressampledeg)
            adjres2 = sparsesample_reweight(adjres2, ressampledeg)

        
        return adjoverlap, adjres1, adjres2

    else:
        
        adjoverlap = spmoverlap_(adj1, adj2)
        
        
        if cnsampledeg > 0:
            adjoverlap = sparsesample_reweight(adjoverlap, cnsampledeg)

    
    return adjoverlap

def block_matrix_multiply(spadj, block_size):
    
    rows = spadj.size(0)
    cols = spadj.size(1)
    
    
    device = spadj.device()
    
    
    result = SparseTensor(row=torch.tensor([], dtype=torch.long, device=device),
                          col=torch.tensor([], dtype=torch.long, device=device),
                          value=torch.tensor([], dtype=torch.long, device=device),
                          sparse_sizes=(rows, cols))
    
    
    spadj_dense = spadj.to_dense()
    
    
    for i in range(0, rows, block_size):
        row_end = min(i + block_size, rows)
        for j in range(0, cols, block_size):
            col_end = min(j + block_size, cols)
            
            
            block_a = spadj_dense[i:row_end, :]
            block_b = spadj_dense[:, j:col_end]
            
            
            block_result_dense = torch.matmul(block_a, block_b)
            
            
            block_result = SparseTensor.from_dense(block_result_dense)
            
            
            result = result + block_result

    return result


def sparse_tensor_multiply(spadj, block_size=1024):
    
    result = block_matrix_multiply(spadj, block_size)
    return result

if __name__ == "__main__":
    adj1 = SparseTensor.from_edge_index(
        torch.LongTensor([[0, 0, 1, 2, 3], [0, 1, 1, 2, 3]]))
    adj2 = SparseTensor.from_edge_index(
        torch.LongTensor([[0, 3, 1, 2, 3], [0, 1, 1, 2, 3]]))
    adj3 = SparseTensor.from_edge_index(
        torch.LongTensor([[0, 1,  2, 2, 2,2, 3, 3, 3], [1, 0,  2,3,4, 5, 4, 5, 6]]))
    print(spmnotoverlap_(adj1, adj2))
    print(spmoverlap_(adj1, adj2))
    print(spmoverlap_notoverlap_(adj1, adj2))
    print(sparsesample2(adj3, 2))
    print(sparsesample_reweight(adj3, 2))