import io

import gym
import imageio
import numpy as np
import torch
import wandb  # FIXME: refactor to remove wandb dependency
from PIL import Image
from einops import rearrange
from matplotlib import pyplot as plt

from envs.wrappers.tensor_wrapper import TensorWrapper
from net.memory.sith import SITHMemory
from policy.base import BasePolicy


class ActProbVisWrapper(gym.Wrapper):

    action_space: gym.spaces.Discrete

    def __init__(self, env, memory_type: str):
        super().__init__(env)
        self.memory_type = memory_type
        self.act_logits = [[0.] * env.action_space.n]
        self.act_idx = 0
        self.z = [[0., 0., 0.]]
        self.ctx = None

    def record_stats(self, act_logits, act_idx, model):
        if isinstance(act_logits, torch.Tensor):
            act_logits = act_logits.detach().cpu().numpy()

        if isinstance(act_idx, torch.Tensor):
            act_idx = act_idx.detach().cpu().numpy()

        self.act_logits = act_logits
        self.act_idx = act_idx

        self.z = model.net.z.detach().cpu().numpy()[0, 0]
        self.ctx = model.net.ctx.detach().cpu().numpy()[0, 0]

        if self.memory_type.startswith("sith"):
            n_col = model.net.features_extractor.memory.sith.n_taus + 1
        else:
            n_col = model.net.features_extractor.memory.output_size
            if model.net.features_extractor.add_z_skip:
                n_col += model.net.features_extractor.latent_size

        self.ctx = np.pad(self.ctx, (0, n_col - (self.ctx.size % n_col)), mode='constant', constant_values=np.nan)
        self.ctx = self.ctx.reshape([-1, n_col])

    def reset(self):
        self.act_logits = [0.] * self.env.action_space.n
        self.act_idx = 0
        self.z = [0., 0., 0.]
        return self.env.reset()

    def render(self, mode: str = "human", **kwargs):
        if mode == "rgb_array":
            env_rgb_array = self.env.render(mode)
            fig = plt.figure(figsize=(5, 5), constrained_layout=True, dpi=256)
            gs = fig.add_gridspec(2, 3)

            ax0 = fig.add_subplot(gs[:, 0])
            ax1 = fig.add_subplot(gs[0, 1])
            ax2 = fig.add_subplot(gs[0, 2])
            ax3 = fig.add_subplot(gs[1, 1:])
            # ax4 = fig.add_subplot(gs[1, 2])

            curr_pos = self.env.current_position
            ax0.set(title='Environment', ylim=[curr_pos+15, curr_pos-15], facecolor="paleturquoise")
            ax0.imshow(env_rgb_array.swapaxes(0, 1))
            ax0.set_xticks([])
            ax0.set_xticklabels([])
            ax0.set_xlabel(" ", labelpad=5)
            ax0.invert_yaxis()

            ax1.bar(range(len(self.z)), self.z)
            ax1.set(title='Encoder output', ylabel='Activation', ylim=(0, None))
            ax1.set_xticks([])
            ax1.set_xticklabels([])

            # ax2.set(title='Q values', xlabel='Action', ylabel='Value', ylim=[-1, 1])
            ax2.set(xlabel='Action', ylabel='Probability', ylim=[0, 1])
            barcolors = ["blue"] * self.env.action_space.n
            barcolors[self.act_idx] = "green"
            ax2.bar(range(self.env.action_space.n), self.act_logits, color=barcolors)
            ax2.set_xticks(range(self.env.action_space.n))
            if self.env.action_space.n == 4:
                ax2.set_xticklabels(["▼", "▲", "▶", "◀"])
            elif self.env.action_space.n == 3:
                ax2.set_xticklabels(["▲", "▶", "◀"])

            ax3.set(title=f'{self.memory_type.upper()} output')
            ax3.imshow(self.ctx)
            # fig.colorbar(im, ax=ax3)
            ax3.set_axis_off()

            fig.canvas.draw()
            rgba = np.asarray(fig.canvas.buffer_rgba())
            im = Image.fromarray(rgba).convert("RGB")
            plt.close(fig)
            return np.array(im)  # noqa
        else:
            super().render(mode)


def log_videos(envs_dict: dict[str, gym.Env], model: BasePolicy, memory_type: str):
    for name, env in envs_dict.items():
        env = TensorWrapper(env)
        env = ActProbVisWrapper(env, memory_type)

        fps = env.metadata["render_fps"]
        buff = io.BytesIO()
        with imageio.get_writer(buff, fps=fps, format="mp4") as writer:
            with torch.no_grad():
                for _ in range(5):
                    obs = env.reset()
                    h = None
                    done = False

                    while not done:
                        obs_in = rearrange(obs, '... -> 1 1 ...')  # add batch and time dim

                        # select action
                        act, h, act_logits, _, _ = model.greedy(obs_in, h)

                        act = rearrange(act, '1 1 -> ')  # remove batch and time dim
                        act_logits = rearrange(act_logits, '1 1 a -> a')  # remove batch and time dim

                        act_idx = act - env.action_space.start
                        env.record_stats(act_logits, act_idx, model)
                        frame = env.render(mode="rgb_array")
                        writer.append_data(frame)

                        # apply action to environment
                        obs, rew, done, info = env.step(act.item())

        buff.seek(0)
        wandb.log({f"video/{name}": wandb.Video(buff, format="mp4")}, commit=False)
        buff.close()
