import sys

sys.dont_write_bytecode = True

import os
import datetime
import numpy as np
import random
import pandas as pd
from collections import defaultdict

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter

import matplotlib.pyplot as plt
import matplotlib.patches as patches
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.patches import FancyArrowPatch
from mpl_toolkits.mplot3d import proj3d


# Enable LaTeX rendering
# plt.rcParams['text.usetex'] = True
# plt.rcParams['font.family'] = 'serif'

# Adjust rcParams for custom background colors
plt.rcParams["figure.facecolor"] = "#ffffff"  # White for figure background
plt.rcParams["axes.facecolor"] = "#ffffff"  # White for axes background
plt.rcParams["axes.edgecolor"] = "black"  # Black border around axes
plt.rcParams["grid.color"] = "#d3d3d3"  # Light gray grid lines
plt.rcParams["grid.alpha"] = 0.8

plt.rcParams["lines.linewidth"] = 0.5
plt.rcParams["axes.edgecolor"] = "black"
plt.rcParams["axes.linewidth"] = 0.5
# Set global font size and weight
plt.rcParams["font.size"] = 22  # Set font size
plt.rcParams["font.weight"] = "bold"  # Set font weight to bold

# Apply to axes titles, labels, and ticks
plt.rcParams["axes.titlesize"] = 24
plt.rcParams["axes.titleweight"] = "bold"
plt.rcParams["axes.labelsize"] = 22
plt.rcParams["axes.labelweight"] = "bold"
plt.rcParams["xtick.labelsize"] = 22
plt.rcParams["ytick.labelsize"] = 22

# Apply to figure titles
plt.rcParams["figure.titlesize"] = 20
plt.rcParams["figure.titleweight"] = "bold"


def set_seed(seed):
    # random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)


class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(Encoder, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)
        self.g = nn.Tanh()
        # self.g = nn.Mish()

        for m in self.modules():
            if isinstance(m, nn.Linear):
                # nn.init.normal_(m.weight, mean=0, std=0.2)
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.fc1(x)
        x = self.g(x)
        x = self.fc2(x)
        x = self.g(x)
        x = self.fc3(x)
        return x


class TransitionModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(TransitionModel, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.g = nn.Tanh()
        # self.g = nn.Mish()

        for m in self.modules():
            if isinstance(m, nn.Linear):
                # nn.init.normal_(m.weight, mean=0, std=0.2)
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)

    def forward(self, x, a):
        x = torch.cat([x, a], dim=-1)
        x = self.fc1(x)
        x = self.g(x)
        x = self.fc2(x)
        return x


class QNetwork(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)
        self.g = nn.Tanh()
        # self.g = nn.ReLU()

        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.fc1(x)
        x = self.g(x)
        x = self.fc2(x)
        x = self.g(x)
        x = self.fc3(x)
        return x


