"""Programmatic version of the legacy case_bias plotting notebook."""

from __future__ import annotations

from pathlib import Path
from typing import Iterable, Sequence

import math
import matplotlib.pyplot as plt
import matplotlib.transforms as mtrans
import numpy as np
import pandas as pd
from brokenaxes import BrokenAxes, brokenaxes
from matplotlib.lines import Line2D
from matplotlib.ticker import MaxNLocator

# Legend order: ADP(1), ADP(decay), ADP(cons), IVW, DPSGD, Target
DEFAULT_METHODS = ["mse", "opt", "cons", "inv", "dpsgd", "tgt"]
DEFAULT_NAME_MAP = {
    "opt": "ADP(decay)",
    "tgt": "Target",
    "mse": "ADP(1)",
    "inv": "IVW",
    "cons": "ADP(cons)",
    "dpsgd": "DPSGD",
}
DEFAULT_METRICS = ["rmse", "coverage", "CIlen"]

# Method-specific style mapping
METHOD_STYLES = {
    "mse": {"linestyle": "--", "marker": "o", "color": "#1f77b4", "lw": 1.4},
    "opt": {"linestyle": "-", "marker": "s", "color": "#ff7f0e", "lw": 1.4},
    "cons": {"linestyle": "-.", "marker": "^", "color": "#2ca02c", "lw": 1.4},
    "inv": {"linestyle": ":", "marker": "D", "color": "#d62728", "lw": 1.4},
    "tgt": {"linestyle": "--", "marker": None, "color": "black", "lw": 1.4},
    "dpsgd": {"linestyle": "-", "marker": "v", "color": "#9467bd", "lw": 1.4},
}
YLABELS = dict(rmse="log MSE", coverage="ECP", CIlen="Length of CI")
YLIMS = dict(CIlen=(0.01, 0.115), coverage=[(-0.08, 0.70), (0.80, 1.0)], rmse=[(-12.9, -7.4), (-6.5, -0.4)])
DIST_YLIM_OVERRIDES = {
    "cauchy": {
        "rmse": [(-13.2, -7.2), (-7.0, 0.5)],
        "CIlen": (0.015, 0.28),
    },
}
RATIOS = dict(coverage=[3, 1], rmse=[1, 3])
FACECOLOR = "#f5f5ff"


def _format_tau_label(value: float) -> str:
    text = f"{value:.6g}"
    if "." in text:
        text = text.rstrip("0").rstrip(".")
    return text


def _resolve_case_bias_file(
    input_dir: Path, base_name: str, dist_type: str, tau: float, r_value: float | None = None
) -> Path:
    tau_label = _format_tau_label(tau)
    if r_value is not None:
        r_label = _format_tau_label(r_value)
        candidate = input_dir / f"{base_name}_{dist_type}_tau{tau_label}_r{r_label}.csv"
        if candidate.exists():
            return candidate
    candidate = input_dir / f"{base_name}_{dist_type}_tau{tau_label}.csv"
    if candidate.exists():
        return candidate
    return input_dir / f"{base_name}.csv"


def _axes_iter(obj):
    return obj.axs if isinstance(obj, BrokenAxes) else [obj]


def _axis_limits(obj):
    if isinstance(obj, BrokenAxes):
        return tuple(
            (round(ax.get_ylim()[0], 6), round(ax.get_ylim()[1], 6))
            for ax in obj.axs
        )
    bottom, top = obj.get_ylim()
    return (round(bottom, 6), round(top, 6))


def _add_ylabel(fig, obj, text, x_shift=-0.133, extra_shift=0.0):
    if isinstance(obj, BrokenAxes):
        bb = mtrans.Bbox.union([ax.get_position() for ax in obj.axs])
        y_mid = (bb.y0 + bb.y1) / 2 + extra_shift
        fig.text(
            bb.x0 + x_shift,
            y_mid,
            text,
            ha="right",
            va="center",
            rotation="vertical",
        )
    else:
        obj.set_ylabel(text, labelpad=2)


def _metric_value_range(df: pd.DataFrame, methods: Sequence[str], metric: str):
    cols = [f"{m}_{metric}" for m in methods if f"{m}_{metric}" in df.columns]
    if not cols:
        return None
    values = df[cols].to_numpy().ravel()
    values = values[~np.isnan(values)]
    if values.size == 0:
        return None
    return float(values.min()), float(values.max())


def _pad_extent(value: float) -> float:
    magnitude = max(1.0, abs(value))
    return 0.02 * magnitude


def _span_limits(data_min: float, data_max: float) -> tuple[float, float]:
    span = max(data_max - data_min, 1e-6)
    pad = max(0.05 * span, 0.02 * max(1.0, abs(data_min), abs(data_max)))
    return data_min - pad, data_max + pad


