from typing import NamedTuple, List, Tuple, Optional

import time
import code # for debugging

from loguru import logger

import torch
import numpy as np
from torch import Tensor
from torch.utils.data import DataLoader
from torch_sparse import SparseTensor
from torch_geometric.data import Data

from utils.compiler import load_grinnder_ext
from utils.adjacency_matrix import AdjacencyMatrixWithOffloader, AdjacencyMatrixWithGDS, AdjacencyMatrixWithSimpleGDS
from tensornvme import DiskOffloader
from utils.debug import get_memory_used
from utils.gdstensor import GDSTensor

# load grinnder external modules
grinnder_ext = load_grinnder_ext()

relabel_fn = grinnder_ext.generate_contiguous_heterograph
mask_gen_fn = grinnder_ext.gen_reuse_mask
cache_mask_gen_fn = grinnder_ext.gen_cache_mask


class MasksPair(NamedTuple):
    src: List[Tensor] = None
    my: List[Tensor] = None
    offloader: DiskOffloader | None = None

    def storage_to_gpu(self, async_: bool = True) -> None:
        pass

    def gpu_to_storage(self, async_: bool = True) -> None:
        pass

class SubData(NamedTuple):
    data: Data
    batch_size: int
    n_id: Tensor  # The indices of mini-batched nodes
    boundaries: List # The list of boundaries to get boundary vertices
                    # from other batches
    # for caching mechanisms
    reuse_masks: MasksPair = None
    cache_masks: MasksPair = None
    storage_masks: MasksPair = None

    def to(self, *args, **kwargs):
        """ Move tensor element to a dedicated device """
        def _to_for_list(tensor_list):
            for element in tensor_list: # element should be a tensor
                element.to(*args, **kwargs)
        return SubData(self.data.to(*args, **kwargs), self.batch_size,
                       self.n_id, _to_for_list(self.boundaries))
    
    def masks_storage_to_gpu(self, async_: bool = True) -> None:
        self.reuse_masks.storage_to_gpu(async_)
        self.cache_masks.storage_to_gpu(async_)
        self.storage_masks.storage_to_gpu(async_)

    def masks_gpu_to_storage(self, async_: bool = True) -> None:
        self.reuse_masks.storage_to_gpu(async_)
        self.cache_masks.storage_to_gpu(async_)
        self.storage_masks.storage_to_gpu(async_)

    def synchronize(self) -> None:
        self.reuse_masks.offloader.synchronize()
        self.cache_masks.offloader.synchronize()
        self.storage_masks.offloader.synchronize()



