"""Visualization for case_r and related varying-r scenarios."""

from __future__ import annotations

import ast
from pathlib import Path
from typing import Sequence

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.lines import Line2D
from matplotlib.ticker import MaxNLocator

METHOD_STYLES = {
    "opt": {"linestyle": "-", "marker": "s", "color": "#ff7f0e", "lw": 1.4},
    "mse": {"linestyle": "--", "marker": "o", "color": "#1f77b4", "lw": 1.4},
    "consvar": {"linestyle": "-.", "marker": "^", "color": "#2ca02c", "lw": 1.4},
    "tgt": {"linestyle": "--", "marker": None, "color": "black", "lw": 1.4},
    "dpsgd": {"linestyle": "-", "marker": "v", "color": "#9467bd", "lw": 1.4},
}
NAME_MAP = {
    "opt": "ADP(decay)",
    "mse": "ADP(1)",
    "consvar": "ADP(consvar)",
    "tgt": "Target",
    "dpsgd": "DPSGD",
}
METRICS = ["rmse", "coverage", "CIlen"]
YLABELS = {"rmse": "log MSE", "coverage": "ECP", "CIlen": "Length of CI"}
FACECOLOR = "#f5f5ff"


def _build_case_r_suffix(dist_type: str, tau: float | None, target_r: float | None) -> str:
    parts = [dist_type]
    if tau is not None:
        parts.append(f"tau{_format_numeric(tau)}")
    if target_r is not None:
        parts.append(f"targetR{_format_numeric(target_r)}")
    return "_".join(parts)


def _format_numeric(value: float) -> str:
    text = f"{value:.6g}"
    if "." in text:
        text = text.rstrip("0").rstrip(".")
    return text


def _source_r(rs_value) -> float:
    if isinstance(rs_value, (list, tuple, np.ndarray)):
        vec = rs_value
    else:
        try:
            vec = ast.literal_eval(rs_value)
        except (ValueError, SyntaxError, TypeError):
            raise ValueError(f"Unable to parse rs value: {rs_value!r}")
    if len(vec) < 2:
        raise ValueError("rs vector must include target and at least one source entry.")
    return float(vec[1])


def _resolve_case_r_file(input_dir: Path, case: str, dist_type: str, tau: float, target_r: float | None) -> Path:
    template = (
        "case_r_dist_normal_r_1.csv" if case == "case_r" else "case_r_dist_normal_r_1_consvar.csv"
    )
    suffix = _build_case_r_suffix(dist_type, tau, target_r)
    candidate = input_dir / f"{case}_{suffix}.csv"
    if candidate.exists():
        return candidate
    fallback = input_dir / template
    if fallback.exists():
        return fallback
    raise FileNotFoundError(candidate)


def plot_case_r(
    case: str,
    input_dir: str | Path,
    output_dir: str | Path,
    tau: float,
    dist_type: str = "normal",
    keep_samples: Sequence[int] | None = None,
    target_r: float | None = None,
    methods: Sequence[str] | None = None,
    metrics: Sequence[str] = METRICS,
    figure_name: str | None = None,
):
    """Plot varying-r experiments (case_r / case_r_consvar)."""

    input_dir = Path(input_dir)
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    csv_path = _resolve_case_r_file(input_dir, case, dist_type, tau, target_r)
    df = pd.read_csv(csv_path)
    df = df[df["tau"] == tau].copy()
    rmse_cols = [c for c in df.columns if c.endswith("_rmse")]
    if rmse_cols:
        df[rmse_cols] = np.log(df[rmse_cols] ** 2)
    if df.empty:
        raise ValueError(f"No rows found for tau={tau} in {csv_path.name}")
    df["source_r"] = df["rs"].apply(_source_r)
    df.sort_values("source_r", inplace=True)

    available_methods = []
    target_methods = (
        list(methods)
        if methods is not None
        else ["mse", "opt", "consvar", "dpsgd", "tgt"]
    )
    for m in target_methods:
        if all(f"{m}_{metric}" in df.columns for metric in metrics):
            available_methods.append(m)
    if not available_methods:
        raise ValueError("No plotting methods available in the CSV.")

    if figure_name is None:
        suffix = _build_case_r_suffix(dist_type, tau, target_r)
        figure_name = f"{case}_{suffix}.pdf"

    plt.rcParams.update(
        {
            "font.size": 12,
            "axes.titlesize": 12,
            "axes.labelsize": 12,
            "xtick.labelsize": 10,
            "ytick.labelsize": 10,
            "legend.fontsize": 10,
        }
    )

    fig, axes = plt.subplots(
        1,
        len(metrics),
        figsize=(2.8 * len(metrics), 2.8),
        squeeze=False,
        sharex=True,
    )
    axes_row = axes[0]

    for ax, metric in zip(axes_row, metrics):
        for m in available_methods:
            style = METHOD_STYLES.get(m, {"linestyle": "-", "marker": "o", "color": "gray", "lw": 1.4})
            ax.plot(
                df["source_r"],
                df[f"{m}_{metric}"],
                color=style["color"],
                linestyle=style["linestyle"],
                marker=style.get("marker"),
                markersize=4 if style.get("marker") else 0,
                linewidth=style["lw"],
            )
        ax.set_facecolor(FACECOLOR)
        ax.grid(True, color="white", zorder=-1)
        ax.tick_params(axis="both", length=0)
        ax.yaxis.set_major_locator(MaxNLocator(nbins=4))
        for side in ("top", "right", "left", "bottom"):
            ax.spines[side].set_visible(False)
        if metric == "coverage":
            ax.axhline(0.95, ls="--", lw=0.8, color="grey", zorder=0)
            ax.set_ylim(0.6, 1.0)
        ax.set_ylabel(YLABELS.get(metric, metric), labelpad=6)
        ax.set_xlim(df["source_r"].min(), df["source_r"].max())
    axes_row[-1].set_xlabel("Source privacy level r", labelpad=4)
    for ax in axes_row[:-1]:
        ax.set_xlabel("Source privacy level r", labelpad=4)

    handles = []
    labels = []
    for m in available_methods:
        style = METHOD_STYLES.get(m, {"linestyle": "-", "marker": "o", "color": "gray", "lw": 1.4})
        handles.append(
            Line2D(
                [],
                [],
                color=style["color"],
                linestyle=style["linestyle"],
                marker=style.get("marker"),
                markersize=4 if style.get("marker") else 0,
                linewidth=style["lw"],
            )
        )
        labels.append(NAME_MAP.get(m, m))
    fig.legend(
        handles,
        labels,
        ncol=len(handles),
        loc="upper center",
        bbox_to_anchor=(0.5, 1.05),
        frameon=False,
        borderaxespad=0.1,
        handletextpad=0.4,
    )
    fig.subplots_adjust(top=0.82, wspace=0.35, left=0.08, right=0.99, bottom=0.18)

    out_path = output_dir / figure_name
    plt.savefig(out_path, dpi=300, bbox_inches="tight")
    plt.close(fig)
    return out_path

