import glob
from pathlib import Path
from typing import Dict, List

import torch
from torch import Tensor
from torch_geometric.data import Data, InMemoryDataset


class GinkgoDataset(InMemoryDataset):
    """Loads simulated jets from root folder.

    Enhances jets with attributes and stores them as Data objects in InMemoryDataset.
    """

    def __init__(self, root: Path) -> None:
        """Initialize `GinkgoDataset`.

        :param root: Path to directory in which simulated jets are stored.
        """
        self.root = root
        super().__init__(self.root, None, None)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self) -> List[str]:
        """Return the list of raw file names for preprocessing. We do not need preprocessing so we
        return an empty list.

        :return: An empty list as this dataset does not use raw file names.
        """
        return []

    @property
    def processed_file_names(self) -> List[str]:
        """Return the list of processed file names. If the files do not exist triggers the
        'process' method.

        :return: A list containing the processed file name "data.pt".
        """
        return ["data.pt"]

    def process(self) -> None:
        """Iterate over all raw graphs in root folder and construct data objects.

        Then collates graphs and stores them in file specified in processed_file_names.
        """
        data_list = []
        paths = sorted(glob.glob(f"{self.root}/*.pt"))
        for path in paths:
            raw_graph = torch.load(path)
            data = self.make_data_from_raw_graph(raw_graph)
            data_list.append(data)
        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

    @staticmethod
    def make_data_from_raw_graph(raw_graph: Dict[str, Tensor]) -> Data:
        """Convert a raw graph dictionary to a `Data` object.

        :param raw_graph: The raw graph data loaded from a .pt file.
        :return: A `Data` object containing the graph data and additional attributes.
        """
        x = raw_graph["X"]

        n_nodes, _ = x.shape
        n_leaves = (n_nodes + 1) // 2
        n_parents = n_nodes // 2
        leaf_indices = torch.arange(n_leaves)
        parent_indices = torch.arange(n_leaves, n_nodes)

        # A
        leaf_edge_index = torch.cartesian_prod(leaf_indices, parent_indices).T
        # B
        parent_edge_index = torch.triu_indices(n_parents, n_parents, 1) + n_leaves
        edge_index = torch.cat([leaf_edge_index, parent_edge_index], dim=1)

        _, n_edges = edge_index.shape
        edge_attr = torch.zeros(n_edges)

        edge_attr_target = torch.zeros(n_edges)
        y = raw_graph["A"]
        for i in range(n_nodes - 1):
            y_index = torch.where((edge_index[0] == i) & (edge_index[1] == y[i]))[0]
            edge_attr_target[y_index] = 1

        n_parents = [n_parents] * n_leaves + list(range(n_parents - 1, -1, -1))
        n_parents = torch.tensor(n_parents, dtype=torch.long)

        parent_mask = torch.zeros(n_nodes, dtype=torch.bool)
        parent_mask[n_leaves:] = True

        return Data(
            x=x,
            edge_index=edge_index,
            edge_attr=edge_attr,
            edge_attr_target=edge_attr_target,
            n_parents=n_parents,
            parent_mask=parent_mask,
            t=torch.zeros(n_nodes)
        )