def _base_limits(metric: str, dist_type: str | None = None):
    if dist_type:
        override = DIST_YLIM_OVERRIDES.get(dist_type, {}).get(metric)
        if override is not None:
            return override
    return YLIMS[metric]


def _resolve_ylims(base_limits, data_range):
    if not data_range:
        return base_limits
    # Metrics like rmse / coverage use broken axes (list of ranges)
    if isinstance(base_limits[0], (list, tuple)):
        spans = [list(span) for span in base_limits]
        data_min, data_max = data_range

        def overlaps(span):
            return data_max >= span[0] and data_min <= span[1]

        adjusted = False
        for span in spans:
            if overlaps(span):
                adjusted = True
                if data_min < span[0]:
                    span[0] = data_min - _pad_extent(data_min)
                if data_max > span[1]:
                    span[1] = data_max + _pad_extent(data_max)
        if adjusted:
            return [tuple(span) for span in spans]

        if data_max < spans[0][0]:
            spans[0][0], spans[0][1] = _span_limits(data_min, data_max)
            return [tuple(span) for span in spans]
        if data_min > spans[-1][1]:
            spans[-1][0], spans[-1][1] = _span_limits(data_min, data_max)
            return [tuple(span) for span in spans]

        for idx in range(len(spans) - 1):
            upper = spans[idx][1]
            lower = spans[idx + 1][0]
            if data_min > upper and data_max < lower:
                left_gap = data_min - upper
                right_gap = lower - data_max
                if left_gap <= right_gap:
                    spans[idx][1] = data_max + _pad_extent(data_max)
                else:
                    spans[idx + 1][0] = data_min - _pad_extent(data_min)
                break
        return [tuple(span) for span in spans]

    data_min, data_max = data_range
    base_min, base_max = base_limits
    if data_min >= base_min and data_max <= base_max:
        return base_limits
    return _span_limits(data_min, data_max)


def _single_axis_limits(metric: str, data_range, base_limits):
    if data_range:
        return _span_limits(*data_range)
    if metric == "rmse":
        return base_limits[0][0], base_limits[-1][1]
    return base_limits


def _prepare_dataframe(
    df_path: Path,
    target_path: Path,
    tau: float,
    keep_samples: Sequence[int] | None,
) -> pd.DataFrame:
    df = pd.read_csv(df_path)
    df_target = pd.read_csv(target_path)
    key_cols = ["n_samples", "source_prop", "tau", "bias_id"]
    df_all = pd.merge(df, df_target, on=key_cols, how="inner")
    rmse_cols = [c for c in df_all.columns if c.endswith("_rmse")]
    df_all[rmse_cols] = np.log(df_all[rmse_cols] ** 2)
    df_all = df_all[df_all["tau"] == tau]
    if keep_samples:
        df_all = df_all[df_all["n_samples"].isin(keep_samples)]
    return df_all


def _prepare_single_case_dataframe(
    df_path: Path,
    tau: float,
    keep_samples: Sequence[int] | None,
) -> pd.DataFrame:
    df = pd.read_csv(df_path)
    rmse_cols = [c for c in df.columns if c.endswith("_rmse")]
    df[rmse_cols] = np.log(df[rmse_cols] ** 2)
    df = df[df["tau"] == tau]
    if keep_samples:
        df = df[df["n_samples"].isin(keep_samples)]
    return df


