import random
import uuid
from datetime import datetime
from pathlib import Path
from copy import copy
from concurrent.futures import ThreadPoolExecutor
import json, time, re
from typing import Callable, Any, Union, Optional

import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib.collections import QuadMesh
from matplotlib.colors import Normalize
from IPython.core.display_functions import display, update_display

from experiments.grid_runs_utils import get_axes_from_grid_config
from scripts.scaling_laws_utils import set_quarter_ticks_from_arrays
from theory.theory import IGNORED_VALUE

# global paths
project_dir = Path(__file__).resolve().parent.parent  # shouldn't change, but adapt if required
results_dir = project_dir / "results"


import matplotlib

# --- Global, paper-friendly sizes (tweak as you like) ---

DPI = 100
# DPI = 300

matplotlib.rcParams.update({
    "figure.dpi": DPI,
    "savefig.dpi": DPI,
    "font.size": 14,          # base font size
    "axes.labelsize": 16,     # x/y labels
    "axes.titlesize": 18,     # title
    "xtick.labelsize": 14,
    "ytick.labelsize": 14,
    "legend.fontsize": 14,
})



def get_last_run_dirs(n: int = 1) -> list[str]:
    """Return the n most recent run dir *names* (newest first).
    Assumes names end with '__DD_MM_YYYY__HH_MM_SS'."""
    def parse_dt(name: str) -> datetime:
        # take the *last two* '__' chunks as date/time
        try:
            date_, time_ = name.rsplit("__")[-2:]
            return datetime.strptime(f"{date_} {time_}", "%d_%m_%Y %H_%M_%S")
        except Exception:
            # anything that doesn't match sorts as very old
            return datetime.min

    names = [p.name for p in results_dir.iterdir() if p.is_dir()]
    if n <= 0 or not names:
        return []
    return sorted(names, key=parse_dt, reverse=True)[:n]


def dot_get(d, path):
    cur = d
    for part in path.split("."):
        if isinstance(cur, dict) and part in cur:
            cur = cur[part]
        else:
            return None
    return cur


def _load_one(log_file: Path, stats: list[str] = ['accuracy']):
    name = log_file.name  # e.g. "V_0327_L_0320_seed_0_cuda_1_0xbcc0.json"

    # names: any-length letters; numbers: \d+; allow trailing stuff after seed
    m = re.search(r"(?i)([A-Z]+)_(\d+)_([A-Z]+)_(\d+)_seed_(\d+)(?=(?:_|\.|$))", name)
    if not m:
        raise ValueError(f"Filename doesn't match expected pattern: {name}")

    x_name, x_str, y_name, y_str, seed_str = m.groups()
    x = int(x_str)
    y = int(y_str)
    seed = int(seed_str)

    # Optional: keep the labels if you care
    # print(x_name, x, y_name, y, seed)

    # logs = wait_and_read_json(log_file, timeout=90, poll=0.25, quiet_seconds=0.75)
    with log_file.open() as fh:
        logs = json.load(fh)

    vals = {}
    for stat in stats:
        val = dot_get(logs, stat)
        # print(f"{name} | {stat}: {val}", flush=True)
        if val is not None:
            vals[stat] = float(val)
    # print(f"{vals=}", flush=True)
    # print(f"{vals['accuracy']=}", flush=True)

    # # temp todo
    # seed += random.randint(0, 10)  ####

    return (x, y, seed), vals



