import os

import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

from vae import VAE
from rollout_buffer import RolloutBuffer
from utils import save_video


class NextStepPredictionSSM(nn.Module):
    def __init__(self, encoded_latent_size: int, action_size: int, hidden_state_size: int, context_size: int = 4, is_minerl: bool = False):
        super(NextStepPredictionSSM, self).__init__()
        # Deterministic state
        self.obs_action_embedding = nn.Linear(context_size * encoded_latent_size + action_size, hidden_state_size)
        self.gru = nn.GRUCell(hidden_state_size, hidden_state_size)
        # Stochastic prior
        self.stochastic_state_prior_embed = nn.Linear(hidden_state_size, hidden_state_size)
        self.stochastic_state_prior_mean = nn.Linear(hidden_state_size, encoded_latent_size)
        self.stochastic_state_prior_logvar = nn.Linear(hidden_state_size, encoded_latent_size)
        # Stochastic posterior
        self.state_posterior_embed = nn.Linear(hidden_state_size + encoded_latent_size, hidden_state_size)
        self.state_posterior_mean = nn.Linear(hidden_state_size, encoded_latent_size)
        self.state_posterior_logvar = nn.Linear(hidden_state_size, encoded_latent_size)

        self.hidden_state_size = hidden_state_size
        self.action_size = action_size
        self.context_size = context_size
        self.is_minerl = is_minerl

    def forward(self, h_t_prev, s_t_prev, a_t_prev, device):
        # Get the deterministic state h_t = f(h_t-1, s_t-1, a_t-1)
        if not self.is_minerl:
            one_hot_action = F.softmax(F.one_hot(a_t_prev.long(), num_classes=self.action_size).squeeze(1).float(), dim=-1).to(device)
        else:
            one_hot_action = a_t_prev.float().to(device)
        s_t_prev = s_t_prev.view(s_t_prev.size(0), -1).to(device)
        obs_act_embed = torch.cat([s_t_prev, one_hot_action], dim=-1)
        encoded_obs_act = self.obs_action_embedding(obs_act_embed)
        h_t = self.gru(encoded_obs_act, h_t_prev)
        # Get the stochastic state p(s_t | h_t)
        s_t = self.stochastic_state_prior_embed(h_t)
        s_t_mean = self.stochastic_state_prior_mean(s_t)
        s_t_logvar = self.stochastic_state_prior_logvar(s_t)
        # Sample the state
        s_t_sample = torch.distributions.Normal(s_t_mean, torch.exp(s_t_logvar)).rsample()
        # Get the next observation p(o_t | h_t, s_t)
        complete_state = torch.cat([h_t, s_t_sample], dim=-1)
        z_t = self.state_posterior_embed(complete_state)
        z_t_mean = self.state_posterior_mean(z_t)
        z_t_logvar = self.state_posterior_logvar(z_t)
        # Sample observation
        o_t = torch.distributions.Normal(z_t_mean, torch.exp(z_t_logvar)).rsample()

        return o_t, h_t, s_t_mean, s_t_logvar, z_t_mean, z_t_logvar

    def get_previous_hidden(self, o, a, device):
        with torch.no_grad():
            h_t_prev = torch.zeros(size=(a.size(0), self.hidden_state_size)).to(device)
            if not self.is_minerl:
                acts = F.softmax(F.one_hot(a.long(), num_classes=self.action_size).squeeze(1).float(), dim=-1).to(device)
            else:
                acts = a.float().to(device)
            obs_act_embed = torch.cat([o.view(o.size(0), -1), acts], dim=-1).to(device)
            encoded_obs_act = self.obs_action_embedding(obs_act_embed)
            h_t = self.gru(encoded_obs_act, h_t_prev)

        return h_t

    def predict(self, h_t, o_t, a_t, device, horizon=1):
        time_sequence = []
        obs_buffer = [o_t[:, i, :].to(device) for i in range(o_t.size(1))]
        with torch.no_grad():
            for t in range(horizon):
                o_next, h_t_next, _, _, _, _ = self.forward(h_t, o_t, a_t[:, t, :], device)
                time_sequence.append(o_next)
                obs_buffer.pop(0)
                obs_buffer.append(o_next)
                o_t = torch.cat(obs_buffer, dim=0).unsqueeze(0)
                h_t = h_t_next

        return time_sequence


