"""Programmatic plotting for the general-bias (gbias) experiment."""

from __future__ import annotations

import math
from pathlib import Path
from typing import Sequence

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.lines import Line2D
from matplotlib.ticker import MaxNLocator

from .case_bias import _format_tau_label

DEFAULT_METHODS = ["mse", "opt", "cons", "tgt"]
NAME_MAP = {
    "mse": "ADP(1)",
    "opt": "ADP(decay)",
    "cons": "ADP(cons)",
    "tgt": "Target",
}

METHOD_STYLES = {
    "mse": {"linestyle": "--", "marker": "o", "color": "#1f77b4", "lw": 1.4},
    "opt": {"linestyle": "-", "marker": "s", "color": "#ff7f0e", "lw": 1.4},
    "cons": {"linestyle": "-.", "marker": "^", "color": "#2ca02c", "lw": 1.4},
    "tgt": {"linestyle": "--", "marker": None, "color": "black", "lw": 1.4},
}

METRICS = ["rmse", "coverage", "CIlen"]
YLABELS = {"rmse": "log MSE", "coverage": "ECP", "CIlen": "Length of CI"}
FACECOLOR = "#f5f5ff"


def _resolve_gbias_files(
    input_dir: Path, dist_type: str, tau_values: Sequence[float], r_value: float | None
) -> list[Path]:
    def _build_name(vals: Sequence[float], prefix: str = "case_gbias") -> Path:
        tau_suffix = "_".join(_format_tau_label(val) for val in vals)
        name = f"{prefix}_{dist_type}_tau{tau_suffix}"
        if r_value is not None:
            name += f"_r{_format_tau_label(r_value)}"
        return input_dir / f"{name}.csv"

    combined_path = _build_name(tau_values)
    if combined_path.exists():
        return [combined_path]

    individual_paths: list[Path] = []
    for val in tau_values:
        path = _build_name([val])
        if not path.exists():
            individual_paths = []
            break
        individual_paths.append(path)
    if individual_paths:
        return individual_paths

    legacy = input_dir / "gbias.csv"
    if legacy.exists():
        return [legacy]
    raise FileNotFoundError(combined_path)


def _resolve_gbias_target_files(
    input_dir: Path, dist_type: str, tau_values: Sequence[float], r_value: float | None
) -> list[Path]:
    def _build_name(vals: Sequence[float]) -> Path:
        tau_suffix = "_".join(_format_tau_label(val) for val in vals)
        name = f"case_gbias_target_{dist_type}_tau{tau_suffix}"
        if r_value is not None:
            name += f"_r{_format_tau_label(r_value)}"
        return input_dir / f"{name}.csv"

    combined_path = _build_name(tau_values)
    if combined_path.exists():
        return [combined_path]

    individual_paths: list[Path] = []
    for val in tau_values:
        path = _build_name([val])
        if not path.exists():
            individual_paths = []
            break
        individual_paths.append(path)
    if individual_paths:
        return individual_paths

    legacy = input_dir / "gbias_target.csv"
    if legacy.exists():
        return [legacy]
    return []


def _resolve_gbias_changesites_file(
    input_dir: Path,
    dist_type: str,
    tau: float,
    n_sites: int | None,
) -> Path:
    tau_label = _format_tau_label(tau)
    name = f"case_gbias_changesites_{dist_type}_tau{tau_label}"
    if n_sites is not None:
        name += f"_nsites{n_sites}"
    candidate = input_dir / f"{name}.csv"
    if candidate.exists():
        return candidate
    fallback = input_dir / "case_gbias_changesites.csv"
    if fallback.exists():
        return fallback
    raise FileNotFoundError(candidate)


def _load_dataframe(
    input_dir: Path,
    dist_type: str,
    tau_values: Sequence[float],
    r_value: float | None,
    keep_samples: Sequence[int] | None,
) -> pd.DataFrame:
    df_paths = _resolve_gbias_files(input_dir, dist_type, tau_values, r_value)
    df = pd.concat((pd.read_csv(path) for path in df_paths), ignore_index=True)
    target_paths = _resolve_gbias_target_files(input_dir, dist_type, tau_values, r_value)
    if target_paths:
        df_target = pd.concat((pd.read_csv(path) for path in target_paths), ignore_index=True)
        key_cols = ["n_samples", "source_prop", "tau", "bias_id"]
        df = pd.merge(df, df_target, on=key_cols, how="inner")
    rmse_cols = [c for c in df.columns if c.endswith("_rmse")]
    df[rmse_cols] = np.log(df[rmse_cols] ** 2)
    if keep_samples:
        df = df[df["n_samples"].isin(keep_samples)]
    return df