def _load_data_from_results_json_files(
        grid_run_dir: Path,
        logs_subdir_name: str,
        stats: list[str],
        n_threads=1,
        use_tqdm=False,
):

    # paths
    logs_dir = grid_run_dir / logs_subdir_name

    # unified config & results
    run_config  = json.loads((grid_run_dir / "run_config.json").read_text())
    grid_config = run_config["grid"]


    # metadata
    # dataset_config = run_config["dataset"]
    # V        = dataset_config["V"]
    # L        = dataset_config["L"]
    # try:
    #     N_facts  = dataset_config["N_facts"]  # new version
    # except KeyError:
    #     N_facts = dataset_config["N_facts"]["train"]  # old version; backward compatibility
    # metadata = {"V": V, "L": L, "N_facts": N_facts}
    grid_axes, grid_constants = get_axes_from_grid_config(grid_config)
    metadata = grid_constants.as_dict()
    if 'N' in metadata.keys() and grid_config.get('scale_N_with_D', False):
        metadata['N'] = f"D/{grid_config['D_to_N_ratio']}"

    # --- Build (x, y) -> {stat_path: value} from JSON files ---
    results_by_xy = {}

    files = sorted(logs_dir.glob("*.json"))[::-1]
    if len(files) == 0:
        return {}, {}, {}, {}, {}

    # print(f"found {len(files)} files", flush=True)

    if n_threads > 1:
        max_workers = min(n_threads, len(files)) or 1
        with ThreadPoolExecutor(max_workers=max_workers) as ex:
            it = ex.map(_load_one, files)
    else:
        it = map(_load_one, files)

    if use_tqdm:
        it = tqdm(it, total=len(files))

    # --- Fill arrays per stat_path, matching grid layout ---
    grid_axes, grid_constants = get_axes_from_grid_config(grid_config)
    x_axis = grid_axes.x.axis
    y_axis = grid_axes.y.axis
    I, J = len(x_axis), len(y_axis)
    # print(f"{I=}, {J=}")

    max_seed_per_grid_point = np.zeros((I, J))
    for (x, y, seed), vals in it:
        # print(f"(x, y, seed), vals = f{(x, y, seed), vals}")
        bucket = results_by_xy.setdefault((x, y, seed), {})
        bucket.update(vals)
        i = np.where(x_axis.astype(int) == x)[0]
        j = np.where(y_axis.astype(int) == y)[0]
        # print(f"{x=}, {y=}, {i=}, {j=}")
        if seed > max_seed_per_grid_point[i, j]:
            max_seed_per_grid_point[i, j] = seed


    stat_arrays = {s: np.full((I, J), fill_value=IGNORED_VALUE) for s in stats}

    # seed_to_use = 0  # TODO, temp
    # choose_best_seed = True

    # seeds
    # seeds_config = run_config.get('best_of_n_seeds', {})
    # n_seeds = seeds_config.get('n_seeds', 1)
    # start_seed = seeds_config.get('start_seed', 0)
    n_seeds = 30
    start_seed = 0
    seeds_to_iterate = range(start_seed, start_seed + n_seeds + 1)

    for i, x in enumerate(x_axis):
        for j, y in enumerate(y_axis):

            # best of n seeds
            acc_per_seed = {}
            for seed in seeds_to_iterate:
                try:
                    # acc_per_seed[seed] = results_by_xy.get((x, y, seed), {}).get('accuracy', np.nan)
                    acc_per_seed[seed] = results_by_xy.get((x, y, seed), {})['accuracy']
                    # acc_per_seed[seed + int(uuid.uuid4())] = results_by_xy.get((x, y, seed), {})['accuracy']
                except:
                    pass

            vals = {}
            if acc_per_seed:
                best_seed = max(acc_per_seed, key=acc_per_seed.get)
                vals = results_by_xy.get((x, y, best_seed))

            if vals:
                for s in stats:
                    v = vals.get(s)
                    if v is not None:
                        stat_arrays[s][i, j] = v
            #             print(f"stat_arrays[{s}][{i}, {j}] = {v}")
            #         else:
            #             print("v is None!")
            # else:
            #     print("not vals!")

    return grid_axes, stat_arrays, metadata, run_config, max_seed_per_grid_point


