from __future__ import annotations

import logging
import yaml
import typing as ty
from argparse import Namespace
from tempfile import mkdtemp
from pathlib import Path
from types import SimpleNamespace
from typing import Dict, List, Optional

import egr

LOG = logging.getLogger('config')


class RunConfig:
    def __init__(self, flow_cfg: WorkflowConfig, **kw):
        self.__dict__.update(**flow_cfg.__dict__)
        self.__dict__.update(**kw)
        self.variant_cfg = self.master_cfg[self.variant]['details']
        self.total_size: int = self.variant_cfg['total_size']
        self.idx_fname: str = f'{self.fold:02d}.json'
        self.variant_dir: Path = self.run_root / self.variant
        self.input_tag: str = f'r{self.iteration}'
        self.iteration_dir: Path = self.variant_dir / self.input_tag
        self.dataset_dir: Path = self.input_root / self.variant
        self.index_path: Path = self.dataset_dir / 'indices' / self.idx_fname
        self.input_graph_path = self.dataset_dir / f'{self.sample_id}.json'
        self.input_label_path: Path = self.dataset_dir / 'labels.txt'
        self.feat_file_name: str = f'features-{self.total_size}.npy'
        self.input_dir: Path = self._io_dir(self.input_tag)
        self.output_path: Path = self._io_dir(self.output_tag)
        self.output_feature_path: Path = self.output_path / self.feat_file_name
        self.ckpt_path: Path = self.input_dir / f'{self.sample_id}.pt'
        self.train_dir: Path = self.input_dir / 'train'
        self.explain_dir = self.input_dir / 'explain'
        labels_file_name: str = f'predicted_labels-{self.sample_id}.txt'
        self.predicted_label_path: Path = self.input_dir / labels_file_name
        self.fsg_dir = self.input_dir / 'fsg'

    @property
    def output_tag(self) -> str:
        if isinstance(self.iteration, int):
            return f'r{self.iteration + 1}'
        return f'r{self.iteration}_'

    @property
    def train_input(self) -> ty.Dict:
        return dict(
            index_file=self.index_path,
            input_graph_file=self.input_graph_path,
            input_label_file=self.input_label_path,
            input_feature_file=self.input_feature_path,
            ckpt_path=self.ckpt_path,
            logdir=self.train_dir,
            predicted_label_file=self.predicted_label_path,
            size=self.total_size,
            **self.global_attr,
        )

    def explain_input(self, node_id: int, **kw) -> ty.Dict:
        return dict(
            ckpt_file=self.ckpt_path,
            logdir=self.explain_dir,
            explain_node=node_id,
            graph_idx=-1,
            graph_mode=False,
            multigraph_class=-1,
            output_type='json',
            sample_id=self.sample_id,
            nodes=list(range(self.total_size)),
            **kw,
            **self.global_attr,
        )

    @property
    def annotate_input(self) -> Dict:
        params = dict(
            index_file=self.index_path,
            data_root=self.explain_dir,
            output_feature_file=self.output_feature_path,
            input_graph_file=self.input_graph_path,
            input_label_file=self.input_label_path,
            run_id=self.run_id,
            output_dir=self.output_path,
            fsg_dir=self.fsg_dir,
            prev_fsg_dir=self.prev_fsg_dir,
            freq=self.gaston_freq_threshold,
            predicted_label_file=self.predicted_label_path,
            data_dim=self.data_dim,
            tmp_dir=Path(mkdtemp()),
            timeout_secs=None,
            **self.global_attr,
        )
        return params

    @property
    def input_feature_path(self) -> Path:
        if self.iteration == 0:
            return self.input_root / 'features/default' / self.feat_file_name
        elif self.iteration == 'L':
            return self.dataset_dir / 'label_features.npy'
        elif self.iteration == 'R':
            random_fname = f'random_{self.feat_file_name}'
            return self.input_root / 'features' / random_fname
        elif isinstance(self.iteration, str) and self.iteration[0] == 'P':
            return self.dataset_dir / self.iteration / f'{self.sample_id}.npy'
        return self.input_dir / self.feat_file_name

    @property
    def global_attr(self) -> Dict[str, str]:
        return dict(input_tag=self.input_tag, gpu=self.gpu, nproc=self.nproc)

    @property
    def prev_fsg_dir(self) -> List[Path]:
        if isinstance(self.iteration, int) and self.iteration > 0:
            return self._io_dir(f'r{self.iteration - 1}') / 'fsg'

    def _io_dir(self, tag: str) -> Path:
        return self.variant_dir / tag / self.sample_id / f'{self.fold:02d}'


