"""
Plotting helpers for training and evaluation.
"""

from __future__ import annotations

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
from matplotlib.colors import BoundaryNorm, ListedColormap

from .constants import ACTION_LABELS, EPSILON_ACTIONS


def setting_title(base_title, m, mu, alpha):
    """Standardized plot title that appends the key settings."""
    return f"{base_title}\nm={m:.3f}, mu={mu:.3f}, alpha={alpha:g}"


def trailing_mean_every_k(returns, k):
    """
    Compute trailing mean of last k returns, evaluated every k episodes.
    """
    r = np.asarray(returns, dtype=float)
    k = int(max(1, k))

    if len(r) == 0:
        return np.array([], dtype=int), np.array([], dtype=float)

    ends = np.arange(k, len(r) + 1, k, dtype=int)
    y = np.array([r[end - k:end].mean() for end in ends], dtype=float)
    return ends, y


def plot_rejection_curves_multi(t, curves, title="", filename=None):
    """
    Plot P(reject by time t) for an arbitrary dict of curves.
    """
    fig, ax = plt.subplots(figsize=(5.9, 4.4))
    for label, p in curves.items():
        ax.plot(t, p, label=label)

    ax.set_xlabel("t", fontsize=14, fontweight="bold")
    ax.set_ylabel("P(reject by time t)", fontsize=14, fontweight="bold")

    ax.tick_params(axis="both", which="major", labelsize=12, width=1.2)
    for tick in ax.get_xticklabels() + ax.get_yticklabels():
        tick.set_fontweight("bold")

    leg = ax.legend(loc="upper left", fontsize=12, frameon=True)
    for txt in leg.get_texts():
        txt.set_fontweight("bold")

    ax.grid(True, alpha=0.25)

    fig.tight_layout()
    if filename is None:
        filename = "rejection_curves.png"
    fig.savefig(filename, dpi=150, bbox_inches='tight')
    plt.close(fig)


def plot_training_curve(history, title="RL training: episodic return", filename=None):
    """
    Plot training curve from history.
    """
    r = np.array(history["rewards"], float)
    plt.figure(figsize=(7.2, 4.2))
    plt.plot(r)
    plt.xlabel("episode")
    plt.ylabel("return")
    plt.title(title)
    plt.tight_layout()
    if filename is None:
        filename = "training_curve.png"
    plt.savefig(filename, dpi=150, bbox_inches='tight')
    plt.close()


def plot_eps_on_logwealth(
    Y_path,
    eps_path,
    alpha=None,
    title="DQN ε-actions along log-wealth trajectory",
    filename="eps_on_logwealth.png",
):
    """
    Plot epsilon actions along a log-wealth trajectory.
    """
    Y_path = np.asarray(Y_path, float)
    eps_path = np.asarray(eps_path, float)
    L = len(eps_path)

    t = np.arange(L + 1)

    points = np.column_stack([t, Y_path]).reshape(-1, 1, 2)
    segments = np.concatenate([points[:-1], points[1:]], axis=1)

    fig, ax = plt.subplots(figsize=(8.0, 4.8))

    lc = LineCollection(segments, array=eps_path, cmap="viridis", linewidth=2.0)
    ax.add_collection(lc)

    ax.set_xlim(t.min(), t.max())
    ypad = 0.05 * (Y_path.max() - Y_path.min() + 1e-12)
    ax.set_ylim(Y_path.min() - ypad, Y_path.max() + ypad)

    if alpha is not None:
        T = np.log(1.0 / alpha)
        ax.axhline(T, linestyle=":", linewidth=1.0, label="threshold")

    cbar = fig.colorbar(lc, ax=ax)
    cbar.set_label("Action Chosen")
    cbar.set_ticks(EPSILON_ACTIONS)
    cbar.ax.set_yticklabels(ACTION_LABELS)

    ax.set_xlabel("t")
    ax.set_ylabel("log wealth")
    ax.set_title(title)
    ax.grid(True, alpha=0.25)

    if alpha is not None:
        ax.legend(loc="best")

    plt.tight_layout()
    plt.savefig(filename, dpi=150, bbox_inches="tight")
    plt.close()


