"""Plot sequences of visited states."""
from swmpo.plotting import get_starts_widths
from swmpo.plotting import get_partition_colors
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.figure import Figure
from pathlib import Path


def plot_visited_states(
        visited_states: dict[str, list[int]],
        available_indices: set[int],
        output_path: Path,
        ):
    """Plot the error of the state machine for predicting the given episode."""
    # Create matplotlib figure
    fig = Figure()
    _ = FigureCanvas(fig)

    # Choose model colors
    local_model_colors = get_partition_colors()

    # Plot active states
    for i, (model_id, model_visited_states) in enumerate(visited_states.items()):
        ax = fig.add_subplot(len(visited_states.keys()), 1, i+1)
        labels = [model_id]
        starts, widths, colors = get_starts_widths(model_visited_states)
        color = [local_model_colors[i] for i in colors]
        ax.barh(labels, widths, left=starts, height=0.5, color=color)
        x = list(range(len(model_visited_states)))
        ax.set_xlim(left=0, right=len(x))

    # Save figure
    fig.tight_layout()
    fig.suptitle('Visited states (non-initial states may be permutated per-trajectory)')
    fig.savefig(output_path)
