from typing import Optional, List, NamedTuple, Tuple, Union
import torch
import os
from tqdm import trange

from torch import Tensor
from torch_geometric.loader import NeighborSampler

from torch_sparse import SparseTensor

sampler_cache_store = {} # Maps from dataset, num_hops -> dict[node_idx] -> (n_id, edge_index)

class Adj(NamedTuple):
    edge_index: torch.Tensor
    e_id: torch.Tensor
    size: Tuple[int, int]

    def to(self, *args, **kwargs):
        return Adj(self.edge_index.to(*args, **kwargs),
                   self.e_id.to(*args, **kwargs), self.size)

def preprocess(edge_index, num_nodes=None, flow='source_to_target'):
    N = int(edge_index.max() + 1) if num_nodes is None else num_nodes
    edge_attr = torch.arange(edge_index.size(1))
    whole_adj = SparseTensor.from_edge_index(edge_index, edge_attr, (N, N),
                                       is_sorted=False)

    whole_adj = whole_adj.t() if flow == 'source_to_target' else whole_adj
    whole_adj = whole_adj.to('cpu')
    rowptr, col, value = whole_adj.csr()  # convert to csr form
    whole_adj = SparseTensor(rowptr=rowptr, col=col, value=value, sparse_sizes=(N, N), is_sorted=True, trust_data=True)
    return whole_adj

def sample_k_hop_subgraph_cache(
    dataset : str,
    node_idx: Union[int, List[int], Tensor],
    num_hops: int,
    whole_adj: Tensor,
    relabel_nodes: bool = False,
    num_nodes: Optional[int] = None,
    flow: str = 'source_to_target',
    size: int = 100,
    cache_only: bool = False,
):
    """Added cache to speed up sampling, calling `sample_k_hop_subgraph` internally
    See docs for that function for info."""

    # Map from cache[(dataset, num_hops)][node_idx] -> (n_id, edge_index)
    key = (dataset, num_hops)

    if key not in sampler_cache_store:
        # First call, not exist yet, initialize cache.
        cache_name = os.path.join(os.path.dirname(__file__), f'../sampling_cache/{dataset}_{num_hops}.p')
        if os.path.exists(cache_name):
            with open(cache_name, 'rb') as f:
                sampler_cache_store[key] = torch.load(f)
        else:
            sampler_cache_store[key] = {}

    if key in sampler_cache_store and node_idx in sampler_cache_store[key]:
        return sampler_cache_store[key][node_idx]
    elif cache_only:
        raise ValueError("Query not in the neighbor sampling cache.")

    # If not in cache, generate and store.
    value = sample_k_hop_subgraph(node_idx, num_hops, whole_adj, relabel_nodes, num_nodes, flow, size)
    sampler_cache_store[key][node_idx] = value

    return value

def flush_sample_cache_all():
    for key in sampler_cache_store:
        flush_sample_cache(key[0], key[1])

def flush_sample_cache(dataset: str, num_hops: int):
    # Check if cache exists.
    assert (dataset, num_hops) in sampler_cache_store, "Cache must exist before being flushed."

    cache_name = os.path.join(os.path.dirname(__file__), f'../sampling_cache/{dataset}_{num_hops}.p')

    os.makedirs(os.path.dirname("../sampling_cache/"), exist_ok=True)
    with open(cache_name, 'wb') as f:
        torch.save(sampler_cache_store[(dataset, num_hops)], f)


