from __future__ import annotations

from typing import Any, Dict, List, Optional, Tuple

import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import numpy as np
from matplotlib.lines import Line2D

from .style import default_method_styles

DEFAULT_FIGSIZE = (1.5 * 3.4, 1.5 * 2.2)


def _plain_log_ticks(ax: plt.Axes, ticks: List[float]) -> None:
    """ plain log ticks for the given inputs."""
    ax.set_xticks(list(map(float, ticks)))
    ax.xaxis.set_minor_locator(mticker.NullLocator())
    ax.xaxis.set_major_formatter(mticker.ScalarFormatter())
    ax.ticklabel_format(style="plain", axis="x")


def _downsample_ticks(ticks: List[float], max_ticks: int) -> List[float]:
    """ downsample ticks for the given inputs."""
    ticks = list(map(float, ticks))
    if max_ticks is None or max_ticks <= 0 or len(ticks) <= max_ticks:
        return ticks
    idx = np.linspace(0, len(ticks) - 1, num=max_ticks)
    idx = np.unique(np.round(idx).astype(int))
    return [ticks[i] for i in idx]


def _ordered_estimator_keys(
    available_keys: List[str],
    estimator_labels: Dict[str, str],
) -> List[str]:
    """ ordered estimator keys for the given inputs."""
    preferred = [k for k in estimator_labels.keys() if k in available_keys]
    extras = [k for k in available_keys if k not in estimator_labels]
    return preferred + extras


def plot_experiment_vs_N(
    results: Dict[str, List[Dict[str, float]]],
    estimator_labels: Dict[str, str],
    N_list: Optional[List[int]] = None,
    y_lim: Tuple[float, float] = None,
    figsize: Tuple[float, float] = DEFAULT_FIGSIZE,
    legend_above: bool = False,
) -> plt.Figure:
    """Plot error curves versus total sample size."""
    # Use the same layout engine as the grid plot so that saved figures (often with
    # bbox_inches="tight") have consistent overall sizing/margins.
    fig, ax = plt.subplots(1, 1, figsize=figsize, constrained_layout=False)

    est_keys = _ordered_estimator_keys(list(results.keys()), estimator_labels)
    styles = default_method_styles(est_keys)

    for k in est_keys:
        recs = results.get(k)
        if not recs:
            continue
        Ns = [rec["N"] for rec in recs]
        means = [rec["mean"] for rec in recs]
        lows = [rec["lower"] for rec in recs]
        ups = [rec["upper"] for rec in recs]

        st = styles[k]
        ax.plot(
            Ns,
            means,
            color=st.color,
            linestyle=st.linestyle,
            marker=st.marker,
            label=estimator_labels.get(k, k),
        )
        ax.fill_between(Ns, lows, ups, color=st.color, alpha=0.12)

    ax.set_xscale("log")
    if N_list:
        ax.set_xlim(min(N_list), max(N_list))
        _plain_log_ticks(ax, N_list)

    if y_lim is not None:
        ax.set_ylim(*y_lim)

    ax.axhline(0.0, color="k", linewidth=1.0, linestyle="--", alpha=0.45)

    ax.set_xlabel("total sample size ($N$)")
    ax.set_ylabel(r"MAE ($\|\hat\beta-\beta\|_1$)")

    if legend_above:
        handles, labels = ax.get_legend_handles_labels()
        if handles:
            ncol = min(4, len(handles))
            fig.legend(
                handles=handles,
                labels=labels,
                loc="upper center",
                ncol=ncol,
                bbox_to_anchor=(0.5, 0.995),
                frameon=True,
                borderaxespad=0.0,
            )
        fig.tight_layout(rect=(0.06, 0.08, 1.0, 0.94))
    else:
        ax.legend(loc="upper right", ncol=2)
        fig.tight_layout()
    return fig


