import logging
import typing as ty
from argparse import Namespace
from datetime import datetime
from pathlib import Path
from shutil import rmtree
from types import SimpleNamespace
from typing import Dict

import torch
from tqdm import tqdm

import workflows.config as wc
from apps.gnn_explainer import (
    explainer_main as explainer,
    configs as trainer_configs,
    train as trainer,
)
from egr.fsg import annotations

LOG = logging.getLogger(__name__)


def train(cfg: wc.WorkflowConfig, step_args, **kw):
    args = trainer_configs.arg_parse()
    params = cfg.train_input(**kw)
    args.__dict__.update(**params, **step_args)
    begin = datetime.now()
    trainer.main(args)
    end = datetime.now()
    return True, {'train': {'begin': begin, 'end': end}}


def explain(cfg: wc.WorkflowConfig, step_args: ty.Dict, **kw) -> bool:
    args = explainer.arg_parse()
    train_params: ty.Dict = cfg.train_input(**kw)
    ckpt: Dict = torch.load(train_params['ckpt_path'])
    pbar = tqdm(range(train_params['size']))

    begin = datetime.now()
    for node_id in pbar:
        params = cfg.explain_input(node_id, **kw, **step_args)
        args.__dict__.update(**params, **step_args)
        explainer.do_explain(args, ckpt)
        pbar.set_description('Finished node=%04d' % (node_id))
    end = datetime.now()
    return True, {'explain': {'begin': begin, 'end': end}}


def annotate(cfg: wc.WorkflowConfig, step_args: ty.Dict, **kw) -> bool:
    LOG.info('step_args:%s', step_args)
    args = Namespace(**cfg.annotate_input(**kw), **step_args)
    timings = annotations.main(args)
    rmtree(args.tmp_dir)
    return True, timings