class WorkflowConfig:
    def __init__(self, args: Namespace):
        configs: Dict = yaml.safe_load(args.config.open())
        defaults: Dict = yaml.safe_load(args.run_defaults.open())
        self.master_cfg: Dict = yaml.safe_load(args.pattern_master.open())

        d: ty.Dict = {**defaults}
        d.update(**configs)
        for idx, step in enumerate(configs['steps']):
            key = step['type']
            for k, v in defaults['steps'][key].items():
                if k not in step:
                    d['steps'][idx].update({k: v})
        self.__dict__.update(**d)

        # self.root_path = Path(self.root_path).expanduser().absolute()
        self.output_root: Path = Path(self.output_root).expanduser().absolute()
        self.run_root: Path = self.output_root / self.run_id
        data_root: Path = Path(self.input_data_root).expanduser().absolute()
        self.input_root: Path = data_root / 'input_data'
        LOG.info('Input:%s, Output:%s', self.input_root, self.run_root)
        if not hasattr(self, 'timeout_secs'):
            self.timeout_secs = None

    @property
    def folds(self) -> ty.List[int]:
        begin: int = self.fold.get('begin', 1)
        end: int = self.fold.get('end', self.fold.get('max', 10))
        return list(range(begin, end + 1))

    def updated_params(
        self, variant: str, params: ty.Optional[Dict] = None
    ) -> Dict:
        params = params or {}
        data = dict(**params)
        if 'node_range' not in params:
            end = self._master_details(variant).total_size
            data.update(node_range=dict(begin=0, end=end))
        data['node_range'] = SimpleNamespace(**data['node_range'])
        return SimpleNamespace(**data)

    def set_iteration(self, iteration: int | str):
        self.iteration = iteration
        self.input_tag: str = f'r{self.iteration}'

    def _index_file_name(self, variant: str) -> str:
        if not hasattr(self, 'index_file_name'):
            m = self._master_details(variant)
            return f'indices-{m.total_size}.json'
        return self.index_file_name

    def _index_path(self, variant: str) -> Path:
        return self.root_path / self._index_file_name(variant)

    def train_input(self, **kw) -> Dict:
        cfg = RunConfig(self, **kw)
        return cfg.train_input

    def predicted_labels(self, variant: str, sample_id: str) -> Path:
        return (
            self.run_root
            / variant
            / self.input_tag
            / f'predicted_labels-{sample_id}.txt'
        )

    def explain_input(self, node_id: int, **kw: Dict):
        cfg = RunConfig(self, **kw)
        return cfg.explain_input(node_id)

    def annotate_input(self, **kw: Dict) -> Dict:
        cfg = RunConfig(self, **kw)
        return cfg.annotate_input

    def _features_file_name(self, variant: str) -> str:
        if not hasattr(self, 'features_file'):
            return f'features-{self._master_details(variant).total_size}.npy'
        return self.features_file

    def input_feature_file(self, variant: str, sample_id: str) -> Path:
        input_dir = self.input_root / variant
        feat_filename = self._features_file_name(variant)
        if self.iteration == 0:
            return self.input_root / 'features/default' / feat_filename
        elif self.iteration == 'L':
            return self.input_root / variant / 'label_features.npy'
        elif self.iteration == 'R':
            return self.input_root / 'features' / f'random_{feat_filename}'
        elif isinstance(self.iteration, str) and self.iteration[0] == 'P':
            return input_dir / self.iteration / f'{sample_id}.npy'
        default_dir = self.run_root / variant / self.input_tag / sample_id
        return default_dir / feat_filename

    def output_feature_file(self, variant: str, sample_id: str) -> Path:
        return self.h_path(self.run_root, variant, sample_id)

    def h_path(self, root: Path, variant: str, sample_id: str) -> Path:
        h_dir = root / variant / self.output_tag / sample_id
        return h_dir / self._features_file_name(variant)

    def histogram_input(self, variant: str, sample_id: str) -> Dict:
        return dict(
            run_id=self.run_id,
            explainer_outputs=self.explain_dir(variant, sample_id),
            hist_output_dir=self.histogram_dir(variant, sample_id),
            variant=variant,
            sample_id=sample_id,
            ba_count=self.ba_count,
            input_graph_path=self.input_graph_path(variant, sample_id),
            input_label_path=self.input_label_path(variant),
            input_feature_path=self.input_feature_file(variant, sample_id),
            random_count=self.ba_count,
        )

    def explain_subgraph_dir(self, variant: str, sample_id: str) -> Path:
        return self.explain_dir(variant, sample_id) / 'subgraph'

    def explain_feature_dir(self, variant: str, sample_id: str) -> Path:
        return self.explain_dir(variant, sample_id) / 'feature'

    def explain_dir(self, variant: str, sample_id: str) -> Path:
        return self.run_root / variant / self.input_tag / sample_id / 'explain'

    def fsg_dir(self, variant: str, sample_id: str) -> Path:
        return self.run_root / variant / self.input_tag / sample_id / 'fsg'

    def prev_fsg_dir(self, variant: str, sample_id: str) -> List[Path]:
        if isinstance(self.iteration, int) and self.iteration > 0:
            return (
                self.run_root
                / variant
                / f'r{self.iteration - 1}'
                / sample_id
                / 'fsg'
            )

    def histogram_dir(self, variant: str, sample_id: str) -> Path:
        return self.run_root / variant / self.input_tag / 'hist'

    def input_path(self, variant: str, sample_id: str) -> Path:
        return self.input_dir(variant) / f'{self.stem(sample_id)}.json'

    def input_graph_path(self, variant: str, sample_id: str) -> Path:
        return self.input_root / variant / f'{sample_id}.json'

    def input_label_path(self, variant: str) -> Path:
        return self.input_root / variant / 'labels.txt'

    def input_dir(self, variant: str) -> Path:
        return self.input_root / variant / 'input'

    def ckpt_path(self, variant: str, sample_id: str) -> Path:
        return self.run_root / variant / self.input_tag / f'{sample_id}.pt'

    def train_dir(self, variant: str, sample_id: str) -> Path:
        return self.run_root / variant / self.input_tag / sample_id / 'train'

    def stem(self, sample_id: str) -> str:
        return f'{self.data_prefix}-{sample_id}'

    def _master_details(self, variant: str) -> SimpleNamespace:
        return SimpleNamespace(**self.master_cfg[variant]['details'])

    def __str__(self) -> Dict:
        return egr.util.to_string(self.__dict__)