def train(ssm: NextStepPredictionSSM, vae: VAE, rollout_buffer: RolloutBuffer, test_rollout_buffer: RolloutBuffer, epochs: int, iterations_per_epoch: int, optimizer: torch.optim.Optimizer,
          track_name: str):
    for e in range(epochs):
        epoch_metrics = {
            "reconstruction_loss": 0.0,
            "kl_divergence_loss": 0.0,
            "total_loss": 0.0
        }
        for i in range(iterations_per_epoch):
            batch = rollout_buffer.sample(batch_size)
            #latent_distrib = torch.distributions.Normal(batch.mean.to(device), torch.exp(batch.logvar).to(device))
            # Used to compute initial state
            previous_obs = batch.observation[:, :-2, :].to(device)
            previous_act = batch.action[:, :1 if not track_name.startswith("MineRL") else 10]
            # Actual sample
            obs = batch.observation[:, 1:-1, :]
            act = batch.action[:, 1 if not track_name.startswith("MineRL") else 10:]
            # Get the previous state given the sampled observation
            h_t = ssm.get_previous_hidden(previous_obs, previous_act, device)
            # Predict next state and obs
            o_t, h_t, prior_mean, prior_logvar, post_mean, post_logvar = ssm(h_t, obs, act, device)
            prior_distrib = torch.distributions.Normal(prior_mean, torch.exp(prior_logvar))
            posterior_distrib = torch.distributions.Normal(post_mean, torch.exp(post_logvar))
            # Compute loss
            reconstruction_loss = F.mse_loss(o_t, batch.observation[:, -1, :].to(device), reduction="mean")
            kl_loss = torch.distributions.kl.kl_divergence(posterior_distrib, prior_distrib).sum(-1).mean()

            loss = beta * kl_loss + reconstruction_loss
            # Backprop
            optimizer.zero_grad()
            nn.utils.clip_grad_norm_(ssm.parameters(), 1000., norm_type=2)
            loss.backward()
            optimizer.step()

            epoch_metrics["reconstruction_loss"] += reconstruction_loss.item()
            epoch_metrics["kl_divergence_loss"] += beta * kl_loss.item()

        mean_rec_loss = epoch_metrics["reconstruction_loss"] / iterations_per_epoch
        mean_kld_loss = epoch_metrics["kl_divergence_loss"] / iterations_per_epoch
        print(f"[Epoch {e + 1}/{epochs}] | Reconstruction Loss: {mean_rec_loss:.5f} | KL Loss: {mean_kld_loss:.5f}")

        if (e + 1) % test_frequency == 0:
            ssm.eval()
            test_batch = test_rollout_buffer.sample(test_batch_size)
            previous_obs = test_batch.observation[:, :-2, :]
            previous_act = test_batch.action[:, :1 if not track_name.startswith("MineRL") else 10]
            # Actual sample
            obs = test_batch.observation[:, 1:-1, :]
            act = test_batch.action[:, 1 if not track_name.startswith("MineRL") else 10:]
            with torch.no_grad():
                # Get the previous state given the sampled observation
                h_t = ssm.get_previous_hidden(previous_obs.to(device), previous_act, device)
                # Predict next state and obs
                o_t, h_t, _, _, _, _ = ssm(h_t, obs, act, device)

                decoded = vae.decode(test_batch.observation[:, -1, :].to(device))
                decoded_next = vae.decode(o_t)

            img = np.hstack([decoded.permute(0, 2, 3, 1).cpu().numpy(), decoded_next.permute(0, 2, 3, 1).cpu().numpy()])
            fig, ax = plt.subplots(1, img.shape[0], figsize=(0.66 * (test_batch_size + 1), 1.5))
            fig.suptitle(f"Epoch {e + 1}")
            for i in range(img.shape[0]):
                ax[i].imshow(img[i])
                ax[i].axis('off')
            plt.show()

        ssm.train()

    os.makedirs("models/ssm_ablation", exist_ok=True)
    torch.save(ssm.state_dict(), f"models/ssm_ablation/ssm_{track_name}_{epochs}_{num_stored}.pth")


