from torch_geometric.data import Data
from torch_geometric.utils import subgraph
import torch_geometric
from ogb.nodeproppred import PygNodePropPredDataset
from scipy.sparse import csr_matrix, eye, csgraph
from scipy.special import softmax
from neighboring.pernode_ppr_neighbor import topk_ppr_matrix
from .modified_tsp import solve_tsp_simulated_annealing
from torch_sparse import SparseTensor
import numpy as np
import itertools
import torch
import time
from train.train_utils import MyGraph


def get_time():
    torch.cuda.synchronize()
    return time.time()


def check_consistence(mode: str, 
                      neighbor_sampling: str, 
                      order: bool, 
                      sample: bool):
    assert mode in ['ppr', 'rand', 'randfix', 'part', 'clustergcn', 'n_sampling', 'rw_sampling', 'ladies']
    if mode in ['ppr', 'part', 'randfix']:
        assert (sample and order) == False
    else:
        assert (sample or order) == False
    
    if mode in ['ppr', 'rand']:
        assert neighbor_sampling in ['ppr', 'hk', 'pnorm']
    elif mode == 'part':
        assert neighbor_sampling in ['ladies', 'batch_hk', 'batch_ppr', 'ppr']
    elif mode == 'clustergcn':
        assert (sample or order) == False
        
    if neighbor_sampling == 'ladies':
        assert mode == 'part'


def load_data(dataset_name: str, 
              small_trainingset: float):
    """
    load ogb dataset
    """
    try:
        dataset = PygNodePropPredDataset(name="ogbn-{:s}".format(dataset_name), root='./')
        split_idx = dataset.get_idx_split()
        graph = dataset[0]
    except ValueError:
        if dataset_name == 'reddit2':
            dataset = torch_geometric.datasets.Reddit2('./')
        elif dataset_name == 'reddit':
            dataset = torch_geometric.datasets.Reddit('./')
        graph = dataset[0]
        if isinstance(graph, torch_geometric.data.Data):
            split_idx = {'train': graph.train_mask.nonzero().reshape(-1), 
                         'valid': graph.val_mask.nonzero().reshape(-1), 
                         'test': graph.test_mask.nonzero().reshape(-1)}
            graph.train_mask, graph.val_mask, graph.test_mask = None, None, None
        
    train_indices = split_idx["train"]
    if isinstance(train_indices, dict):
        train_indices = train_indices['paper']
    train_indices = train_indices.cpu().detach().numpy()
    
    if small_trainingset < 1:
        np.random.seed(2021)
        train_indices = np.sort(np.random.choice(train_indices, 
                                                 size=int(len(train_indices) * small_trainingset), 
                                                 replace=False, 
                                                 p=None))
        
    val_indices = split_idx["valid"]
    if isinstance(val_indices, dict):
        val_indices = val_indices['paper']
    val_indices = val_indices.cpu().detach().numpy()
    
    test_indices = split_idx["test"]
    if isinstance(test_indices, dict):
        test_indices = test_indices['paper']
    test_indices = test_indices.cpu().detach().numpy()
    
    if dataset_name == 'mag':
        graph = Data(x=graph.x_dict['paper'],
                    edge_index=graph.edge_index_dict[('paper', 'cites', 'paper')],
                    y=graph.y_dict['paper'])
    
    return graph, (train_indices, val_indices, test_indices,)


def graph_preprocess(graph: Data, 
                     self_loop: bool = True, 
                     to_undirected: bool = True, 
                     normalization: str = 'sym'):
    """
    graph preprocess, to undirected, add self-loop, add adj_t
    """
    if graph.y.dim() > 1:
        graph.y = graph.y.reshape(-1)
    
    row, col = graph.edge_index.cpu().detach().numpy()
    graph.edge_index = None
    data = np.ones_like(row, dtype=np.bool_)
    adj = csr_matrix((data, (row, col)), shape=(graph.num_nodes, graph.num_nodes))
        
    if to_undirected:
        adj += adj.transpose()
        
    if self_loop:
        adj += eye(graph.num_nodes, dtype=np.bool_)
    
    adj = normalize_adjmat(adj, normalization, inplace=True)
    graph.adj_t = SparseTensor.from_scipy(adj)
    
    
def get_partitions(mode: str, 
                   mat: SparseTensor, 
                   num_parts: int, 
                   force: bool = False) -> list:
    
    partitions = None
    if mode in ['part', 'clustergcn'] or force:
        _, partptr, perm = mat.partition(num_parts=num_parts, recursive=False, weighted=False)

        partitions = []
        for i in range(len(partptr) - 1):
            partitions.append(perm[partptr[i] : partptr[i + 1]].cpu().detach().numpy())

    return partitions