class WorldModel:
    def __init__(
        self,
        env,
        latent_dim=3,
        hidden_dim=32,
        encoder_lr=3e-4,
        transition_lr=1e-4,
        q_lr=1e-4,
        momentum=0.9,
        weight_decay=0,
    ):
        self.env = env
        self.latent_dim = latent_dim
        self.hidden_dim = hidden_dim
        self.encoder_lr = encoder_lr
        self.transition_lr = transition_lr

        self.encoder = Encoder(
            input_dim=2 * env.size + 4, hidden_dim=hidden_dim, output_dim=latent_dim
        )
        self.transition_model = TransitionModel(
            input_dim=latent_dim + env.num_actions,
            hidden_dim=hidden_dim,
            output_dim=latent_dim,
        )
        self.encoder_optimizer = optim.RMSprop(
            self.encoder.parameters(),
            lr=encoder_lr,
            momentum=momentum,
            weight_decay=weight_decay,
        )
        self.transition_optimizer = optim.RMSprop(
            self.transition_model.parameters(),
            lr=transition_lr,
            momentum=momentum,
            weight_decay=weight_decay,
        )

        self.q_network = QNetwork(
            input_dim=4,
            hidden_dim=32,
            output_dim=env.num_actions,
        )
        self.q_network_target = QNetwork(
            input_dim=4,
            hidden_dim=32,
            output_dim=env.num_actions,
        )
        self.q_optimizer = optim.AdamW(
            self.q_network.parameters(),
            lr=q_lr,
            weight_decay=weight_decay,
        )

    def _place_z_on_circle(self, z, r=1.0):
        x = torch.cos(z) * r
        y = torch.sin(z) * r
        return torch.cat([x.unsqueeze(1), y.unsqueeze(1)], dim=1)

    def _compute_prediction_loss(self, state, state_hat):
        l2_loss = torch.linalg.norm(state - state_hat, ord=2, dim=-1).mean()
        return l2_loss
        # l1_loss = (state - state_hat).abs().mean()
        # return l1_loss

    def _compute_entropy_loss(self, state_x, state_y, C=5.0):
        l2_distance = torch.linalg.norm(state_x - state_y, ord=2, dim=-1)
        entropy_loss = (-C * l2_distance).exp().mean()
        return entropy_loss

    def _compute_identity_loss(self, x):
        l1_loss = x.abs().mean()
        return l1_loss
        # l2_loss = torch.linalg.norm(x, ord=2, dim=-1).mean()
        # return l2_loss

    def _compute_hinge_loss(self, state_x, state_y, margin=1.0):
        l2_distance = torch.linalg.norm(state_x - state_y, ord=2, dim=-1)
        hinge_loss = (
            torch.max(torch.zeros_like(l2_distance), l2_distance - margin)
            .clamp(0, 100)
            .mean()
        )
        return hinge_loss

    def _compute_hypercube_loss(self, state, r=1.0):
        linf_distance = torch.linalg.norm(state.abs(), ord=float("inf"), dim=-1)
        hypercube_loss = (
            torch.max(torch.zeros_like(linf_distance), linf_distance - r)
            .clamp(0, 100)
            .mean()
        )
        return hypercube_loss

    def _compute_infonce_loss(self, pred, gt, temp=1.0):
        # InfoNCE loss
        # pred, gt shape is b x d
        # Compute a pair-wise similarity matrix based on L2 distance
        sim = (torch.cdist(pred, gt, p=1).neg() / temp).exp()
        # Compute logits by dividing each similarity score by the sum of all similarities for each batch
        # (softmax across the columns for gt against all pred)
        pos_logits = torch.diag(sim)  # Positive samples are on the diagonal
        loss = -torch.log(pos_logits / sim.sum(dim=1)).mean()
        return loss

    def learn(self, cur_step, max_steps, batch_size, log_every=100):
        state, action, next_state, reward, done = self.env.sample_batch_transition(
            batch_size=batch_size
        )
        action_onehot = torch.zeros(batch_size, self.env.num_actions)
        action_onehot.scatter_(1, action, 1)

        # Forward pass
        z = self.encoder(state)
        next_z = self.encoder(next_state)
        d = self.transition_model(z, action_onehot)

        # Compute losses
        z[:, 0] = torch.remainder(z[:, 0], 2 * np.pi)
        next_z[:, 0] = torch.remainder(next_z[:, 0], 2 * np.pi)
        z_hat = z + d
        z_hat[:, 0] = torch.remainder(z_hat[:, 0], 2 * np.pi)
        radius = 1.0
        c = self._place_z_on_circle(z[:, 0], r=radius)
        c_hat = self._place_z_on_circle(z_hat[:, 0], r=radius)
        next_c = self._place_z_on_circle(next_z[:, 0], r=radius)
        s = z[:, 1:]
        s_hat = z_hat[:, 1:]
        next_s = next_z[:, 1:]
        x = torch.cat([c, s], dim=-1)
        x_hat = torch.cat([c_hat, s_hat], dim=-1)
        next_x = torch.cat([next_c, next_s], dim=-1)
        # print(z)
        # print(z_hat)
        # print(next_z)

        act0_indices = torch.where(action == 0)[0]
        act1_indices = torch.where(action == 1)[0]
        # print(act0_indices, act1_indices)
        # raise Exception

        # transition_loss = self._compute_prediction_loss(c_hat, next_c)
        # transition_loss += self._compute_prediction_loss(s_hat, next_s)
        transition_loss = self._compute_infonce_loss(
            x_hat[act0_indices], next_x[act0_indices]
        )
        transition_loss += self._compute_infonce_loss(
            x_hat[act1_indices], next_x[act1_indices]
        )

        entropy_loss = torch.tensor(0.0)
        # entropy_loss += self._compute_entropy_loss(
        #     c, c[torch.randperm(batch_size)], C=4.0
        # )
        # entropy_loss += self._compute_entropy_loss(
        # s, s[torch.randperm(batch_size)], C=4.0
        # )
        # entropy_loss += self._compute_entropy_loss(c, next_c, C=5.0)
        # entropy_loss += self._compute_entropy_loss(s[act1_indices], next_s[act1_indices], C=5.0)

        # Action 0 (turn right)
        act0_d = torch.masked_select(d, mask=(action == 0).expand(d.shape)).reshape(
            -1, d.shape[-1]
        )
        # Zeros enforcing at the dim 1, 2
        act0_zero_coords = [1, 2]
        identity_loss = self._compute_identity_loss(act0_d[:, act0_zero_coords])

        # Action 1 (go forward)
        act1_d = torch.masked_select(d, mask=(action == 1).expand(d.shape)).reshape(
            -1, d.shape[-1]
        )
        # Zeros enforcing at the dim 0
        act1_zero_coords = [0]
        identity_loss += self._compute_identity_loss(act1_d[:, act1_zero_coords])

        # Other losses
        hypercube_loss = self._compute_hypercube_loss(
            s, r=1.0
        ) + self._compute_hinge_loss(s, next_s, margin=0.1)

        # Total loss
        loss = transition_loss + entropy_loss + identity_loss + hypercube_loss

        # Backward pass
        self.encoder_optimizer.zero_grad()
        self.transition_optimizer.zero_grad()
        loss.backward()

        # Report gradient norm
        encoder_grad_norm = 0
        for param in self.encoder.parameters():
            if param.grad is not None:
                encoder_grad_norm += param.grad.norm().item() ** 2
        encoder_grad_norm = encoder_grad_norm**0.5

        transition_grad_norm = 0
        for param in self.transition_model.parameters():
            if param.grad is not None:
                transition_grad_norm += param.grad.norm().item() ** 2
        transition_grad_norm = transition_grad_norm**0.5

        self.encoder_optimizer.step()
        self.transition_optimizer.step()

        # Decay the learning rate
        # scale ranges from 1.0 (start) down to min_lr/initial_lr (end)
        min_lr = 1e-7
        fraction_done = cur_step / (max_steps - 1)
        linear_decay_scale = 1.0 - fraction_done * (1.0 - min_lr / self.encoder_lr)
        new_lr = self.encoder_lr * linear_decay_scale

        for param_group in self.encoder_optimizer.param_groups:
            param_group["lr"] = new_lr

        linear_decay_scale = 1.0 - fraction_done * (1.0 - min_lr / self.transition_lr)
        new_lr = self.transition_lr * linear_decay_scale

        for param_group in self.transition_optimizer.param_groups:
            param_group["lr"] = new_lr

        # Logging
        if cur_step % log_every == 0:
            self.writer.add_scalar("Loss/Total", loss.item(), cur_step)
            self.writer.add_scalar("Loss/Transition", transition_loss.item(), cur_step)
            self.writer.add_scalar("Loss/Entropy", entropy_loss.item(), cur_step)
            self.writer.add_scalar("Loss/Identity", identity_loss.item(), cur_step)
            self.writer.add_scalar("Loss/Hypercube", hypercube_loss.item(), cur_step)
            self.writer.add_scalar("GradNorm/Encoder", encoder_grad_norm, cur_step)
            self.writer.add_scalar(
                "GradNorm/Transition", transition_grad_norm, cur_step
            )

            # eval_hits_at_k, eval_mmr = self._measure_performance(self.env.disabled_transition)
            # self.writer.add_scalar("Eval/MMR", eval_mmr, cur_step)

            train_hits_at_k, train_mmr = self._measure_performance(
                self.env.enabled_transition
            )
            self.writer.add_scalar("Train/MMR", train_mmr, cur_step)

            print(
                f"Step: {cur_step}, Lr: {new_lr:.4f}, Loss: {loss.item():.4f}, Train/MMR {train_mmr:.4f}, Train/H@1: {train_hits_at_k[1]:.4f}, Train/H@5: {train_hits_at_k[5]:.4f}"
            )

            # Flush the stdout buffer so that the logs are written to the file
            sys.stdout.flush()

            # Log to a DataFrame
            new_row = pd.DataFrame(
                [
                    {
                        "Step": cur_step,
                        "Lr": new_lr,
                        "Loss": loss.item(),
                        "TransitionLoss": transition_loss.item(),
                        "EntropyLoss": entropy_loss.item(),
                        "IdentityLoss": identity_loss.item(),
                        "HypercubeLoss": hypercube_loss.item(),
                        "Train_MMR": train_mmr,
                        "Train_H@1": train_hits_at_k[1],
                        "Train_H@5": train_hits_at_k[5],
                        # "Eval_MMR": eval_mmr,
                        # "Eval_H@1": eval_hits_at_k[1],
                        # "Eval_H@5": eval_hits_at_k[5],
                    }
                ]
            )

            self.wm_df = pd.concat([self.wm_df, new_row], ignore_index=True)

    def _l2_distance(self, x, y):
        return torch.linalg.norm(x - y, ord=2, dim=-1)

    @torch.inference_mode()
    def _measure_performance(self, transitions):
        mmr = 0
        hits_at_k = defaultdict(float)
        hits_at_k[1] = 0
        hits_at_k[5] = 0
        hits_at_k[10] = 0

        if len(transitions) == 0:
            return hits_at_k, mmr

        batch_size = len(transitions)
        state, action, next_state, reward, done = (
            self.env._convert_transitions_to_tensors(transitions)
        )
        action_onehot = torch.zeros(batch_size, self.env.num_actions)
        action_onehot.scatter_(1, action, 1)

        z = self.encoder(state)
        next_z = self.encoder(next_state)
        d = self.transition_model(z, action_onehot)
        z[:, 0] = torch.remainder(z[:, 0], 2 * np.pi)
        next_z[:, 0] = torch.remainder(next_z[:, 0], 2 * np.pi)
        z_hat = z + d
        z_hat[:, 0] = torch.remainder(z_hat[:, 0], 2 * np.pi)
        radius = 1.0
        c = self._place_z_on_circle(z[:, 0], r=radius)
        c_hat = self._place_z_on_circle(z_hat[:, 0], r=radius)
        next_c = self._place_z_on_circle(next_z[:, 0], r=radius)
        s = z[:, 1:]
        s_hat = z_hat[:, 1:]
        next_s = next_z[:, 1:]
        x = torch.cat([c, s], dim=-1)
        x_hat = torch.cat([c_hat, s_hat], dim=-1)
        next_x = torch.cat([next_c, next_s], dim=-1)

        # enable_transitions = self.env.enabled_transition
        # _, _, train_next_states = self.env._convert_transitions_to_tensors(enable_transitions)
        # train_next_z = self.encoder(train_next_states)
        # train_next_z[:, 0] = torch.remainder(train_next_z[:, 0], 2 * np.pi)
        # train_next_c = self._place_z_on_circle(train_next_z[:, 0], r=radius)
        # train_next_s = train_next_z[:, 1:]
        # train_next_x = torch.cat([train_next_c, train_next_s], dim=-1)

        # Compute the ranking-based score
        # Ref: https://arxiv.org/pdf/1911.12247 (Appendix C)
        # Metric:
        #   - Hits @ Rank K
        #   - Mean Reciprocal Rank (MMR)
        dist_matrix = torch.cdist(x_hat, next_x, p=2)
        dist_matrix_diag = torch.diag(dist_matrix).unsqueeze(-1)
        dist_matrix_augmented = torch.cat([dist_matrix_diag, dist_matrix], dim=1)
        sorted_indices = torch.argsort(dist_matrix_augmented, dim=1, stable=True)

        for i in range(len(transitions)):
            # Compute the rank of the ground truth
            rank = torch.where(sorted_indices[i] == 0)[0].item()

            # Compute the hits at rank k
            for k in [1, 5, 10]:
                if rank < k:
                    hits_at_k[k] += 1

            # Compute the reciprocal rank
            mmr += 1.0 / (rank + 1)

        for k in [1, 5, 10]:
            hits_at_k[k] /= len(transitions)

        mmr /= len(transitions)
        return hits_at_k, mmr

    def evaluate(self, cur_step):
        self.env.draw_learned_latent(self, cur_step)

    def train(
        self,
        max_steps=100000,
        eval_every=1000,
        batch_size=32,
        log_dir=None,
        log_every=100,
    ):
        if log_dir is not None:
            # Create the log folder if it does not exist
            timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
            self.log_dir = os.path.join(log_dir, timestamp)
            # Create the log file
            os.makedirs(self.log_dir, exist_ok=True)
            self.log_file = open(os.path.join(self.log_dir, "output.log"), "w")
            sys.stdout = self.log_file
            sys.stderr = self.log_file
            self.writer = SummaryWriter(log_dir=self.log_dir)

            # Create a DataFrame to log the metrics
            self.wm_df = pd.DataFrame(
                columns=[
                    "Step",
                    "Lr",
                    "Loss",
                    "TransitionLoss",
                    "EntropyLoss",
                    "IdentityLoss",
                    "HypercubeLoss",
                    "Train_MMR",
                    "Train_H@1",
                    "Train_H@5",
                    # "Eval_MMR",
                    # "Eval_H@1",
                    # "Eval_H@5",
                ]
            )

        for step in range(max_steps):
            if step % eval_every == 0:
                self.evaluate(step)

            self.learn(
                step, max_steps=max_steps, batch_size=batch_size, log_every=log_every
            )

            # Save the log DataFrame
            self.wm_df.to_csv(os.path.join(self.log_dir, "wm_log.csv"), index=False)

    def q_learning(self, step, batch_size=32, gamma=0.9):
        state, action, next_state, reward, done = self.env.sample_batch_transition(
            batch_size=batch_size, enabled_invalid=True
        )

        with torch.no_grad():
            z = self.encoder(state)
            next_z = self.encoder(next_state)
            z[:, 0] = torch.remainder(z[:, 0], 2 * np.pi)
            next_z[:, 0] = torch.remainder(next_z[:, 0], 2 * np.pi)
            c = self._place_z_on_circle(z[:, 0], r=1.0)
            next_c = self._place_z_on_circle(next_z[:, 0], r=1.0)
            abs_state = torch.cat([c, z[:, 1:]], dim=-1)
            abs_next_state = torch.cat([next_c, next_z[:, 1:]], dim=-1)

        q_target_next = (
            self.q_network_target(abs_next_state).detach().max(dim=1)[0].unsqueeze(1)
        )
        q_target = reward + (1 - done) * gamma * q_target_next
        q_expected = self.q_network(abs_state).gather(1, action)
        loss = nn.functional.mse_loss(q_expected, q_target)
        self.q_optimizer.zero_grad()
        loss.backward()
        self.q_optimizer.step()

        # Soft update the target network
        tau = 1e-3
        for target_param, param in zip(
            self.q_network_target.parameters(), self.q_network.parameters()
        ):
            target_param.data.copy_((1 - tau) * target_param.data + tau * param.data)

        # Evaluate the Q-learning performance
        eval_reward = self.evaluate_q_learing(step)

        print(
            f"Q-learning step: {step}, Loss: {loss.item():.4f}, Eval Reward: {eval_reward:.4f}"
        )

        # Log to a DataFrame
        new_row = pd.DataFrame(
            [
                {
                    "Step": step,
                    "Loss": loss.item(),
                    "Eval_Reward": eval_reward,
                }
            ]
        )
        self.dqn_df = pd.concat([self.dqn_df, new_row], ignore_index=True)

    @torch.inference_mode()
    def _q_learning_act(self, state):
        with torch.no_grad():
            onehot_state = self.env._onehot_encode(state).unsqueeze(0)
            z = self.encoder(onehot_state)
            z[:, 0] = torch.remainder(z[:, 0], 2 * np.pi)
            c = self._place_z_on_circle(z[:, 0], r=1.0)
            abs_state = torch.cat([c, z[:, 1:]], dim=-1)
            q_values = self.q_network(abs_state)
            action = q_values.argmax(dim=1).item()
        return action

    def evaluate_q_learing(self, step, max_len=100):
        cur_state = [0, 0, 0]  # Start from the top-left corner
        total_reward = 0
        for _ in range(max_len):
            action = self._q_learning_act(cur_state)
            next_state = cur_state.copy()
            reward = -1
            done = False

            if action == 0:
                next_state[2] = (cur_state[2] + 1) % 4
            else:
                direction = cur_state[2]
                new_x, new_y = (
                    cur_state[0] + self.env.move_actions[direction][0],
                    cur_state[1] + self.env.move_actions[direction][1],
                )
                if 0 <= new_x < self.env.size and 0 <= new_y < self.env.size:
                    if [new_x, new_y] == self.env.goal_state:
                        reward = 0
                        done = True
                    next_state[0] = new_x
                    next_state[1] = new_y

            total_reward += reward
            cur_state = next_state
            if done:
                break
        return total_reward

    def train_dqn(self, log_dir, max_steps=100, batch_size=32, gamma=0.97):
        self.encoder.eval()
        self.dqn_df = pd.DataFrame(
            columns=[
                "Step",
                "Loss",
                "Eval_Reward",
            ]
        )
        for step in range(max_steps):
            self.q_learning(step=step, batch_size=batch_size, gamma=gamma)

        if self.log_dir is not None:
            self.log_file.close()
            self.writer.close()

        self.dqn_df.to_csv(os.path.join(self.log_dir, "dqn_log.csv"), index=False)