def _bias_values_for_sample_size(n_samples: int, num_points: int) -> np.ndarray:
    """Mirror the bias grid used in the simulations."""
    if n_samples >= 200000:
        start_exp = -5.0 - math.log(4.0)
        end_val = 0.25
    else:
        start_exp = -5.0
        end_val = 1.0
    return np.logspace(start_exp, math.log(end_val), num_points, base=np.e)


def plot_case_gbias(
    input_dir: str | Path,
    output_dir: str | Path,
    tau: float = 0.25,
    taus: Sequence[float] | None = None,
    dist_type: str = "normal",
    keep_samples: Sequence[int] | None = None,
    r_value: float | None = None,
    methods: Sequence[str] = DEFAULT_METHODS,
    metrics: Sequence[str] = METRICS,
    figure_name: str | None = None,
):
    """Reproduce the legacy general-bias figure (without broken axes)."""
    input_dir = Path(input_dir)
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    tau_values = list(taus) if taus else [tau]
    df = _load_dataframe(
        input_dir,
        dist_type=dist_type,
        tau_values=tau_values,
        r_value=r_value,
        keep_samples=keep_samples,
    )
    df = df[df["tau"].isin(tau_values)]
    if df.empty:
        raise ValueError("No rows found for the given tau / keep_samples configuration.")

    n_samples_vals = sorted(df["n_samples"].unique())
    if not n_samples_vals:
        raise ValueError("No rows found for the given tau / keep_samples configuration.")

    available_methods = []
    for m in methods:
        if all(f"{m}_{metric}" in df.columns for metric in metrics):
            available_methods.append(m)
    methods = available_methods
    if not methods:
        raise ValueError("No plotting methods available in the provided CSV files.")

    plt.rcParams.update(
        {
            "font.size": 12,
            "axes.titlesize": 12,
            "axes.labelsize": 12,
            "xtick.labelsize": 10,
            "ytick.labelsize": 10,
            "legend.fontsize": 10,
        }
    )

    max_bias = int(df["bias_id"].max())
    bias_maps = {}
    for n in n_samples_vals:
        bias_vals = _bias_values_for_sample_size(n, max_bias)
        bias_maps[n] = {idx + 1: val for idx, val in enumerate(bias_vals)}
    df["bias_value"] = df.apply(lambda row: bias_maps[row["n_samples"]][int(row["bias_id"])], axis=1)
    df["log_bias"] = df["bias_value"].apply(np.log)
    multi_samples = len(n_samples_vals) > 1
    multi_taus = len(tau_values) > 1
    if multi_samples and multi_taus:
        raise ValueError("Plotting multiple taus and sample sizes simultaneously is not supported.")
    column_mode = "sample" if multi_samples and not multi_taus else "tau"

    n_rows = len(metrics)
    n_cols = len(n_samples_vals) if column_mode == "sample" else len(tau_values)
    fig, axes = plt.subplots(
        n_rows,
        n_cols,
        figsize=(3.1 * n_cols if column_mode == "sample" else 2.45 * n_cols, 2.3 * n_rows),
        squeeze=False,
    )

    column_iter = n_samples_vals if column_mode == "sample" else list(tau_values)
    for c, col_value in enumerate(column_iter):
        if column_mode == "sample":
            sample = col_value
            tau_val = tau_values[0]
            subset = df.query("n_samples == @sample and tau == @tau_val").sort_values("bias_id")
        else:
            tau_val = col_value
            sample = n_samples_vals[0]
            subset = df.query("n_samples == @sample and tau == @tau_val").sort_values("bias_id")
        if subset.empty:
            raise ValueError(f"No rows found for tau={tau_val} and n_samples={sample}.")
        for r, metric in enumerate(metrics):
            ax = axes[r, c]
            for m in methods:
                style = METHOD_STYLES.get(m, {"linestyle": "-", "marker": "o", "color": "gray", "lw": 1.4})
                ax.plot(
                    subset["log_bias"],
                    subset[f"{m}_{metric}"],
                    color=style["color"],
                    linestyle=style["linestyle"],
                    marker=style.get("marker"),
                    markersize=4 if style.get("marker") else 0,
                    linewidth=style["lw"],
                )

            ax.set_facecolor(FACECOLOR)
            ax.grid(True, color="white", zorder=-1)
            ax.tick_params(axis="both", length=0)
            ax.yaxis.set_major_locator(MaxNLocator(nbins=4))
            for side in ("top", "right", "left", "bottom"):
                ax.spines[side].set_visible(False)
            if metric == "coverage":
                ax.axhline(0.95, ls="--", lw=0.8, color="grey", zorder=0)
                ax.set_ylim(0.6, 1.0)

            if column_mode == "sample":
                ax.set_ylabel(YLABELS[metric] if c == 0 else "", labelpad=6 if c == 0 else 2)
                if c != 0:
                    ax.tick_params(axis="y", left=True, labelleft=True)
            else:
                if c == 0:
                    ax.set_ylabel(YLABELS[metric], labelpad=6)
                else:
                    ax.tick_params(axis="y", left=False, labelleft=False)

            if r == n_rows - 1:
                ax.set_xticks(subset["log_bias"])
                ax.set_xticklabels(
                    subset["bias_value"].map("{:.2g}".format),
                    rotation=45,
                    ha="center",
                    fontsize=8,
                )
                ax.set_xlabel("Bias", labelpad=2)
            else:
                ax.tick_params(axis="x", bottom=False, labelbottom=False)

            if r == 0:
                if column_mode == "tau":
                    ax.set_title(f"τ = {tau_val}", pad=3 if c == 0 else 1.5)
                else:
                    ax.set_title(rf"$\,n_0 = {{{sample}}}\,$", pad=3)

    handles = []
    labels = []
    for m in methods:
        style = METHOD_STYLES.get(m, {"linestyle": "-", "marker": "o", "color": "gray", "lw": 1.4})
        handles.append(
            Line2D(
                [],
                [],
                color=style["color"],
                linestyle=style["linestyle"],
                marker=style.get("marker"),
                markersize=4 if style.get("marker") else 0,
                linewidth=style["lw"],
            )
        )
        labels.append(NAME_MAP.get(m, m))

    fig.legend(
        handles,
        labels,
        ncol=len(methods),
        loc="upper center",
        bbox_to_anchor=(0.5, 0.99),
        frameon=False,
        borderaxespad=0.1,
        handletextpad=0.4,
    )
    fig.align_ylabels(axes[:, 0])
    plt.subplots_adjust(
        wspace=0.16 if column_mode == "sample" else 0.08,
        hspace=0.14,
        top=0.935,
        left=0.09,
        right=0.995,
        bottom=0.14,
    )

    if figure_name is None:
        tau_suffix = "_".join(_format_tau_label(val) for val in tau_values)
        suffix = f"{dist_type}_tau{tau_suffix}"
        if r_value is not None:
            suffix += f"_r{_format_tau_label(r_value)}"
        if keep_samples:
            suffix += "_" + "_".join(str(s) for s in keep_samples)
        figure_name = f"case_gbias_{suffix}.pdf"
    output_path = output_dir / figure_name
    plt.savefig(output_path, dpi=300, bbox_inches="tight")
    plt.close(fig)
    return output_path


