from pathlib import Path

import gym
import imageio
import numpy as np
import torch
import wandb
from PIL import Image
from einops import rearrange
from matplotlib import pyplot as plt

from envs.wrappers.tensor_wrapper import TensorWrapper
from helpers import init_envs, init_model
from policy.base import BasePolicy


VALID_ENV = "len300"

WANDB_PROJECT = "ENTER WANDB PROJECT HERE"
WANDB_ENTITY = "ENTER WANDB USERNAME HERE"
WANDB_RUN = "run-name-1"

GROUPBY_CONFIG_KEY = "memory_type"

EPISODE = 100_000


class ActProbVisWrapper(gym.Wrapper):

    action_space: gym.spaces.Discrete

    def __init__(self, env, memory_type: str):
        super().__init__(env)
        self.memory_type = memory_type
        self.act_logits = [[0.] * env.action_space.n]
        self.act_idx = 0
        self.z = [[0., 0., 0.]]
        self.ctx = None

    def record_stats(self, act_logits, act_idx, model):
        if isinstance(act_logits, torch.Tensor):
            act_logits = act_logits.detach().cpu().numpy()

        if isinstance(act_idx, torch.Tensor):
            act_idx = act_idx.detach().cpu().numpy()

        self.act_logits = act_logits
        self.act_idx = act_idx

        self.z = model.net.z.detach().cpu().numpy()[0, 0]
        self.ctx = model.net.ctx.detach().cpu().numpy()[0, 0]

        if self.memory_type.startswith("sith"):
            n_col = model.net.features_extractor.memory.sith.n_taus
        else:
            n_col = int(np.sqrt(model.net.features_extractor.memory.output_size))
            if model.net.features_extractor.add_z_skip:
                n_col += model.net.features_extractor.latent_size

        self.ctx = np.pad(self.ctx, (0, n_col - (self.ctx.size % n_col)), mode='constant', constant_values=np.nan)
        self.ctx = self.ctx.reshape([-1, n_col])

    def reset(self):
        self.act_logits = [0.] * self.env.action_space.n
        self.act_idx = 0
        self.z = [0., 0., 0.]
        return self.env.reset()

    def render(self, mode: str = "human", **kwargs):
        if mode == "rgb_array":
            env_rgb_array = self.env.render(mode)
            fig = plt.figure(figsize=(5, 5), constrained_layout=True, dpi=256)
            gs = fig.add_gridspec(2, 3)

            ax0 = fig.add_subplot(gs[:, 0])
            # ax1 = fig.add_subplot(gs[0, 1])
            ax2 = fig.add_subplot(gs[0, 1:])
            ax3 = fig.add_subplot(gs[1, 1:])
            # ax4 = fig.add_subplot(gs[1, 2])

            curr_pos = self.env.current_position
            ax0.set(title='Environment', ylim=[curr_pos+15, curr_pos-15], facecolor="paleturquoise")
            ax0.imshow(env_rgb_array.swapaxes(0, 1))
            ax0.set_xticks([])
            ax0.set_xticklabels([])
            ax0.set_xlabel(" ", labelpad=5)
            ax0.invert_yaxis()

            # ax1.bar(range(len(self.z)), self.z)
            # ax1.set(title='Encoder output', ylabel='Activation', ylim=(0, None))
            # ax1.set_xticks([])
            # ax1.set_xticklabels([])

            # ax2.set(title='Q values', xlabel='Action', ylabel='Value', ylim=[-1, 1])
            ax2.set(xlabel='Action', ylabel='Probability', ylim=[0, 1])
            barcolors = ["blue"] * self.env.action_space.n
            barcolors[self.act_idx] = "green"
            ax2.bar(range(self.env.action_space.n), self.act_logits, color=barcolors)
            ax2.set_xticks(range(self.env.action_space.n))
            if self.env.action_space.n == 4:
                ax2.set_xticklabels(["▼", "▲", "▶", "◀"])
            elif self.env.action_space.n == 3:
                ax2.set_xticklabels(["▲", "▶", "◀"])

            ax3.set(title=f'{self.memory_type.upper()} output')
            ax3.imshow(self.ctx, vmax=0.35)
            # fig.colorbar(im, ax=ax3)
            ax3.set_axis_off()

            fig.canvas.draw()
            rgba = np.asarray(fig.canvas.buffer_rgba())
            im = Image.fromarray(rgba).convert("RGB")
            plt.close(fig)
            return np.array(im)  # noqa
        else:
            super().render(mode)


def log_videos(env: gym.Env, model: BasePolicy, memory_type: str):
    env = TensorWrapper(env)
    if memory_type == "sith_sub_nosum":
        memory_type = "sith_sub"
    env = ActProbVisWrapper(env, memory_type)

    obs = env.reset()
    obs = env.reset()
    obs = env.reset()

    fps = env.metadata["render_fps"]
    Path(f"out/agent_movies/").mkdir(parents=True, exist_ok=True)
    with imageio.get_writer(f"out/agent_movies/{WANDB_RUN}.mp4", fps=fps) as writer:
        with torch.no_grad():
            for _ in range(1):
                obs = env.reset()
                h = None
                done = False

                while not done:
                    obs_in = rearrange(obs, '... -> 1 1 ...')  # add batch and time dim

                    # select action
                    act, h, act_logits, _, _ = model.greedy(obs_in, h)

                    act = rearrange(act, '1 1 -> ')  # remove batch and time dim
                    act_logits = rearrange(act_logits, '1 1 a -> a')  # remove batch and time dim

                    act_idx = act - env.action_space.start
                    env.record_stats(act_logits, act_idx, model)
                    frame = env.render(mode="rgb_array")
                    writer.append_data(frame)

                    # apply action to environment
                    obs, rew, done, info = env.step(act.item())


def main():
    api = wandb.Api()
    runs = api.runs(f'{WANDB_ENTITY}/{WANDB_PROJECT}', filters={"display_name": WANDB_RUN})

    # print(f"{len(runs)} run(s) found.")
    assert len(runs) == 1

    run = runs[0]

    config = dict(run.config)

    train_env, valid_envs_dict = init_envs(config)
    model = init_model(config, train_env)

    print(f"{run.name} ({EPISODE = }) ({run.url})")

    file = run.file(f"checkpoints/{EPISODE}.pt")

    if file.size <= 0:
        print("Episode not found.")
        return

    file.download(root=f"./tmp/{run.name}", replace=True)

    model.load_state_dict(torch.load(f"./tmp/{run.name}/{file.name}")["model_state_dict"])
    model.eval()

    valid_env = valid_envs_dict[VALID_ENV]

    log_videos(valid_env, model, config["memory_type"])


if __name__ == "__main__":
    main()