class SubgraphLoader(DataLoader):
    r"""A simple subgraph loader that, given a pre-partioned :obj:`data` object,
    generates subgraphs from mini-batches in :obj:`ptr` (including their 1-hop
    neighbors)."""
    def __init__(self, data: Data, ptr: Tensor,
                 use_cache: bool = False, cache_size: Optional[int] = None,
                 bipartite: bool = True, log: bool = True,
                 shuffle: Optional[bool] = False,
                 storage_offload: bool = True,
                 storage_path: str = None,
                 optimize_dataloader: bool = True,
                 **kwargs):
        
        if storage_offload or optimize_dataloader:
            assert storage_path is not None

        self.num_threads = 32

        self.data = data
        self.ptr = ptr
        self.bipartite = bipartite
        self.log = log

        self.storage_offload = storage_offload
        self.storage_path = storage_path
        self.optimize_dataloader = optimize_dataloader

        n_id = torch.arange(data.num_nodes)

        self.use_cache = use_cache
        self.cache_size = cache_size
        if use_cache:
            logger.info('Use cache... so profiling the visiting table...')
            # for caching table
            self.visiting_table_score = torch.zeros(data.num_nodes, dtype=torch.int64)
            self.visiting_table_id = None
            # after filling caching table...
            # we need to sort the table
            # then we get the topk element id
            # from that information, we figure out the indices that
            # each SubData do not need to upload
            # because it is preloaded to GPU memory

        batches = n_id.split((ptr[1:] - ptr[:-1]).tolist())
        self.numel_per_batch = [batches[i].numel() for i in range(len(batches))]
        self.numall_per_batch = []
        self.max_indices = [batches[i][-1].item() for i in range(len(batches))]
        self.min_indices = [batches[i][0].item()  for i in range(len(batches))]
        batches = [(i, batches[i]) for i in range(len(batches))]

        assert shuffle == False, 'GriNNder is a exact framework, which does not use shuffle.'

        # preprocess
        if log:
            t = time.perf_counter()
            logger.info('Pre-processing subgraphs...')

        data_list = list(  # igb-med: 45.8 GB -> 188 GB (X) 87.8 GB (X) 52.08 GB
            DataLoader(batches, collate_fn=self.compute_subgraph,
                        batch_size=1, **kwargs))

        if log:
            logger.info(f'Pre-processing time: {time.perf_counter() - t} s')
            t = time.perf_counter()

        # len_lists = []
        for data in data_list:
            self._boundary_check(data.batch_size, data.n_id, data.boundaries)
            # len_list = self._boundary_check(data.batch_size, data.n_id, data.boundaries)
            # logger.info(f'Boundary length: {reversed(sorted(len_list))}')
            # rev_sorted_list = sorted(len_list, reverse=True)
            # len_lists.append(rev_sorted_list)
        # # save len_lists to csv
        # with open('len_lists.csv', 'w') as f:
        #     for len_list in len_lists:
        #         f.write(','.join(map(str, len_list)))
        #         f.write('\n')

        

        # if use_cache:
        #     print('[GriNNder] >>> Caching element selection finished...')
        #     self.visiting_table_score, self.visiting_table_id \
        #                             = self.visiting_table_score.sort(descending=True)
        #     self.cache_ids = self.visiting_table_id[:cache_size].sort()[0] if use_cache else None
        #     self.cache_id_per_part = self._cache_id_per_partition(self.cache_ids)

        #     print('[GriNNder] >>> Calculate the heuristic best ordering...')
        #     new_ordering = self._get_partition_ordering(data_list)
        #     assert len(set(new_ordering)) == len(data_list), 'invalid reordering...'
        #     print(f'[GriNNder] >>> new order: {new_ordering}')
        #     data_list: list[SubData] = [data_list[i] for i in new_ordering]
        #     # we must reorder numel_per_batch for functionality
        #     self.numel_per_batch = [self.numel_per_batch[i] for i in new_ordering]

        #     for i in range(len(data_list)):
        #         new_nid = self._reorder_nids(i, data_list[i].n_id, new_ordering, data_list[i].boundaries)
        #         data_list[i] =  SubData(data_list[i].data,
        #                                 data_list[i].batch_size,
        #                                 new_nid,
        #                                 [data_list[i].boundaries[j] for j in new_ordering])
           
        #     self.numall_per_batch = [self.numall_per_batch[i] for i in new_ordering]
        #     self.max_indices = [self.max_indices[i] for i in new_ordering]
        #     self.min_indices = [self.min_indices[i] for i in new_ordering]
        #     self.cache_id_per_part = [self.cache_id_per_part[i] for i in new_ordering]
        #     print('[GriNNder] >>> Reorder the DataLoader... Done!')

        #     for i, _ in enumerate(data_list):
        #         if i == 0:
        #             continue
        #         data_list[i-1].boundaries[i-1] = torch.empty(0, dtype=torch.int64)
        #         data_list[i].boundaries[i] = torch.tensor(0, dtype=torch.int64)
        #         reuse_src_masks, reuse_dst_masks \
        #             = mask_gen_fn(i-1, data_list[i-1].n_id, data_list[i-1].boundaries,
        #                         i, data_list[i].n_id, data_list[i].boundaries,
        #                         self.num_threads)
        #         cache_src_masks, cache_my_masks \
        #             = cache_mask_gen_fn(self.cache_id_per_part,
        #                                 i, data_list[i].n_id, data_list[i].boundaries,
        #                                 reuse_dst_masks, self.num_threads)

        #         # mask size check
        #         mask_sum = 0
        #         for j in range(len(reuse_dst_masks)):
        #             mask_sum += len(reuse_dst_masks[j])
        #             mask_sum += len(cache_my_masks[j])
        #         assert mask_sum == len(data_list[i].n_id), 'we assume that we have enough cache size to handle a single layer' 

        #         data_list[i-1].boundaries[i-1] = None
        #         data_list[i].boundaries[i] = None
        #         data_list[i] = SubData(data_list[i].data, data_list[i].batch_size,
        #                                data_list[i].n_id, data_list[i].boundaries,
        #                                MasksPair(reuse_src_masks, reuse_dst_masks, None),
        #                                MasksPair(cache_src_masks, cache_my_masks, None),
        #                                MasksPair(None, None, None))
        #         data_list[i].reuse_masks.offloader = DiskOffloader(f'{self.storage_path}/{id(data_list[i].reuse_masks)}', backend='aio')
        #         data_list[i].cache_masks.offloader = DiskOffloader(f'{self.storage_path}/{id(data_list[i].cache_masks)}', backend='aio')
        #         data_list[i].storage_masks.offloader = DiskOffloader(f'{self.storage_path}/{id(data_list[i].storage_masks)}', backend='aio')
        #         data_list[i].masks_gpu_to_storage(async_=True)

        super().__init__(data_list, batch_size=1, # GriNNder do not need shuffle
                            collate_fn=lambda x: x[0], shuffle=shuffle, **kwargs)
        del self.data

        if log:
            logger.info(f'Done! [{time.perf_counter() - t:.2f}s]')

    def _boundary_check(self, batch_size: int, n_id: Tensor, boundaries: List):
        """ check the boundary """
        total_size = batch_size
        # check the number of none boundaries
        # len_list = []
        for boundary in boundaries:
            if boundary is None:
                continue
            # len_list.append(len(boundary))
            total_size += len(boundary)
        # percentage of none boundaries
        assert total_size == len(n_id), 'invalid boundary size'
        # return len_list

    def _cache_id_per_partition(self, nids: Tensor):
        """ return the cache id per partition """
        cache_id_per_partition = []
        max_index_list = [-1] + self.max_indices
        for i in range(len(self.max_indices)):
            cache_id_per_partition.append(nids[(max_index_list[i] < nids) & (nids <= max_index_list[i+1])])
        # for i, (min_offset, max_offset) in enumerate(zip(self.min_indices, self.max_indices)):
        #     cache_id_per_partition.append(nids[(min_offset <= nids) & (nids < max_offset)])
        total_cache_size = 0
        for cache_ids in cache_id_per_partition:
            total_cache_size += len(cache_ids)
        assert total_cache_size == len(nids), 'invalid cache size'
        return cache_id_per_partition
    
    def _reorder_nids(self, changed_id: int, nids: Tensor, new_order: List[int],
                      orig_boundaries: List[Tensor]):
        """ reorder nids """
        batch_size = self.numel_per_batch[changed_id]

        new_nids = torch.empty_like(nids)
        new_nids[:batch_size] = nids[:batch_size]
        
        orig_start_offsets = []
        orig_end_offsets = []

        cur_offset = batch_size
        for boundary in orig_boundaries:
            if boundary is None:
                orig_start_offsets.append(0)
                orig_end_offsets.append(batch_size)
                continue
            orig_start_offsets.append(cur_offset)
            cur_offset += len(boundary)
            orig_end_offsets.append(cur_offset)

        cur_offset = batch_size
        for i, new_i in enumerate(new_order):
            if i == changed_id:
                continue
            n_cur_el = len(orig_boundaries[new_i])
            new_nids[cur_offset:cur_offset+n_cur_el] \
                = nids[orig_start_offsets[new_i]:orig_end_offsets[new_i]]
            cur_offset += n_cur_el
        assert cur_offset == new_nids.numel(), 'invalid numel'
        return new_nids

    def _get_partition_ordering(self, data_list: List[SubData] = None):
        expansion_list = []
        for intra_numel, total_numel in zip(self.numel_per_batch, self.numall_per_batch):
            expansion_list.append(total_numel / intra_numel)
        expansion_list = np.array(expansion_list)
        min_index = np.argmin(expansion_list)
        new_ordering = [min_index]
        cur_part = min_index
        for _ in range(len(self.numel_per_batch)-1):
            boundary_portion = []
            for i, (_, _, _, boundaries, *_) in enumerate(data_list):
                if i in new_ordering:
                    boundary_portion.append(-1)
                    continue
                else:
                    boundary_portion.append(len(boundaries[cur_part]) / self.numall_per_batch[i])
            boundary_portion = np.array(boundary_portion)
            max_index = np.argmax(boundary_portion)
            new_ordering.append(max_index)
            cur_part = max_index
        return new_ordering

    def get_expansion_ratio(self) -> float:
        """ return the expansion ratio of the loader """
        avg_exp_ratio = 0.0
        for intra_numel, total_numel in zip(self.numel_per_batch, self.numall_per_batch):
            avg_exp_ratio += total_numel / intra_numel
        return avg_exp_ratio / len(self.numel_per_batch)

    def print_reusability(self) -> None:
        """ return the reusability of the loader """
        reusability_list = []
        for i, (_, _, _, boundaries, *_) in enumerate(self):
            boundary_length = []
            for boundary in boundaries:
                if boundary is None:
                    boundary_length.append(0)
                else:
                    boundary_length.append(len(boundary))
            reusability_list.append(np.max(np.array(boundary_length)/self.numall_per_batch[i]*100))
        logger.info(f'Avg. Reusability: {np.mean(reusability_list)}, Std.: {np.std(reusability_list)}')

    def compute_subgraph(self, batches: List[Tuple[int, Tensor]]) -> SubData:
        batch_ids, n_ids = zip(*batches)
        n_id = torch.cat(n_ids, dim=0)
        batch_id = batch_ids[0]

        batch_size = n_id.numel()

        # 1. Retrieve the original adjacency in CSR format:
        rowptr, col, value = self.data.adj_t.csr()

        # 2. Relabel nodes so subgraph is contiguous:
        rowptr, col, value, n_id = relabel_fn(rowptr, col, value, n_id, self.bipartite)

        # 3. Build boundary info:
        boundaries = [None] * len(self.max_indices)
        other_n_id = n_id[batch_size:]
        max_index_list = [-1] + self.max_indices
        for i in range(len(self.max_indices)):
            boundaries[i] = other_n_id[
                (max_index_list[i] < other_n_id) & (other_n_id <= max_index_list[i + 1])
            ] - self.min_indices[i]
        assert len(boundaries[batch_id]) == 0
        boundaries[batch_id] = None

        # 4. (Optional) Update visiting table if using caching:
        if self.use_cache:
            assert torch.all(other_n_id < self.data.num_nodes), 'invalid node id'
            self.visiting_table_score[other_n_id] += 1

        # 5. Create a subgraph adjacency using AdjacencyMatrixWithSimpleGDS:
        adj_t = AdjacencyMatrixWithSimpleGDS(
            rowptr=rowptr,
            col=col,
            value=value,
            sparse_sizes=(rowptr.numel() - 1, n_id.numel()),
            is_sorted=True,
            storage_path=f'{self.storage_path}',  # or whatever constructor arg you need
        )

        # print(f"Adj_t: {adj_t}")

        # 6. Offload to storage right away if desired:
        if self.storage_offload:
            pass # now we change the default as storage offload

        # print(f"Moved to storage Adj_t: {adj_t}")

        # 7. Create the sub-data object:
        data = self.data.__class__(adj_t=adj_t) 
        for k, v in self.data:
            # Copy only those features that are node-level and not 'x':
            if isinstance(v, Tensor) and v.size(0) == self.data.num_nodes and k != 'x':
                data[k] = v.index_select(0, n_id)

        # Track total (in-batch + boundary) subgraph size:
        self.numall_per_batch.append(len(n_id))

        # 8. Build the SubData struct:
        subdata = SubData(data, batch_size, n_id, boundaries)

        # 9. If optimizing data loader, also offload node features for in-batch nodes:
        if self.optimize_dataloader:
            subdata.data.true_x = GDSTensor(
                self.data.x.index_select(0, n_id[:batch_size]),
                f'{self.storage_path}'
            )
            # Immediately move them to storage:
            subdata.data.true_x.to_inplace('storage')
            subdata.data.true_x.synchronize()
        else:
            subdata.data.true_x = self.data.x.index_select(0, n_id[:batch_size])

        return subdata


    # def compute_subgraph_old(self, batches: List[Tuple[int, Tensor]]) -> SubData:
    #     batch_ids, n_ids = zip(*batches)
    #     n_id = torch.cat(n_ids, dim=0)
    #     batch_id = batch_ids[0]

    #     batch_size = n_id.numel()

    #     # get the original graph's csr
    #     rowptr, col, value = self.data.adj_t.csr()
    #     # logger.info(f'Before relabel')
    #     rowptr, col, value, n_id = relabel_fn(rowptr, col, value, n_id,
    #                                           self.bipartite)
    #     # logger.info(f'After relabel')

    #     boundaries = [None] * len(self.max_indices)
    #     # batch_size elements are in-batch vertices
    #     other_n_id = n_id[batch_size:]
    #     max_index_list = [-1] + self.max_indices
    #     for i in range(len(self.max_indices)):
    #         boundaries[i] = other_n_id[(max_index_list[i] < other_n_id) & (other_n_id <= max_index_list[i + 1])] - self.min_indices[i]
    #     assert len(boundaries[batch_id]) == 0
    #     boundaries[batch_id] = None

    #     if self.use_cache:
    #         # for caching table
    #         # our visiting-based caching counts the boundary cases
    #         assert torch.all(other_n_id < self.data.num_nodes), 'invalid node id'
    #         self.visiting_table_score[other_n_id] += 1

    #     # now we get the extracted bipartite graph
    #     adj_t = AdjacencyMatrixWithOffloader(
    #                 rowptr=rowptr, col=col, value=value,
    #                 sparse_sizes=(rowptr.numel() - 1, n_id.numel()),
    #                 is_sorted=True,
    #             )
    #     if self.storage_offload:
    #         adj_t.offloader = DiskOffloader(f'{self.storage_path}/{id(adj_t)}', backend='uring')
    #     else:
    #         adj_t.offloader = None

    #     data = self.data.__class__(adj_t=adj_t)
    #     for k, v in self.data:
    #         if isinstance(v, Tensor) and v.size(0) == self.data.num_nodes:
    #             if k != 'x':  # data.x is calculated after offloading
    #                 data[k] = v.index_select(0, n_id)

    #     self.numall_per_batch.append(len(n_id))
    #     subdata = SubData(data, batch_size, n_id, boundaries)
    #     if self.storage_offload:
    #         adj_t.gpu_to_storage(async_=True)
    #     if self.optimize_dataloader:
    #         subdata.data.true_x = GDSTensor(self.data.x.index_select(0, n_id[:batch_size]), f'{self.storage_path}')
    #         subdata.data.true_x.to_inplace('storage')
    #         subdata.data.true_x.synchronize()
    #     else:
    #         subdata.data.true_x = self.data.x.index_select(0, n_id[:batch_size])

    #     return subdata

    def get_numel(self) -> List:
        """ return numel size (V size)"""
        return self.numel_per_batch
    def get_numall(self) -> List:
        """ return numall size (U size)"""
        return self.numall_per_batch

    def __repr__(self):
        return f'{self.__class__.__name__}()'