class MiniGridEnvV2:
    def __init__(self, size=5, seed=42, pct_transitions_disabled=0.1):
        print("MiniGridEnvV2")
        self.size = size
        self.agent_state = [0, 0, 0]  # [x, y, direction]
        self.move_actions = [[-1, 0], [0, 1], [1, 0], [0, -1]]
        self.num_actions = 2  # 0: Turn right, 1: Move forward
        self.goal_state = [size - 1, size - 1]  # [x, y, _]
        self.directions = ["N", "E", "S", "W"]
        self.seed = seed
        random.seed(self.seed)  # Fixed seed

        self.all_states = self._get_all_states()
        self.valid_transitions = self._get_all_valid_transitions()
        print("Number of valid transitions: ", len(self.valid_transitions))

        self.disabled_transition_indices = self._get_disabled_indices(
            pct_transitions_disabled=pct_transitions_disabled
        )
        self.disabled_transition = [
            self.valid_transitions[i] for i in self.disabled_transition_indices
        ]
        self.enabled_transition_indices = [
            i
            for i in range(len(self.valid_transitions))
            if i not in self.disabled_transition_indices
        ]
        self.enabled_transition = [
            self.valid_transitions[i] for i in self.enabled_transition_indices
        ]
        print("Number of disabled transitions: ", len(self.disabled_transition_indices))

    def _get_disabled_indices(self, pct_transitions_disabled=0.1):
        # Sample a random subset of transitions to disable

        num_transitions = len(self.valid_transitions)
        num_transitions_disabled = int(pct_transitions_disabled * num_transitions)
        disabled_indices = random.sample(
            range(num_transitions), num_transitions_disabled
        )
        return disabled_indices

    def _get_all_states(self):
        all_states = []
        for x in range(self.size):
            for y in range(self.size):
                for direction in range(4):
                    all_states.append([x, y, direction])
        return all_states

    def _get_all_valid_transitions(self):
        valid_transitions = []
        for x in range(self.size):
            for y in range(self.size):

                # Attempt to turn right
                for direction in range(4):
                    reward = -1
                    done = False
                    if [x, y] == self.goal_state:
                        reward = 0
                        done = True
                    valid_transitions.append(
                        (
                            [x, y, direction],
                            0,
                            [x, y, (direction + 1) % 4],
                            reward,
                            done,
                        )
                    )

                # Attempt to move forward
                for direction in range(4):
                    reward = -1
                    done = False
                    if [x, y] == self.goal_state:
                        reward = 0
                        done = True
                    new_x, new_y = (
                        x + self.move_actions[direction][0],
                        y + self.move_actions[direction][1],
                    )

                    if 0 <= new_x < self.size and 0 <= new_y < self.size:
                        if [new_x, new_y] == self.goal_state:
                            reward = 0
                            done = True
                        else:
                            reward = -1
                            done = False

                        valid_transitions.append(
                            (
                                [x, y, direction],
                                1,
                                [new_x, new_y, direction],
                                reward,
                                done,
                            )
                        )
                    else:
                        # If the agent tries to move out of bounds, it stays in the same position
                        valid_transitions.append(
                            ([x, y, direction], 1, [x, y, direction], reward, done)
                        )

        return valid_transitions

    def sample_transition(self, k=1, enabled_invalid=False):
        # Randomly sample k transitions with replacement from the non-disabled transitions
        if enabled_invalid:
            return random.choices(self.valid_transitions, k=k)
        else:
            return [
                self.valid_transitions[i]
                for i in random.choices(self.enabled_transition_indices, k=k)
            ]

    def _onehot_encode(self, state):
        # One-hot encode the state with vector size = 2 * size + 4
        onehot_state = torch.zeros(self.size * 2 + 4)
        onehot_state[state[0]] = 1
        onehot_state[self.size + state[1]] = 1
        onehot_state[self.size * 2 + state[2]] = 1
        return onehot_state.float()

    def _convert_transitions_to_tensors(self, transitions):
        # Convert the transitions to tensors
        state = torch.stack(
            [self._onehot_encode(transition[0]) for transition in transitions]
        )
        action = torch.tensor(
            [transition[1] for transition in transitions], dtype=torch.long
        )
        next_state = torch.stack(
            [self._onehot_encode(transition[2]) for transition in transitions]
        )
        reward = torch.tensor(
            [transition[3] for transition in transitions], dtype=torch.float
        )
        done = torch.tensor(
            [transition[4] for transition in transitions], dtype=torch.float
        )
        action = action.view(-1, 1)
        reward = reward.view(-1, 1)
        done = done.view(-1, 1)

        return state, action, next_state, reward, done

    def sample_batch_transition(self, batch_size=8, enabled_invalid=False):
        # Randomly sample a batch of transitions
        sampled_transitions = self.sample_transition(
            k=batch_size, enabled_invalid=enabled_invalid
        )
        return self._convert_transitions_to_tensors(sampled_transitions)

    @torch.inference_mode()
    def draw_learned_latent(self, world_model, cur_step):
        colors = [
            "tab:blue",
            "tab:orange",
            "tab:green",
            "tab:red",
            "tab:purple",
            "tab:brown",
            "tab:pink",
            "tab:gray",
            "tab:olive",
            "tab:cyan",
        ]

        state = torch.stack([self._onehot_encode(s) for s in self.all_states])
        z = world_model.encoder(state).detach().numpy()
        z[:, 0] = np.remainder(z[:, 0], 2 * np.pi)
        s = z[:, 1:]

        # Visualize the first coordinate of the latent space
        style = "Simple, tail_width=0.1, head_width=3, head_length=5"
        kw = dict(arrowstyle=style, color="k", linewidth=2)
        disabled_kw = dict(arrowstyle=style, color="r", linewidth=2)
        fig, ax = plt.subplots(self.size, self.size, figsize=(15, 15), dpi=200)

        for i in range(self.size):
            for j in range(self.size):
                ax[i, j].set_title(f"Agent at position ({i+1}, {j+1})", fontsize=14)
                ax[i, j].set_xticks([])
                ax[i, j].set_yticks([])
                ax[i, j].set_ylim((-0.5, 0.5))
                ax[i, j].axhline(y=0, color="black", linestyle="-")
                ax[i, j].axvline(x=0, color="black", linestyle="-")
                ax[i, j].axvline(x=2 * np.pi, color="black", linestyle="--")
                ax[i, j].text(
                    2 * np.pi - 0.4, 0.1, r"$x=2*\pi$", rotation=90, fontsize=12
                )
                color_id = (i + j) % len(colors)

                cur_z = []
                cur_dir = []
                for state in self.all_states:
                    x, y, direction = state
                    if x == i and y == j:
                        cur_z.append(self._onehot_encode(state))
                        cur_dir.append(direction)

                cur_z = torch.vstack(cur_z)
                cur_z = world_model.encoder(cur_z)
                cur_z[:, 0] = torch.remainder(cur_z[:, 0], 2 * np.pi)
                action_onehot = torch.zeros(len(cur_z), self.num_actions)
                action_onehot[:, 0] = 1
                d = world_model.transition_model(cur_z, action_onehot)
                next_z = cur_z + d
                next_z[:, 0] = torch.remainder(next_z[:, 0], 2 * np.pi)

                for k in range(4):
                    x = cur_z[k, 0]
                    next_x_no_mod = cur_z[k, 0] + d[k, 0]
                    next_x = next_z[k, 0]
                    ax[i, j].plot(x, 0, "o", color=colors[color_id])
                    ax[i, j].annotate(
                        text=f"({self.directions[cur_dir[k]]})",
                        xy=(x, 0.05),
                        fontsize=12,
                    )

                    state = [i, j, cur_dir[k]]
                    action = 0
                    next_state = [i, j, (cur_dir[k] + 1) % 4]
                    reward = 0 if next_state[:2] == self.goal_state else -1
                    done = True if reward == 0 else False
                    ax[i, j].add_patch(
                        patches.FancyArrowPatch(
                            (x, 0),
                            (next_x_no_mod, 0),
                            connectionstyle="arc3, rad=-.5",
                            **(
                                kw
                                if (state, action, next_state, reward, done)
                                not in self.disabled_transition
                                else disabled_kw
                            ),
                        )
                    )
                    ax[i, j].add_patch(
                        patches.FancyArrowPatch(
                            (x, 0),
                            (next_x, 0),
                            connectionstyle="arc3, rad=-.5",
                            alpha=0.2,
                            **(
                                kw
                                if (state, action, next_state, reward, done)
                                not in self.disabled_transition
                                else disabled_kw
                            ),
                        )
                    )

        plt.tight_layout()
        plt.savefig(os.path.join(world_model.log_dir, f"c_{cur_step}.pdf"), dpi=200)

        fig, ax = plt.subplots(1, 1, figsize=(8, 8), dpi=200)
        annotated_states = []
        for i in range(len(self.all_states)):
            x, y = self.all_states[i][0], self.all_states[i][1]
            direction = self.all_states[i][2]
            color_id = (x + y) % len(colors)

            z = self._onehot_encode(self.all_states[i])
            z = world_model.encoder(z.unsqueeze(0))
            action_onehot = torch.zeros(1, self.num_actions)
            action_onehot[:, 1] = 1
            d = world_model.transition_model(z, action_onehot)
            next_z = z + d

            # state_id = x * self.size + y
            # state_color = colors[state_id]
            ax.plot(s[i, 0], s[i, 1], marker="o", color=colors[color_id])

            if (x, y) not in annotated_states:
                ax.annotate(
                    text=f"({x}, {y})",
                    xy=(s[i, 0], s[i, 1]),
                )
                annotated_states.append((x, y))

            # If going forward is valid, draw the arrow
            next_x, next_y = (
                x + self.move_actions[direction][0],
                y + self.move_actions[direction][1],
            )

            state = [x, y, direction]
            action = 1
            next_state = [next_x, next_y, direction]
            if 0 <= next_x < self.size and 0 <= next_y < self.size:
                ax.arrow(
                    z[0, 1],
                    z[0, 2],
                    d[0, 1],
                    d[0, 2],
                    alpha=1,
                    head_width=0.03,
                    head_length=0.05,
                    linewidth=2,
                    color=(
                        "red"
                        if (state, action, next_state) in self.disabled_transition
                        else "black"
                    ),
                )

        plt.tight_layout()
        plt.savefig(os.path.join(world_model.log_dir, f"s_{cur_step}.pdf"), dpi=200)