def _draw_numbers_over_array_plot(
    ax,
    x_axis, y_axis,
    Z, numbers,
    *,
    cmap=None, norm=None,
    skip_zeros=True,
    fontsize=8,
    fmt: Union[str, Callable[..., str]] = None,
    text_kw=None,
    black_white_threshold=0.5,
    cell_value_to_ignore=None,
):
    """Overlay numbers on a pcolormesh drawn with 1-x x/y axes.

    x_axis, y_axis: 1-x arrays of edges (len = n+1) or centers (len = n).
    Z, numbers: shape (ny, nx).  numbers are displayed in each cell.
    """

    if fmt is None:
        fmt = "{:.2f}"

    Z = np.asarray(Z)
    numbers = np.asarray(numbers)
    if Z.shape != numbers.shape:
        raise ValueError(f"Z.shape {Z.shape} must match numbers.shape {numbers.shape}")

    # inherit cmap/norm from last QuadMesh if not provided
    if cmap is None or norm is None:
        for coll in reversed(ax.collections):
            if isinstance(coll, QuadMesh):
                cmap = cmap or coll.get_cmap()
                norm = norm or coll.norm
                break
    cmap = cmap or plt.get_cmap("viridis")
    if norm is None:
        norm = plt.Normalize(np.nanmin(Z.astype(float)), np.nanmax(Z.astype(float)))

    # mask invalid cells
    mask = np.zeros_like(Z, dtype=bool)
    if np.ma.isMaskedArray(Z):
        mask |= np.ma.getmaskarray(Z)
    mask |= ~np.isfinite(Z.astype(float))

    top_z = (max([c.get_zorder() for c in ax.collections], default=0) + 1)
    base_kw = dict(ha="center", va="center", fontsize=fontsize, clip_on=True, zorder=top_z)

    if text_kw:
        base_kw.update(text_kw)

    for j, y in enumerate(y_axis):
        # y = yc[j]
        for i, x in enumerate(x_axis):
            if mask[j, i]:
                continue
            n = numbers[j, i]
            if (cell_value_to_ignore is not None) and (Z[j, i] == cell_value_to_ignore):
                continue
            # choose contrasting color from background
            # r, g, b, _ = cmap(norm(float(Z[j, i])))
            # lum = 0.2126*r + 0.7152*g + 0.0722*b
            kw = dict(base_kw)
            kw["color"] = "black" if float(Z[j, i]) > black_white_threshold else "white"
            # kw["color"] = "black"
            # kw.setdefault("path_effects", [pe.withStroke(linewidth=1, foreground="k", alpha=0.5)])
            if isinstance(fmt, str):
                number_str = fmt.format(int(n))
            elif isinstance(fmt, Callable):
                number_str = fmt(n)
            else:
                raise ValueError("unknown format type")
            ax.text(x, y, number_str, **kw)


