import os
import pathlib
import torch
import torch_geometric.utils
from torch_geometric.data import InMemoryDataset, Data
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
import networkx as nx

from datasets.abstract_dataset import AbstractDataModule, AbstractDatasetInfos


class MyPlanarGraphDataset(InMemoryDataset):
    def __init__(
        self,
        root,
        split='test',
        transform=None,
        pre_transform=None,
        pre_filter=None,
    ):
        self.split = split
        super().__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return [f"{self.split}.pt"]

    @property
    def processed_file_names(self):
        return [f"data_{self.split}.pt"]

    def download(self):
        pass

    def process(self):
        raw_path = os.path.join(self.raw_dir, f"{self.split}.pt")
        if not os.path.exists(raw_path):
            raise FileNotFoundError(f"[anonymized] {raw_path}，[anonymized]")
            
        loaded_data = torch.load(raw_path)
        
        if isinstance(loaded_data, list):
            data_list = []
            for adj in loaded_data:
                n = adj.shape[-1]
                X = torch.ones(n, 1, dtype=torch.float)
                y = torch.zeros([1, 0]).float()
                edge_index, _ = torch_geometric.utils.dense_to_sparse(adj)
                edge_attr = torch.zeros(edge_index.shape[-1], 2, dtype=torch.float)
                edge_attr[:, 1] = 1
                num_nodes = n * torch.ones(1, dtype=torch.long)
                
                data_obj = Data(
                    x=X, edge_index=edge_index, edge_attr=edge_attr, y=y, n_nodes=num_nodes
                )
                
                if self.pre_filter is not None and not self.pre_filter(data_obj):
                    continue
                if self.pre_transform is not None:
                    data_obj = self.pre_transform(data_obj)
                
                data_list.append(data_obj)
        elif isinstance(loaded_data, tuple) and len(loaded_data) == 2:
            data, slices = loaded_data
            if slices is not None:
                data_list = []
                num_graphs = len(slices['x']) - 1
                for i in range(num_graphs):
                    start_idx = slices['x'][i]
                    end_idx = slices['x'][i + 1]
                    
                    x = data.x[start_idx:end_idx]
                    n = x.shape[0]
                    
                    edge_start = slices['edge_index'][i]
                    edge_end = slices['edge_index'][i + 1]
                    edges = data.edge_index[:, edge_start:edge_end] - start_idx
                    
                    adj = torch.zeros(n, n)
                    adj[edges[0], edges[1]] = 1
                    
                    X = torch.ones(n, 1, dtype=torch.float)
                    y = torch.zeros([1, 0]).float()
                    edge_index, _ = torch_geometric.utils.dense_to_sparse(adj)
                    edge_attr = torch.zeros(edge_index.shape[-1], 2, dtype=torch.float)
                    edge_attr[:, 1] = 1
                    num_nodes = n * torch.ones(1, dtype=torch.long)
                    
                    data_obj = Data(
                        x=X, edge_index=edge_index, edge_attr=edge_attr, y=y, n_nodes=num_nodes
                    )
                    
                    if self.pre_filter is not None and not self.pre_filter(data_obj):
                        continue
                    if self.pre_transform is not None:
                        data_obj = self.pre_transform(data_obj)
                    
                    data_list.append(data_obj)
            else:
                if hasattr(data, 'to_data_list'):
                    original_data_list = data.to_data_list()
                    data_list = []
                    
                    for data_obj in original_data_list:
                        n = data_obj.x.shape[0]
                        
                        adj = torch.zeros(n, n)
                        adj[data_obj.edge_index[0], data_obj.edge_index[1]] = 1
                        
                        X = torch.ones(n, 1, dtype=torch.float)
                        y = torch.zeros([1, 0]).float()
                        edge_index, _ = torch_geometric.utils.dense_to_sparse(adj)
                        edge_attr = torch.zeros(edge_index.shape[-1], 2, dtype=torch.float)
                        edge_attr[:, 1] = 1
                        num_nodes = n * torch.ones(1, dtype=torch.long)
                        
                        new_data_obj = Data(
                            x=X, edge_index=edge_index, edge_attr=edge_attr, y=y, n_nodes=num_nodes
                        )
                        
                        if self.pre_filter is not None and not self.pre_filter(new_data_obj):
                            continue
                        if self.pre_transform is not None:
                            new_data_obj = self.pre_transform(new_data_obj)
                        
                        data_list.append(new_data_obj)
                else:
                    raise ValueError(f"Unknown data format: {type(data)}")
        else:
            if hasattr(loaded_data, 'to_data_list'):
                original_data_list = loaded_data.to_data_list()
                data_list = []
                
                for data_obj in original_data_list:
                    n = data_obj.x.shape[0]
                    
                    adj = torch.zeros(n, n)
                    adj[data_obj.edge_index[0], data_obj.edge_index[1]] = 1
                    
                    X = torch.ones(n, 1, dtype=torch.float)
                    y = torch.zeros([1, 0]).float()
                    edge_index, _ = torch_geometric.utils.dense_to_sparse(adj)
                    edge_attr = torch.zeros(edge_index.shape[-1], 2, dtype=torch.float)
                    edge_attr[:, 1] = 1
                    num_nodes = n * torch.ones(1, dtype=torch.long)
                    
                    new_data_obj = Data(
                        x=X, edge_index=edge_index, edge_attr=edge_attr, y=y, n_nodes=num_nodes
                    )
                    
                    if self.pre_filter is not None and not self.pre_filter(new_data_obj):
                        continue
                    if self.pre_transform is not None:
                        new_data_obj = self.pre_transform(new_data_obj)
                    
                    data_list.append(new_data_obj)
            else:
                raise ValueError(f"Unknown data format: {type(loaded_data)}")

        torch.save(self.collate(data_list), self.processed_paths[0])


