import collections
import dataclasses
import json
import logging
import types
import typing as ty
from pathlib import Path
from typing import Dict, List
from IPython.display import display, HTML, Markdown
from matplotlib import pyplot as plt

import numpy as np
import pandas as pd

import egr.util as eu
import egr.pandas_helpers as ep
import nb.reports.vis_helpers as vh

LOG = logging.getLogger(__name__)

TAG_DESC = {
    'rR': 'Random',
    'rL': 'Label Encoded',
    'rP1': 'Arbitrary',
    'rP2': 'Adversarial',
    'rP4': 'Adversarial',
    'r0': '(Vanilla) Round-0',
    'r1': 'Round-1',
    'r2': 'Round-2',
    'r3': 'Round-3',
    'r4': 'Round-4',
    'r5': 'Round-5',
    'r6': 'Round-6',
    'r7': 'Round-7',
    'r8': 'Round-8',
}

PATTERN_MASTER_FILE = (
    Path('../../run_configs/pattern_master.yml').expanduser().absolute()
)


def round_metrics(round_dir: Path, samples: List, fold: int) -> Dict:
    suffix = 'train/test_metrics.json'
    data = {}
    for sid in samples:
        metrics_file = round_dir / sid / f'{fold:02d}' / suffix
        if metrics_file.exists():
            d = json.load(metrics_file.open())
            for k, v in d.items():
                if k not in data:
                    data.update({k: []})
                data[k].append(d[k])
        else:
            LOG.warning('file %s does not exist', metrics_file)
    return data


def make_metrics_table(
    dirpath: Path, samples: List, variants: List, rounds: List
):
    all_data = {v: pd.DataFrame() for v in variants}
    for var in variants:
        data = {'Accuracy': [], 'Precision': [], 'Recall': [], 'F1-Score': []}
        for rnd in rounds:
            rnd_data = round_metrics(dirpath / var / rnd, samples)

            data['Precision'].append(np.mean(rnd_data['prec']))
            data['Recall'].append(np.mean(rnd_data['recall']))
            data['F1-Score'].append(np.mean(rnd_data['f1_score']))
        all_data[var] = (
            pd.DataFrame(data=data, index=[TAG_DESC[rnd] for rnd in rounds])
            * 100
        )
    return all_data


def make_table(
    dirpath: Path, samples: List, variant: str, rounds: List, fold: int
):
    data = {'Precision': [], 'Recall': [], 'F1-Score': []}
    rounds_to_plot = []
    for rnd in rounds:
        rnd_data = round_metrics(dirpath / variant / rnd, samples, fold)

        if rnd_data == {}:
            LOG.warning(
                'No data found for variant:%s, samples:%s, round:%s',
                variant,
                ','.join(samples),
                rnd,
            )
            continue

        rounds_to_plot.append(rnd)
        data['Precision'].append(np.mean(rnd_data['prec']))
        data['Recall'].append(np.mean(rnd_data['recall']))
        data['F1-Score'].append(np.mean(rnd_data['f1_score']))
    df = pd.DataFrame(data=data, index=[TAG_DESC[r] for r in rounds_to_plot])
    return df * 100


def to_latex(dirpath: str, variant: str, samples: List, rounds: List):
    df = make_table(dirpath, samples, variant, rounds)
    name = variant.replace('_', ' ').capitalize()
    print(df.style.format('{:.2f}').to_latex(caption=name, hrules=True))


def to_pretty(
    dirpath: str, variant: str, samples: List, rounds: List, fold: int
):
    df = make_table(dirpath, samples, variant, rounds, fold)
    display(HTML(df.to_html()))


def show_table(
    dirpath: str, variant: str, samples: List, rounds: List, fold: int
):
    display(Markdown(f'### Fold-{fold:2d}'))
    to_pretty(dirpath, variant, samples, rounds, fold)


def show_latex(dirpath: str, variant: str, samples: List, rounds: List):
    to_latex(dirpath, variant, samples, rounds)


def fold_tag(fold: int) -> str:
    return f'{fold:02d}'


def get_index_file(variant: str, fold: int, root_path: Path) -> Path:
    size = get_variant_spec(variant).details.get('total_size', 0)
    assert size > 0
    path = root_path / f'indices/{size}/{fold:02d}.json'
    assert path.exists(), f'{path} not found'
    return path


def get_label_file(variant: str, root_path: Path) -> Path:
    return root_path / 'input_data' / variant / 'labels.txt'