def sample_k_hop_subgraph(
    node_idx: Union[int, List[int], Tensor],
    num_hops: int,
    whole_adj: Tensor,
    relabel_nodes: bool = False,
    num_nodes: Optional[int] = None,
    flow: str = 'source_to_target',
    size: int = 100
):
    '''
    input: similar to k_hop_subgraph (check https://pytorch-geometric.readthedocs.io/en/1.5.0/modules/utils.html?highlight=subgraph#torch_geometric.utils.k_hop_subgraph)
      key difference:
        (1) we need preprocess function and achieve the adj (of type SparseTensor)
        (2) argument `size` is added to control the number of nodes in the sampled subgraph
    output:
      n_id: the nodes involved in the subgraph
      edge_index: the edge_index in the subgraph
      edge_id: the id of edges in the subgraph
    '''
    if isinstance(node_idx, int):
        node_idx = torch.tensor([node_idx])
    elif isinstance(node_idx, List[int]):
        node_idx = torch.tensor([node_idx])

    assert isinstance(whole_adj, SparseTensor)

    batch_size: int = len(node_idx)
    adjs = []

    n_id = node_idx
    sizes = [size]*num_hops
    for size in sizes:
        adj, n_id = whole_adj.sample_adj(n_id, size, replace=False)
        row, col, e_id = adj.coo()
        size = adj.sparse_sizes()
        if flow == 'source_to_target':
            edge_index = torch.stack([col, row], dim=0)
            size = size[::-1]
        else:
            edge_index = torch.stack([row, col], dim=0)

        adjs.append(Adj(edge_index, e_id, size))

    if len(adjs) > 1:
        edge_index = torch.cat([adj.edge_index for adj in adjs], dim=1)
        e_id = torch.cat([adj.e_id for adj in adjs])
    else:
        edge_index = adjs[0].edge_index
        e_id = adjs[0].e_id

    return n_id, edge_index if relabel_nodes else n_id[edge_index], e_id


class NeighborSampler:
    def __init__(
        self,
        dataset : str, # cache key
        graph, # pyg graph
        num_hops: int,
        relabel_nodes: bool = False,
        flow: str = 'source_to_target',
        size: int = 100,
    ):
        self.dataset = dataset
        self.graph = graph
        self.num_hops = num_hops
        self.relabel_nodes = relabel_nodes
        self.flow = flow
        self.size = size
        self.whole_adj = None

    def sample_node(self, node_idx):
        try:
            return sample_k_hop_subgraph_cache(
                self.dataset,
                node_idx,
                self.num_hops,
                self.whole_adj,
                self.relabel_nodes,
                size=self.size,
                cache_only=True,
            )
        finally:
            if self.whole_adj is None:
                self.whole_adj = preprocess(self.graph.edge_index, self.graph.num_nodes, self.flow)
            return sample_k_hop_subgraph_cache(
                self.dataset,
                node_idx,
                self.num_hops,
                self.whole_adj,
                self.relabel_nodes,
                size=self.size,
                cache_only=False,
            )

    def flush_cache(self):
        flush_sample_cache(self.dataset, self.num_hops)

    def precompute_cache(self):
        # Just need to do this once before experiment training.
        print("Precomputing neighbor sampling cache.")
        # Precompute the cache.
        print(f"Creating neighbor sampling cache for {self.dataset}")
        for node in trange(self.graph.num_nodes):
            self.sample_node(node)

        # Flush the cache.
        self.flush_cache()


class NeighborSamplerNoCache:
    def __init__(
        self,
        graph, # pyg graph
        num_hops: int,
        relabel_nodes: bool = False,
        flow: str = 'source_to_target',
        size: int = 100,
    ):
        self.num_hops = num_hops
        self.relabel_nodes = relabel_nodes
        self.size = size
        self.whole_adj = preprocess(graph.edge_index, graph.num_nodes, flow)
        self.whole_adj.share_memory_()

    def sample_node(self, node_idx):
        return sample_k_hop_subgraph(
            node_idx,
            num_hops=self.num_hops,
            whole_adj=self.whole_adj,
            relabel_nodes=self.relabel_nodes,
            size=self.size,
        )


class NeighborSamplerCacheAdj(NeighborSamplerNoCache):
    def __init__(
        self,
        cache_path,
        graph, # pyg graph
        num_hops: int,
        relabel_nodes: bool = False,
        flow: str = 'source_to_target',
        size: int = 100,
    ):
        self.num_hops = num_hops
        self.relabel_nodes = relabel_nodes
        self.size = size
        if os.path.exists(cache_path):
            print(f"Loading adjacent matrix for neighbor sampling from {cache_path}")
            self.whole_adj = torch.load(cache_path)
            print(f"Loaded adjacent matrix for neighbor sampling from {cache_path}")
        else:
            print(f"Preprocessing adjacent matrix for neighbor sampling")
            self.whole_adj = preprocess(graph.edge_index, graph.num_nodes, flow)
            print(f"Saving adjacent matrix for neighbor sampling to {cache_path}")
            torch.save(self.whole_adj, cache_path)
            print(f"Saved adjacent matrix for neighbor sampling to {cache_path}")
        self.whole_adj.share_memory_()
