import os
import pathlib
import pickle

import networkx as nx
import torch
import torch_geometric.utils
from torch_geometric.data import InMemoryDataset, Data

from datasets.abstract_dataset import AbstractDataModule, AbstractDatasetInfos


class MyTreeGraphDataset(InMemoryDataset):
    def __init__(
        self,
        root,
        split='test',                                       
        transform=None,
        pre_transform=None,
        pre_filter=None,
    ):
        self.split = split
        super().__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return [f"{self.split}.pt"]

    @property
    def processed_file_names(self):
        return [f"data_{self.split}.pt"]

    def download(self):
        pass

    def process(self):
        raw_path = os.path.join(self.raw_dir, f"{self.split}.pt")
        if not os.path.exists(raw_path):
            raise FileNotFoundError(f"[anonymized] {raw_path}，[anonymized]")
            
        loaded_data = torch.load(raw_path)
        
        if isinstance(loaded_data, tuple) and len(loaded_data) == 2:
            data, slices = loaded_data
            
            if slices is None and hasattr(data, 'to_data_list'):
                print(f"Converting Batch data with {data.num_graphs} graphs...")
                data_list = data.to_data_list()
                
                data, slices = self.collate(data_list)
                print(f"Conversion complete. Number of graphs: {len(slices['x']) - 1}")
            elif slices is not None:
                print(f"Data already in correct format with {len(slices['x']) - 1} graphs")
        else:
            if hasattr(loaded_data, 'to_data_list'):
                print(f"Converting single Batch data with {loaded_data.num_graphs} graphs...")
                data_list = loaded_data.to_data_list()
                data, slices = self.collate(data_list)
                print(f"Conversion complete. Number of graphs: {len(slices['x']) - 1}")
            else:
                print("Unknown data format, using as-is")
                data = loaded_data
                slices = None

        torch.save((data, slices), self.processed_paths[0])


class MyTreeGraphDataModule(AbstractDataModule):
    def __init__(self, cfg):
        self.cfg = cfg
        self.datadir = cfg.dataset.datadir
        base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]
        root_path = os.path.join(base_path, self.datadir)

        train_dataset = MyTreeGraphDataset(root=root_path, split='train')
        val_dataset = MyTreeGraphDataset(root=root_path, split='val')
        test_dataset = MyTreeGraphDataset(root=root_path, split='test')
        
        datasets = {
            "train": train_dataset,
            "val": val_dataset,
            "test": test_dataset
        }
        
        super().__init__(cfg, datasets)
        self.inner = self.train_dataset                


class MyTreeDatasetInfos(AbstractDatasetInfos):
    def __init__(self, datamodule, dataset_config):
        self.datamodule = datamodule
        self.dataset_name = "my_tree"                 
        self.n_nodes = self.datamodule.node_counts()
        self.node_types = self.datamodule.node_types()
        self.edge_types = self.datamodule.edge_counts()
        super().complete_infos(self.n_nodes, self.node_types)


def convert_nx_to_pyg_data(graph: nx.Graph) -> Data:
    adj = torch.Tensor(nx.to_numpy_array(graph))

    n = adj.shape[-1]

    x = torch.ones(n, 1, dtype=torch.float)

    y = torch.zeros([1, 0]).float()

    edge_index, _ = torch_geometric.utils.dense_to_sparse(adj)

    edge_attr = torch.zeros(edge_index.shape[-1], 2, dtype=torch.float)
    edge_attr[:, 1] = 1

    num_nodes = torch.tensor(n, dtype=torch.long).view(1)

    data = Data(
        x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, n_nodes=num_nodes
    )
    return data 