def plot_case_bias(
    input_dir: str | Path,
    output_dir: str | Path,
    tau: float = 0.25,
    keep_samples: Sequence[int] | None = (20000, 200000),
    methods: Sequence[str] = DEFAULT_METHODS,
    metrics: Sequence[str] = DEFAULT_METRICS,
    name_map: dict[str, str] | None = None,
    dist_type: str = "normal",
    r_value: float | None = None,
    figure_name: str | None = None,
):
    """
    Reproduce the figure previously generated via plot_case1_submit.ipynb.
    """
    input_dir = Path(input_dir)
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    name_map = name_map or DEFAULT_NAME_MAP

    df_path = _resolve_case_bias_file(input_dir, "case_bias", dist_type, tau, r_value=r_value)
    target_path = _resolve_case_bias_file(
        input_dir, "case_bias_target", dist_type, tau, r_value=r_value
    )
    df_all = _prepare_dataframe(df_path, target_path, tau=tau, keep_samples=keep_samples)
    n_samples_vals = sorted(df_all["n_samples"].unique())
    if not n_samples_vals:
        raise ValueError("No rows remain after filtering; check tau/keep_samples.")
    
    # Filter methods to only those present for all requested metrics
    available_methods = []
    for m in methods:
        if all(f"{m}_{metric}" in df_all.columns for metric in metrics):
            available_methods.append(m)
    methods = available_methods

    plt.rcParams.update(
        {
            "font.size": 12,
            "axes.titlesize": 12,
            "axes.labelsize": 12,
            "xtick.labelsize": 10,
            "ytick.labelsize": 10,
            "legend.fontsize": 10,
        }
    )

    if figure_name is None:
        suffix = f"{dist_type}_tau{_format_tau_label(tau)}"
        if r_value is not None:
            suffix += f"_r{_format_tau_label(r_value)}"
        figure_name = f"case_bias_{suffix}.pdf"

    use_legacy_all = (
        r_value is None
        or math.isclose(r_value, 0.5, rel_tol=1e-9, abs_tol=1e-9)
    )

    fig = plt.figure(figsize=(2.6 * len(n_samples_vals), 2.4 * len(metrics)))
    gs = fig.add_gridspec(len(metrics), len(n_samples_vals))
    first_column_limits: dict[str, tuple] = {}

    for c, n_samp in enumerate(n_samples_vals):
        sub = df_all.query("n_samples == @n_samp").sort_values("bias_id")
        for r, metric in enumerate(metrics):
            metric_range = None
            if not use_legacy_all and metric in ("rmse", "CIlen"):
                metric_range = _metric_value_range(sub, methods, metric)
            base_limits = _base_limits(metric, dist_type)

            allow_rmse_broken = (
                metric == "rmse" and use_legacy_all and dist_type != "cauchy"
            )
            use_broken = metric == "coverage" or allow_rmse_broken

            if use_broken:
                ylims = _resolve_ylims(base_limits, None)
                axp = brokenaxes(
                    ylims=ylims,
                    height_ratios=RATIOS[metric],
                    subplot_spec=gs[r, c],
                    despine=True,
                    diag_color="none",
                )
            else:
                axp = fig.add_subplot(gs[r, c])

            for m in methods:
                style = METHOD_STYLES.get(m, {"linestyle": "-", "marker": "o", "color": "gray", "lw": 1.4})
                axp.plot(
                    sub["bias_id"],
                    sub[f"{m}_{metric}"],
                    color=style["color"],
                    linestyle=style["linestyle"],
                    marker=style.get("marker"),
                    markersize=4 if style.get("marker") is not None else 0,
                    linewidth=style["lw"],
                )

            if metric == "rmse" and not use_broken:
                rmse_lims = _single_axis_limits(metric, metric_range, base_limits)
                if rmse_lims:
                    axp.set_ylim(*rmse_lims)
            if metric == "CIlen":
                ci_lims = _single_axis_limits(
                    metric,
                    metric_range if not use_legacy_all else None,
                    base_limits,
                )
                if ci_lims:
                    axp.set_ylim(*ci_lims)

            for a in _axes_iter(axp):
                a.set_facecolor(FACECOLOR)
                a.grid(True, color="white", zorder=-1)
                a.tick_params(axis="both", length=0)
                a.yaxis.set_major_locator(MaxNLocator(nbins=4))
                for side in ("top", "right", "left", "bottom"):
                    a.spines[side].set_visible(False)
                if metric == "coverage":
                    a.axhline(0.95, ls="--", lw=0.8, color="grey", zorder=0)

            current_limits = _axis_limits(axp)

            if c == 0:
                first_column_limits[metric] = current_limits
                _add_ylabel(
                    fig,
                    axp,
                    YLABELS[metric],
                    extra_shift=0.03 if metric == "rmse" else 0.0,
                )
            else:
                same_scale = first_column_limits.get(metric) == current_limits
                if same_scale:
                    for a in _axes_iter(axp):
                        a.tick_params(axis="y", left=False, labelleft=False)
                else:
                    for a in _axes_iter(axp):
                        a.tick_params(axis="y", left=True, labelleft=True)
                    first_column_limits[metric] = current_limits

            for a in _axes_iter(axp):
                if r == len(metrics) - 1:
                    a.xaxis.set_major_locator(MaxNLocator(integer=True))
                    a.set_xlabel("Bias Level", labelpad=1)
                else:
                    a.tick_params(axis="x", bottom=False, labelbottom=False)

            if r == 0:
                axp.set_title(rf"$\,n_0 = {{{n_samp}}}\,$", pad=2)

    handles = []
    labels = []
    for m in methods:
        style = METHOD_STYLES.get(m, {"linestyle": "-", "marker": "o", "color": "gray", "lw": 1.4})
        handles.append(
            Line2D(
                [],
                [],
                color=style["color"],
                linestyle=style["linestyle"],
                marker=style.get("marker"),
                markersize=4 if style.get("marker") is not None else 0,
                linewidth=style["lw"],
            )
        )
        labels.append(name_map.get(m, m))
    fig.legend(
        handles,
        labels,
        ncol=len(methods),
        loc="upper center",
        bbox_to_anchor=(0.5, 1.0),
        frameon=False,
        borderaxespad=0.1,
        handletextpad=0.4,
    )
    plt.subplots_adjust(wspace=0.14, hspace=0.12, top=0.93, left=0.08, right=0.985, bottom=0.08)

    out_path = output_dir / figure_name
    plt.savefig(out_path, dpi=300, bbox_inches="tight")
    plt.close(fig)
    return out_path


