import sys

sys.dont_write_bytecode = True

import os
import math
import datetime
import numpy as np
import random
import pandas as pd
from collections import defaultdict

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter

import matplotlib.pyplot as plt
import matplotlib.patches as patches
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.patches import FancyArrowPatch
from mpl_toolkits.mplot3d import proj3d


# Enable LaTeX rendering
# plt.rcParams['text.usetex'] = True
# plt.rcParams['font.family'] = 'serif'

# Adjust rcParams for custom background colors
plt.rcParams["figure.facecolor"] = "#ffffff"  # White for figure background
plt.rcParams["axes.facecolor"] = "#ffffff"  # White for axes background
plt.rcParams["axes.edgecolor"] = "black"  # Black border around axes
plt.rcParams["grid.color"] = "#d3d3d3"  # Light gray grid lines
plt.rcParams["grid.alpha"] = 0.8

plt.rcParams["lines.linewidth"] = 0.5
plt.rcParams["axes.edgecolor"] = "black"
plt.rcParams["axes.linewidth"] = 0.5
# Set global font size and weight
plt.rcParams["font.size"] = 14  # Set font size to 11
plt.rcParams["font.weight"] = "bold"  # Set font weight to bold

# Apply to axes titles, labels, and ticks
plt.rcParams["axes.titlesize"] = 14
plt.rcParams["axes.titleweight"] = "bold"
plt.rcParams["axes.labelsize"] = 14
plt.rcParams["axes.labelweight"] = "bold"
plt.rcParams["xtick.labelsize"] = 14
plt.rcParams["ytick.labelsize"] = 14

# Apply to figure titles
plt.rcParams["figure.titlesize"] = 16
plt.rcParams["figure.titleweight"] = "bold"

# Add tight spacing for ICML figures
# plt.rcParams['figure.autolayout'] = True  # Avoid clipping and wasted space


def set_seed(seed):
    # random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)


def count_learnable_parameters(model: nn.Module) -> int:
    """
    Calculate the total number of learnable parameters in a PyTorch model.

    Args:
        model (nn.Module): The PyTorch model.

    Returns:
        int: Total number of learnable parameters.
    """
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(Encoder, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)
        self.g = nn.Tanh()
        # self.g = nn.Mish()

        for m in self.modules():
            if isinstance(m, nn.Linear):
                # nn.init.normal_(m.weight, mean=0, std=0.2)
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.fc1(x)
        x = self.g(x)
        x = self.fc2(x)
        x = self.g(x)
        x = self.fc3(x)
        return x


class TransitionModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(TransitionModel, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.g = nn.Tanh()
        # self.g = nn.Mish()

        for m in self.modules():
            if isinstance(m, nn.Linear):
                # nn.init.normal_(m.weight, mean=0, std=0.2)
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)

    def forward(self, x, a):
        x = torch.cat([x, a], dim=-1)
        x = self.fc1(x)
        x = self.g(x)
        x = self.fc2(x)
        return x


