import numpy as np
import matplotlib

matplotlib.use("Agg")
import matplotlib.cm as cm
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from typing import Dict, List, Optional


def plot_robot_trace(robot_traces: List[Dict], title: Optional[str] = None):
    """Plot robot trajectories with color gradients and draw start/end boxes.

    Args:
        robot_traces (List[Dict]):
            A list of dictionaries. Each dict may contain:
            - "robot_id" (int, optional): Numeric robot identifier.
            - "robot_name" (str, optional): Human-readable name.
            - "start_robot_positions" (List[float] | Tuple[float, float, float, float]):
              Start bounding box as (min_x, max_x, min_y, max_y).
            - "end_robot_positions" (List[float] | Tuple[float, float, float, float]):
              End bounding box as (min_x, max_x, min_y, max_y).
            - "com_trace" (List[List[float]]): Sequence of [x, y] trajectory points.
        title (Optional[str]): Title for the figure.

    Returns:
        Optional[plt.Figure]: The created matplotlib figure if any trace is plotted; otherwise, ``None``.
    """

    if not robot_traces:
        print("No trace data provided")
        return None

    fig, ax = plt.subplots(figsize=(10, 10))

    # Background
    bg_color = (1, 1, 1)
    fig.patch.set_facecolor(bg_color)
    ax.set_facecolor(bg_color)

    # Legend proxies
    start_patch = Rectangle(
        (0, 0), 1, 1, fill=False, linestyle="-", edgecolor="blue", linewidth=1.5
    )
    end_patch = Rectangle(
        (0, 0), 1, 1, fill=False, linestyle="--", edgecolor="red", linewidth=1.5
    )
    (trace_handle,) = ax.plot([], [], lw=2, color="red", alpha=0.7)

    any_plotted = False
    xs: List[float] = []
    ys: List[float] = []

    for trace in robot_traces:
        points = trace.get("com_trace")
        if not points:
            # No trajectory to draw
            # Attempt boxes only if present
            pass
        else:
            pts = np.asarray(points, dtype=float)
            if pts.ndim == 2 and pts.shape[1] >= 2 and len(pts) >= 2:
                pts = pts[:, :2]
                xs.extend(pts[:, 0].tolist())
                ys.extend(pts[:, 1].tolist())

                # Color gradient along time
                cmap = cm.get_cmap("Reds")
                tgrad = np.linspace(0.2, 0.9, len(pts) - 1)
                for i in range(len(pts) - 1):
                    ax.plot(
                        pts[i : i + 2, 0],
                        pts[i : i + 2, 1],
                        color=cmap(tgrad[i]),
                        linewidth=2,
                        alpha=0.9,
                    )
                any_plotted = True

        # Draw start box (inline; no helper to keep code flat)
        start_box = trace.get("start_robot_positions")
        if start_box is not None and len(start_box) == 4:
            try:
                s_min_x, s_max_x, s_min_y, s_max_y = map(float, start_box)
                if s_max_x >= s_min_x and s_max_y >= s_min_y:
                    ax.add_patch(
                        Rectangle(
                            (s_min_x, s_min_y),
                            s_max_x - s_min_x,
                            s_max_y - s_min_y,
                            fill=False,
                            linestyle="-",
                            edgecolor="blue",
                            linewidth=1.5,
                            alpha=0.7,
                        )
                    )
                    xs.extend([s_min_x, s_max_x])
                    ys.extend([s_min_y, s_max_y])
            except (TypeError, ValueError):
                # Ignore malformed box
                pass

        # Draw end box (inline; no helper to keep code flat)
        end_box = trace.get("end_robot_positions")
        if end_box is not None and len(end_box) == 4:
            try:
                e_min_x, e_max_x, e_min_y, e_max_y = map(float, end_box)
                if e_max_x >= e_min_x and e_max_y >= e_min_y:
                    ax.add_patch(
                        Rectangle(
                            (e_min_x, e_min_y),
                            e_max_x - e_min_x,
                            e_max_y - e_min_y,
                            fill=False,
                            linestyle="--",
                            edgecolor="red",
                            linewidth=1.5,
                            alpha=0.7,
                        )
                    )
                    xs.extend([e_min_x, e_max_x])
                    ys.extend([e_min_y, e_max_y])
            except (TypeError, ValueError):
                # Ignore malformed box
                pass

    if not any_plotted and not xs and not ys:
        print("No valid traces to plot")
        plt.close(fig)
        return None

    # Labels, title, legend
    ax.set_xlabel("X position", fontsize=12)
    ax.set_ylabel("Y position", fontsize=12)
    if title:
        ax.set_title(title, fontsize=14, fontweight="bold")

    ax.legend(
        [start_patch, end_patch, trace_handle],
        ["start position", "end position", "trace"],
        loc="upper right",
        framealpha=0.9,
    )

    # Limits with padding
    # if xs and ys:
    #     x_min, x_max = min(xs), max(xs)
    #     y_min, y_max = min(ys), max(ys)
    #     pad_x = 0.05 * max(1.0, (x_max - x_min) or 1.0)
    #     pad_y = 0.05 * max(1.0, (y_max - y_min) or 1.0)
    #     ax.set_xlim(x_min - pad_x, x_max + pad_x)
    #     ax.set_ylim(y_min - pad_y, y_max + pad_y)
    # else:
    ax.set_xlim(-4, 4)
    ax.set_ylim(-4, 4)

    ax.set_aspect("equal", adjustable="box")
    ax.grid(True, alpha=0.3, color="lightgray")

    print("Trace image plotted successfully")
    return fig
