"""Plotting utilities."""
from swmpo.state_machine import StateMachine
from pathlib import Path
from itertools import product
import matplotlib as mpl
from matplotlib import cm
import matplotlib.colors
import tempfile
import networkx as nx
import ffmpeg
import numpy as np
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.figure import Figure
from swmpo.transition import Transition
from swmpo.state_machine import get_local_model_errors
from swmpo.state_machine import get_state_machine_errors
from swmpo.state_machine import get_visited_states
from swmpo.sequence_distance import _get_best_permutation
import random


def plot_state_machine(
        state_machine: StateMachine,
        active_state: int,
        output_path: Path,
        ):
    """Plot the state machine diagram."""
    G = nx.DiGraph()
    state_indices = list(range(len(state_machine.local_models)))
    for i in state_indices:
        G.add_node(i)
    for i, j in product(state_indices, state_indices):
        # Check that the transition is possible
        if state_machine.transition_histogram[i][j] > 0:
            G.add_edge(i, j)

    # Create matplotlib figure
    fig = Figure()
    _ = FigureCanvas(fig)
    ax = fig.add_subplot()

    # Decide node colors
    node_color = [
        "blue" if state == active_state else "black"
        for state in state_indices
    ]

    # Plot
    nx.draw_circular(G, ax=ax, node_color=node_color)

    # Save figure
    fig.suptitle('Active state')
    fig.savefig(output_path)


def plot_animation(
        state_machine: StateMachine,
        visited_states: list[int],
        output_path: Path,
        fps: int,
        ):
    """Plot an animation of the given state machine."""
    with tempfile.TemporaryDirectory() as tdir:
        frame_dir = Path(tdir)
        for i, state in enumerate(visited_states):
            state_path = frame_dir/(f"{i}.png").rjust(10, "0")
            plot_state_machine(
                state_machine=state_machine,
                active_state=state,
                output_path=state_path,
            )
        (
            ffmpeg
            .input(frame_dir/"*.png", pattern_type="glob", framerate=fps)
            .output(str(output_path))
            .run(quiet=True)
        )


def plot_state_machine_errors_diagram(
        state_machine: StateMachine,
        local_model_errors: list[float],
        max_error: float,
        output_path: Path,
        ):
    """Plot the state machine diagram coloring the nodes according to
    their prediction error."""
    G = nx.DiGraph()
    state_indices = list(range(len(state_machine.local_models)))
    for i in state_indices:
        G.add_node(i)
    for i, j in product(state_indices, state_indices):
        # Check that the transition is possible
        if state_machine.transition_histogram[i][j] > 0:
            G.add_edge(i, j)

    # Create matplotlib figure
    fig = Figure()
    _ = FigureCanvas(fig)
    ax = fig.add_subplot()

    # Decide node colors
    cmap = mpl.colormaps['viridis']
    norm = matplotlib.colors.Normalize(vmin=0.0, vmax=max_error)
    mappable = cm.ScalarMappable(norm=norm, cmap=cmap)
    node_color = [
        mappable.to_rgba(np.array([error]))[0]
        for error in local_model_errors
    ]

    # Plot
    nx.draw_circular(G, ax=ax, node_color=node_color)

    # Add colorbar
    fig.colorbar(mappable, ax=ax)

    # Save figure
    fig.suptitle('Node errors')
    fig.savefig(output_path)


def plot_errors_animation(
        state_machine: StateMachine,
        local_model_errors: list[list[float]],
        output_path: Path,
        fps: int,
        ):
    """Plot an animation of the given state machine."""
    max_error = max(max(errors) for errors in local_model_errors)
    with tempfile.TemporaryDirectory() as tdir:
        frame_dir = Path(tdir)
        for i, errors in enumerate(local_model_errors):
            state_path = frame_dir/(f"{i}.png").rjust(10)
            plot_state_machine_errors_diagram(
                state_machine=state_machine,
                local_model_errors=errors,
                max_error=max_error,
                output_path=state_path,
            )
        (
            ffmpeg
            .input(frame_dir/"*.png", pattern_type="glob", framerate=fps)
            .output(str(output_path))
            .run(quiet=True)
        )