def plot_case_gbias_changesites(
    input_dir: str | Path,
    output_dir: str | Path,
    tau: float = 0.5,
    dist_type: str = "normal",
    n_sites: int | None = None,
    keep_samples: Sequence[int] | None = None,
    figure_name: str | None = None,
):
    """Plot the case_gbias_changesites scenario in a 1×3 layout."""

    input_dir = Path(input_dir)
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    csv_path = _resolve_gbias_changesites_file(input_dir, dist_type, tau, n_sites)
    df = pd.read_csv(csv_path)
    df = df[df["tau"] == tau].copy()
    rmse_cols = [c for c in df.columns if c.endswith("_rmse")]
    df[rmse_cols] = np.log(df[rmse_cols] ** 2)
    if keep_samples:
        df = df[df["n_samples"].isin(keep_samples)]
    if df.empty:
        raise ValueError("No rows remain after filtering; check tau/keep_samples.")

    n_samples_vals = sorted(df["n_samples"].unique())
    if len(n_samples_vals) != 1:
        raise ValueError("case_gbias_changesites plot expects a single sample size in the CSV.")
    n_samp = n_samples_vals[0]

    methods = [m for m in ("mse", "opt", "cons", "tgt") if all(f"{m}_{metric}" in df.columns for metric in METRICS)]
    if not methods:
        raise ValueError("Missing ADP/Tgt methods in CSV.")

    plt.rcParams.update(
        {
            "font.size": 12,
            "axes.titlesize": 12,
            "axes.labelsize": 12,
            "xtick.labelsize": 10,
            "ytick.labelsize": 10,
            "legend.fontsize": 10,
        }
    )

    max_bias = int(df["bias_id"].max())
    bias_vals = _bias_values_for_sample_size(n_samp, max_bias)
    bias_map = {idx + 1: val for idx, val in enumerate(bias_vals)}
    df["bias_value"] = df["bias_id"].map(bias_map)
    df["log_bias"] = df["bias_value"].apply(np.log)
    df.sort_values("bias_id", inplace=True)

    fig, axes = plt.subplots(
        1,
        len(METRICS),
        figsize=(3.0 * len(METRICS), 2.5),
        squeeze=False,
        sharex=True,
    )
    axes_row = axes[0]
    for ax, metric in zip(axes_row, METRICS):
        for m in methods:
            style = METHOD_STYLES.get(m, {"linestyle": "-", "marker": "o", "color": "gray", "lw": 1.4})
            ax.plot(
                df["log_bias"],
                df[f"{m}_{metric}"],
                color=style["color"],
                linestyle=style["linestyle"],
                marker=style.get("marker"),
                markersize=4 if style.get("marker") else 0,
                linewidth=style["lw"],
            )
        ax.set_facecolor(FACECOLOR)
        ax.grid(True, color="white", zorder=-1)
        ax.tick_params(axis="both", length=0)
        ax.yaxis.set_major_locator(MaxNLocator(nbins=4))
        for side in ("top", "right", "left", "bottom"):
            ax.spines[side].set_visible(False)
        if metric == "coverage":
            ax.axhline(0.95, ls="--", lw=0.8, color="grey", zorder=0)
            ax.set_ylim(0.6, 1.0)
        ax.set_ylabel(YLABELS[metric], labelpad=6)
        ax.set_xlabel("Bias", labelpad=2)

    for ax in axes_row:
        ax.set_xticks(df["log_bias"])
        ax.set_xticklabels(df["bias_value"].map("{:.3g}".format), rotation=45, ha="center", fontsize=9)

    handles = []
    labels = []
    for m in methods:
        style = METHOD_STYLES.get(m, {"linestyle": "-", "marker": "o", "color": "gray", "lw": 1.4})
        handles.append(
            Line2D(
                [],
                [],
                color=style["color"],
                linestyle=style["linestyle"],
                marker=style.get("marker"),
                markersize=4 if style.get("marker") else 0,
                linewidth=style["lw"],
            )
        )
        labels.append(NAME_MAP.get(m, m))
    fig.legend(
        handles,
        labels,
        ncol=len(methods),
        loc="upper center",
        bbox_to_anchor=(0.5, 1.02),
        frameon=False,
        borderaxespad=0.1,
        handletextpad=0.4,
    )
    fig.subplots_adjust(top=0.84, wspace=0.35, left=0.08, right=0.995, bottom=0.3)

    if figure_name is None:
        suffix = f"{dist_type}_tau{_format_tau_label(tau)}"
        if n_sites is not None:
            suffix += f"_nsites{n_sites}"
        figure_name = f"case_gbias_changesites_{suffix}.pdf"
    output_path = output_dir / figure_name
    plt.savefig(output_path, dpi=300, bbox_inches="tight")
    plt.close(fig)
    return output_path


