import numpy as np
import torch_geometric.utils as pygutils
from torch_geometric.datasets import TUDataset
from torch_geometric.transforms import Compose, ToDense
from gcip.preparators.graph.preparator_graph import GraphPreparator
from gcip.utils.io import dict_to_cn

import gcip.utils.io as pb_io


class MyFilter:
    def __init__(self, max_nodes):
        self.max_nodes = max_nodes

    def __call__(self, data):
        return data.num_nodes <= self.max_nodes


class TUPreparator(GraphPreparator):

    def __init__(self,
                 name,
                 **kwargs):

        self.dataset = None

        super().__init__(name=name,
                         transductive=False,
                         task='graph',
                         **kwargs)


    @classmethod
    def params(cls, dataset):
        if isinstance(dataset, dict):
            dataset = dict_to_cn(dataset)

        my_dict = {
        }

        my_dict.update(GraphPreparator.params(dataset))

        return my_dict

    @classmethod
    def loader(cls, dataset):
        my_dict = TUPreparator.params(dataset)

        return cls(**my_dict)

    def get_transform_fn(self):
        def transform(data):
            data.y = data.y.unsqueeze(-1)
            return data

        return transform

    @property
    def use_node_attr(self):
        return True

    def _get_dataset_raw(self):

        if self.is_dense:
            transform = Compose([ToDense(self.max_nodes), self.get_transform_fn()])
            dataset = TUDataset(root=self.root,
                                name=self.name,
                                transform=transform,
                                pre_filter=MyFilter(self.max_nodes),
                                use_node_attr=True)
            del dataset._data.edge_attr
        else:
            dataset = TUDataset(root=self.root,
                                name=self.name,
                                transform=self.get_transform_fn(),
                                pre_transform=None,
                                pre_filter=None,
                                use_node_attr=self.use_node_attr,
                                use_edge_attr=True,
                                cleaned=False)
        return dataset

    def edge_attr_dim(self):
        return None

    def _metric_names(self):
        return ['accuracy', 'precision', 'recall']

    def _transform_dataset_pre_split(self, dataset_raw):
        return dataset_raw

    def convert_pyg_to_nx(self, data):
        # pb_io.print_debug(f'data: {data} {data.edge_index.max()} {data.edge_index.min()}')
        return pygutils.to_networkx(data,
                                    to_undirected=True,
                                    node_attrs=['x'])

    def get_scaler_info(self):
        if self.scale in ['default']:
            return {
                'x': [('identity', None)]
            }
        else:
            raise NotImplementedError
