import dgl
import torch
from graph_learning.data_setting import DataSettingConfig, DataTransform

@DataSettingConfig.register('print-data')
class PrintDataConfig(DataSettingConfig):
    def __init__(self, args, context):
        super().__init__(args, context)

    @property
    def builder(self):
        return PrintData

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)

class PrintData(DataTransform):
    def transform(self, graph):
        print(graph)
        return graph

@DataSettingConfig.register('inspect-data')
class InspectDataConfig(DataSettingConfig):
    def __init__(self, args, context):
        super().__init__(args, context)

    @property
    def builder(self):
        return InspectData

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)

class InspectData(DataTransform):
    def transform(self, graph):
        import ipdb; ipdb.set_trace()
        return graph

@DataSettingConfig.register('write-graph')
class PrintDataConfig(DataSettingConfig):
    def __init__(self, args, context):
        super().__init__(args, context)
        self.logger = context.global_.logger

    @property
    def builder(self):
        return WriteGraph

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)

class WriteGraph(DataTransform):
    def __init__(self, logger):
        self.logger = logger

    def transform(self, graph):
        import networkx as nx
        nx_graph = graph.to('cpu').to_networkx()
        vis_logger = self.logger['visual']
        nx.write_gexf(nx_graph, vis_logger.path(f'{graph.gdata["name"]}_raw.gexf'))
        return graph