def plot_grid_run_stats(
        run_name: str,
        stat_list: list[str] = ['accuracy', 'error_rate', 'loss'],
        logs_dir_name: str = 'train_logs',
        colorbar=True,
        cmap='inferno',
        figsize=(12, 5),
        binary_accuracy=False, threshold=1,
        # invert=True,
        invert=False,
        theory_lines_plot_function: Optional[Callable] = None,
        aspect='auto',
        save_figure=False,
        save_dir=None,
        save_name=None,
        metric_to_overlay=None,
        overlay_only_diff=False,
        prev_stat_arrays_dict=None,
        upper_title=None,
        no_title=False,
        show=True,
):
    stat_list = copy(stat_list)

    # prepare plot
    n = len(stat_list)
    # if only_binary_accuracy:
    #     n = 1
    fig = plt.figure(figsize=figsize, facecolor='white')
    gs  = fig.add_gridspec(1, n + 1, width_ratios=[1]*n + [0.05], wspace=0.25)
    fig_axes = [fig.add_subplot(gs[0, i]) for i in range(n)]
    for ax in np.ravel(fig_axes):
        ax.set_facecolor('white')
        ax.patch.set_alpha(1)
    fig.patch.set_alpha(1)

    # load
    grid_run_dir = results_dir / run_name

    _, stat_arrays_dict, metadata, run_config, max_seed_per_grid_point = (
        _load_data_from_results_json_files(grid_run_dir, logs_dir_name, stat_list))
    # if (not grid_axes) or (not metadata):  # TODO
    #     # print("got (not grid_axes) or (not metadata); exits")
    #     return

    model_config = run_config['model']
    model_variant = model_config['variant']
    grid_config = run_config['grid']

    grid_axes, grid_constants = get_axes_from_grid_config(grid_config)
    x_axis = grid_axes.x.axis
    y_axis = grid_axes.y.axis
    x_label = grid_axes.x.name
    y_label = grid_axes.y.name

    # # temp, todo
    # y_axis = np.asarray(y_axis)
    # y_axis = np.where(np.isfinite(y_axis), y_axis, np.nan)  # drop infs
    # y_axis = np.clip(y_axis, 0, None)  # no negatives

    masked_stat_arrays_dict = {k: np.ma.masked_equal(v, IGNORED_VALUE) for k, v in stat_arrays_dict.items()}

    if binary_accuracy:
        assert (threshold >= 0) and (threshold <= 1)
        # if threshold < 1:
        #     bin_acc_name = f'accuracy >= {threshold}'
        # else:
        #     bin_acc_name = f'accuracy == {threshold}'
        bin_acc_name = 'accuracy'
        masked_stat_arrays_dict[bin_acc_name] = masked_stat_arrays_dict['accuracy'] >= threshold
        # stat_list.remove('accuracy')

    for i_ax, (ax, stat) in enumerate(zip(fig_axes, masked_stat_arrays_dict.keys())):

        if stat in ['accuracy', 'error_rate']:
            norm = Normalize(0, 1)
        else:
            norm = None

        stat_array = masked_stat_arrays_dict[stat]

        if invert and stat == 'accuracy':

            # error rate
            stat_array = 1 - stat_array

            # log
            # safe_array = np.where(stat_array == 0, np.finfo(float).eps, stat_array)  # zeros -> small positive number
            safe_array = np.where(stat_array == 0, np.nan, stat_array)
            # stat_array = np.log10(safe_array)

            stat_array = (
                    # (stat_array <= 0).astype(int) +
                    (stat_array <= 3.162e-3).astype(int) +
                    (stat_array <= 1.000e-2).astype(int) +
                    (stat_array <= 3.162e-2).astype(int) +
                    (stat_array <= 1.000e-1).astype(int) +
                    (stat_array <= 3.162e-1).astype(int)
            )

            norm = None

        # plot
        ax.pcolormesh(
            x_axis, y_axis, stat_array.T,
            cmap=cmap,
            # interpolation='none',
            norm=norm,
        )

        if binary_accuracy and (theory_lines_plot_function is not None):
            theory_lines_plot_function(x_axis, y_axis, y_label, run_title=upper_title, p=threshold)

        # overlay seed numbers

        # numbers_array_to_draw = np.zeros_like(stat_array)  # sanity check
        # numbers_array_to_draw = np.random.random_integers(5, size=stat_array.shape)

        if metric_to_overlay is not None:

            if metric_to_overlay == 'stat':
                numbers_array_to_draw = stat_array * 1
            elif metric_to_overlay == 'seed':
                numbers_array_to_draw = max_seed_per_grid_point
            else:
                raise NotImplementedError

            if overlay_only_diff and (prev_stat_arrays_dict is not None):
                prev_masked_stat_arrays_dict = {k: np.ma.masked_equal(v, IGNORED_VALUE) for k, v in prev_stat_arrays_dict.items()}
                diff_array_dict = {k: (masked_stat_arrays_dict[k] - prev_masked_stat_arrays_dict[k]) for k in prev_stat_arrays_dict.keys()}
                diff_array = diff_array_dict[stat]
                # masked_diff_array = np.ma.masked_equal(diff_array, 0)
                stat_array_to_plot = stat_array * 1
                stat_array_to_plot[diff_array == 0] = IGNORED_VALUE
            else:
                stat_array_to_plot = stat_array * 1

            if (metric_to_overlay == 'stat'):
                fmt = lambda x: f"{x:.2f}".replace(".00", "").replace("0.", ".")
            elif metric_to_overlay == 'seed':
                fmt = "{:d}"
            else:
                fmt = None

            # overlay
            _draw_numbers_over_array_plot(
                ax, x_axis, y_axis,
                Z=stat_array_to_plot.T,  # Z used for background color contrast
                numbers=numbers_array_to_draw.T,  # int array, same shape as stat_array.T
                cmap=cmap, norm=norm,
                fontsize=7,
                fmt=fmt,
                skip_zeros=False,
                cell_value_to_ignore=IGNORED_VALUE,
            )

        # adjust
        ax.set_aspect(aspect, adjustable='box')
        _x_label = copy(x_label)
        _y_label = copy(y_label)
        if y_label == "V":
            ax.set_yscale('log')
            _y_label += ' (Log)'
        ax.set_xlabel(_x_label)
        ax.set_ylabel(_y_label)

        set_quarter_ticks_from_arrays(ax, x_axis, y_axis)  # temp

        # subplot title
        trained = "_trained" if run_config['runtime']['should_train'] else ""
        metadata_title = ", ".join([f"{k}={v}" for k, v in metadata.items()])
        if upper_title is None:
            _title = lambda x: ((x or "").replace('_', ' ').title()
                                .replace('Ssm', 'SSM')
                                .replace('Simplified', 'Linear')
                                .replace('Ideal', 'Optimal'))
            title_str = _title(f"{model_variant}{trained}")
        else:
            title_str = upper_title.replace('_', ' ').title()

        if (stat == 'accuracy') and binary_accuracy:
            if threshold < 1:
                stat_name = f'accuracy >= {threshold}'
            else:
                stat_name = f'accuracy == {threshold}'
        else:
            stat_name = stat

        stat_str = stat_name.replace('_', ' ').title()

        # plt.suptitle(_title(stat))

        # title_components = [_title(stat)]


        title_components = [
            stat_str,
            title_str,
            metadata_title,
        ]

        subtitle = "\n".join(title_components)

        if not no_title:
            ax.set_title(subtitle)

        short_name = lambda k: k if k != 'N_facts' else 'Nf'
        meta_for_name = {} if metadata is None else {k: v for k, v in metadata.items()}
        config_name = "_".join([f"{short_name(k)}{v}" for k, v in meta_for_name.items()]) if meta_for_name else "run"
        fname = f"{model_variant}__{config_name}"

        # if binary_graph:
        #     fname += "__binary"

    # optionally save the whole figure
    if save_figure:
        save_figure_to_png(fig, save_dir=save_dir, save_name=save_name, dpi=300)

    if show:
        plt.show()

    return grid_axes, stat_arrays_dict, metadata_title


