"""Matplotlib figure generation for BioDimBench."""

from __future__ import annotations

import os
from pathlib import Path

_MPLCONFIGDIR = Path(__file__).resolve().parents[1] / "outputs" / ".mplconfig"
_MPLCONFIGDIR.mkdir(parents=True, exist_ok=True)
os.environ.setdefault("MPLCONFIGDIR", str(_MPLCONFIGDIR))

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from .utils import ERROR_TYPE_DISPLAY, INVALID_CANDIDATE_TYPES, METHOD_ORDER, method_label


COLORS = ["#4C78A8", "#F58518", "#54A24B", "#E45756", "#72B7B2", "#B279A2"]


def generate_figures(metrics: pd.DataFrame, error_recall: pd.DataFrame, figures_dir: Path) -> None:
    """Create all publication-oriented figures."""

    figures_dir.mkdir(parents=True, exist_ok=True)
    ordered_metrics = _order_metrics(metrics)
    _bar_chart(
        ordered_metrics,
        value_column="invalid_recall",
        ylabel="Invalid recall",
        title="Invalid recall by verifier",
        output_path=figures_dir / "invalid_recall_by_method.png",
    )
    _bar_chart(
        ordered_metrics,
        value_column="accuracy",
        ylabel="Accuracy",
        title="Accuracy by verifier",
        output_path=figures_dir / "accuracy_by_method.png",
    )
    _error_type_chart(error_recall, figures_dir / "error_type_recall.png")


def _order_metrics(metrics: pd.DataFrame) -> pd.DataFrame:
    frame = metrics.copy()
    frame["method_order"] = frame["method"].map({m: i for i, m in enumerate(METHOD_ORDER)}).fillna(99)
    return frame.sort_values(["method_order", "split"]).reset_index(drop=True)


def _bar_chart(metrics: pd.DataFrame, value_column: str, ylabel: str, title: str, output_path: Path) -> None:
    labels = [method_label(row.method, row.split) for row in metrics.itertuples()]
    values = metrics[value_column].astype(float).to_numpy()

    plt.figure(figsize=(8.0, 4.6))
    x = np.arange(len(values))
    plt.bar(x, values, color=COLORS[: len(values)], edgecolor="black", linewidth=0.6)
    plt.xticks(x, labels, rotation=25, ha="right")
    plt.ylim(0, 1.05)
    plt.ylabel(ylabel)
    plt.title(title)
    plt.grid(axis="y", alpha=0.25)
    plt.tight_layout()
    plt.savefig(output_path, dpi=220)
    plt.close()


def _error_type_chart(error_recall: pd.DataFrame, output_path: Path) -> None:
    frame = error_recall.copy()
    frame["method_order"] = frame["method"].map({m: i for i, m in enumerate(METHOD_ORDER)}).fillna(99)
    frame = frame.sort_values(["method_order", "split"])
    method_rows = frame[["method", "split"]].drop_duplicates().to_dict(orient="records")
    method_keys = [(row["method"], row["split"]) for row in method_rows]

    x = np.arange(len(INVALID_CANDIDATE_TYPES))
    width = 0.82 / max(len(method_keys), 1)
    plt.figure(figsize=(10.0, 4.8))

    for idx, (method, split) in enumerate(method_keys):
        values = []
        subset = frame[(frame["method"] == method) & (frame["split"] == split)]
        for candidate_type in INVALID_CANDIDATE_TYPES:
            match = subset[subset["candidate_type"] == candidate_type]
            values.append(float(match["recall"].iloc[0]) if not match.empty else np.nan)
        offset = (idx - (len(method_keys) - 1) / 2) * width
        plt.bar(
            x + offset,
            np.nan_to_num(values, nan=0.0),
            width=width,
            label=method_label(method, split),
            color=COLORS[idx % len(COLORS)],
            edgecolor="black",
            linewidth=0.45,
        )

    labels = [ERROR_TYPE_DISPLAY[candidate_type] for candidate_type in INVALID_CANDIDATE_TYPES]
    plt.xticks(x, labels, rotation=20, ha="right")
    plt.ylim(0, 1.05)
    plt.ylabel("Recall")
    plt.title("Recall by corrupted solution type")
    plt.grid(axis="y", alpha=0.25)
    plt.legend(frameon=False, ncol=2)
    plt.tight_layout()
    plt.savefig(output_path, dpi=220)
    plt.close()
