import os
import pathlib
import pickle as pkl

from networkx import to_numpy_array

import torch
import torch_geometric.utils
from torch_geometric.data import InMemoryDataset, download_url

from src.datasets.abstract_dataset import AbstractDataModule, AbstractDatasetInfos


class SpectreGraphDataset(InMemoryDataset):
    def __init__(
        self,
        dataset_name,
        split,
        root,
        transform=None,
        pre_transform=None,
        pre_filter=None,
    ):
        self.planar_file = "planar_64_200.pt"
        self.dataset_name = dataset_name
        self.split = split
        super().__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0], weights_only=False)
        self.num_graphs = len(self.data.n_nodes)

    @property
    def raw_file_names(self):
        return ["train.pt", "val.pt", "test.pt"]

    @property
    def processed_file_names(self):
        return [self.split + ".pt"]

    def download(self):
        if self.dataset_name == "planar":
            raw_url = "..."
        else:
            raise ValueError(f"Unknown dataset {self.dataset_name}")
        file_path = download_url(raw_url, self.raw_dir)

        if self.dataset_name in ["planar"]:
            with open(file_path, "rb") as f:
                dataset = pkl.load(f)
            train_data = dataset["train"]
            val_data = dataset["val"]
            test_data = dataset["test"]

            train_data = [
                torch.Tensor(to_numpy_array(graph)).fill_diagonal_(0)
                for graph in train_data
            ]
            val_data = [
                torch.Tensor(to_numpy_array(graph)).fill_diagonal_(0)
                for graph in val_data
            ]
            test_data = [
                torch.Tensor(to_numpy_array(graph)).fill_diagonal_(0)
                for graph in test_data
            ]
        else:
           raise NotImplementedError

        torch.save(train_data, self.raw_paths[0])
        torch.save(val_data, self.raw_paths[1])
        torch.save(test_data, self.raw_paths[2])

    def process(self):
        file_idx = {"train": 0, "val": 1, "test": 2}
        raw_dataset = torch.load(self.raw_paths[file_idx[self.split]])

        data_list = []
        for adj in raw_dataset:
            n = adj.shape[-1]
            X = torch.ones(n, 1, dtype=torch.float)
            y = torch.zeros([1, 0]).float()
            edge_index, _ = torch_geometric.utils.dense_to_sparse(adj)
            edge_attr = torch.zeros(edge_index.shape[-1], 2, dtype=torch.float)
            edge_attr[:, 1] = 1
            num_nodes = n * torch.ones(1, dtype=torch.long)
            data = torch_geometric.data.Data(
                x=X, edge_index=edge_index, edge_attr=edge_attr, y=y, n_nodes=num_nodes
            )

            if self.pre_filter is not None and not self.pre_filter(data):
                continue
            if self.pre_transform is not None:
                data = self.pre_transform(data)

            data_list.append(data)

        torch.save(self.collate(data_list), self.processed_paths[0])


class SpectreGraphDataModule(AbstractDataModule):
    def __init__(self, cfg, n_graphs=200):
        self.cfg = cfg
        self.datadir = cfg.dataset.datadir
        base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]
        root_path = os.path.join(base_path, self.datadir)

        datasets = {
            "train": SpectreGraphDataset(
                dataset_name=self.cfg.dataset.name, split="train", root=root_path
            ),
            "val": SpectreGraphDataset(
                dataset_name=self.cfg.dataset.name, split="val", root=root_path
            ),
            "test": SpectreGraphDataset(
                dataset_name=self.cfg.dataset.name, split="test", root=root_path
            ),
        }

        train_len = len(datasets["train"].data.n_nodes)
        val_len = len(datasets["val"].data.n_nodes)
        test_len = len(datasets["test"].data.n_nodes)
        print(f"Dataset sizes: train {train_len}, val {val_len}, test {test_len}")

        super().__init__(cfg, datasets)
        self.inner = self.train_dataset

    def __getitem__(self, item):
        return self.inner[item]


class SpectreDatasetInfos(AbstractDatasetInfos):
    def __init__(self, datamodule, dataset_config):
        self.datamodule = datamodule
        self.dataset_name = datamodule.inner.dataset_name
        self.n_nodes = self.datamodule.node_counts()
        self.node_types = self.datamodule.node_types()
        self.edge_types = self.datamodule.edge_counts()
        super().complete_infos(self.n_nodes, self.node_types)
