import logging
from pathlib import Path
from typing import List, Optional, Tuple

import IPython.display as disp
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
import seaborn as sns

from egr import util
import egr.util as eu

LOG = logging.getLogger()


COLORS = plt.get_cmap('Set1').colors


def visualize(
    G: nx.Graph,
    class_labels: Optional[List] = None,
    node_size=500,
    font_size=8,
    figsize: Optional[Tuple[int, int]] = None,
):
    figsize = figsize or (10, 10)
    plt.figure(figsize=figsize)
    # pos = nx.spring_layout(G)
    pos = nx.kamada_kawai_layout(G)

    class_labels = class_labels or [0] * G.number_of_nodes()

    def s(idx, node_id):
        return f'id:{node_id}\nc:{class_labels[idx]}'

    labels = {n: s(i, n) for i, n in enumerate(G.nodes())}
    nx.draw(
        G,
        pos=pos,
        label='label',
        labels=labels,
        node_color=[COLORS[l] for l in class_labels],
        node_size=node_size,
        font_size=font_size,
        # font_weight='bold',
        width=1.25,
    )
    plt.show()


def show_features(X: np.ndarray):
    cmap = plt.get_cmap('Greys')
    plt.clf()
    plt.imshow(X, aspect='auto', cmap=cmap, interpolation='none')
    plt.tight_layout()
    plt.show()


def fsg_tag(
    data_dir: Path,
    r: int,
    axs,
    num_features: int = 10,
    show_metrics: bool = False,
):
    label = f'R{r} $\\rightarrow$ R{r + 1}'
    axs[r, 0].set_ylabel(
        label,
        rotation=0,
        fontsize=16,
        fontweight='bold',
        labelpad=50,
        loc='center',
    )
    paths = [p for p in sorted(data_dir.rglob('*.json'))]
    G_all = [util.load_graph(path) for path in paths]
    for idx, G in enumerate(G_all[:num_features]):
        colors = [
            'red' if n == G.graph['__root__'] else 'grey' for n in G.nodes
        ]
        ax = axs[r, idx]
        nx.draw(
            G,
            ax=ax,
            node_size=200,
            width=0.5,
            node_color=colors,
        )
        acc = G.graph['accuracy']
        prec = G.graph['precision']
        recall = G.graph['recall']
        f1 = G.graph['f1_score']
        title = G.graph['label']
        fontsize: int = 16
        fontweight: str = 'bold'
        if show_metrics:
            title += (
                f'\nP:{prec:.4f},  R:{recall:.4f}\nA:{acc:.4f},  F1:{f1:.4f}'
            )
            fontsize = 11
            fontweight = 'normal'
        if r == 0:
            ax.set_title(title, fontsize=fontsize, fontweight=fontweight)
        ax.axis('on')


all_samples = [f'{i:04d}' for i in range(1, 6)]
# control_rounds = ['rR', 'rL', 'rP2']
# eegl_rounds = [f'r{i}' for i in range(3)]
# all_rounds = control_rounds + eegl_rounds
data_root: Path = Path(f'/data/results')
num_features: int = 10
ncols: int = num_features


def fsg(run_id: str, tags, variant, sample, fold, save: bool = False):
    data_root: Path = Path(f'/data/results/{run_id}')
    vis_dir: Path = data_root / 'vis' / variant / sample

    plt.clf()
    nrows = len(tags)
    fig, axs = plt.subplots(
        nrows=nrows, ncols=ncols, figsize=(21.875, nrows * 2.1875)
    )
    plt.figtext(
        -0.8,
        1.12,
        'Labels:',
        ha='left',
        va='top',
        transform=axs[0, 0].transAxes,
        fontweight='bold',
        fontsize=14,
    )

    for idx, tag in enumerate(tags):
        data_dir = data_root / variant / tag / sample / f'{fold:02d}' / 'fsg'
        fsg_tag(data_dir, idx, axs)
    fig.tight_layout()
    if save:
        save_path: Path = vis_dir / f'fsg_fold-{fold:02d}.pdf'
        vis_dir.mkdir(parents=True, exist_ok=True)
        plt.savefig(save_path)
    fig.show()


def draw_features(path: str | Path, ax):
    if not isinstance(path, Path):
        path = Path(path)
    X = util.load_features(path)
    ax.set_xticks([i for i in range(X.shape[1])])
    ax.set_yticks([i for i in range(X.shape[0])])
    ax.set_title(path.stem.capitalize())
    ax.imshow(
        X,
        aspect='auto',
        vmin=0,
        vmax=1,
        cmap=plt.get_cmap('Greys'),
        interpolation='none',
    )


def show_cm(
    ax,
    df: pd.DataFrame,
    title: str,
    cmap: str = 'viridis_r',
    fontsize: int = 10,
    fmt: str = '.2g',
):
    kws = dict(fontsize=fontsize)
    pd.set_option('display.precision', 2)
    s_ax = sns.heatmap(
        df,
        cmap=cmap,
        annot=True,
        cbar=False,
        linewidths=1,
        square=True,
        ax=ax,
        fmt=fmt,
        annot_kws=kws,
    )
    s_ax.set_title(title, fontsize=fontsize + 2, fontweight='bold')
    return s_ax


def show_cm_for_file(path: Path, ax, cmap: str = 'viridis_r'):
    data = eu.read_json(path)
    df_cm = pd.DataFrame(data['conf_mat'])
    df_cm_norm = pd.DataFrame(data['conf_mat_normalized_all'])
    show_cm(ax=ax[0], df=df_cm, cmap=cmap, title='Counts')
    show_cm(ax=ax[1], df=df_cm_norm, cmap=cmap, title='Normalized')


TABLE_STYLE = "style='display:inline;padding-left:15px;vertical-align:top;'"


def show_side_by_size(*dfs):
    html = ''
    for df, title in dfs:
        styler = (
            df.style.set_table_attributes(TABLE_STYLE)
            .set_caption('<h4>' + title + '</h4>')
            .format(precision=2)
        )
        html += styler._repr_html_()
    disp.display_latex(html, raw=True)
