"""Plot learned representation of an mode_world_model."""
from swmpo.plotting import get_partition_colors
from swmpo.transition import Transition
from swmpo.partition import ModeWorldModel
from swmpo.partition import get_mode_vector
from swmpo.transition_prunning.epsilon_prunning import StatePartitionItem
from pathlib import Path
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.figure import Figure
from sklearn.preprocessing import StandardScaler
from swmpo.partition import get_partition_modes
import umap


def get_dimension_reduced_vectors(
    vectors: list[list[float]],
) -> list[list[float]]:
    reducer = umap.UMAP()
    vectors = StandardScaler().fit_transform(vectors)
    embedding = reducer.fit_transform(vectors).tolist()
    return embedding


def plot_learned_representation(
    mode_world_model: ModeWorldModel,
    partition: list[StatePartitionItem],
    episodes: list[list[Transition]],
    ground_truth_modes: list[list[int]],
    device: str,
    output_dir: Path,
):
    """Plot learned representation by the given mode world model.

    The `output_dir` is assumed to exist.
    """
    # Get mode labels
    partition_modes = [
        get_partition_modes(
            trajectory=trajectory,
            partition=partition,
        )
        for trajectory in episodes
    ]

    # Get transition embeddings
    embeddings = list()
    embedding_i_to_trajectory = dict()
    for i, trajectory in enumerate(episodes):
        for j, transition in enumerate(trajectory):
            index = (i, j)
            tid = len(embeddings)
            embedding = get_mode_vector(
                transition,
                mode_world_model=mode_world_model,
                device=device,
            )
            embeddings.append(embedding)
            embedding_i_to_trajectory[tid] = index

    # Reduce dimensionality of the embeddings
    vectors = get_dimension_reduced_vectors(embeddings)

    # Plot vectors
    x = list()
    y = list()
    time_colors = list()
    ground_truth_colors = list()
    partition_colors = list()
    available_mode_colors = get_partition_colors()

    for i, (xp, yp) in enumerate(vectors):
        x.append(xp)
        y.append(yp)

        # Decide color
        (ti, tj) = embedding_i_to_trajectory[i]
        trajectory = episodes[ti]
        time_color = tj/len(trajectory)
        time_colors.append(time_color)

        ground_truth_mode = ground_truth_modes[ti][tj]
        ground_truth_color = available_mode_colors[ground_truth_mode]
        ground_truth_colors.append(ground_truth_color)

        mode = partition_modes[ti][tj]
        partition_color = available_mode_colors[mode]
        partition_colors.append(partition_color)

    plots = [
        ("Learned representation (color=time)", time_colors),
        ("Learned representation (color=ground truth label)", ground_truth_colors),
        ("Learned representation (color=partition subset)", partition_colors),
    ]

    for (title, colors) in plots:
        # Create matplotlib figure
        fig = Figure()
        _ = FigureCanvas(fig)
        ax = fig.add_subplot()
        ax.scatter(x, y, c=colors, alpha=0.5)

        # Save figure
        output_path = output_dir/f"embedding_dimensionality_reduction_{title}.svg"
        fig.legend()
        fig.suptitle(title)
        fig.tight_layout()
        fig.savefig(output_path)
