import torch
from torch import Tensor
import numpy as np

from torch_geometric.data import (
    Data,
    InMemoryDataset,
    download_url,
    extract_zip,
)
from torch_geometric.io import fs

from typing import List, Optional, Callable
import os.path as osp


import shutil
import os



def cat(seq: List[Optional[Tensor]]) -> Optional[Tensor]:
    values = [v for v in seq if v is not None]
    values = [v for v in values if v.numel() > 0]
    values = [v.unsqueeze(-1) if v.dim() == 1 else v for v in values]
    return torch.cat(values, dim=-1) if len(values) > 0 else None


def load_graph_dataset(
        name: str, root:str, mode:str='train', 
        use_decomp:str='all_graphs', dim:int=10, pad_multi_n:int=4, 
        scales:list=[0.25, 0.5, 1, 2, 5, 10],
        force_reload=True):
        if mode is None:
            dataset = TransTUDataset(
                root=root,
                name=name,

                use_decomp=use_decomp,
                dim=dim,
                pad_multi_n=pad_multi_n,
                scales=scales, 

                force_reload=force_reload, 
                mode=None, 
            )
        else:
            dataset = TransTUDataset(
                root=root,
                name=name,
                mode=mode,
                use_decomp=use_decomp,
                dim=dim,
                pad_multi_n=pad_multi_n,
                scales=scales, 
            )
        return dataset


class TransTUDataset(InMemoryDataset):

    url = 'https://www.chrsmrrs.com/graphkerneldatasets'
    cleaned_url = ('https://raw.githubusercontent.com/nd7141/'
                   'graph_datasets/master/datasets')

    def __init__(self, root: str, name: str,
                 
                 transform: Optional[Callable] = None,
                 pre_transform: Optional[Callable] = None,
                 pre_filter: Optional[Callable] = None,
                 use_node_attr: bool = True, 
                 use_edge_attr: bool = False,
                 cleaned: bool = False, 

                 use_decomp = 'all_graphs', 
                 dim: int = 10, 
                 pad_multi_n: int = 4, 
                 scales: list = [0.25, 0.5, 1, 2, 5, 10],
                 mode: str=None, 
                 force_reload: bool = False, 
                 align_feat: bool = False
                 ):
        
        '''
        mode: None, 'train' or 'test', default None
        '''
        self.name = name
        self.cleaned = cleaned
        self.use_decomp = use_decomp
        self.dim = dim
        self.pad_muti_n = pad_multi_n
        self.mode = mode
        self.scales = scales
        if mode is not None:
            self.src_dir = osp.join(osp.dirname(root), f'{name}_train')
        else:
            self.src_dir = None

        if mode == 'test':
            self.align_feat = align_feat
        else:
            self.align_feat = False

        super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload)

        out = torch.load(self.processed_paths[0], weights_only=True)
        if not isinstance(out, tuple) or len(out) != 3:
            raise RuntimeError(
                "The 'data' object was created by an older version of PyG. "
                "If this error occurred while loading an already existing "
                "dataset, remove the 'processed/' directory in the dataset's "
                "root folder and try again.")
        data, self.slices, self.sizes = out
        self.data = Data.from_dict(data) if isinstance(data, dict) else data

        if self._data.x is not None and not use_node_attr:
            num_node_attributes = self.num_node_attributes
            self._data.x = self._data.x[:, num_node_attributes:]
        if self._data.edge_attr is not None and not use_edge_attr:
            num_edge_attrs = self.num_edge_attributes
            self._data.edge_attr = self._data.edge_attr[:, num_edge_attrs:]


    @property
    def raw_dir(self) -> str:
        name = f'raw{"_cleaned" if self.cleaned else ""}'
        if self.mode is not None:
            return osp.join(self.root, name)
        else: 
            return osp.join(self.root, self.name, name)

    @property
    def processed_dir(self) -> str:
        name = f'processed{"_cleaned" if self.cleaned else ""}'
        if self.mode is not None:
            return osp.join(self.root, name)
        else: 
            return osp.join(self.root, self.name, name)


    @property
    def num_node_attributes(self) -> int:
        return self.sizes['num_node_attributes']

    @property
    def num_edge_labels(self) -> int:
        return self.sizes['num_edge_labels']

    @property
    def num_edge_attributes(self) -> int:
        return self.sizes['num_edge_attributes']

    @property
    def raw_file_names(self) -> List[str]:
        names = ['A', 'graph_indicator']
        return [f'{self.name}_{name}.txt' for name in names]

    @property
    def processed_file_names(self) -> str:
        return 'data.pt'

    def download(self):
        # For train or test mode, assume the dataset has already been downloaded and split
        if self.mode is not None:
            print(f"Using {self.mode} mode, skipping download...")
            return
            
        # For None mode, download the dataset normally
        url = self.cleaned_url if self.cleaned else self.url
        folder = osp.join(self.root, self.name)
        path = download_url(f'{url}/{self.name}.zip', folder)
        extract_zip(path, folder)
        os.unlink(path)
        shutil.rmtree(self.raw_dir)
        os.rename(osp.join(folder, self.name), self.raw_dir)

    def process(self):

        from .read_dataset import read_and_process
        
        self.data, self.slices, sizes = read_and_process(self.root, 
                                                     self.name, 
                                                     mode=self.mode,
                                                     use_decomp=self.use_decomp, 
                                                     scales=self.scales, 
                                                     dim=self.dim, 
                                                     pad_multi_n=self.pad_muti_n,
                                                     src_dir=self.src_dir,
                                                     align_feat=self.align_feat)

        if self.pre_filter is not None or self.pre_transform is not None:
            data_list = [self.get(idx) for idx in range(len(self))]

            if self.pre_filter is not None:
                data_list = [d for d in data_list if self.pre_filter(d)]

            if self.pre_transform is not None:
                data_list = [self.pre_transform(d) for d in data_list]

            self.data, self.slices = self.collate(data_list)
            self._data_list = None  # Reset cache.

        torch.save((self._data.to_dict(), self.slices, sizes),
                   self.processed_paths[0])

    def __repr__(self) -> str:
        return f'{self.name}({len(self)})'


    