def plot_grid_experiment(
    results: Dict[int, Dict[str, List[Dict[str, float]]]],
    r_list: List[int],
    estimator_labels: Dict[str, str],
    N_list: Optional[List[int]] = None,
    y_lim: Tuple[float, float] = None,
    figsize: Tuple[float, float] = DEFAULT_FIGSIZE,
) -> plt.Figure:
    """Plot grid of error curves across r values."""
    n = len(r_list)
    ncols = 2
    nrows = int(np.ceil(n / ncols))
    fig, axes = plt.subplots(
        nrows, ncols, figsize=figsize, sharey=True, constrained_layout=False
    )
    axes = np.array(axes).ravel()

    per_r_keys = []
    for r in r_list:
        rec = results.get(int(r))
        if rec is not None:
            per_r_keys.append(set(rec.keys()))

    if per_r_keys:
        common_keys = sorted(set.intersection(*per_r_keys))
    else:
        common_keys = []

    est_keys = _ordered_estimator_keys(common_keys, estimator_labels)
    styles = default_method_styles(est_keys)

    grid_ticks = None
    if N_list:
        grid_ticks = _downsample_ticks(list(N_list), max_ticks=4)

    for i, r in enumerate(r_list):
        ax = axes[i]
        if int(r) not in results:
            ax.axis("off")
            continue
        for k in est_keys:
            recs = results[int(r)].get(k)
            if not recs:
                continue
            Ns = [rec["N"] for rec in recs]
            means = [rec["mean"] for rec in recs]
            lows = [rec["lower"] for rec in recs]
            ups = [rec["upper"] for rec in recs]

            st = styles[k]
            ax.plot(Ns, means, color=st.color, linestyle=st.linestyle, marker=st.marker)
            ax.fill_between(Ns, lows, ups, color=st.color, alpha=0.10)

        ax.set_xscale("log")
        if N_list:
            ax.set_xlim(min(N_list), max(N_list))
            ticks = grid_ticks if grid_ticks is not None else N_list
            _plain_log_ticks(ax, ticks)
            ax.set_xticklabels([f"{val / 100:g}" for val in ticks])

        ax.axhline(0.0, color="k", linewidth=1.0, linestyle="--", alpha=0.45)

        if y_lim is not None:
            ax.set_ylim(*y_lim)

        ax.text(
            0.05,
            0.93,
            f"$n/m={int(r)}$",
            transform=ax.transAxes,
            ha="left",
            va="top",
            fontweight="bold",
            bbox=dict(
                boxstyle="round,pad=0.25",
                facecolor="white",
                alpha=0.85,
                edgecolor="none",
            ),
        )

        row = i // ncols
        if row == nrows - 1:
            for lab in ax.get_xticklabels():
                lab.set_rotation(0)
                lab.set_ha("center")
        else:
            ax.tick_params(axis="x", labelbottom=False)

    for j in range(len(r_list), len(axes)):
        axes[j].axis("off")

    handles = []
    for k in est_keys:
        st = styles[k]
        handles.append(
            Line2D(
                [0],
                [0],
                color=st.color,
                linestyle=st.linestyle,
                marker=st.marker,
                label=estimator_labels.get(k, k),
            )
        )
    if handles:
        ncol = min(4, len(handles))
        fig.legend(
            handles=handles,
            loc="upper center",
            ncol=ncol,
            bbox_to_anchor=(0.5, 0.995),
            frameon=True,
            borderaxespad=0.0,
        )

        fig.tight_layout(rect=(0.06, 0.08, 1.0, 0.94))
    else:
        fig.tight_layout(rect=(0.06, 0.08, 1.0, 1.0))

    # Single centered axis labels for cleaner appearance
    fig.supxlabel("total sample size ($N$) ($\\times 10^2$)")
    fig.supylabel(r"MAE ($\|\hat\beta-\beta\|_1$)")

    return fig