def main(
    seed=42,
    size=5,
    hidden_dim=32,
    batch_size=64,
    max_steps=50000,
    encoder_lr=1e-4,
    transition_lr=1e-4,
    momentum=0.9,
    weight_decay=0,
    pct_transitions_disabled=0.1,
    dqn_steps=10000,
    gamma=0.97,
    q_lr=1e-4,
    log_dir="logs/minigrid_v2_generalization",
):
    env = MiniGridEnvV2(
        size=size, seed=42, pct_transitions_disabled=pct_transitions_disabled
    )
    disabled_indicies = env.disabled_transition_indices
    print(
        f"Number of disabled transitions: {len(disabled_indicies)}/{len(env.valid_transitions)}"
    )
    for i in disabled_indicies:
        state, action, next_state, reward, done = env.valid_transitions[i]
        print(state, action, next_state, reward, done)
    # print(f"Number of enabled transitions: {len(env.enabled_transition_indices)}/{len(env.valid_transitions)}")
    # for i in range(len(env.valid_transitions)):
    #     state, action, next_state, reward, done = env.valid_transitions[i]
    #     print(state, "forward" if action == 1 else "turn right", next_state, reward, done)
    # raise Exception

    set_seed(seed)
    world_model = WorldModel(
        env=env,
        latent_dim=3,
        hidden_dim=hidden_dim,
        encoder_lr=encoder_lr,
        transition_lr=transition_lr,
        momentum=momentum,
        weight_decay=weight_decay,
        q_lr=q_lr,
    )
    world_model.train(
        max_steps=max_steps,
        eval_every=1000,
        batch_size=batch_size,
        log_dir=log_dir,
        log_every=100,
    )
    world_model.train_dqn(
        log_dir=log_dir,
        max_steps=dqn_steps,
        batch_size=batch_size,
        gamma=gamma,
    )


if __name__ == "__main__":
    from fire import Fire

    Fire(main)