def get_starts_widths(visited_states: list[int]) -> tuple[list[int], list[int], list[int]]:
    """Helper function to compute the starts and widths for matplotlib's `barh`
    function, for visualizing sequences of visited states."""
    starts = [0]
    widths = list()
    current_width = 0
    colors = list()
    previous_state = visited_states[0]
    for i in range(1, len(visited_states)):
        current_width += 1
        current_state = visited_states[i]
        if current_state != previous_state:
            starts.append(i)
            colors.append(previous_state)
            widths.append(current_width)
            current_width = 0
        previous_state = current_state
    # Trailing transition
    if current_width != 0:
        colors.append(visited_states[-1])
        widths.append(current_width)
    else:
        starts = starts[:-1]
    return starts, widths, colors


def get_partition_colors() -> list[tuple[float, float, float, float]]:
    """Get colors with which partitions can be plotted."""
    cmap = mpl.colormaps['tab20b']
    colors = cmap(list(range(cmap.N))).tolist()
    _random = random.Random("seed")
    _random.shuffle(colors)
    return colors


def plot_state_machine_errors(
        state_machine: StateMachine,
        episode: list[Transition],
        initial_state: int,
        dt: float,
        ground_truth_visited_states: list[int] | None,
        output_path: Path,
        ):
    """Plot the error of the state machine for predicting the given episode."""
    # Run state machine
    local_model_errors = get_local_model_errors(
        state_machine=state_machine,
        episode=episode,
        dt=dt,
    )
    state_machine_errors = get_state_machine_errors(
        state_machine=state_machine,
        episode=episode,
        initial_state=initial_state,
        dt=dt,
    )
    visited_states = get_visited_states(
        state_machine=state_machine,
        initial_state=initial_state,
        episode=episode,
        dt=dt,
    )

    # Create matplotlib figure
    fig = Figure()
    _ = FigureCanvas(fig)

    # Choose model colors
    partition_colors = get_partition_colors()

    # Permute local models for easy visualization
    if ground_truth_visited_states is not None:
        # Get mode permutation
        perm = _get_best_permutation(
            sequence=visited_states,
            ground_truth=ground_truth_visited_states,
            indices=list(range(len(state_machine.local_models))),
        )
        local_model_colors = [
            partition_colors[perm[i]] if i in perm.keys() else partition_colors[i]
            for i in range(len(partition_colors))
        ]
    else:
        local_model_colors = list(partition_colors)

    # Plot errors
    ax1 = fig.add_subplot(2, 1, 1)
    x = list(range(len(local_model_errors)))
    for i in range(len(state_machine.local_models)):
        model_errors = [
            local_model_errors[j][i]
            for j in x
        ]
        color = local_model_colors[i]
        ax1.plot(x, model_errors, label=f"Local model {i}", color=color)
    x = list(range(len(state_machine_errors)))
    ax1.plot(
        x,
        state_machine_errors,
        color='black',
        linestyle=':',
        label="Full state machine",
    )
    ax1.set_xlim(left=0, right=len(x))

    # Plot active states
    ax2 = fig.add_subplot(2, 1, 2)
    labels = ["FSM states"]
    starts, widths, colors = get_starts_widths(visited_states)
    color = [local_model_colors[i] for i in colors]
    ax2.barh(labels, widths, left=starts, height=0.5, color=color)
    ax2.set_xlim(left=0, right=len(x))

    # Plot "partition"
    labels = ["Minimum-loss induced states"]
    induced_visited_states = list()
    for step_losses in local_model_errors:
        indices = list(range(len(step_losses)))
        min_loss_i = min(indices, key=lambda i: step_losses[i])
        induced_visited_states.append(min_loss_i)
    starts, widths, colors = get_starts_widths(induced_visited_states)
    color = [local_model_colors[i] for i in colors]
    ax2.barh(labels, widths, left=starts, height=0.5, color=color)
    ax2.set_xlim(left=0, right=len(x))

    # Plot ground truth
    if ground_truth_visited_states is not None and len(ground_truth_visited_states) > 0:
        labels = ["Ground truth states"]
        starts, widths, colors = get_starts_widths(ground_truth_visited_states)
        color = [partition_colors[i] for i in colors]
        ax2.barh(labels, widths, left=starts, height=1.5, color=color)

    # Save figure
    fig.legend()
    fig.suptitle('State machine prediction errors')
    fig.tight_layout()
    fig.savefig(output_path)