class WorldModel:
    def __init__(
        self,
        env,
        latent_dim=2,
        hidden_dim=32,
        encoder_lr=3e-4,
        transition_lr=1e-4,
        momentum=0.9,
        weight_decay=0,
    ):
        self.env = env
        self.latent_dim = latent_dim
        self.hidden_dim = hidden_dim
        self.encoder_lr = encoder_lr
        self.transition_lr = transition_lr

        self.encoder = Encoder(
            input_dim=2 * env.size, hidden_dim=hidden_dim, output_dim=latent_dim
        )
        self.transition_model = TransitionModel(
            input_dim=latent_dim + env.num_actions,
            hidden_dim=hidden_dim,
            output_dim=latent_dim,
        )
        self.encoder_optimizer = optim.RMSprop(
            self.encoder.parameters(),
            lr=encoder_lr,
            momentum=momentum,
            weight_decay=weight_decay,
        )
        self.transition_optimizer = optim.RMSprop(
            self.transition_model.parameters(),
            lr=transition_lr,
            momentum=momentum,
            weight_decay=weight_decay,
        )
        # self.encoder_optimizer = optim.Adafactor(
        #     self.encoder.parameters(),
        #     lr=encoder_lr,
        #     weight_decay=weight_decay,
        # )
        # self.transition_optimizer = optim.Adafactor(
        #     self.transition_model.parameters(),
        #     lr=transition_lr,
        #     weight_decay=weight_decay,
        # )
        # self.encoder_optimizer = SOAP(
        #     self.encoder.parameters(),
        #     lr=encoder_lr,
        #     weight_decay=weight_decay,
        # )
        # self.transition_optimizer = SOAP(
        #     self.transition_model.parameters(),
        #     lr=transition_lr,
        #     weight_decay=weight_decay,
        # )

        # Print the total number of learnable parameters
        print(
            "Number of learnable parameters in Encoder: ",
            count_learnable_parameters(self.encoder),
        )
        print(
            "Number of learnable parameters in TransitionModel: ",
            count_learnable_parameters(self.transition_model),
        )
        print(
            "Total number of learnable parameters: ",
            count_learnable_parameters(self.encoder)
            + count_learnable_parameters(self.transition_model),
        )

    def _place_z_on_circle(self, z, r=1.0):
        x = torch.cos(z) * r
        y = torch.sin(z) * r
        return torch.cat([x.unsqueeze(1), y.unsqueeze(1)], dim=1)

    def _compute_prediction_loss(self, state, state_hat):
        l2_loss = torch.linalg.norm(state - state_hat, ord=2, dim=-1).mean()
        return l2_loss
        # l1_loss = (state - state_hat).abs().mean()
        # return l1_loss

    def _compute_entropy_loss(self, state_x, state_y, C=5.0):
        l2_distance = torch.linalg.norm(state_x - state_y, ord=2, dim=-1)
        entropy_loss = (-C * l2_distance).exp().mean()
        return entropy_loss

    def _compute_identity_loss(self, x):
        l1_loss = x.abs().mean()
        return l1_loss
        # l2_loss = torch.linalg.norm(x, ord=2, dim=-1).mean()
        # return l2_loss

    def _compute_hinge_loss(self, state_x, state_y, margin=1.0):
        l2_distance = torch.linalg.norm(state_x - state_y, ord=2, dim=-1)
        hinge_loss = (
            torch.max(torch.zeros_like(l2_distance), l2_distance - margin)
            .clamp(0, 100)
            .mean()
        )
        return hinge_loss

    def _compute_hypercube_loss(self, state, r=1.0):
        linf_distance = torch.linalg.norm(state.abs(), ord=float("inf"), dim=-1)
        hypercube_loss = (
            torch.max(torch.zeros_like(linf_distance), linf_distance - r)
            .clamp(0, 100)
            .mean()
        )
        return hypercube_loss

    def _compute_infonce_loss(self, pred, gt, temp=1.0):
        # InfoNCE loss
        # pred, gt shape is b x d
        # Compute a pair-wise similarity matrix based on L2 distance
        sim = (torch.cdist(pred, gt, p=2).neg() / temp).exp()
        # Compute logits by dividing each similarity score by the sum of all similarities for each batch
        # (softmax across the columns for gt against all pred)
        pos_logits = torch.diag(sim)  # Positive samples are on the diagonal
        loss = -torch.log(pos_logits / sim.sum(dim=1)).mean()
        return loss

    def learn(self, cur_step, max_steps, batch_size, log_every=100):
        state, action, next_state = self.env.sample_batch_transition(
            batch_size=batch_size
        )
        action_onehot = torch.zeros(batch_size, self.env.num_actions)
        action_onehot.scatter_(1, action, 1)

        # Forward pass
        z = self.encoder(state)
        next_z = self.encoder(next_state)
        d = self.transition_model(z, action_onehot)
        z_hat = z + d

        transition_loss = self._compute_infonce_loss(z_hat, next_z)
        hypercube_loss = self._compute_hypercube_loss(
            z_hat, r=1.0
        ) + self._compute_hinge_loss(z, next_z, margin=0.3)
        randperm_z = torch.randperm(z.size(0))
        entropy_loss = self._compute_entropy_loss(z, z[randperm_z, :], C=5.0)
        identity_loss = torch.tensor(0.0)

        # Total loss
        loss = transition_loss + entropy_loss + identity_loss + hypercube_loss

        # Backward pass
        self.encoder_optimizer.zero_grad()
        self.transition_optimizer.zero_grad()
        loss.backward()

        # Report gradient norm
        encoder_grad_norm = 0
        for param in self.encoder.parameters():
            if param.grad is not None:
                encoder_grad_norm += param.grad.norm().item() ** 2
        encoder_grad_norm = encoder_grad_norm**0.5

        transition_grad_norm = 0
        for param in self.transition_model.parameters():
            if param.grad is not None:
                transition_grad_norm += param.grad.norm().item() ** 2
        transition_grad_norm = transition_grad_norm**0.5

        self.encoder_optimizer.step()
        self.transition_optimizer.step()

        # Decay the learning rate
        # scale ranges from 1.0 (start) down to min_lr/initial_lr (end)
        min_lr = 1e-7
        fraction_done = cur_step / (max_steps - 1)
        linear_decay_scale = 1.0 - fraction_done * (1.0 - min_lr / self.encoder_lr)
        new_lr = self.encoder_lr * linear_decay_scale

        for param_group in self.encoder_optimizer.param_groups:
            param_group["lr"] = new_lr

        linear_decay_scale = 1.0 - fraction_done * (1.0 - min_lr / self.transition_lr)
        new_lr = self.transition_lr * linear_decay_scale

        for param_group in self.transition_optimizer.param_groups:
            param_group["lr"] = new_lr

        # Logging
        if cur_step % log_every == 0:
            self.writer.add_scalar("Loss/Total", loss.item(), cur_step)
            self.writer.add_scalar("Loss/Transition", transition_loss.item(), cur_step)
            self.writer.add_scalar("Loss/Entropy", entropy_loss.item(), cur_step)
            self.writer.add_scalar("Loss/Identity", identity_loss.item(), cur_step)
            self.writer.add_scalar("Loss/Hypercube", hypercube_loss.item(), cur_step)
            self.writer.add_scalar("GradNorm/Encoder", encoder_grad_norm, cur_step)
            self.writer.add_scalar(
                "GradNorm/Transition", transition_grad_norm, cur_step
            )

            eval_hits_at_k, eval_mmr = self._measure_performance(
                self.env.disabled_transition
            )
            self.writer.add_scalar("Eval/MMR", eval_mmr, cur_step)

            train_hits_at_k, train_mmr = self._measure_performance(
                self.env.enabled_transition
            )
            self.writer.add_scalar("Train/MMR", train_mmr, cur_step)

            print(
                f"Step: {cur_step}, Lr: {new_lr:.4f}, Loss: {loss.item():.4f}, Train/MMR {train_mmr:.4f}, Train/H@1: {train_hits_at_k[1]:.4f}, Train/H@5: {train_hits_at_k[5]:.4f}, Eval/MMR: {eval_mmr:.4f}, Eval/H@1: {eval_hits_at_k[1]:.4f}, Eval/H@5: {eval_hits_at_k[5]:.4f}"
            )

            # Flush the stdout buffer so that the logs are written to the file
            sys.stdout.flush()

            # Log to a DataFrame
            new_row = pd.DataFrame(
                [
                    {
                        "Step": cur_step,
                        "Lr": new_lr,
                        "Loss": loss.item(),
                        "TransitionLoss": transition_loss.item(),
                        "EntropyLoss": entropy_loss.item(),
                        "IdentityLoss": identity_loss.item(),
                        "HypercubeLoss": hypercube_loss.item(),
                        "Train_MMR": train_mmr,
                        "Train_H@1": train_hits_at_k[1],
                        "Train_H@5": train_hits_at_k[5],
                        "Eval_MMR": eval_mmr,
                        "Eval_H@1": eval_hits_at_k[1],
                        "Eval_H@5": eval_hits_at_k[5],
                    }
                ]
            )

            self.df = pd.concat([self.df, new_row], ignore_index=True)

    def _l2_distance(self, x, y):
        return torch.linalg.norm(x - y, ord=2, dim=-1)

    @torch.inference_mode()
    def _measure_performance(self, transitions):
        batch_size = len(transitions)
        state, action, next_state = self.env._convert_transitions_to_tensors(
            transitions
        )
        action_onehot = torch.zeros(batch_size, self.env.num_actions)
        action_onehot.scatter_(1, action, 1)

        z = self.encoder(state)
        next_z = self.encoder(next_state)
        d = self.transition_model(z, action_onehot)

        z_hat = z + d
        x_hat = z_hat
        next_x = next_z

        # enable_transitions = self.env.enabled_transition
        # _, _, train_next_states = self.env._convert_transitions_to_tensors(enable_transitions)
        # train_next_z = self.encoder(train_next_states)
        # train_next_x = train_next_z

        # Compute the ranking-based score
        # Ref: https://arxiv.org/pdf/1911.12247 (Appendix C)
        # Metric:
        #   - Hits @ Rank K
        #   - Mean Reciprocal Rank (MMR)
        dist_matrix = torch.cdist(x_hat, next_x, p=2)
        dist_matrix_diag = torch.diag(dist_matrix).unsqueeze(-1)
        dist_matrix_augmented = torch.cat([dist_matrix_diag, dist_matrix], dim=1)
        sorted_indices = torch.argsort(dist_matrix_augmented, dim=1, stable=True)

        mmr = 0
        hits_at_k = defaultdict(float)
        for i in range(len(transitions)):
            # Compute the rank of the ground truth
            rank = torch.where(sorted_indices[i] == 0)[0].item()

            # Compute the hits at rank k
            for k in [1, 5, 10]:
                if rank < k:
                    hits_at_k[k] += 1

            # Compute the reciprocal rank
            mmr += 1.0 / (rank + 1)

        for k in [1, 5, 10]:
            hits_at_k[k] /= len(transitions)

        mmr /= len(transitions)
        return hits_at_k, mmr

    def evaluate(self, cur_step):
        self.env.draw_learned_latent(self, cur_step)

    def train(
        self,
        max_steps=100000,
        eval_every=1000,
        batch_size=32,
        log_dir=None,
        log_every=100,
    ):
        if log_dir is not None:
            # Create the log folder if it does not exist
            timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
            self.log_dir = os.path.join(log_dir, timestamp)
            # Create the log file
            os.makedirs(self.log_dir, exist_ok=True)
            self.log_file = open(os.path.join(self.log_dir, "output.log"), "w")
            sys.stdout = self.log_file
            sys.stderr = self.log_file
            self.writer = SummaryWriter(log_dir=self.log_dir)

            # Create a DataFrame to log the metrics
            self.df = pd.DataFrame(
                columns=[
                    "Step",
                    "Lr",
                    "Loss",
                    "TransitionLoss",
                    "EntropyLoss",
                    "IdentityLoss",
                    "HypercubeLoss",
                    "Train_MMR",
                    "Train_H@1",
                    "Train_H@5",
                    "Eval_MMR",
                    "Eval_H@1",
                    "Eval_H@5",
                ]
            )

        for step in range(max_steps):
            if step % eval_every == 0:
                self.evaluate(step)

            self.learn(
                step, max_steps=max_steps, batch_size=batch_size, log_every=log_every
            )

        if log_dir is not None:
            self.log_file.close()
            self.writer.close()

            # Save the log DataFrame
            self.df.to_csv(os.path.join(self.log_dir, "log.csv"), index=False)