def plot_agreement_grid_scatter(
    results: Dict[int, Dict[str, List[Dict[str, Any]]]],
    r_list: List[int],
    key_base: str = "up_gmm_hd",
    key_ana: str = "up_gmm_hd_analytic",
    n_panels: int = 4,
    ncols: int = 2,
    figsize: Tuple[float, float] = DEFAULT_FIGSIZE,
    show_rel_text: bool = False,
    *,
    use_log_axes: bool = False,
    robust_quantiles: Tuple[float, float] = (0.005, 0.995),
    include_residual: bool = True,
) -> plt.Figure:
    """Plot agreement scatter grids for two estimators."""
    r_sel = list(r_list)[:n_panels]
    n = len(r_sel)
    nrows = int(np.ceil(n / ncols))

    xs_all, ys_all, ds_all = [], [], []
    per_panel: List[Tuple[int, Optional[np.ndarray], Optional[np.ndarray]]] = []

    for r in r_sel:
        block = results.get(int(r), {})
        recs_b = block.get(key_base, [])
        recs_a = block.get(key_ana, [])
        if not recs_b or not recs_a:
            per_panel.append((int(r), None, None))
            continue

        mb = {float(rec["N"]): rec for rec in recs_b}
        ma = {float(rec["N"]): rec for rec in recs_a}
        Ns = sorted(set(mb) & set(ma))

        xs, ys = [], []
        for N in Ns:
            bb = mb[N].get("rep_bhat", None)
            ba = ma[N].get("rep_bhat", None)
            if bb is None or ba is None:
                continue

            B = np.asarray(bb, dtype=float)
            A = np.asarray(ba, dtype=float)
            if B.ndim != 2 or A.ndim != 2:
                continue

            L = min(B.shape[0], A.shape[0])
            if L == 0:
                continue
            B = B[:L, :]
            A = A[:L, :]

            if B.shape[1] != A.shape[1]:
                d = min(B.shape[1], A.shape[1])
                B = B[:, :d]
                A = A[:, :d]

            xs.append(B.ravel())
            ys.append(A.ravel())

        if not xs:
            per_panel.append((int(r), None, None))
            continue

        x = np.concatenate(xs)
        y = np.concatenate(ys)
        per_panel.append((int(r), x, y))

        finite = np.isfinite(x) & np.isfinite(y)
        if np.any(finite):
            xf = x[finite]
            yf = y[finite]
            xs_all.append(xf)
            ys_all.append(yf)
            ds_all.append(yf - xf)

    # Robust shared limits for coefficient-coefficient scatter
    if xs_all and ys_all:
        allv = np.concatenate([np.concatenate(xs_all), np.concatenate(ys_all)])
        allv = allv[np.isfinite(allv)]
        if allv.size:
            qlo, qhi = robust_quantiles
            lo_q, hi_q = np.quantile(allv, [qlo, qhi])
            m = float(max(abs(lo_q), abs(hi_q)))
            if m <= 0:
                m = float(np.max(np.abs(allv))) if allv.size else 1.0
            pad = 0.02 * (2 * m + 1e-12)
            coef_lo, coef_hi = -m - pad, m + pad
        else:
            coef_lo, coef_hi = -1.0, 1.0
    else:
        coef_lo, coef_hi = -1.0, 1.0

    # Robust limits for residual (delta) axis
    if ds_all:
        alld = np.concatenate(ds_all)
        alld = alld[np.isfinite(alld)]
        if alld.size:
            qlo, qhi = robust_quantiles
            lo_d, hi_d = np.quantile(alld, [qlo, qhi])
            md = float(max(abs(lo_d), abs(hi_d)))
            if md <= 0:
                md = float(np.max(np.abs(alld))) if alld.size else 1.0
            pad_d = 0.02 * (2 * md + 1e-12)
            d_lo, d_hi = -md - pad_d, md + pad_d
        else:
            d_lo, d_hi = -1.0, 1.0
    else:
        d_lo, d_hi = -1.0, 1.0

    # Residual-only grid: (analytic - base) vs base.
    #
    # NOTE: We intentionally do not render the "agreement" panels (with diagonal
    # line) anymore; those were visually redundant once the residual view exists.
    fig, axes = plt.subplots(nrows, ncols, figsize=figsize, sharex=True, sharey=True)
    axes_arr = np.array(axes).ravel()

    for i, (r, x, y) in enumerate(per_panel):
        ax = axes_arr[i]
        if x is None or y is None:
            ax.axis("off")
            continue

        d = y - x
        ax.scatter(x, d, s=10, alpha=0.35)
        ax.axhline(0.0, linestyle="--", linewidth=1)
        ax.set_xlim(coef_lo, coef_hi)
        ax.set_ylim(d_lo, d_hi)

        # Log axes don't make sense with signed coefficients / residuals.
        _ = use_log_axes
        _ = include_residual

        ax.text(
            0.05,
            0.93,
            f"$n/m={int(r)}$",
            transform=ax.transAxes,
            ha="left",
            va="top",
            fontweight="bold",
            bbox=dict(
                boxstyle="round,pad=0.25",
                facecolor="white",
                alpha=0.85,
                edgecolor="none",
            ),
        )

        if show_rel_text:
            diff = np.abs(d)
            diff = diff[np.isfinite(diff)]
            if diff.size:
                ax.text(
                    0.05,
                    0.80,
                    f"p99 |Δ|: {np.quantile(diff, 0.99):.3g}\nmax |Δ|: {diff.max():.3g}",
                    transform=ax.transAxes,
                    ha="left",
                    va="top",
                    bbox=dict(
                        boxstyle="round,pad=0.20",
                        facecolor="white",
                        alpha=0.80,
                        edgecolor="none",
                    ),
                )

    # Turn off unused axes
    for j in range(len(per_panel), len(axes_arr)):
        axes_arr[j].axis("off")

    fig.tight_layout()
    # Add extra margin so labels don't overlap tick labels or get clipped.
    fig.subplots_adjust(left=0.18, bottom=0.18)

    # Center labels relative to the plotted axes grid, not the full figure.
    fig.canvas.draw()
    visible_axes = [ax for ax in axes_arr[: len(per_panel)] if ax.get_visible()]
    if visible_axes:
        left = min(ax.get_position().x0 for ax in visible_axes)
        right = max(ax.get_position().x1 for ax in visible_axes)
        bottom = min(ax.get_position().y0 for ax in visible_axes)
        top = max(ax.get_position().y1 for ax in visible_axes)
    else:
        left, right, bottom, top = 0.1, 0.9, 0.1, 0.9

    x_center = 0.5 * (left + right)
    y_center = 0.5 * (bottom + top)

    renderer = fig.canvas.get_renderer()
    inv = fig.transFigure.inverted()
    x_tick_bboxes = []
    y_tick_bboxes = []
    for ax in visible_axes:
        for label in ax.get_xticklabels():
            if label.get_visible():
                x_tick_bboxes.append(
                    inv.transform_bbox(label.get_window_extent(renderer))
                )
        for label in ax.get_yticklabels():
            if label.get_visible():
                y_tick_bboxes.append(
                    inv.transform_bbox(label.get_window_extent(renderer))
                )

    gap = 0.01
    if x_tick_bboxes:
        x_label_y = min(b.y0 for b in x_tick_bboxes) - gap
    else:
        x_label_y = bottom - 0.04
    if y_tick_bboxes:
        y_label_x = min(b.x0 for b in y_tick_bboxes) - gap
    else:
        y_label_x = left - 0.04

    fig.text(
        x_center,
        x_label_y,
        r"$\hat\beta^{\mathrm{HD}}_{GMM}$",
        ha="center",
        va="top",
        clip_on=False,
    )
    fig.text(
        y_label_x,
        y_center,
        r"$\hat\beta^{\mathrm{HD}, analytic}_{GMM} - \hat\beta^{\mathrm{HD}}_{GMM}$",
        rotation="vertical",
        va="center",
        ha="right",
        clip_on=False,
    )
    return fig

