from abc import ABC, abstractmethod
import io
from typing import List, Optional, Union
import chex

from flax.struct import PyTreeNode, field
import jax
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
import numpy as np
import wandb
from PIL import Image


class Visualization(PyTreeNode, ABC):
    @abstractmethod
    def visualize(self): ...


class RolloutGroup(PyTreeNode):
    rollouts: Union[jax.Array, np.ndarray]
    color: Optional[str] = field(pytree_node=False, default=None)
    label: Optional[str] = field(pytree_node=False, default=None)
    linestyle: Optional[str] = field(pytree_node=False, default=None)
    alpha: float = field(pytree_node=False, default=1.0)

    def __getitem__(self, i: int):
        return RolloutGroup(
            self.rollouts[i],
            self.color,
            self.label,
        )

    def to_numpy(self):
        return RolloutGroup(np.asarray(self.rollouts), self.color, self.label)

    def plot(self, ax: Axes):
        positions = self.rollouts[..., :2]
        directions = self.rollouts[..., 2:]

        if positions.ndim == 2:
            positions = positions[None]
            directions = directions[None]

        init_pos = np.broadcast_to(np.array([0, 0]), positions[:, :1].shape)
        init_dir = np.broadcast_to(np.array([1, 0]), positions[:, :1].shape)
        positions = np.concatenate([init_pos, positions], axis=1)
        directions = np.concatenate([init_dir, directions], axis=1)

        # Plot
        px, py, dx, dy = (
            positions[..., 0],
            positions[..., 1],
            directions[..., 0],
            directions[..., 1],
        )
        ax.plot(px.T, py.T, color=self.color, label=self.label, linestyle=self.linestyle)
        ax.quiver(
            px.flatten(),
            py.flatten(),
            dx.flatten(),
            dy.flatten(),
            color=self.color,
            alpha=self.alpha,
            angles="xy",
            scale_units="xy",
            zorder=2,
        )


def visualize_rollout(
    image: np.ndarray,
    goal: np.ndarray,
    rollout_groups: List[RolloutGroup],
    time_to_goal: np.ndarray,
):
    # Make a figure
    fig, axs = plt.subplots(1, 3, figsize=(10, 5), dpi=300)

    IMAGENET_MEAN = np.array([0.485, 0.456, 0.406])
    IMAGENET_STD = np.array([0.229, 0.224, 0.225])
    axs[0].imshow(np.clip(image * IMAGENET_STD + IMAGENET_MEAN, 0, 1))
    axs[1].imshow(np.clip(goal * IMAGENET_STD + IMAGENET_MEAN, 0, 1))

    # Rollout actions
    for group in rollout_groups:
        group.plot(axs[2])

    axs[2].set_aspect("equal")
    axs[2].legend()

    fig.suptitle(f"Time to goal: {time_to_goal.item():.2f}")

    data_out = io.BytesIO()
    fig.savefig(data_out)
    return wandb.Image(Image.open(data_out))


class RolloutVisualization(Visualization):
    image: jax.Array
    goal: jax.Array
    rollout_groups: list[RolloutGroup]
    time_to_goal: jax.Array

    def visualize(self):
        batch_size = self.image.shape[0]

        image = np.asarray(self.image)
        goal = np.asarray(self.goal)
        rollout_groups = [g.to_numpy() for g in self.rollout_groups]
        time_to_goal = np.asarray(self.time_to_goal)

        return [
            visualize_rollout(
                image[i],
                goal[i],
                [group[i] for group in rollout_groups],
                time_to_goal[i],
            )
            for i in range(batch_size)
        ]