class TorusEnvV2:
    def __init__(self, size=5, seed=42, pct_transitions_disabled=0.1):
        print("TorusEnvV2")
        self.size = size
        self.agent_state = [0, 0]  # [x, y]
        self.move_actions = [[0, 1], [1, 0]]
        self.num_actions = 2  # 0: Go right, 1: Go down
        self.seed = seed
        random.seed(self.seed)  # Fixed seed

        self.all_states = self._get_all_states()
        self.valid_transitions = self._get_all_valid_transitions()
        print("Number of valid transitions: ", len(self.valid_transitions))

        self.disabled_transition_indices = self._get_disabled_indices(
            pct_transitions_disabled=pct_transitions_disabled
        )
        self.disabled_transition = [
            self.valid_transitions[i] for i in self.disabled_transition_indices
        ]
        self.enabled_transition_indices = [
            i
            for i in range(len(self.valid_transitions))
            if i not in self.disabled_transition_indices
        ]
        self.enabled_transition = [
            self.valid_transitions[i] for i in self.enabled_transition_indices
        ]
        print("Number of disabled transitions: ", len(self.disabled_transition_indices))

    def _get_disabled_indices(self, pct_transitions_disabled=0.1):
        # Sample a random subset of transitions to disable
        num_transitions = len(self.valid_transitions)
        num_transitions_disabled = int(pct_transitions_disabled * num_transitions)
        disabled_indices = random.sample(
            range(num_transitions), num_transitions_disabled
        )
        return disabled_indices

    def _get_all_states(self):
        all_states = []
        for x in range(self.size):
            for y in range(self.size):
                all_states.append([x, y])
        return all_states

    def _get_all_valid_transitions(self):
        valid_transitions = []
        for x in range(self.size):
            for y in range(self.size):
                # Assume that the agent state is [x, y]

                # Attempt to do action 0 (go right)
                valid_transitions.append(([x, y], 0, [x, (y + 1) % self.size]))

                # Attempt to do action 1 (go down)
                valid_transitions.append(([x, y], 1, [(x + 1) % self.size, y]))

        return valid_transitions

    def sample_transition(self, k=1):
        # Randomly sample k transitions with replacement from the non-disabled transitions
        # return random.choices(self.valid_transitions, k=k)
        return [
            self.valid_transitions[i]
            for i in random.choices(self.enabled_transition_indices, k=k)
        ]

    def _onehot_encode(self, state):
        # One-hot encode the state with vector size = 2 * size
        onehot_state = torch.zeros(self.size * 2)
        onehot_state[state[0]] = 1
        onehot_state[self.size + state[1]] = 1
        return onehot_state.float()

    def _convert_transitions_to_tensors(self, transitions):
        # Convert the transitions to tensors
        state = torch.stack(
            [self._onehot_encode(transition[0]) for transition in transitions]
        )
        action = torch.tensor(
            [transition[1] for transition in transitions], dtype=torch.long
        )
        next_state = torch.stack(
            [self._onehot_encode(transition[2]) for transition in transitions]
        )
        action = action.view(-1, 1)

        return state, action, next_state

    def sample_batch_transition(self, batch_size=8):
        # Randomly sample a batch of transitions
        sampled_transitions = self.sample_transition(k=batch_size)
        return self._convert_transitions_to_tensors(sampled_transitions)

    @torch.inference_mode()
    def draw_learned_latent(self, world_model, cur_step):
        colors = [
            "tab:blue",
            "tab:orange",
            "tab:green",
            "tab:red",
            "tab:purple",
            "tab:brown",
            "tab:pink",
            "tab:gray",
            "tab:olive",
            "tab:cyan",
        ]

        fig, ax = plt.subplots(1, 1, figsize=(6, 6), dpi=200)
        connectionstyle = "arc3, rad=0"
        arrowstyle = "Simple, tail_width=0.1, head_width=3, head_length=5"
        # tr_kw = dict(arrowstyle=arrowstyle, color="#259ADA")
        # gd_kw = dict(arrowstyle=arrowstyle, color="#DA6525")
        kw = dict(arrowstyle=arrowstyle, color="black")
        kw_unseen = dict(arrowstyle=arrowstyle, color="red")
        ticks = [
            -2 * np.pi,
            -1.5 * np.pi,
            -np.pi,
            -0.5 * np.pi,
            0,
            0.5 * np.pi,
            np.pi,
            1.5 * np.pi,
            2 * np.pi,
        ]
        tick_labels = [
            r"$-2\pi$",
            r"$-1.5\pi$",
            r"$-\pi$",
            r"$-0.5\pi$",
            r"$0$",
            r"$0.5\pi$",
            r"$\pi$",
            r"$1.5\pi$",
            r"$2\pi$",
        ]
        # ax.set_title(r"Abstract state space $\mathcal{Z}$")
        ax.set_title(r"Abstract state space $Z$")
        # ax.set_xticks(ticks)
        # ax.set_xticklabels(tick_labels, fontsize=16)
        # ax.set_yticks(ticks)
        # ax.set_yticklabels(tick_labels, fontsize=16)
        # ax.axhline(y=0, color="black", alpha=0.8, linestyle="--")
        # ax.axvline(x=0, color="black", alpha=0.8, linestyle="--")
        # ax.axhline(y=2 * math.pi, color="black", alpha=0.8, linestyle="--")
        # ax.axvline(x=2 * math.pi, color="black", alpha=0.8, linestyle="--")
        unseen_transtion_linestyle = (0, (5, 10))  # loosely dashed

        for state in self.all_states:
            x, y = state
            z = (
                world_model.encoder(self._onehot_encode(state).unsqueeze(0))
                .detach()
                .numpy()
            )
            ax.plot(z[0, 0].item(), z[0, 1].item(), "o", color="black")

            for a in range(self.num_actions):
                next_state = state.copy()
                if a == 0:
                    next_state[1] = (y + 1) % self.size
                else:
                    next_state[0] = (x + 1) % self.size

                is_unseen = (state, a, next_state) in self.disabled_transition

                next_z = (
                    world_model.encoder(self._onehot_encode(next_state).unsqueeze(0))
                    .detach()
                    .numpy()
                )
                action_onehot = torch.zeros(1, self.num_actions)
                action_onehot[0, a] = 1
                d = (
                    world_model.transition_model(torch.tensor(z), action_onehot)
                    .detach()
                    .numpy()
                )
                pred_z = z + d

                ax.add_patch(
                    patches.FancyArrowPatch(
                        (z[0, 0].item(), z[0, 1].item()),
                        (pred_z[0, 0].item(), pred_z[0, 1].item()),
                        connectionstyle=connectionstyle,
                        linestyle="-",
                        **(kw_unseen if is_unseen else kw),
                    )
                )

        legend_tr_arrow = patches.FancyArrowPatch(
            (0, 0),
            (1, 0),
            color="black",
            label="Seen transition",
            arrowstyle=arrowstyle,
        )
        legend_gd_arrow = patches.FancyArrowPatch(
            (0, 0),
            (1, 0),
            color="red",
            label="Unseen transition",
            arrowstyle=arrowstyle,
        )
        # ax.legend(
        #     handles=[legend_tr_arrow, legend_gd_arrow],
        #     loc="upper right",
        #     fontsize=16,
        #     bbox_to_anchor=(1.15, 1),  # Coordinates to position outside the plot
        #     borderaxespad=0.  # Padding between the axes and legend
        # )
        ax.legend(
            handles=[legend_tr_arrow, legend_gd_arrow],
            loc="lower center",
            fontsize=16,
            bbox_to_anchor=(0.5, -0.15),  # Coordinates to position outside the plot
            borderaxespad=0.0,  # Padding between the axes and legend
            ncol=2,
            frameon=False,
        )

        plt.tight_layout()
        plt.savefig(
            os.path.join(world_model.log_dir, f"latent-{cur_step}.pdf"), dpi=200
        )


def main(
    seed=42,
    latent_dim=2,
    hidden_dim=32,
    batch_size=64,
    max_steps=50000,
    encoder_lr=1e-4,
    transition_lr=1e-4,
    momentum=0.9,
    weight_decay=0,
    pct_transitions_disabled=0.1,
    log_dir="logs/torus_v2_generalization",
):
    env = TorusEnvV2(size=5, seed=42, pct_transitions_disabled=pct_transitions_disabled)
    disabled_indicies = env.disabled_transition_indices
    print(
        f"Number of disabled transitions: {len(disabled_indicies)}/{len(env.valid_transitions)}"
    )
    for i in disabled_indicies:
        state, action, next_state = env.valid_transitions[i]
        print(state, action, next_state)

    set_seed(seed)
    world_model = WorldModel(
        env=env,
        latent_dim=latent_dim,
        hidden_dim=hidden_dim,
        encoder_lr=encoder_lr,
        transition_lr=transition_lr,
        momentum=momentum,
        weight_decay=weight_decay,
    )
    world_model.train(
        max_steps=max_steps,
        eval_every=1000,
        batch_size=batch_size,
        log_dir=log_dir,
        log_every=100,
    )


if __name__ == "__main__":
    from fire import Fire

    Fire(main)
