import torch
import os
import numpy as np
from .loader_clustergcn import ClusterGCN_loader
from .loader_part_based import partition_fixed_loader
from .loader_part_ppr_mix import partition_ppr_loader
from .loader_rand import RandLoader, rand_fixed_loader
from .loader_ppr_based import ppr_fixed_loader
from torch_geometric.data import NeighborSampler
from torch_sparse import SparseTensor
from scipy.sparse import csr_matrix
from .MySaintSampler import SaintRWTrainSampler, SaintRWValSampler
from neighboring.layerwise_sampling import LadiesLoader


def get_loader(mode: str, 
               neighbor_sampling: str, 
               adj: SparseTensor,
               num_nodes: int,
               merge_max_size: int, 
               num_batches: int,
               num_layers: int,
               partition_diffusion_param: float, 
               n_sampling_params: dict, 
               rw_sampling_params: dict,
               train: bool = True,
               partitions: list = None,
               part_topk: float = 1, 
               prime_indices: np.ndarray = None,
               neighbors: list = None, 
               ppr_mat: csr_matrix = None):
    
    loader = [[], None]
    if mode == 'clustergcn':
        loader[0] = ClusterGCN_loader(partitions, prime_indices)
        
    elif mode == 'part':
        if neighbor_sampling == 'ladies':
            raw = ClusterGCN_loader(partitions, prime_indices)
            prime_batchlist = [p for (p, _) in raw]
            loader[0] = LadiesLoader(prime_batchlist, np.full(num_layers, merge_max_size), adj.to_scipy('csr'))
            
        elif neighbor_sampling in ['batch_ppr', 'batch_hk']:
            topk = merge_max_size * part_topk
            
            loader[0] = partition_fixed_loader(neighbor_sampling, 
                                           adj, 
                                           num_nodes, 
                                           partitions, 
                                           prime_indices, 
                                           topk=int(topk), 
                                           partition_diffusion_param=partition_diffusion_param)
        elif neighbor_sampling == 'ppr':
            prime_batchlist = ClusterGCN_loader(partitions, prime_indices)
            prime_batchlist = [p for (p, _) in prime_batchlist]
            loader[0] = partition_ppr_loader(prime_batchlist, neighbors, prime_indices)
            
        
    elif mode == 'ppr':
        loader[0] = ppr_fixed_loader(ppr_mat[:, prime_indices], 
                                 prime_indices, 
                                 neighbors, 
                                 merge_max_size=merge_max_size)
        
    elif mode == 'n_sampling':
        num_batches = n_sampling_params['num_batches'].pop(0)
        batch_size = len(prime_indices) // num_batches + ((len(prime_indices) % num_batches) > 0)
        loader[0] = NeighborSampler(adj, node_idx=torch.from_numpy(prime_indices),
                                   sizes=n_sampling_params['n_nodes'], 
                                   batch_size=batch_size,
                                   shuffle=True, num_workers=0)
        
    elif mode == 'rw_sampling':
        dir_name = './'
        if not os.path.isdir:
            os.mkdir(dir_name)
            
        if train:
            loader[0] = SaintRWTrainSampler(adj, num_nodes, 
                                         batch_size=rw_sampling_params['batch_size'][0], 
                                         walk_length=rw_sampling_params['walk_length'], 
                                         num_steps=rw_sampling_params['num_steps'], 
                                         sample_coverage=rw_sampling_params['sample_coverage'], 
                                         save_dir=dir_name)
        else:
            loader[0] = SaintRWValSampler(adj, prime_indices, num_nodes, 
                                       walk_length=rw_sampling_params['walk_length'], 
                                       sample_coverage=rw_sampling_params['sample_coverage'], 
                                       save_dir=dir_name,
                                       batch_size=rw_sampling_params['batch_size'][1])
            
    elif mode == 'rand':
        loader[0] = RandLoader(prime_indices, neighbors, merge_max_size)
    elif mode == 'randfix':
        loader[0] = rand_fixed_loader(prime_indices, neighbors, merge_max_size)
        
    elif mode == 'ladies':
        loader[0] = LadiesLoader(prime_indices, np.full(num_layers, merge_max_size), adj.to_scipy('csr'), num_batches)
            
    if not train \
       and not (mode == 'part' and 'batch' in neighbor_sampling) \
       and partitions is not None:
        # add LBMB validation
        topk = int((num_nodes // num_batches + 1) * part_topk)
            
        loader[1] = partition_fixed_loader('batch_ppr', 
                                           adj, 
                                           num_nodes, 
                                           partitions, 
                                           prime_indices, 
                                           topk=topk, 
                                           partition_diffusion_param=partition_diffusion_param)

    return loader[0] if train else loader
