from __future__ import annotations
from argparse import ArgumentParser
from ctypes import Union
from functools import wraps

import json
import logging
from datetime import datetime
from pathlib import Path
from types import SimpleNamespace
from typing import Any, Dict, List
from egr.log import init_logging

import networkx as nx
import numpy as np
from torch import Tensor, LongTensor

LOG = logging.getLogger(__name__)


def read_json(path: Path) -> Dict:
    """Read JSON file from path"""
    LOG.debug('Loading path %s', path)
    with normalize_path(path).open() as f:
        return json.load(f)


def load_graph(path: Path) -> nx.Graph:
    """Load nx.Graph from json file

    Parameters
    ----------
    path : pathlib.Path
        Path to the JSON file

    Returns
    -------
    nx.Graph
        The graph object loaded from file

    """
    return nx.json_graph.node_link_graph(read_json(path))


def now_ts(fmt: str = '%Y%m%d-%H%M%S') -> str:
    """Make now string"""
    return datetime.now().strftime(fmt)


class IoEncoder(json.JSONEncoder):
    def default(self, o: Any) -> Any:
        if isinstance(o, np.integer):
            return int(o)
        elif isinstance(o, np.float32):
            return float(o)
        elif isinstance(o, np.ndarray):
            return o.tolist()
        elif isinstance(o, Tensor):
            return o.cpu().detach().tolist()
        elif isinstance(o, Path):
            return str(o)
        elif isinstance(o, datetime):
            return o.isoformat()
        return super().default(o)


def to_dict(G: nx.Graph) -> Dict:
    if isinstance(G, nx.Graph):
        return nx.json_graph.node_link_data(G)
    elif isinstance(G, dict):
        return G
    raise RuntimeError(f'Unsupported type {type(G)}')


def to_json(G: nx.Graph, **kwargs) -> str:
    """Convert graph to JSON

    Parameters
    ----------
    G : networkx.Graph
        Graph object
    **kwargs : dict
        Keyword arguments

    Returns
    -------
    str
        JSON string

    """
    return to_string(to_dict(G), **kwargs)


def to_string(data: Dict, **kwargs) -> str:
    return json.dumps(data, cls=IoEncoder, **kwargs)


def save(G: nx.Graph, path: Path, **kwargs):
    """Save graph to given path

    Parameters
    ----------
    G : nx.Graph
        Graph object
    path : pathlib.Path
        Path to save the graph to
    **kwargs: dict
        json.dump keyword args

    """
    data = to_dict(G)
    save_json(data, path)


def make_args(**kw) -> SimpleNamespace:
    return SimpleNamespace(**kw)


def normalize_path(path: Union[str, Path]) -> Path:
    path = (
        (path if isinstance(path, Path) else Path(path))
        .expanduser()
        .absolute()
    )
    return path


def save_json(data: Union[List, Dict], path: Path, **kwargs):
    try:
        path = normalize_path(path)
        path.parent.mkdir(parents=True, exist_ok=True)
        with path.open('w') as f:
            json.dump(data, f, cls=IoEncoder, **kwargs)
            return path
    except TypeError as e:
        LOG.error('%s\n%s', e, data)


def load_indices(path: Path) -> Dict[str, np.ndarray]:
    """Load indices as numpy arrays"""
    indices = read_json(path)
    return {k: np.array(v) for k, v in indices.items()}


def load_graph_features(data: Any) -> np.array:
    G: nx.Graph = None
    if isinstance(data, Path) or isinstance(data, str):
        return load_graph_features(load_graph(data))
    elif isinstance(data, nx.Graph):
        G = data
    else:
        raise RuntimeError(f'{type(data)} is not supported')
    return np.array([G.nodes[n]['feat'] for n in G.nodes()])


def app_config(f):
    @wraps(f)
    def wrapper(*args, **kwargs):
        parser = ArgumentParser()
        parser.add_argument(
            '--log-level',
            type=str,
            default='debug',
            choices=['debug', 'info', 'warning', 'error', 'critical'],
        )
        f(parser)
        cfg = parser.parse_args()
        init_logging(level_name=cfg.log_level)
        return cfg

    return wrapper


def save_features(fpath: Path, data: np.ndarray):
    if fpath.suffix == '.npy':
        np.save(fpath, data)
    elif fpath.suffix == '.txt':
        np.savetxt(fpath, data)
    elif fpath.suffix == '.csv':
        np.savetxt(fpath, data, delimiter=',')
    else:
        raise RuntimeError(f'Unsupported file type {fpath.suffix}')


def load_features(fpath: Path) -> np.ndarray:
    if fpath.suffix == '.npy':
        return np.load(fpath, allow_pickle=True)
    elif fpath.suffix == '.txt':
        return np.loadtxt(fpath)
    elif fpath.suffix == '.csv':
        return np.loadtxt(fpath, delimiter=',')
    else:
        raise RuntimeError(f'Unsupported file type {fpath.suffix}')


def load_labels(path: Path) -> Tensor:
    return Tensor([int(l) for l in path.read().split(',')]).type(LongTensor)


def save_labels(data: List, path: Path):
    path.open('w').write(','.join([str(l) for l in data]))
