import networkx as nx
import pandas as pd
import torch_geometric as pyg
import numpy as np
import torch

from gnn_xai_common.datasets import BaseGraphDataset
from gnn_xai_common.datasets.utils import default_ax, unpack_G


class IMDBDataset(BaseGraphDataset):
    NODE_CLS =    {
        0 : ''
    }

    GRAPH_CLS = {
        0: '',
        1: '',
        2: '',
    }

    def __init__(self, *,
                 name='IMDB',
                 url='https://www.chrsmrrs.com/graphkerneldatasets/IMDB-MULTI.zip',
                 **kwargs):
        self.url = url
        super().__init__(name=name, **kwargs)

    @property
    def raw_file_names(self):
        return ["IMDB-MULTI/IMDB-MULTI_A.txt",
                "IMDB-MULTI/IMDB-MULTI_graph_indicator.txt",
                "IMDB-MULTI/IMDB-MULTI_graph_labels.txt"]

    def download(self):
        pyg.data.download_url(self.url, self.raw_dir)
        pyg.data.extract_zip(f'{self.raw_dir}/IMDB-MULTI.zip', self.raw_dir)

    def generate(self):
        edges = pd.read_csv(self.raw_paths[0], header=None).to_numpy(dtype=int) - 1
        graph_idx = pd.read_csv(self.raw_paths[1], header=None)[0].to_numpy(dtype=int) - 1
        graph_labels = pd.read_csv(self.raw_paths[2], header=None)[0].to_numpy(dtype=int) - 1   
        
        super_G = nx.Graph(edges.tolist(), label=graph_labels)  
        nx.set_node_attributes(super_G, {i: 0 for i in range(len(graph_idx))} , name='label')     #Need to set all node labels to 0, do so using dictionary comprehension      
        nx.set_node_attributes(super_G, dict(enumerate(graph_idx)), name='graph')        
        return unpack_G(super_G)


    
    @default_ax
    def draw(self, G, pos=None, label=False, ax=None):
        pos = pos or nx.kamada_kawai_layout(G)
        nx.draw_networkx_nodes(G, pos,
                               ax=ax,
                               nodelist=G.nodes,
                               node_size=300,
                               edgecolors='black')
        
        nx.draw_networkx_edges(G.subgraph(G.nodes), pos, ax=ax, width=1,edge_color='tab:gray')

    def process(self):
        super().process()
