from __future__ import annotations

from typing import Dict, List, Optional, Tuple

import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import numpy as np
from matplotlib.colors import LinearSegmentedColormap


def _edges_from_centers(vals: List[float], log: bool) -> np.ndarray:
    """ edges from centers for the given inputs."""
    v = np.asarray(vals, dtype=float)
    if v.size < 2:
        dv = 0.5 if not log else 2.0
        return np.array([v[0] / dv, v[0] * dv], dtype=float)

    edges = np.empty(v.size + 1, dtype=float)
    if log:
        if np.any(v <= 0):
            raise ValueError("Log-edges require strictly positive centers.")
        edges[1:-1] = np.sqrt(v[:-1] * v[1:])
        edges[0] = v[0] ** 2 / edges[1]
        edges[-1] = v[-1] ** 2 / edges[-2]
    else:
        edges[1:-1] = 0.5 * (v[:-1] + v[1:])
        edges[0] = v[0] - (edges[1] - v[0])
        edges[-1] = v[-1] + (v[-1] - edges[-2])
    return edges


def plot_two_heatmaps(
    mats: Dict[str, np.ndarray],
    est_keys: Tuple[str, str],
    estimator_labels: Dict[str, str],
    x_vals: List[int],
    y_vals: List[int],
    x_label: str,
    y_label: str,
    log_x: bool = True,
    log_y: bool = True,
    x_ticks: Optional[List[float]] = None,
    y_ticks: Optional[List[float]] = None,
    y_ticklabels: Optional[List[str]] = None,
    x_ticklabels: Optional[List[str]] = None,
    figsize: Tuple[float, float] = (10, 4),
    show_grid: bool = False,
    font_size: int = 12,
    title_size: Optional[int] = None,
    max_xticks: int = 8,
    x_tick_rotation: float = 45.0,
    x_tick_ha: str = "right",
) -> plt.Figure:
    """Plot side-by-side heatmaps for two estimators."""
    x = list(map(float, x_vals))
    y = list(map(float, y_vals))
    x_edges = _edges_from_centers(x, log=log_x)
    y_edges = _edges_from_centers(y, log=log_y)

    all_vals = []
    for k in est_keys:
        Z = np.asarray(mats[k], dtype=float)
        all_vals.append(Z[np.isfinite(Z)])
    all_vals = np.concatenate(all_vals) if len(all_vals) else np.array([], dtype=float)

    if all_vals.size == 0:
        norm = mcolors.Normalize(vmin=0.0, vmax=1.0)
    else:
        vmin = float(np.nanmin(all_vals))
        vmax = float(np.nanmax(all_vals))
        if vmin > 0 and vmax > vmin and (vmax / vmin) >= 10.0:
            norm = mcolors.LogNorm(vmin=max(vmin, 1e-12), vmax=vmax)
        else:
            norm = mcolors.Normalize(vmin=vmin, vmax=vmax)

    cmap = LinearSegmentedColormap.from_list(
        "okabe_ito_green_yellow_red", ["#009E73", "#F0E442", "#D55E00"], N=256
    ).copy()
    cmap.set_bad(color="lightgray")

    if title_size is None:
        title_size = int(font_size)

    def _format_log_tick(val: float, _pos: int) -> str:
        """ format log tick for the given inputs."""
        if val == 0:
            return "0"
        exp = int(np.floor(np.log10(abs(val))))
        coeff = val / (10**exp)
        coeff = round(coeff, 1)
        if np.isclose(coeff, round(coeff)):
            coeff = int(round(coeff))
        if exp == 0:
            return f"{coeff:g}"
        return f"{coeff:g}e{exp}"

    fig, axes = plt.subplots(1, 2, figsize=figsize, sharey=True)

    meshes = []
    for ax, k in zip(axes, est_keys):
        Z = np.ma.masked_invalid(np.asarray(mats[k], dtype=float))
        mesh = ax.pcolormesh(
            x_edges,
            y_edges,
            Z,
            shading="auto",
            norm=norm,
            cmap=cmap,
            edgecolors="none",
            linewidth=0,
            antialiased=False,
        )
        mesh.set_rasterized(True)
        meshes.append(mesh)

        ax.set_title(estimator_labels.get(k, k), fontsize=title_size)

        if log_x:
            ax.set_xscale("log")
        if log_y:
            ax.set_yscale("log")

        xt = x_vals if x_ticks is None else list(x_ticks)
        xt = list(map(float, xt))
        if max_xticks and len(xt) > max_xticks:
            idx = np.linspace(0, len(xt) - 1, num=max_xticks, dtype=int)
            xt = [xt[i] for i in idx]

        ax.set_xticks(xt)
        ax.xaxis.set_minor_locator(mticker.NullLocator())
        if x_ticklabels is not None:
            ax.set_xticklabels(x_ticklabels)
        else:
            if log_x:
                ax.xaxis.set_major_formatter(mticker.FuncFormatter(_format_log_tick))
            else:
                ax.xaxis.set_major_formatter(mticker.ScalarFormatter())
                ax.ticklabel_format(style="plain", axis="x")
        ax.tick_params(axis="both", which="major", labelsize=font_size)
        for lab in ax.get_xticklabels():
            lab.set_rotation(x_tick_rotation)
            lab.set_ha(x_tick_ha)

        if y_ticks is not None:
            ax.set_yticks(list(map(float, y_ticks)))
            ax.yaxis.set_minor_locator(mticker.NullLocator())
            ax.yaxis.set_major_formatter(mticker.ScalarFormatter())
            ax.ticklabel_format(style="plain", axis="y")
            if y_ticklabels is not None:
                ax.set_yticklabels(y_ticklabels)
        else:
            ax.yaxis.set_minor_locator(mticker.NullLocator())

        ax.grid(False)
        if show_grid:
            ax.grid(True, which="major", axis="both", alpha=0.25)

    axes[0].set_ylabel(y_label, fontsize=font_size)
    fig.supxlabel(x_label, fontsize=font_size, y=0.04)
    bottom = 0.26 if x_tick_rotation else 0.2
    fig.subplots_adjust(bottom=bottom, right=0.88, wspace=0.08)
    pos = axes[-1].get_position()
    cax_left = min(pos.x1 + 0.02, 0.95)
    cax = fig.add_axes([cax_left, pos.y0, 0.02, pos.height])
    cbar = fig.colorbar(meshes[-1], cax=cax)
    cbar.set_label(r"MAE ($\|\hat\beta-\beta^*\|_1$)", fontsize=font_size)
    cbar.ax.tick_params(labelsize=font_size)

    if isinstance(norm, mcolors.LogNorm):
        cbar.locator = mticker.LogLocator(base=10, subs=(1.0, 2.0, 5.0))
        cbar.formatter = mticker.FuncFormatter(lambda v, _pos: f"{v:.3g}")
        cbar.update_ticks()
    else:
        cbar.locator = mticker.MaxNLocator(nbins=5)
        cbar.update_ticks()

    return fig

