from __future__ import annotations

import json
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional

import networkx as nx
import torch
from torch import Tensor

from egr.util import load_features, load_graph, load_labels

LOG = logging.getLogger(__name__)


class EgrData:
    def __str__(self) -> str:
        return f'|N|={self.N},|E|={self.num_edges}'

    @property
    def num_nodes(self) -> int:
        return self.V.numel()

    @property
    def N(self) -> int:
        return self.num_nodes

    @property
    def num_edges(self) -> int:
        return self.E.size(0)

    @property
    def directed(self) -> bool:
        return self.G.is_directed()


@dataclass
class EgrDenseData(EgrData):
    y: Tensor
    G: nx.Graph

    @classmethod
    def read(cls, path: Path) -> EgrDenseData:
        G: nx.Graph = load_graph(path)
        LOG.debug('Loaded %s from %s path', G, path)
        y = Tensor(G.graph['labels']).type(torch.LongTensor)
        for n in G.nodes:
            G.nodes[n]['feat']: List = Tensor(G.nodes[n]['feat'])
        return cls(y, G=G)

    @classmethod
    def read_new(
        cls,
        graph_path: Path,
        label_path: Path,
        feature_path: Optional[Path] = None,
    ) -> EgrDenseData:
        G: nx.Graph = cls.load_graph(graph_path)
        y = load_labels(label_path.open())
        assert G.number_of_nodes() == y.shape[0]
        if feature_path is not None:
            h = load_features(feature_path)
            nx.set_node_attributes(G, {i: {'feat': h[i]} for i in G.nodes()})
        return cls(y, G=G)

    @staticmethod
    def load_graph(path: Path) -> nx.Graph:
        G = nx.Graph()
        data = json.load(path.open())
        u, v = data['u'], data['v']
        assert len(u) == len(v)
        G.add_nodes_from([n for n in range(data['n'])])
        G.add_edges_from([(u[i], v[i]) for i in range(len(u))])
        return G

    @staticmethod
    def to_compact(G: nx.Graph) -> Dict:
        u, v = list(zip(*[e for e in G.edges()]))
        return dict(n=G.number_of_nodes(), u=u, v=v)

    @classmethod
    def save(cls, G: nx.Graph, path: Path):
        data = cls.to_compact(G)
        json.dump(data, fp=path.open('w'))