def plot_modal_eps_grid(
    modal_eps,
    y_edges,
    epsilon_actions,
    alpha=None,
    title="Modal ε-action over (t, log wealth)",
    filename="modal_eps_grid.png",
    t_edges=None,
    y_min=-4.0,
    *,
    dp_value=None,
    hopeless_threshold=None,
    hopeless_label="Hopeless",
    hopeless_color="0.70",
):
    """
    Plot modal epsilon grid, optionally overlaying DP hopeless region.
    """
    modal_eps = np.asarray(modal_eps, dtype=float)
    num_y_bins, num_t_bins = modal_eps.shape

    if t_edges is None:
        t_edges = np.arange(num_t_bins + 1)

    eps_vals = np.asarray(epsilon_actions, dtype=float)

    A = np.full(modal_eps.shape, -1, dtype=np.int32)
    for i, ev in enumerate(eps_vals):
        A[np.isclose(modal_eps, ev, atol=1e-8, rtol=0.0)] = i + 1

    include_hopeless = (dp_value is not None) and (hopeless_threshold is not None)
    if include_hopeless:
        dp_value = np.asarray(dp_value, dtype=float)
        if dp_value.shape != modal_eps.shape:
            raise ValueError(
                f"dp_value shape {dp_value.shape} must match modal_eps shape {modal_eps.shape}"
            )
        hopeless_mask = np.isfinite(dp_value) & (dp_value < float(hopeless_threshold))
        A[hopeless_mask] = 0

    aggr_fill = "#E69F00"
    kelly_fill = "#56B4E9"
    less_fill = "#009E73"
    dead_fill = "1.0"

    colors = [less_fill, kelly_fill, aggr_fill]
    labels = list(ACTION_LABELS)

    if include_hopeless:
        colors.insert(0, hopeless_color)
        labels.insert(0, hopeless_label)

    cmap = ListedColormap(colors, name=f"actions{len(colors)}")
    cmap.set_bad(dead_fill)

    bounds = np.arange(-0.5, len(colors) + 0.5, 1.0)
    norm = BoundaryNorm(bounds, ncolors=cmap.N)

    Z = np.ma.masked_where(A < 0, A)

    fig, ax = plt.subplots(figsize=(5.9, 4.4))
    im = ax.pcolormesh(
        t_edges, y_edges, Z,
        cmap=cmap, norm=norm, shading="auto"
    )

    ax.set_xlabel("t", fontsize=14, fontweight="bold")
    ax.set_ylabel("Log Wealth", fontsize=14, fontweight="bold")
    ax.grid(True, alpha=0.2)

    ax.tick_params(axis="both", which="major", labelsize=12, width=1.2)
    for tick in ax.get_xticklabels() + ax.get_yticklabels():
        tick.set_fontweight("bold")

    if y_min is not None:
        ax.set_ylim(bottom=max(float(y_min), float(y_edges[0])))

    if alpha is not None:
        T = np.log(1.0 / alpha)
        ax.axhline(T, linestyle=":", linewidth=1.0, label="threshold")
        leg = ax.legend(loc="upper right", fontsize=12, frameon=True)
        for txt in leg.get_texts():
            txt.set_fontweight("bold")

    cbar_ticks = np.arange(len(colors))
    cbar = fig.colorbar(im, ax=ax, ticks=cbar_ticks)
    cbar.set_label("DQN Policy", fontsize=12, fontweight="bold")
    cbar.ax.set_yticklabels(labels)
    cbar.ax.tick_params(labelsize=11, width=1.0)
    for tick in cbar.ax.get_yticklabels():
        tick.set_fontweight("bold")

    plt.tight_layout()
    plt.savefig(filename, dpi=150, bbox_inches="tight")
    plt.close()


def plot_confidence_grid(
    confidence,
    y_edges,
    alpha=None,
    title="Confidence of modal ε-action over (t, log wealth)",
    filename="modal_confidence_grid.png",
    vmin=0.0,
    vmax=1.0,
    t_edges=None,
    y_min=-4.0,
):
    """
    Plot confidence heatmap for modal epsilon actions.
    """
    confidence = np.asarray(confidence, dtype=float)
    num_y_bins, num_t_bins = confidence.shape

    if t_edges is None:
        t_edges = np.arange(num_t_bins + 1)

    Z = np.ma.masked_invalid(confidence)

    fig, ax = plt.subplots(figsize=(10.5, 4.8))
    im = ax.pcolormesh(
        t_edges, y_edges, Z,
        shading="auto",
        vmin=vmin, vmax=vmax
    )

    ax.set_xlabel("t" if (len(t_edges) > 1 and t_edges[1] - t_edges[0] == 1) else "t (binned)")
    ax.set_ylabel("log wealth")
    ax.set_title(title)
    ax.grid(True, alpha=0.2)

    if y_min is not None:
        ax.set_ylim(bottom=max(float(y_min), float(y_edges[0])))

    if alpha is not None:
        T = np.log(1.0 / alpha)
        ax.axhline(T, linestyle=":", linewidth=1.0, label="threshold")
        ax.legend(loc="best")

    cbar = fig.colorbar(im, ax=ax)
    cbar.set_label("confidence = max_count / total_visits")

    plt.tight_layout()
    plt.savefig(filename, dpi=150, bbox_inches="tight")
    plt.close()


__all__ = [
    "setting_title",
    "trailing_mean_every_k",
    "plot_rejection_curves_multi",
    "plot_training_curve",
    "plot_eps_on_logwealth",
    "plot_modal_eps_grid",
    "plot_confidence_grid",
]