def plot_case_bias_smallk(
    input_dir: str | Path,
    output_dir: str | Path,
    tau: float = 0.25,
    keep_samples: Sequence[int] | None = None,
    methods: Sequence[str] = DEFAULT_METHODS,
    metrics: Sequence[str] = DEFAULT_METRICS,
    name_map: dict[str, str] | None = None,
    dist_type: str = "normal",
    r_value: float | None = None,
    figure_name: str | None = None,
):
    """Single-row visualization for the small-K scenario."""

    input_dir = Path(input_dir)
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    name_map = name_map or DEFAULT_NAME_MAP

    df_path = _resolve_case_bias_file(
        input_dir, "case_bias_smallK", dist_type, tau, r_value=r_value
    )
    df_all = _prepare_single_case_dataframe(df_path, tau=tau, keep_samples=keep_samples)
    if df_all.empty:
        raise ValueError("No rows remain after filtering; check tau/keep_samples.")

    available_methods = []
    for m in methods:
        if all(f"{m}_{metric}" in df_all.columns for metric in metrics):
            available_methods.append(m)
    methods = available_methods
    if not methods:
        raise ValueError("No plotting methods available in the CSV.")

    if figure_name is None:
        suffix = f"{dist_type}_tau{_format_tau_label(tau)}"
        if r_value is not None:
            suffix += f"_r{_format_tau_label(r_value)}"
        figure_name = f"case_bias_smallK_{suffix}.pdf"

    plt.rcParams.update(
        {
            "font.size": 12,
            "axes.titlesize": 12,
            "axes.labelsize": 12,
            "xtick.labelsize": 10,
            "ytick.labelsize": 10,
            "legend.fontsize": 10,
        }
    )

    sub = df_all.sort_values("bias_id")
    x_vals = sub["bias_id"].values
    x_ticks = sub["bias_id"].unique()
    fig, axes = plt.subplots(
        1,
        len(metrics),
        figsize=(2.9 * len(metrics), 2.7),
        squeeze=False,
        sharex=True,
    )
    axes_row = axes[0]

    for ax, metric in zip(axes_row, metrics):
        metric_range = _metric_value_range(sub, methods, metric)
        base_limits = _base_limits(metric, dist_type)
        for m in methods:
            style = METHOD_STYLES.get(
                m, {"linestyle": "-", "marker": "o", "color": "gray", "lw": 1.4}
            )
            ax.plot(
                x_vals,
                sub[f"{m}_{metric}"],
                color=style["color"],
                linestyle=style["linestyle"],
                marker=style.get("marker"),
                markersize=4 if style.get("marker") is not None else 0,
                linewidth=style["lw"],
            )
        ax.set_facecolor(FACECOLOR)
        ax.grid(True, color="white", zorder=-1)
        ax.tick_params(axis="both", length=0)
        ax.yaxis.set_major_locator(MaxNLocator(nbins=4))
        for side in ("top", "right", "left", "bottom"):
            ax.spines[side].set_visible(False)
        if metric == "coverage":
            ax.axhline(0.95, ls="--", lw=0.8, color="grey", zorder=0)
            ax.set_ylim(0.6, 1.0)
        else:
            limits = _single_axis_limits(metric, metric_range, base_limits)
            if limits:
                ax.set_ylim(*limits)
        ax.set_xlabel("Bias Level", labelpad=3)
        ax.set_ylabel(YLABELS.get(metric, metric), labelpad=6)
        ax.set_xticks(x_ticks)

    handles = []
    labels = []
    for m in methods:
        style = METHOD_STYLES.get(m, {"linestyle": "-", "marker": "o", "color": "gray", "lw": 1.4})
        handles.append(
            Line2D(
                [],
                [],
                color=style["color"],
                linestyle=style["linestyle"],
                marker=style.get("marker"),
                markersize=4 if style.get("marker") is not None else 0,
                linewidth=style["lw"],
            )
        )
        labels.append(name_map.get(m, m))
    fig.legend(
        handles,
        labels,
        ncol=len(methods),
        loc="upper center",
        bbox_to_anchor=(0.5, 1.05),
        frameon=False,
        borderaxespad=0.1,
        handletextpad=0.4,
    )
    fig.subplots_adjust(top=0.8, wspace=0.28, left=0.075, right=0.99, bottom=0.2)

    out_path = output_dir / figure_name
    plt.savefig(out_path, dpi=300, bbox_inches="tight")
    plt.close(fig)
    return out_path