def get_ppr_mat(mode: str, 
                neighbor_sampling: str,
                prime_indices: np.ndarray, 
                scipy_adj: csr_matrix, 
                topk=256, 
                eps=None) -> csr_matrix:
    
    ppr_mat = None
    if 'ppr' in [mode, neighbor_sampling]:
        # if too many prime nodes, we don't need many pairs
        if eps is None:
            eps = 1e-4 if (scipy_adj.nnz / len(prime_indices) < 100) else 1e-5
        ppr_mat = topk_ppr_matrix(scipy_adj, 0.05, eps, prime_indices, topk=topk, normalization='sym')

    return ppr_mat

    
def normalize_adjmat(adj: [SparseTensor, csr_matrix], 
                     normalization: str, 
                     inplace: bool = False):
    
    assert normalization in ['sym', 'rw']
    
    if isinstance(adj, SparseTensor):
        if not inplace:
            adj = adj.clone()
        adj = adj.fill_value(1, dtype=torch.float32)
        degree = adj.sum(0)
    elif isinstance(adj, csr_matrix):
        if not inplace:
            adj = adj.copy()
        adj.data = np.ones_like(adj.data, dtype=np.float32)
        degree = adj.sum(0).A1
        
    degree[degree == 0.] = 1e-12
    deg_inv = 1 / degree
    
    if normalization == 'sym':
        deg_inv_sqrt = deg_inv ** 0.5
        if isinstance(adj, csr_matrix):
            adj = adj.multiply(deg_inv_sqrt.reshape(1, -1))
            adj = adj.multiply(deg_inv_sqrt.reshape(-1, 1))
        elif isinstance(adj, SparseTensor):
            adj = adj * deg_inv_sqrt.reshape(1, -1)
            adj = adj * deg_inv_sqrt.reshape(-1, 1)
            
    elif normalization == 'rw':
        if isinstance(adj, csr_matrix):
            adj = adj.multiply(deg_inv.reshape(-1, 1))
        elif isinstance(adj, SparseTensor):
            adj = adj * deg_inv.reshape(-1, 1)
    
    return adj


def kl_divergence(p: np.ndarray, q: np.ndarray):
    return (p * np.log(p / q)).sum()


def get_pair_wise_distance(ys: list, num_classes: int, dist_type: str = 'kl'):
    num_batches = len(ys)

    counts = np.zeros((num_batches, num_classes), dtype=np.int32)
    for i in range(num_batches):
        unique, count = np.unique(ys[i], return_counts=True)
        counts[i, unique] = count

    counts += 1
    counts = counts / counts.sum(1).reshape(-1, 1)
    pairwise_dist = np.zeros((num_batches, num_batches), dtype=np.float64)

    for i in range(0, num_batches - 1):
        for j in range(i + 1, num_batches):
            if dist_type == 'l1':
                pairwise_dist[i, j] = np.sum(np.abs(counts[i] - counts[j]))
            elif dist_type == 'kl':
                pairwise_dist[i, j] = kl_divergence(counts[i], counts[j]) + kl_divergence(counts[j], counts[i])
            else:
                raise ValueError

    pairwise_dist += pairwise_dist.T
    
#     # softmax
#     np.fill_diagonal(pairwise_dist, -1e5)
#     pairwise_dist = softmax(pairwise_dist, axis=1)
#     # ^ 2
#     pairwise_dist = pairwise_dist ** 2
    
    pairwise_dist += 1e-5   # for numerical stability
    np.fill_diagonal(pairwise_dist, 0.)
    
    return pairwise_dist


def tsp_brute_force(pairwise_dist: np.ndarray):
    num_batches = pairwise_dist.shape[0]
    permute = itertools.permutations(list(range(num_batches)))

    best_dist = -1.
    best_perm = None

    for term in permute:
        term = list(term)
        term.append(term[0])
        dist = pairwise_dist[term[:-1], term[1:]].sum()
        if dist > best_dist:
            best_dist = dist
            best_perm = term[:-1]
            
    return best_perm, best_dist


def tsp_heuristic(pairwise_dist: np.ndarray):
    assert np.all(pairwise_dist >= 0.)
    
    best_perm, best_dist = solve_tsp_simulated_annealing(-pairwise_dist, 
                                                         max_processing_time=1800)
    
    return best_perm, best_dist
