from __future__ import annotations

from pathlib import Path
from typing import Dict

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.lines import Line2D

from unpaired_iv_experiments.plotting.curves import _plain_log_ticks
from unpaired_iv_experiments.plotting.style import default_method_styles


def plot_single_panel(
    summary: pd.DataFrame,
    outpath: Path,
    estimator_labels: Dict[str, str],
    dataset_name: str,
    x_log: bool = True,
    x_col: str = "N",
    x_label: str = "total sample size ($N$)",
) -> None:
    est_keys = [
        k for k in estimator_labels.keys() if k in summary["estimator"].unique()
    ]
    styles = default_method_styles(est_keys)
    sdf = summary[summary["dataset"] == dataset_name].copy()

    fig, ax = plt.subplots(
        1, 1, figsize=(1.5 * 3.4, 1.0 * 2.2), constrained_layout=False
    )
    if sdf.empty:
        ax.set_title(dataset_name)
    else:
        x_ticks = sorted(sdf[x_col].unique().tolist())
        if x_log and x_ticks:
            x_positions = np.logspace(0.0, 1.0, num=len(x_ticks), base=10.0)
            x_map = {int(n): float(x) for n, x in zip(x_ticks, x_positions)}
        else:
            x_positions = None
            x_map = {}

        for key in est_keys:
            t = sdf[sdf["estimator"] == key].sort_values(x_col)
            if t.empty:
                continue
            st = styles[key]
            xs_raw = t[x_col].to_numpy()
            xs = (
                np.array([x_map[int(n)] for n in xs_raw], dtype=float)
                if x_log and x_map
                else xs_raw.astype(float)
            )
            means = t["mae"].to_numpy()
            se = t["se"].to_numpy()
            ax.plot(
                xs,
                means,
                color=st.color,
                linestyle=st.linestyle,
                marker=st.marker,
                label=estimator_labels.get(key, key),
            )
            ax.fill_between(
                xs,
                means - se,
                means + se,
                color=st.color,
                alpha=0.12,
            )

        # ax.set_title(dataset_name)
        if x_log:
            ax.set_xscale("log")
            if x_positions is not None and x_ticks:
                ax.set_xticks(x_positions)
                ax.set_xticklabels([str(n) for n in x_ticks])
            else:
                _plain_log_ticks(ax, sorted(sdf[x_col].unique().tolist()))

        max_mae = float(sdf["mae"].max()) if len(sdf) else 0.0
        y_pad = max(0.002, 0.05 * max_mae)
        ax.set_ylim(-y_pad, max_mae + y_pad)
        ax.axhline(0.0, color="k", linewidth=1.0, linestyle="--", alpha=0.45)

    handles = []
    for k in est_keys:
        st = styles[k]
        handles.append(
            Line2D(
                [0],
                [0],
                color=st.color,
                linestyle=st.linestyle,
                marker=st.marker,
                label=estimator_labels.get(k, k),
            )
        )
    if handles:
        ncol = min(4, len(handles))
        fig.legend(
            handles=handles,
            loc="upper center",
            ncol=ncol,
            bbox_to_anchor=(0.55, 0.995),
            frameon=True,
            borderaxespad=0.0,
        )
        fig.tight_layout(rect=(0.06, 0.08, 1.0, 0.94))
    else:
        fig.tight_layout(rect=(0.06, 0.08, 1.0, 1.0))

    fig.canvas.draw()
    ax_box = ax.get_position()
    x_center = 0.5 * (ax_box.x0 + ax_box.x1)
    y_center = 0.5 * (ax_box.y0 + ax_box.y1)
    x_label_x = min(0.98, x_center + 0.04)
    x_label_y = min(0.22, ax_box.y0 - 0.19)
    y_label_y = min(0.88, y_center + 0.04)
    fig.supxlabel(x_label, x=x_label_x, y=x_label_y)
    fig.supylabel(r"MAE ($\|\hat\beta-\beta\|_1$)", y=y_label_y)
    fig.savefig(outpath, bbox_inches="tight")
    plt.close(fig)
