import sys

sys.dont_write_bytecode = True

import os
import gym
import gymnasium
from vizdoom import gymnasium_wrapper
import math
import copy
import random
import datetime
import numpy as np
import cv2
from time import sleep
import wandb
import pickle

import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.lines import Line2D
import matplotlib
import pickle
import pandas as pd
from collections import deque, namedtuple

import torch
import torch.nn as nn
from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR, StepLR, CyclicLR
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
from transformers import get_cosine_schedule_with_warmup

from components.drawing import Arrow3D

font = {"weight": "bold", "size": 11}

matplotlib.rc("font", **font)


def set_seed(seed):
    """Sets the seed of the entire notebook so results are the same every time we run.
    This is for REPRODUCIBILITY."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)

    print("Global seeds set to", seed)


Experience = namedtuple(
    "Experience",
    field_names=["state", "action", "reward", "next_state", "done", "state_hash"],
)


class ReplayBuffer:
    def __init__(self, buffer_size, batch_size, seed, memory=None):
        if memory == None:
            self.memory = deque(maxlen=buffer_size)
        else:
            self.memory = memory
        self.batch_size = batch_size

    def add(self, state, action, reward, next_state, done, state_hash):
        """Add a new experience to memory."""
        e = Experience(state, action, reward, next_state, done, state_hash)
        self.memory.append(e)

    def sample(self):
        experiences = random.sample(self.memory, k=self.batch_size)
        states = torch.from_numpy(
            np.vstack([e.state[None, :] for e in experiences if e is not None])
        ).float()
        actions = torch.from_numpy(
            np.vstack([e.action for e in experiences if e is not None])
        ).long()
        rewards = torch.from_numpy(
            np.vstack([e.reward for e in experiences if e is not None])
        ).float()
        next_states = torch.from_numpy(
            np.vstack([e.next_state[None, :] for e in experiences if e is not None])
        ).float()
        dones = torch.from_numpy(
            np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)
        ).float()

        return (states, actions, rewards, next_states, dones)

    def __len__(self):
        return len(self.memory)

    def save(self, file_path):
        """Save the replay buffer to a file."""
        with open(file_path, "wb") as f:
            pickle.dump(self.memory, f)
        print(f"Replay buffer saved to {file_path}.")

    def load(self, file_path):
        """Load the replay buffer from a file."""
        with open(file_path, "rb") as f:
            self.memory = pickle.load(f)
        print(f"Replay buffer loaded from {file_path}.")
        print(f"Replay buffer contains {len(self.memory)} experiences.")


def split_buffer(buffer, split_ratio=0.8):
    """Split the replay buffer into train and test sets."""
    if not (0 <= split_ratio <= 1):
        raise ValueError("train_percentage must be between 0 and 1.")

    # Shuffle the memory to ensure randomness
    shuffled_memory = list(buffer.memory)
    random.shuffle(shuffled_memory)

    # Split the memory
    split_idx_train = int(len(shuffled_memory) * split_ratio)
    split_idx_test = int(len(shuffled_memory) * 0.8)
    train_set = shuffled_memory[:split_idx_train]
    test_set = shuffled_memory[split_idx_test:]
    train_buffer = ReplayBuffer(
        buffer_size=len(train_set),
        batch_size=buffer.batch_size,
        memory=train_set,
        seed=33,
    )
    test_buffer = ReplayBuffer(
        buffer_size=len(test_set),
        batch_size=buffer.batch_size,
        memory=test_set,
        seed=33,
    )

    return train_buffer, test_buffer


class EncoderCNN(nn.Module):
    def __init__(self, output_dim, input_dim):
        super(EncoderCNN, self).__init__()
        self.num_channel, self.h, self.w = input_dim

        self.gate = nn.Tanh()
        self.conv_encoder = nn.Sequential(
            nn.Conv2d(
                in_channels=self.num_channel,
                out_channels=32,
                kernel_size=(4, 4),
                stride=2,
            ),
            self.gate,
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(4, 4), stride=2),
            self.gate,
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(4, 4), stride=2),
            self.gate,
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(4, 4), stride=2),
            self.gate,
        )

        self.flatten_dim_after_conv = 256 * 2 * 2

        self.fc_after_conv = nn.Sequential(
            nn.Linear(self.flatten_dim_after_conv, 128),
            self.gate,
            nn.Linear(128, 64),
            self.gate,
            nn.Linear(64, 32),
            self.gate,
            nn.Linear(32, output_dim),
        )

    def forward(self, x):

        x = self.conv_encoder(x)
        x = torch.flatten(x, start_dim=1)
        x = self.fc_after_conv(x)

        return x


class Transition(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        gate = nn.Tanh()

        self.fc = nn.Sequential(
            nn.Linear(input_dim, 32),
            gate,
            nn.Linear(32, 32),
            gate,
            nn.Linear(32, 32),
            gate,
            nn.Linear(32, output_dim),
        )

    def forward(self, z, action):
        d = self.fc(torch.cat([z, action], dim=-1))
        d = d

        return d


class MBAgent:
    def __init__(
        self,
        env: gym.Env,
        latent_dim: int,
        batch_size: int,
        device: str = "cpu",
        debug: bool = False,
        global_lr: float = 1e-6,
        max_norm: float = 0.1,
        buffer_size=10000,
        seed=2022,
        wandb_run=None,
        temp=0.1,
        split_ratio=0.8,
        rb_path=None,
    ):
        self.env = env
        self.device = device
        self.batch_size = batch_size
        self.latent_dim = latent_dim
        self.debug = debug
        self.max_norm = max_norm
        self.replays = ReplayBuffer(
            buffer_size=buffer_size, batch_size=batch_size, seed=seed
        )
        if rb_path is not None:
            self.replays.load(rb_path)
        self.replays_train, self.replays_test = split_buffer(
            self.replays, split_ratio=split_ratio
        )
        self.temp = temp
        self.split_ratio = split_ratio
        self.seed = seed
        self.num_actions = self.env.get_num_actions()
        self.wandb_run = wandb_run

        self.encoder = EncoderCNN(
            input_dim=self.env.observation_space.shape, output_dim=latent_dim
        ).to(self.device)
        self.transition = Transition(
            input_dim=latent_dim + self.num_actions, output_dim=latent_dim
        ).to(self.device)

        # self.optimizer = torch.optim.RMSprop(
        #     [
        #         {"params": self.encoder.parameters(), "lr": encoder_lr, "weight_decay": 0.0},
        #         {"params": self.transition.parameters(), "lr": trans_lr, "weight_decay": 0.0},
        #     ],
        #     lr=encoder_lr,
        #     weight_decay=0.0,
        #     alpha=0.99,
        #     eps=1e-20
        # )

        self.optimizer = torch.optim.AdamW(
            list(self.encoder.parameters()) + list(self.transition.parameters()),
            lr=global_lr,
        )

        self.step_count = 0
        self.grad_update_count = 0

    def _compute_infonce_loss(self, pred, gt, temp=1.0):
        dist = torch.cdist(pred, gt, p=2)
        sim = -dist / temp
        logits = F.log_softmax(sim, dim=1)
        pos_logits = torch.diag(logits)
        loss = -pos_logits.mean()

        return loss

    def _compute_hinge_loss(self, state_x, state_y, margin=1.0):
        l2_distance = torch.linalg.norm(state_x - state_y, ord=2, dim=-1)
        hinge_loss = (
            torch.max(torch.zeros_like(l2_distance), l2_distance - margin)
            .clamp(0, 100)
            .mean()
        )
        return hinge_loss

    def _compute_hypercube_loss(self, state, r=1.0):
        linf_distance = torch.linalg.norm(state.abs(), ord=float("inf"), dim=-1)
        hypercube_loss = (
            torch.max(torch.zeros_like(linf_distance), linf_distance - r)
            .clamp(0, 100)
            .mean()
        )
        return hypercube_loss

    def _place_z_on_circle(self, z):
        x = torch.cos(z)
        y = torch.sin(z)
        return torch.cat([x.unsqueeze(1), y.unsqueeze(1)], dim=1)

    def _compute_identity_loss(self, x):
        l1_loss = (x.abs()).mean()
        return l1_loss

    def learn(self):
        states, actions, rewards, next_states, dones = self.replays_train.sample()
        actions = actions.to(self.device)
        states = states.to(self.device)
        dones = dones.to(self.device)
        next_states = next_states.to(self.device)
        onehot_actions = torch.zeros(
            self.batch_size, self.num_actions, device=self.device
        )
        onehot_actions[np.arange(self.batch_size), actions.reshape(-1)] = 1

        batch_size = states.shape[0]

        ### FORWARD PASS ###
        z = self.encoder(states)
        next_z = self.encoder(next_states)
        d = self.transition(z, onehot_actions)

        ### CONTRASTIVE REPRESENTATION LEARNING ###
        # "nothing" = 0
        # "turn left" = 1
        # "turn right" = 2
        # "forward" = 3

        z[:, 0] = torch.remainder(z[:, 0], 2 * np.pi)
        next_z[:, 0] = torch.remainder(next_z[:, 0], 2 * np.pi)
        z_hat = z + d
        z_hat[:, 0] = torch.remainder(z_hat[:, 0], 2 * np.pi)

        c = self._place_z_on_circle(z[:, 0])
        c_hat = self._place_z_on_circle(z_hat[:, 0])
        next_c = self._place_z_on_circle(next_z[:, 0])
        s = z[:, 1:]
        s_hat = z_hat[:, 1:]
        next_s = next_z[:, 1:]
        x = torch.cat([c, s], dim=-1)
        x_hat = torch.cat([c_hat, s_hat], dim=-1)
        next_x = torch.cat([next_c, next_s], dim=-1)

        transition_loss = self._compute_infonce_loss(x_hat, next_x, temp=self.temp)
        spatial_hypercube_loss = self._compute_hypercube_loss(
            s, r=1.0
        ) + self._compute_hinge_loss(s, next_s, margin=0.1)

        ## 1. Contrastive loss when taking "turn right" action

        act0_d = torch.masked_select(d, mask=(actions == 2).expand(d.shape)).reshape(
            -1, d.shape[1]
        )
        act0_zero_coords = [1, 2]
        identity_loss = self._compute_identity_loss(act0_d[:, act0_zero_coords])

        ## 2. Contrastive loss when taking "go straight" action

        act1_d = torch.masked_select(d, mask=(actions == 3).expand(d.shape)).reshape(
            -1, d.shape[1]
        )

        act1_zero_coords = [0]
        identity_loss += self._compute_identity_loss(act1_d[:, act1_zero_coords])

        ## 3. Constrative loss when taking "nothing" action

        act2_d = torch.masked_select(d, mask=(actions == 0).expand(d.shape)).reshape(
            -1, d.shape[1]
        )

        act2_zero_coords = [0]
        identity_loss += self._compute_identity_loss(act2_d[:, act2_zero_coords])

        loss = transition_loss + identity_loss + spatial_hypercube_loss

        if self.grad_update_count % 100 == 0:
            last_lr = self.lr_scheduler.get_last_lr()[0]
            print(
                f"step {self.grad_update_count} - "
                f"transition_loss: {transition_loss.item():.4f} - "
                f"identity_loss: {identity_loss.item():.4f} - "
                f"spatial_hypercube_loss: {spatial_hypercube_loss.item():.4f} - "
                f"lr: {last_lr:.6f}"
            )

            self.wandb_run.log(
                data={
                    "loss/transition_loss": transition_loss.item(),
                    "loss/identity_loss": identity_loss.item(),
                    "loss/spatial_hypercube_loss": spatial_hypercube_loss.item(),
                    "lr": last_lr,
                },
                step=self.grad_update_count,
            )

        # Backward pass
        self.optimizer.zero_grad()
        loss.backward()

        # Gradient clipping
        total_norm_encoder = nn.utils.clip_grad_norm_(
            self.encoder.parameters(), max_norm=self.max_norm
        )
        total_norm_transition = nn.utils.clip_grad_norm_(
            self.transition.parameters(), max_norm=self.max_norm
        )
        if not self.debug:
            # self.tb_writer.add_scalar("total_norm/encoder", total_norm_encoder.item(), self.grad_update_count)
            # self.tb_writer.add_scalar("total_norm/transition", total_norm_transition.item(), self.grad_update_count)
            self.wandb_run.log(
                data={
                    "total_norm/encoder": total_norm_encoder.item(),
                    "total_norm/transition": total_norm_transition.item(),
                },
                step=self.grad_update_count,
            )

        # Gradients update
        self.optimizer.step()
        self.lr_scheduler.step()

        if self.grad_update_count % 1000 == 0:
            for name, param in self.encoder.named_parameters():
                if param.grad is not None and not self.debug:
                    # self.tb_writer.add_histogram(f"grad/encoder_{name}", param.grad.cpu(), self.grad_update_count)
                    # self.tb_writer.add_histogram(f"weight/encoder_{name}", param.clone().cpu().data.numpy(), self.grad_update_count)
                    # self.tb_writer.add_scalar(f"grad_norm/encoder_{name}", param.grad.data.norm(2).item(), self.grad_update_count)
                    self.wandb_run.log(
                        data={
                            "grad/encoder_"
                            + name: wandb.Histogram(param.grad.cpu().numpy()),
                            "weight/encoder_"
                            + name: wandb.Histogram(param.clone().cpu().data.numpy()),
                            "grad_norm/encoder_" + name: param.grad.data.norm(2).item(),
                        },
                        step=self.grad_update_count,
                    )

            for name, param in self.transition.named_parameters():
                if param.grad is not None and not self.debug:
                    # self.tb_writer.add_histogram(f"grad/transition_{name}", param.grad.cpu(), self.grad_update_count)
                    # self.tb_writer.add_histogram(f"weight/transition_{name}", param.clone().cpu().data.numpy(), self.grad_update_count)
                    # self.tb_writer.add_scalar(f"grad_norm/transition_{name}", param.grad.data.norm(2).item(), self.grad_update_count)
                    self.wandb_run.log(
                        data={
                            "grad/transition_"
                            + name: wandb.Histogram(param.grad.cpu().numpy()),
                            "weight/transition_"
                            + name: wandb.Histogram(param.clone().cpu().data.numpy()),
                            "grad_norm/transition_"
                            + name: param.grad.data.norm(2).item(),
                        },
                        step=self.grad_update_count,
                    )

            if self.grad_update_count % 100 == 0:
                self.compute_hk_test()
                self.compute_hk_train()

        self.grad_update_count += 1

    def train(self, global_step):

        self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer, T_max=global_step, eta_min=1e-7
        )

        for step in range(global_step):
            self.learn()

            if step % 1000 == 0:
                self.evaluate(epoch=step)

    @torch.inference_mode()
    def compute_hk_test(self):
        size_test = len(self.replays_test.memory)
        hits_at_1 = 0
        hits_at_3 = 0
        hits_at_5 = 0
        rr_sum = 0

        data_loader = torch.utils.data.DataLoader(
            self.replays_test.memory, batch_size=self.batch_size, shuffle=False
        )

        for i, batch in enumerate(data_loader):
            states, actions, rewards, next_states, dones, state_hashes = batch

            actions = actions.to(self.device)
            states = states.to(self.device).float()
            dones = dones.to(self.device)
            next_states = next_states.to(self.device).float()
            onehot_actions = torch.zeros(
                actions.shape[0], self.num_actions, device=self.device
            )
            onehot_actions[np.arange(actions.shape[0]), actions.reshape(-1)] = 1

            z = self.encoder(states)
            next_z = self.encoder(next_states)
            d = self.transition(z, onehot_actions)

            z[:, 0] = torch.remainder(z[:, 0], 2 * np.pi)
            next_z[:, 0] = torch.remainder(next_z[:, 0], 2 * np.pi)
            z_hat = z + d
            z_hat[:, 0] = torch.remainder(z_hat[:, 0], 2 * np.pi)

            z_hat = z + d
            z_hat[:, 0] = torch.remainder(z_hat[:, 0], 2 * np.pi)
            c = self._place_z_on_circle(z[:, 0])
            c_hat = self._place_z_on_circle(z_hat[:, 0])
            next_c = self._place_z_on_circle(next_z[:, 0])
            s = z[:, 1:]
            s_hat = z_hat[:, 1:]
            next_s = next_z[:, 1:]
            x = torch.cat([c, s], dim=-1)
            x_hat = torch.cat([c_hat, s_hat], dim=-1)
            next_x = torch.cat([next_c, next_s], dim=-1)

            batch_size = states.shape[0]

            dist_matrix = torch.cdist(x_hat, next_x, p=2)
            dist_matrix_diag = torch.diag(dist_matrix).unsqueeze(-1)
            dist_matrix_augmented = torch.cat([dist_matrix_diag, dist_matrix], dim=1)

            dist_np = dist_matrix_augmented.detach().cpu().numpy()
            indices = []

            for row in dist_np:
                keys = (np.arange(len(row)), row)
                indices.append(np.lexsort(keys))

            indices = np.stack(indices, axis=0)
            indices = torch.from_numpy(indices).long()

            labels = torch.zeros(
                indices.size(0), device=indices.device, dtype=torch.int64
            ).unsqueeze(-1)

            match = indices[:, :5] == labels
            num_match = match.sum()
            hits_at_5 += num_match.item()

            match = indices[:, :3] == labels
            num_match = match.sum()
            hits_at_3 += num_match.item()

            match = indices[:, :1] == labels
            num_match = match.sum()
            hits_at_1 += num_match.item()

            match = indices == labels
            _, rank = match.max(dim=1)

            reciprocal_rank = torch.reciprocal(rank.float() + 1)
            rr_sum += reciprocal_rank.sum()

        rr_sum /= float(size_test)
        hits_at_5 /= float(size_test)
        hits_at_3 /= float(size_test)
        hits_at_1 /= float(size_test)

        wandb.log(
            {
                "metrics/hits_at_1": hits_at_1,
                "metrics/hits_at_3": hits_at_3,
                "metrics/hits_at_5": hits_at_5,
                "metrics/reciprocal_rank": rr_sum,
            },
            step=self.grad_update_count,
        )

        return rr_sum, (hits_at_1, hits_at_3, hits_at_5)

    @torch.inference_mode()
    def compute_hk_train(self):
        size_test = len(self.replays_train.memory)
        hits_at_1 = 0
        hits_at_3 = 0
        hits_at_5 = 0
        rr_sum = 0

        data_loader = torch.utils.data.DataLoader(
            self.replays_train.memory, batch_size=self.batch_size, shuffle=False
        )

        for i, batch in enumerate(data_loader):
            states, actions, rewards, next_states, dones, state_hashes = batch

            actions = actions.to(self.device)
            states = states.to(self.device).float()
            dones = dones.to(self.device)
            next_states = next_states.to(self.device).float()
            onehot_actions = torch.zeros(
                actions.shape[0], self.num_actions, device=self.device
            )
            onehot_actions[np.arange(actions.shape[0]), actions.reshape(-1)] = 1

            z = self.encoder(states)
            next_z = self.encoder(next_states)
            d = self.transition(z, onehot_actions)

            z[:, 0] = torch.remainder(z[:, 0], 2 * np.pi)
            next_z[:, 0] = torch.remainder(next_z[:, 0], 2 * np.pi)
            z_hat = z + d
            z_hat[:, 0] = torch.remainder(z_hat[:, 0], 2 * np.pi)

            z_hat = z + d
            z_hat[:, 0] = torch.remainder(z_hat[:, 0], 2 * np.pi)
            c = self._place_z_on_circle(z[:, 0])
            c_hat = self._place_z_on_circle(z_hat[:, 0])
            next_c = self._place_z_on_circle(next_z[:, 0])
            s = z[:, 1:]
            s_hat = z_hat[:, 1:]
            next_s = next_z[:, 1:]
            x = torch.cat([c, s], dim=-1)
            x_hat = torch.cat([c_hat, s_hat], dim=-1)
            next_x = torch.cat([next_c, next_s], dim=-1)

            batch_size = states.shape[0]

            dist_matrix = torch.cdist(x_hat, next_x, p=2)
            dist_matrix_diag = torch.diag(dist_matrix).unsqueeze(-1)
            dist_matrix_augmented = torch.cat([dist_matrix_diag, dist_matrix], dim=1)

            dist_np = dist_matrix_augmented.detach().cpu().numpy()
            indices = []

            for row in dist_np:
                keys = (np.arange(len(row)), row)
                indices.append(np.lexsort(keys))

            indices = np.stack(indices, axis=0)
            indices = torch.from_numpy(indices).long()

            labels = torch.zeros(
                indices.size(0), device=indices.device, dtype=torch.int64
            ).unsqueeze(-1)

            match = indices[:, :5] == labels
            num_match = match.sum()
            hits_at_5 += num_match.item()

            match = indices[:, :3] == labels
            num_match = match.sum()
            hits_at_3 += num_match.item()

            match = indices[:, :1] == labels
            num_match = match.sum()
            hits_at_1 += num_match.item()

            match = indices == labels
            _, rank = match.max(dim=1)

            reciprocal_rank = torch.reciprocal(rank.float() + 1)
            rr_sum += reciprocal_rank.sum()

        rr_sum /= float(size_test)
        hits_at_5 /= float(size_test)
        hits_at_3 /= float(size_test)
        hits_at_1 /= float(size_test)

        wandb.log(
            {
                "metrics/hits_at_1_train": hits_at_1,
                "metrics/hits_at_3_train": hits_at_3,
                "metrics/hits_at_5_train": hits_at_5,
                "metrics/reciprocal_rank_train": rr_sum,
            },
            step=self.grad_update_count,
        )

        return rr_sum, (hits_at_1, hits_at_3, hits_at_5)

    def evaluate(self, epoch):
        self._infer_mode()
        self.env.draw_latent(model=self, epoch=epoch)

    def _infer_mode(self):
        self.encoder.eval()
        self.transition.eval()


class VizdoomSingleRoom(gymnasium.Env):
    def __init__(self, render_mode="human", num_stack=4):
        self.env = gymnasium.make("VizdoomSingleRoom-v0", render_mode=render_mode)
        self.num_stack = num_stack
        self.env.observation_space = self.env.observation_space.spaces["screen"]
        self.observation_space = gymnasium.spaces.Box(
            low=0, high=1, shape=(3, 64, 64), dtype=np.float32
        )
        self.action_space = self.env.action_space
        self.curr_step = 0
        self.curr_pos = np.array([0, 0])
        self.max_dist = np.linalg.norm(np.array([-224, -224]) - np.array([224, 224]))

    def step(self, action, savefig=False):
        self.curr_step += 1

        if isinstance(action, np.ndarray):
            action = int(action.item())

        if action not in [0, 1, 2, 3]:
            raise ValueError(f"Invalid action: {action}")
        actions = {
            0: {"binary": 0, "continuous": np.array([0], dtype=np.float32)},
            1: {"binary": 0, "continuous": np.array([-36], dtype=np.float32)},
            2: {"binary": 0, "continuous": np.array([36], dtype=np.float32)},
            3: {"binary": 1, "continuous": np.array([0], dtype=np.float32)},
        }

        action = actions[action]
        obs, reward, done, truncated, info = self.env.step(action)
        obs, self.curr_pos = obs["screen"], obs["gamevariables"]
        self.curr_pos = np.array(self.curr_pos)

        obs = cv2.resize(obs, (64, 64), interpolation=cv2.INTER_AREA)
        obs = np.array(obs) / 255.0
        obs = obs.transpose(2, 0, 1)

        reward = self.reward()
        done = self.is_done()
        truncated = self.curr_step > 2500

        return obs, reward, done, truncated, info

    def get_num_actions(self):
        return 4

    def reset(self, **kwargs):
        self.curr_step = 0
        obs, info = self.env.reset(**kwargs)
        obs, self.curr_pos = obs["screen"], obs["gamevariables"]
        self.curr_pos = np.array(self.curr_pos)

        obs = cv2.resize(obs, (64, 64), interpolation=cv2.INTER_AREA)
        obs = np.array(obs) / 255.0
        obs = obs.transpose(2, 0, 1)

        return obs, info

    def reward(self):
        dist = np.linalg.norm(np.array([180, 180]) - self.curr_pos) / self.max_dist
        if dist < 0.1:
            reward = 10.0
        else:
            reward = -dist
        return reward

    def is_done(self):
        dist = np.linalg.norm(np.array([180, 180]) - self.curr_pos) / self.max_dist
        if dist < 0.1:
            done = True
        else:
            done = False
        return done

    def render(self, mode="human"):
        return self.env.render(mode)

    def close(self):
        self.env.close()

    def seed(self, seed=None):
        self.env.seed(seed)

    @torch.inference_mode()
    def draw_latent(self, model, epoch, saveimg=False):
        unique_hashes = set()
        states = []
        agent_pos_id = []
        agent_pos_dir = []
        actions = [3] * 30 + [0] * 30 + [2] * (2 * 10)
        state, _ = self.reset()
        states.append(state)

        for action in actions:
            if saveimg:
                state, _, _, _, _ = self.step(action, savefig=True)
            else:
                state, _, _, _, _ = self.step(action)
            states.append(state)
        states = np.stack(states[:-1])
        states = torch.tensor(states).float().to(model.device)
        num_state = states.shape[0]

        zs = model.encoder(states)
        zs[:, 0] = torch.remainder(zs[:, 0], 2 * math.pi)

        # Plot rotation neuron z_0
        style = "Simple, tail_width=0.1, head_width=3, head_length=5"
        kw = dict(arrowstyle=style, color="k")
        fig, ax = plt.subplots(1, 1, figsize=(9, 9))

        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_ylim((-0.5, 0.5))
        ax.axhline(y=0, color="black", linestyle="-")
        ax.axvline(x=0, color="black", linestyle="-")
        ax.axvline(x=2 * math.pi, color="black", linestyle="--")
        ax.text(2 * math.pi - 0.4, 0.3, "x=2*pi", rotation=90, fontsize=9)

        cur_zs = zs[59:]
        # cur_zs [:,0]= cur_zs[:,0]+0.2
        actions = actions[59:]
        num_action = self.get_num_actions()
        num_state = cur_zs.shape[0]

        onehot_actions = torch.zeros((num_state, num_action), device=model.device)
        onehot_actions[np.arange(num_state), actions] = 1

        d = model.transition(cur_zs, onehot_actions)
        next_zs = cur_zs + d
        next_zs[:, 0] = torch.remainder(next_zs[:, 0], (2 * math.pi))

        cmap = plt.get_cmap("coolwarm")
        norm = plt.Normalize(vmin=0, vmax=cur_zs.shape[0])
        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
        sm.set_array([])

        for k in range(cur_zs.shape[0]):
            x = cur_zs[k, 0].cpu().numpy()
            next_x_no_mod = (cur_zs[k, 0] + d[k, 0]).cpu().numpy()
            next_x = next_zs[k, 0].cpu().numpy()
            color = cmap(norm(k))

            ax.plot(x, 0, "o", color=color)

            ax.add_patch(
                patches.FancyArrowPatch(
                    (x, 0), (next_x_no_mod, 0), connectionstyle="arc3, rad=-.5", **kw
                )
            )
            ax.add_patch(
                patches.FancyArrowPatch(
                    (x, 0),
                    (next_x, 0),
                    connectionstyle="arc3, rad=-.5",
                    alpha=0.2,
                    linestyle="--",
                    **kw,
                )
            )
        cbar = fig.colorbar(sm, ax=ax)
        model.wandb_run.log(
            {"plot/rotation": wandb.Image(fig)}, step=model.grad_update_count
        )
        plt.close()

        # # Plot spatial neurons z_1, z_2
        fig, ax = plt.subplots(1, 1, figsize=(9, 9))
        vst = set()
        cmap = plt.get_cmap("coolwarm")
        norm = plt.Normalize(vmin=0, vmax=zs.shape[0])
        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
        sm.set_array([])
        for i in range(zs.shape[0]):
            color = cmap(norm(i))
            ax.plot(zs[[i], 1].cpu(), zs[[i], 2].cpu(), "o", color=color)

        cbar = fig.colorbar(sm, ax=ax)
        model.wandb_run.log(
            {"plot/spatial": wandb.Image(fig)}, step=model.grad_update_count
        )
        plt.close()


def main(
    global_step=50000,
    seed=27,
    batch_size=250,
    latent_dim=3,
    buffer_size=1000000,
    global_lr=1e-4,
    max_norm=0.5,
    device="cuda",
    temp=0.5,
    split_ratio=0.07,
    rb_path=None,
):

    set_seed(seed)
    run = wandb.init(
        project="vizdoom_generalization",
        # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
        name=f"prior_seed_{seed}_{split_ratio}",
        # Track hyperparameters and run metadata
        config={
            "encoder_lr": global_lr,
            "architecture": "CNN",
            "global step": global_step,
            "seed": seed,
            "prior": "test_old_code",
            "split_ratio": split_ratio,
            "temp": temp,
            "batch_size": batch_size,
        },
    )

    env = VizdoomSingleRoom(render_mode=None)

    agent = MBAgent(
        env=env,
        latent_dim=latent_dim,
        batch_size=batch_size,
        buffer_size=buffer_size,
        global_lr=global_lr,
        max_norm=max_norm,
        device=device,
        seed=seed,
        wandb_run=run,
        temp=temp,
        split_ratio=split_ratio,
        rb_path=rb_path,
    )

    agent.train(global_step=global_step)
    env.close()
    run.finish()


if __name__ == "__main__":
    split_ratios = [0.8, 0.1]
    seeds = [5, 55, 27, 33, 96]

    for split_ratio in split_ratios:
        for seed in seeds:
            main(
                seed=seed,
                split_ratio=split_ratio,
                rb_path="/rl-transitions/src/envs/vizdoom/dataset/replay_buffer.pickle",
            )