@dataclasses.dataclass
class ReportArgs:
    dirpath: Path
    variant: str
    tags: ty.List[str]
    samples: ty.List[str]
    fold: int
    eegl_only: bool = False
    save: bool = False
    vis_sample: str = '0001'

    @property
    def metrics_paths(self) -> ty.List[Path]:
        for tag in self.tags:
            yield self.metrics_path(tag)

    @property
    def vis_base_dir(self) -> Path:
        return self.dirpath / self.variant / self.vis_sample / self.fold_str

    @property
    def fold_str(self) -> str:
        return f'{self.fold:02d}'

    @property
    def vis_root(self) -> Path:
        return self.dirpath / 'vis' / self.variant / self.vis_sample

    def vis_metrics_path(self, tag: str) -> Path:
        return self.vis_fold_dir(tag) / 'train/test_metrics.json'

    def vis_fold_dir(self, tag: str) -> Path:
        return (
            self.dirpath / self.variant / tag / self.vis_sample / self.fold_str
        )


def plot_conf_mats(
    args: ReportArgs,
    fontsize: int = 9,
    figsize: ty.Tuple[int, int] = (10, 4),
    cmap: str = 'viridis_r',
) -> types.SimpleNamespace:
    def plot_cm(df, ax, name, cmap=cmap, **kw):
        vh.show_cm(
            ax=ax, df=df, cmap=cmap, title=name, fontsize=fontsize, **kw
        )

    ncols: int = 3
    ntags: int = len(args.tags)
    nrows: int = (
        1
        if args.eegl_only
        else ((ntags // ncols) + (1 if (ntags % ncols) > 0 else 0))
    )
    _, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize)
    tags = ['r0', 'r1', 'r2'] if args.eegl_only else args.tags
    for i, tag in enumerate(tags):
        path = args.vis_metrics_path(tag)
        if not path.exists():
            LOG.warning('File %s does not exist, skipping', path)
            continue
        data = eu.read_json(path)
        df = pd.DataFrame(data['conf_mat'])
        r, c = i // ncols, i % ncols
        if args.eegl_only:
            plot_cm(df, axs[c], TAG_DESC[tag])
        else:
            plot_cm(df, axs[r][c], TAG_DESC[tag])
        plt.tight_layout()
    if args.save:
        args.vis_root.mkdir(parents=True, exist_ok=True)
        prefix: str = f'cm_fold-{args.fold:02d}'
        name = f'{prefix}_eegl_only.pdf' if args.eegl_only else f'{prefix}.pdf'
        fig_path: Path = args.vis_root / name
        plt.savefig(fig_path)
    plt.show()


def plot_all_cm(
    run_dir: Path,
    variant: str,
    tags: List[str],
    samples: List[str],
    title: str,
    fold: int,
    fontsize: int = 12,
    figsize: ty.Tuple[int, int] = (7, 3),
    cmap: str = 'viridis_r',
):
    for tag in tags:
        df_cms, df_cm_norms = [], []
        for sample in samples:
            p = (
                run_dir
                / variant
                / tag
                / sample
                / f'{fold:02d}'
                / 'train/test_metrics.json'
            )
            if not p.exists():
                LOG.warning('file %s does not exist', p)
                continue

            data = eu.read_json(p)
            df_cms.append(pd.DataFrame(data['conf_mat']))
            # display(f'{p}')
            # display(data['conf_mat'])
            # display(df_cms)
            df_cm_norms.append(pd.DataFrame(data['conf_mat_normalized_all']))
        df_cm = ep.mean_df(df_cms)
        df_cm_norm = ep.mean_df(df_cm_norms)

        plt.clf()
        _, ax = plt.subplots(nrows=1, ncols=2, figsize=figsize)
        vh.show_cm(
            ax=ax[0],
            df=df_cm,
            cmap=cmap,
            title='Counts',
            fontsize=fontsize,
        )
        vh.show_cm(
            ax=ax[1],
            df=df_cm_norm,
            cmap=cmap,
            title='Normalized',
            fontsize=fontsize - 2,
            fmt='.2f',
        )
        suptitle: str = f'{title}: {tag}'
        plt.suptitle(suptitle, fontsize=12, fontweight='bold')
        plt.tight_layout()
        plt.show()


# def make_cm_df()


def get_variant_spec(
    name: str, spec_path: Path = PATTERN_MASTER_FILE
) -> ty.Dict:
    import yaml
    import types

    spec_path = spec_path if isinstance(spec_path, Path) else Path(spec_path)
    assert spec_path.exists(), f'{spec_path} does not exist'
    data = yaml.safe_load(spec_path.open()).get(name, {})
    return types.SimpleNamespace(**data)


def count_labels(labels: ty.List[int]) -> ty.List:
    counts = {label: 0 for label in list(set(labels))}
    for label in labels:
        counts[label] += 1
    return counts


def make_label_stats(
    v: str, f: int, root_dir: Path, group: str = 'test'
) -> ty.Dict:
    indices = eu.load_indices(get_index_file(v, f, root_dir))[group]
    labels = eu.load_labels(get_label_file(v, root_dir).open()).numpy()
    test_labels: np.ndarray = np.sort(labels[indices])
    count_data = count_labels(test_labels.tolist())
    data = {str(k): [v] for k, v in count_data.items()}
    return pd.DataFrame(data=data, index=['Frequency']), 'Label Frequencies'