class MyPlanarGraphDataModule(AbstractDataModule):
    def __init__(self, cfg):
        self.cfg = cfg
        self.datadir = cfg.dataset.datadir
        base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]
        root_path = os.path.join(base_path, self.datadir)

        train_dataset = MyPlanarGraphDataset(root=root_path, split='train')
        val_dataset = MyPlanarGraphDataset(root=root_path, split='val')
        test_dataset = MyPlanarGraphDataset(root=root_path, split='test')
        
        datasets = {
            "train": train_dataset,
            "val": val_dataset,
            "test": test_dataset
        }
        
        train_len = len(datasets["train"].data.n_nodes)
        val_len = len(datasets["val"].data.n_nodes)
        test_len = len(datasets["test"].data.n_nodes)
        print(f"Dataset sizes: train {train_len}, val {val_len}, test {test_len}")
        
        super().__init__(cfg, datasets)
        self.inner = self.train_dataset                

    def __getitem__(self, item):
        return self.inner[item]


class MyPlanarDatasetInfos(AbstractDatasetInfos):
    def __init__(self, datamodule, dataset_config):
        self.datamodule = datamodule
        self.dataset_name = "my_planar"                
        self.n_nodes = self.datamodule.node_counts()
        self.node_types = self.datamodule.node_types()
        self.edge_types = self.datamodule.edge_counts()
        super().complete_infos(self.n_nodes, self.node_types)


def convert_nx_to_pyg_data(graph: nx.Graph) -> Data:
    adj = torch.Tensor(nx.to_numpy_array(graph))

    n = adj.shape[-1]

    X = torch.ones(n, 1, dtype=torch.float)
    y = torch.zeros([1, 0]).float()
    edge_index, _ = torch_geometric.utils.dense_to_sparse(adj)
    edge_attr = torch.zeros(edge_index.shape[-1], 2, dtype=torch.float)
    edge_attr[:, 1] = 1
    num_nodes = n * torch.ones(1, dtype=torch.long)

    data = Data(
        x=X, edge_index=edge_index, edge_attr=edge_attr, y=y, n_nodes=num_nodes
    )
    return data