def save_figure_to_png(fig, save_dir=None, save_name=None, dpi=300):

    if save_dir is None:
        out_dir = Path(project_dir) / "figures"
    else:
        out_dir = save_dir
    out_dir.mkdir(parents=True, exist_ok=True)

    if save_name is None:
        file_name = 'temp_out_name'
    else:
        file_name = save_name
    out_png = out_dir / f"{file_name}.png"

    fig.savefig(
        out_png,
        dpi=dpi, facecolor='white', bbox_inches='tight',
        pad_inches=0.05, transparent=False
    )
    # print(f"Figure saved to: \n{out_png}")


def live_plot_grid_run_stats(
        run_names: list[str],
        stat_list: list[str] = ['accuracy','loss','grad_norm'],
        logs_dir_name: str = 'train_logs',
        figsize=(12,3),
        sleep_time_sec=10,
        add_binary_accuracy=False, threshold=1,
        metric_to_overlay=None,
        overlay_only_diff=True,
):

    display_id = "live_grid_run_plot_minimal"

    first = True

    try:

        stat_arrays_dict = None

        while True:

            # build a new figure (your current function)
            # Let plot_grid_run_stats create/return the figure object.
            fig = None

            plt.close()

            for run_name in run_names:

                # If your plot_grid_run_stats currently calls plt.show() internally,
                # modify it to return the figure instead (or call it but grab plt.gcf()).

                stat_arrays_dict = plot_grid_run_stats(
                    run_name=run_name,
                    stat_list=stat_list,
                    logs_dir_name=logs_dir_name,
                    figsize=figsize,
                    binary_accuracy=add_binary_accuracy,
                    threshold=threshold,
                    metric_to_overlay=metric_to_overlay,
                    overlay_only_diff=overlay_only_diff,
                    prev_stat_arrays_dict=copy(stat_arrays_dict),
                )

            fig = plt.gcf()  # grab the current figure object (works if plot_grid_run_stats draws on it)

            if first:
                display(fig, display_id=display_id)
                first = False
            else:
                update_display(fig, display_id=display_id)

            # let the GUI update smoothly
            plt.pause(0.1)
            time.sleep(sleep_time_sec)
            plt.close()

    except KeyboardInterrupt:
        print("\nlive plot terminated by user")