if __name__ == '__main__':
    input_shape = (1, 3, 64, 64)
    hidden_state_size = 256
    context_size = 4
    obs_size = (64, 64)
    epochs = 250
    iterations_per_epoch = 500
    batch_size = 64
    learning_rate = 1e-3
    beta = 0.1
    track_name = "MineRLTreechop-v0"
    env = "minerl"
    latent_size = 128 if env == "super_tux_kart" else 512
    action_size = 9 if env == "super_tux_kart" else 10
    max_episode_length = 1500 if env == "super_tux_kart" else 10000
    if env == "super_tux_kart":
        data_path = f"/home/federico/PycharmProjects/PythonProject/{env}/trajectories/{track_name}/train"
        test_data_path = f"/home/federico/PycharmProjects/PythonProject/{env}/trajectories/{track_name}/test"
    elif env == "minerl":
        data_path = f"/home/federico/PycharmProjects/PythonProject/{env}/{track_name}"
        test_data_path = data_path
    test_frequency = 25
    test_batch_size = 10
    #num_stored = 10 if env == "minerl" else len(os.listdir(data_path))

    if env == "super_tux_kart":
        vae_weight_path = f"/home/federico/PycharmProjects/planet-torch/vae_models/{env}/{track_name}/encoder_{latent_size}_250_image.pth"
    elif env == "minerl":
        vae_weight_path = f"/home/federico/PycharmProjects/planet-torch/vae_models/{env}/{track_name}/encoder_{latent_size}_50.pth"
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    vae = VAE(
        input_shape=input_shape,
        latent_size=latent_size,
        custom_encoder_class=None,
        custom_encoder_args=None
    )
    vae.load_state_dict(torch.load(vae_weight_path))
    #vae.eval()
    vae.to(device)

    for num_stored in [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]:
        print(f"Training SSM on {num_stored} trajectories")
        ssm = NextStepPredictionSSM(
            encoded_latent_size=latent_size,
            action_size=action_size,
            hidden_state_size=hidden_state_size,
            context_size=context_size,
            is_minerl=True if env == "minerl" else False
        )
        ssm.to(device)

        rollout_buffer = RolloutBuffer(
            num_collected=num_stored,
            obs_size=latent_size,
            num_actions=action_size,
            max_episode_len=max_episode_length,
            context_size=context_size,
            is_minerl=True if env == "minerl" else False
        )
        test_rollout_buffer = RolloutBuffer(
            num_collected=5,
            obs_size=latent_size,
            num_actions=action_size,
            max_episode_len=max_episode_length,
            context_size=context_size,
            is_minerl=True if env == "minerl" else False
        )

        # Pre-compute latents for convenience
        if env == "super_tux_kart":
            for traj in os.listdir(f"{data_path}")[:num_stored]:
                data = np.load(f"{data_path}/{traj}", allow_pickle=True)
                latents, means, logvars = vae.get_latent(torch.tensor(data["observations"] / 255.).float().permute(0, 3, 1, 2).to(device))
                observations = latents.float()
                actions = torch.from_numpy(data["actions"]).float().unsqueeze(-1)
                next_observations = torch.cat([observations[1:], torch.zeros(size=(1, *observations[0].size())).to(device)], dim=0)

                rollout_buffer.add_trajectory(
                    observations=observations,
                    actions=actions,
                    next_observations=next_observations,
                    means=means,
                    logvars=logvars
                )
            for traj in os.listdir(f"{test_data_path}")[:5]:
                data = np.load(f"{test_data_path}/{traj}", allow_pickle=True)
                latents, means, logvars = vae.get_latent(torch.tensor(data["observations"] / 255.).float().permute(0, 3, 1, 2).to(device))
                observations = latents.float()
                actions = torch.from_numpy(data["actions"]).float().unsqueeze(-1)
                next_observations = torch.cat([observations[1:], torch.zeros(size=(1, *observations[0].size())).to(device)], dim=0)

                test_rollout_buffer.add_trajectory(
                    observations=observations,
                    actions=actions,
                    next_observations=next_observations,
                    means=means,
                    logvars=logvars
                )
        elif env == "minerl":
            max_frames = None
            full_file_paths = [os.path.join(data_path, x) for x in os.listdir(data_path)][:num_stored]
            for fp in full_file_paths:
                data = np.load(os.path.join(fp, "rendered.npz"), allow_pickle=True)
                action_keys = [x for x in data.files if x.startswith("action")]
                if "action$place" in action_keys:
                    action_keys.remove("action$place")
                actions = []
                for x in action_keys:
                    if data[x].ndim == 1:
                        actions.append(np.expand_dims(data[x].astype(np.float32), -1))
                    else:
                        for i in range(data[x].shape[-1]):
                            actions.append(np.expand_dims(data[x][:, i].astype(np.float32), -1))
                observations = []
                actions = np.concatenate(actions, axis=-1)
                next_observations = []
                means = []
                logvars = []
                cap = cv2.VideoCapture(os.path.join(fp, "recording.mp4"))
                count = 0
                while cap.isOpened():
                    ret, frame = cap.read()
                    if not ret or (max_frames and count >= max_frames):
                        break
                    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                    latent, mean, logvar = vae.get_latent(torch.tensor(frame / 255.).float().unsqueeze(0).permute(0, 3, 1, 2).to(device))
                    observations.append(latent)
                    means.append(mean)
                    logvars.append(logvar)

                    count += 1
                cap.release()
                observations = torch.tensor(torch.cat(observations[:actions.shape[0]], dim=0))
                actions = torch.tensor(actions)
                next_observations = torch.cat([observations[1:], torch.zeros(size=(1, *observations[0].size())).to(device)], dim=0)
                means = torch.tensor(torch.cat(means[:actions.shape[0]], dim=0))
                logvars = torch.tensor(torch.cat(logvars[:actions.shape[0]], dim=0))
                rollout_buffer.add_trajectory(
                    observations=observations,
                    actions=actions,
                    next_observations=next_observations,
                    means=means,
                    logvars=logvars
                )

            full_file_paths = [os.path.join(data_path, x) for x in os.listdir(data_path)][num_stored: num_stored + 5]
            for fp in full_file_paths:
                data = np.load(os.path.join(fp, "rendered.npz"), allow_pickle=True)
                action_keys = [x for x in data.files if x.startswith("action")]
                if "action$place" in action_keys:
                    action_keys.remove("action$place")
                actions = []
                for x in action_keys:
                    if data[x].ndim == 1:
                        actions.append(np.expand_dims(data[x].astype(np.float32), -1))
                    else:
                        for i in range(data[x].shape[-1]):
                            actions.append(np.expand_dims(data[x][:, i].astype(np.float32), -1))
                observations = []
                actions = np.concatenate(actions, axis=-1)
                next_observations = []
                means = []
                logvars = []
                cap = cv2.VideoCapture(os.path.join(fp, "recording.mp4"))
                count = 0
                while cap.isOpened():
                    ret, frame = cap.read()
                    if not ret or (max_frames and count >= max_frames):
                        break
                    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                    latent, mean, logvar = vae.get_latent(torch.tensor(frame / 255.).float().unsqueeze(0).permute(0, 3, 1, 2).to(device))
                    observations.append(latent)
                    means.append(mean)
                    logvars.append(logvar)

                    count += 1
                cap.release()
                observations = torch.tensor(torch.cat(observations[:actions.shape[0]], dim=0))
                actions = torch.tensor(actions)
                next_observations = torch.cat([observations[1:], torch.zeros(size=(1, *observations[0].size())).to(device)], dim=0)
                means = torch.tensor(torch.cat(means[:actions.shape[0]], dim=0))
                logvars = torch.tensor(torch.cat(logvars[:actions.shape[0]], dim=0))
                test_rollout_buffer.add_trajectory(
                    observations=observations,
                    actions=actions,
                    next_observations=next_observations,
                    means=means,
                    logvars=logvars
                )

        optimizer = torch.optim.Adam(ssm.parameters(), lr=3e-4, eps=1e-4)

        train(
            ssm=ssm,
            vae=vae,
            rollout_buffer=rollout_buffer,
            test_rollout_buffer=test_rollout_buffer,
            optimizer=optimizer,
            epochs=epochs,
            iterations_per_epoch=iterations_per_epoch,
            track_name=track_name
        )

    """horizon = 350
    horizon_img = 35
    ssm.load_state_dict(torch.load(f"models/ssm_{track_name}_{epochs}_c51.pth"))
    #ssm.eval()
    test_batch = rollout_buffer.sample_sequence(horizon=horizon)

    previous_obs = test_batch.observation[:ssm.context_size, :].unsqueeze(0)
    previous_act = test_batch.action[ssm.context_size - 1].unsqueeze(0)
    h_t = ssm.get_previous_hidden(previous_obs, previous_act, device)

    obs = test_batch.observation[1 :ssm.context_size + 1, :].unsqueeze(0)
    act = test_batch.action[ssm.context_size: ssm.context_size + horizon].unsqueeze(0)

    sequence = ssm.predict(h_t, obs, act, device, horizon)

    side_to_side = []
    img_comparison = []
    for obs, pred in zip(test_batch.next_observation[ssm.context_size:], sequence):
        with torch.no_grad():
            decoded_obs = vae.decode(obs.unsqueeze(0).to(device))
            decoded_pred = vae.decode(pred)
        side_to_side.append(torch.cat([decoded_obs, decoded_pred], dim=-1).squeeze(0).cpu().numpy())
        img_comparison.append(torch.cat([decoded_obs, decoded_pred], dim=-2).squeeze(0).permute(1, 2, 0).cpu().numpy())

    save_video(np.array(side_to_side), "/home/federico/PycharmProjects/planet-torch", f"{track_name}_{horizon}_ssm_{epochs}")
    # Save image
    plt.tight_layout()
    fig, ax = plt.subplots(1, horizon_img, figsize=(0.64 * horizon_img, 1.5))
    #fig.suptitle(f"Trajectory prediction")
    text_xs = []
    for i in range(horizon_img):
        ax[i].imshow(img_comparison[i])
        ax[i].axis('off')
        if i % 5 == 0:
            ax[i].set_title(f"t = {i}")

    fig.savefig(f"trajectory_prediction_{track_name}_{epochs}.png")
    plt.show()"""
