import numpy as np
import json
import matplotlib.pyplot as plt
from pathlib import Path
from typing import Callable, Mapping, Literal, Any, LiteralString, Sequence, Iterable
from .figutils import tabulate_results, get_eval_results, Result, Results, PathLike


def key2digits(x: str):
   split = x.split('x')
   if len(split) == 2:
      return int(split[0]), int(split[1])
   else:
      raise ValueError


type MultResults = Results[tuple[int, int]]


def _get_overall_result(results: MultResults,
                        key: str = "score",
                        filter: Callable[[int, int], bool] | None = None,
                        percentage: bool = True):
    n: int = 0
    mu: float | None = None
    for (dx, dy), r in results.items():
        v = r[key]
        if v is None or r.count == 0:
            continue
        if filter is None or filter(dx, dy):
            dn = r.count
            new_n = n + dn
            if mu is None:
                mu = v
            else:
                mu = mu * (n / new_n) + v * (dn / new_n)
            n = new_n
    if mu is not None:
        if percentage:
            mu = mu * 100
    else:
        mu = np.nan
    return mu


def _plot_hotmap(results: MultResults,
                 key: str,
                 title: str | None,
                 colorbar: bool,
                 bound: tuple[int, int] | None,
                 title_fontsize: int,
                 cell_fontsize: int,
                 axis_label_fontsize: int):
    xs = set(k[0] for k in results)
    ys = set(k[1] for k in results)
    xrange = range(min(xs), max(xs) + 1)
    yrange = range(min(ys), max(ys) + 1)
    mat = np.full((len(yrange), len(xrange)), np.nan)

    if key == "score":
        cmap = 'RdYlGn'
        textcolor = 'black'
    elif key == "refl_freq":
        cmap = 'hot'
        textcolor = 'cyan'
    else:
        assert False

    for (xi, yj), result in results.items():
        i = xi - xrange.start
        j = yj - yrange.start
        if (v := result.get(key)) is not None:
            mat[j, i] = v
            plt.text(i, j, f'{round(100 * v)}',
                    ha='center', va='center', color=textcolor, fontsize=cell_fontsize)
    plt.imshow(mat, cmap=cmap, interpolation='nearest', vmin=0, vmax=1)
    plt.xlabel("number of x's digits", fontsize=axis_label_fontsize)
    plt.ylabel("number of y's digits", fontsize=axis_label_fontsize)
    plt.xticks(ticks=range(mat.shape[1]), labels=list(map(str, xrange)))
    plt.yticks(ticks=range(mat.shape[0]), labels=list(map(str, yrange)))

    if bound is not None:
        xborder = bound[0] - xrange.start + 0.5
        yborder = bound[1] - yrange.start + 0.5
        plt.plot([-0.5, xborder], [yborder, yborder], '--', color="white", linewidth=2,)
        plt.plot([xborder, xborder], [-0.5, yborder], '--', color="white", linewidth=2,)
        avr_id = _get_overall_result(results, key,
                                     lambda dx, dy: dx <= bound[0] and dy <= bound[1],
                                     percentage=True)
        avr_ood = _get_overall_result(results, key,
                                      lambda dx, dy: dx > bound[0] or dy > bound[1],
                                      percentage=True)
        if title is not None:
            title = title.format(avr_id, avr_ood)
    else:
        avr = _get_overall_result(results, key)
        if title is not None:
            title = title.format(avr)
    
    if colorbar:
        plt.colorbar()
    if title is not None:
        plt.title(title, fontsize=title_fontsize)


def plot_hotmaps(
    datas: Mapping[str, PathLike],
    dpi: int = 144,
    max_ncol: int = 3,
    col_width: float = 3.5,
    row_height: float = 3.5,
    title_fontsize: int = 12,
    cell_fontsize: int = 9,
    axis_label_fontsize: int = 11,
    bound: int | tuple[int, int] | None = None,
    root: PathLike | None = None,
    show: bool = False,
    save: str | Path | None = None,
    key: str = "score",
):
    if root is not None:
        root = Path(root)
    if isinstance(bound, int):
        bound = (bound, bound)
    
    ncol = min(max_ncol, len(datas))
    nrow = float.__ceil__(len(datas) / ncol)
    plt.figure(figsize=(ncol * col_width, nrow * row_height), dpi=dpi)
    for i, (title, path) in enumerate(datas.items(), start=1):
        if root is not None:
            path = root / path
        try:
            results = get_eval_results(path, kmap=key2digits)
        except FileNotFoundError:
            results = None
        plt.subplot(nrow, ncol, i)
        if results:
            _plot_hotmap(results, key, title, False, bound, title_fontsize, cell_fontsize, axis_label_fontsize)
        else:
            plt.title(title + "\nNO RESULT FOUND", fontsize=title_fontsize)
    plt.tight_layout()
    if save:
        if isinstance(save, str):
            save = Path(save)
        save.parent.mkdir(parents=True, exist_ok=True)
        plt.savefig(save)
    if show:
        plt.show()
