import sys

sys.dont_write_bytecode = True

import os
import datetime
import numpy as np
import random
import pandas as pd
from collections import defaultdict

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.tensorboard import SummaryWriter

import matplotlib.pyplot as plt
import matplotlib.patches as patches
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.patches import FancyArrowPatch
from mpl_toolkits.mplot3d import proj3d

# Enable LaTeX rendering
# plt.rcParams['text.usetex'] = True
# plt.rcParams['font.family'] = 'serif'

# Adjust rcParams for custom background colors
plt.rcParams["figure.facecolor"] = "#ffffff"  # White for figure background
plt.rcParams["axes.facecolor"] = "#ffffff"  # White for axes background
plt.rcParams["axes.edgecolor"] = "black"  # Black border around axes
plt.rcParams["grid.color"] = "#d3d3d3"  # Light gray grid lines
plt.rcParams["grid.alpha"] = 0.8

plt.rcParams["lines.linewidth"] = 0.5
plt.rcParams["axes.edgecolor"] = "black"
plt.rcParams["axes.linewidth"] = 0.5
# Set global font size and weight
plt.rcParams["font.size"] = 18  # Set font size
plt.rcParams["font.weight"] = "bold"  # Set font weight to bold

# Apply to axes titles, labels, and ticks
plt.rcParams["axes.titlesize"] = 22
plt.rcParams["axes.titleweight"] = "bold"
plt.rcParams["axes.labelsize"] = 18
plt.rcParams["axes.labelweight"] = "bold"
plt.rcParams["xtick.labelsize"] = 18
plt.rcParams["ytick.labelsize"] = 18

# Apply to figure titles
plt.rcParams["figure.titlesize"] = 20
plt.rcParams["figure.titleweight"] = "bold"


def set_seed(seed):
    # random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)


class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(Encoder, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)
        self.g = nn.Tanh()
        # self.g = nn.Mish()

        for m in self.modules():
            if isinstance(m, nn.Linear):
                # nn.init.normal_(m.weight, mean=0, std=0.2)
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.fc1(x)
        x = self.g(x)
        x = self.fc2(x)
        x = self.g(x)
        x = self.fc3(x)
        return x


class TransitionModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(TransitionModel, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.g = nn.Tanh()
        # self.g = nn.Mish()

        for m in self.modules():
            if isinstance(m, nn.Linear):
                # nn.init.normal_(m.weight, mean=0, std=0.2)
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)

    def forward(self, x, a):
        x = torch.cat([x, a], dim=-1)
        x = self.fc1(x)
        x = self.g(x)
        x = self.fc2(x)
        return x


class WorldModel:
    def __init__(
        self,
        env,
        latent_dim=3,
        hidden_dim=32,
        encoder_lr=3e-4,
        transition_lr=1e-4,
        momentum=0.9,
        weight_decay=0,
    ):
        self.env = env
        self.latent_dim = latent_dim
        self.hidden_dim = hidden_dim
        self.encoder_lr = encoder_lr
        self.transition_lr = transition_lr

        self.encoder = Encoder(
            input_dim=2 * env.size + 4, hidden_dim=hidden_dim, output_dim=latent_dim
        )
        self.transition_model = TransitionModel(
            input_dim=latent_dim + env.num_actions,
            hidden_dim=hidden_dim,
            output_dim=latent_dim,
        )
        self.encoder_optimizer = optim.RMSprop(
            self.encoder.parameters(),
            lr=encoder_lr,
            momentum=momentum,
            weight_decay=weight_decay,
        )
        self.transition_optimizer = optim.RMSprop(
            self.transition_model.parameters(),
            lr=transition_lr,
            momentum=momentum,
            weight_decay=weight_decay,
        )
        # self.encoder_optimizer = optim.Adafactor(
        #     self.encoder.parameters(),
        #     lr=encoder_lr,
        #     weight_decay=weight_decay,
        # )
        # self.transition_optimizer = optim.Adafactor(
        #     self.transition_model.parameters(),
        #     lr=transition_lr,
        #     weight_decay=weight_decay,
        # )
        # self.encoder_optimizer = SOAP(
        #     self.encoder.parameters(),
        #     lr=encoder_lr,
        #     weight_decay=weight_decay,
        # )
        # self.transition_optimizer = SOAP(
        #     self.transition_model.parameters(),
        #     lr=transition_lr,
        #     weight_decay=weight_decay,
        # )

    def _place_z_on_circle(self, z, r=1.0):
        x = torch.cos(z) * r
        y = torch.sin(z) * r
        return torch.cat([x.unsqueeze(1), y.unsqueeze(1)], dim=1)

    def _compute_prediction_loss(self, state, state_hat):
        l2_loss = torch.linalg.norm(state - state_hat, ord=2, dim=-1).mean()
        return l2_loss
        # l1_loss = (state - state_hat).abs().mean()
        # return l1_loss

    def _compute_entropy_loss(self, state_x, state_y, C=5.0):
        l2_distance = torch.linalg.norm(state_x - state_y, ord=2, dim=-1)
        entropy_loss = (-C * l2_distance).exp().mean()
        return entropy_loss

    def _compute_identity_loss(self, x):
        l1_loss = x.abs().mean()
        return l1_loss
        # l2_loss = torch.linalg.norm(x, ord=2, dim=-1).mean()
        # return l2_loss

    def _compute_hinge_loss(self, state_x, state_y, margin=1.0):
        l2_distance = torch.linalg.norm(state_x - state_y, ord=2, dim=-1)
        hinge_loss = (
            torch.max(torch.zeros_like(l2_distance), l2_distance - margin)
            .clamp(0, 10)
            .mean()
        )
        return hinge_loss

    def _compute_hypercube_loss(self, state, r=1.0):
        linf_distance = torch.linalg.norm(state.abs(), ord=float("inf"), dim=-1)
        hypercube_loss = (
            torch.max(torch.zeros_like(linf_distance), linf_distance - r)
            .clamp(0, 10)
            .mean()
        )
        return hypercube_loss

    def _compute_infonce_loss(self, pred, gt, temp=1.0):
        # InfoNCE loss
        # pred, gt shape is b x d
        # Compute a pair-wise similarity matrix based on L2 distance
        sim = (torch.cdist(pred, gt, p=2).neg() / temp).exp()
        # Compute logits by dividing each similarity score by the sum of all similarities for each batch
        # (softmax across the columns for gt against all pred)
        pos_logits = torch.diag(sim)  # Positive samples are on the diagonal
        loss = -torch.log(pos_logits / sim.sum(dim=1)).mean()
        return loss

    def _get_off_diagonal_elements(self, matrix):
        # Create a boolean mask where diagonal elements are False
        mask = ~torch.eye(matrix.size(0), dtype=torch.bool, device=matrix.device)
        # Use the mask to extract off-diagonal elements
        off_diagonal_elements = matrix[mask]
        return off_diagonal_elements

    def _compute_prae_loss(self, pred, gt, eps=1.0):
        """
        Compute the loss (Eq. 8) from "Plannable Approximations to MDP Homomorphisms: Equivariance under Actions" (van der Pol et al. 2020)
        https://www.ifaamas.org/Proceedings/aamas2020/pdfs/p1431.pdf
        """
        sim = torch.cdist(pred, gt, p=2).square() / 2.0
        pos = torch.diag(sim)
        # Get the off-diagonal elements
        off_diag = self._get_off_diagonal_elements(sim)
        # print(off_diag.shape)
        # raise Exception
        # Compute the hinge loss
        hinge_neg = torch.max(torch.zeros_like(off_diag), eps - off_diag)
        loss = (pos.sum() + hinge_neg.sum()) / pred.shape[0]
        return loss

    def learn(self, cur_step, max_steps, batch_size, log_every=100):
        state, action, next_state = self.env.sample_batch_transition(
            batch_size=batch_size
        )
        action_onehot = torch.zeros(batch_size, self.env.num_actions)
        action_onehot.scatter_(1, action, 1)

        z = self.encoder(state)
        next_z = self.encoder(next_state)
        d = self.transition_model(z, action_onehot)
        z_hat = z + d

        # transition_loss = self._compute_infonce_loss(z_hat, next_z)
        # hypercube_loss = self._compute_hypercube_loss(z_hat, r=1.0) + self._compute_hinge_loss(z, next_z, margin=0.3)
        # randperm_z = torch.randperm(z.size(0))
        # entropy_loss = self._compute_entropy_loss(z, z[randperm_z, :], C=5.0)
        transition_loss = self._compute_prae_loss(z_hat, next_z)
        hypercube_loss = torch.tensor(0.0)
        entropy_loss = torch.tensor(0.0)
        identity_loss = torch.tensor(0.0)

        # Total loss
        loss = transition_loss + entropy_loss + identity_loss + hypercube_loss

        # Backward pass
        self.encoder_optimizer.zero_grad()
        self.transition_optimizer.zero_grad()
        loss.backward()

        # Report gradient norm
        encoder_grad_norm = 0
        for param in self.encoder.parameters():
            if param.grad is not None:
                encoder_grad_norm += param.grad.norm().item() ** 2
        encoder_grad_norm = encoder_grad_norm**0.5

        transition_grad_norm = 0
        for param in self.transition_model.parameters():
            if param.grad is not None:
                transition_grad_norm += param.grad.norm().item() ** 2
        transition_grad_norm = transition_grad_norm**0.5

        self.encoder_optimizer.step()
        self.transition_optimizer.step()

        # Decay the learning rate
        # scale ranges from 1.0 (start) down to min_lr/initial_lr (end)
        min_lr = 1e-7
        fraction_done = cur_step / (max_steps - 1)
        linear_decay_scale = 1.0 - fraction_done * (1.0 - min_lr / self.encoder_lr)
        new_lr = self.encoder_lr * linear_decay_scale

        for param_group in self.encoder_optimizer.param_groups:
            param_group["lr"] = new_lr

        linear_decay_scale = 1.0 - fraction_done * (1.0 - min_lr / self.transition_lr)
        new_lr = self.transition_lr * linear_decay_scale

        for param_group in self.transition_optimizer.param_groups:
            param_group["lr"] = new_lr

        # Logging
        if cur_step % log_every == 0:
            self.writer.add_scalar("Loss/Total", loss.item(), cur_step)
            self.writer.add_scalar("Loss/Transition", transition_loss.item(), cur_step)
            self.writer.add_scalar("Loss/Entropy", entropy_loss.item(), cur_step)
            self.writer.add_scalar("Loss/Identity", identity_loss.item(), cur_step)
            self.writer.add_scalar("Loss/Hypercube", hypercube_loss.item(), cur_step)
            self.writer.add_scalar("GradNorm/Encoder", encoder_grad_norm, cur_step)
            self.writer.add_scalar(
                "GradNorm/Transition", transition_grad_norm, cur_step
            )

            eval_hits_at_k, eval_mmr = self._measure_performance(
                self.env.disabled_transition
            )
            self.writer.add_scalar("Eval/MMR", eval_mmr, cur_step)

            train_hits_at_k, train_mmr = self._measure_performance(
                self.env.enabled_transition
            )
            self.writer.add_scalar("Train/MMR", train_mmr, cur_step)

            print(
                f"Step: {cur_step}, Lr: {new_lr:.4f}, Loss: {loss.item():.4f}, Train/MMR {train_mmr:.4f}, Train/H@1: {train_hits_at_k[1]:.4f}, Train/H@5: {train_hits_at_k[5]:.4f}, Eval/MMR: {eval_mmr:.4f}, Eval/H@1: {eval_hits_at_k[1]:.4f}, Eval/H@5: {eval_hits_at_k[5]:.4f}"
            )

            # Flush the stdout buffer so that the logs are written to the file
            sys.stdout.flush()

            # Log to a DataFrame
            new_row = pd.DataFrame(
                [
                    {
                        "Step": cur_step,
                        "Lr": new_lr,
                        "Loss": loss.item(),
                        "TransitionLoss": transition_loss.item(),
                        "EntropyLoss": entropy_loss.item(),
                        "IdentityLoss": identity_loss.item(),
                        "HypercubeLoss": hypercube_loss.item(),
                        "Train_MMR": train_mmr,
                        "Train_H@1": train_hits_at_k[1],
                        "Train_H@5": train_hits_at_k[5],
                        "Eval_MMR": eval_mmr,
                        "Eval_H@1": eval_hits_at_k[1],
                        "Eval_H@5": eval_hits_at_k[5],
                    }
                ]
            )

            self.df = pd.concat([self.df, new_row], ignore_index=True)

    def _l2_distance(self, x, y):
        return torch.linalg.norm(x - y, ord=2, dim=-1)

    @torch.inference_mode()
    def _measure_performance(self, transitions):
        mmr = 0
        hits_at_k = defaultdict(float)
        hits_at_k[1] = 0
        hits_at_k[5] = 0
        hits_at_k[10] = 0

        if len(transitions) == 0:
            return hits_at_k, mmr

        batch_size = len(transitions)
        state, action, next_state = self.env._convert_transitions_to_tensors(
            transitions
        )
        action_onehot = torch.zeros(batch_size, self.env.num_actions)
        action_onehot.scatter_(1, action, 1)

        z = self.encoder(state)
        next_z = self.encoder(next_state)
        d = self.transition_model(z, action_onehot)
        z_hat = z + d
        x_hat = z_hat
        next_x = next_z

        # enable_transitions = self.env.enabled_transition
        # _, _, train_next_states = self.env._convert_transitions_to_tensors(enable_transitions)
        # train_next_z = self.encoder(train_next_states)
        # train_next_x = train_next_z

        # Compute the ranking-based score
        # Ref: https://arxiv.org/pdf/1911.12247 (Appendix C)
        # Metric:
        #   - Hits @ Rank K
        #   - Mean Reciprocal Rank (MMR)
        dist_matrix = torch.cdist(x_hat, next_x, p=2)
        dist_matrix_diag = torch.diag(dist_matrix).unsqueeze(-1)
        dist_matrix_augmented = torch.cat([dist_matrix_diag, dist_matrix], dim=1)
        sorted_indices = torch.argsort(dist_matrix_augmented, dim=1, stable=True)

        for i in range(len(transitions)):
            # Compute the rank of the ground truth
            rank = torch.where(sorted_indices[i] == 0)[0].item()

            # Compute the hits at rank k
            for k in [1, 5, 10]:
                if rank < k:
                    hits_at_k[k] += 1

            # Compute the reciprocal rank
            mmr += 1.0 / (rank + 1)

        for k in [1, 5, 10]:
            hits_at_k[k] /= len(transitions)

        mmr /= len(transitions)
        return hits_at_k, mmr

    def evaluate(self, cur_step):
        self.env.draw_learned_latent(self, cur_step)
        # pass

    def train(
        self,
        max_steps=100000,
        eval_every=1000,
        batch_size=32,
        log_dir=None,
        log_every=100,
    ):
        if log_dir is not None:
            # Create the log folder if it does not exist
            timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
            self.log_dir = os.path.join(log_dir, timestamp)
            # Create the log file
            os.makedirs(self.log_dir, exist_ok=True)
            self.log_file = open(os.path.join(self.log_dir, "output.log"), "w")
            sys.stdout = self.log_file
            sys.stderr = self.log_file
            self.writer = SummaryWriter(log_dir=self.log_dir)

            # Create a DataFrame to log the metrics
            self.df = pd.DataFrame(
                columns=[
                    "Step",
                    "Lr",
                    "Loss",
                    "TransitionLoss",
                    "EntropyLoss",
                    "IdentityLoss",
                    "HypercubeLoss",
                    "Train_MMR",
                    "Train_H@1",
                    "Train_H@5",
                    "Eval_MMR",
                    "Eval_H@1",
                    "Eval_H@5",
                ]
            )

        for step in range(max_steps):
            if step % eval_every == 0:
                self.evaluate(step)

            self.learn(
                step, max_steps=max_steps, batch_size=batch_size, log_every=log_every
            )

        if log_dir is not None:
            self.log_file.close()
            self.writer.close()

            # Save the log DataFrame
            self.df.to_csv(os.path.join(self.log_dir, "log.csv"), index=False)


class Arrow3D(FancyArrowPatch):
    def __init__(self, xs, ys, zs, *args, **kwargs):
        super().__init__((0, 0), (0, 0), *args, **kwargs)
        self._verts3d = xs, ys, zs

    def do_3d_projection(self, renderer=None):
        xs3d, ys3d, zs3d = self._verts3d
        xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.M)
        self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))

        return np.min(zs)


class MiniGridEnvV2:
    def __init__(self, size=5, seed=42, pct_transitions_disabled=0.1):
        print("MiniGridEnvV2")
        self.size = size
        self.agent_state = [0, 0, 0]  # [x, y, direction]
        self.move_actions = [[-1, 0], [0, 1], [1, 0], [0, -1]]
        self.num_actions = 2  # 0: Turn right, 1: Move forward
        self.directions = ["N", "E", "S", "W"]
        self.seed = seed
        random.seed(self.seed)  # Fixed seed

        self.all_states = self._get_all_states()
        self.valid_transitions = self._get_all_valid_transitions()
        print("Number of valid transitions: ", len(self.valid_transitions))

        self.disabled_transition_indices = self._get_disabled_indices(
            pct_transitions_disabled=pct_transitions_disabled
        )
        self.disabled_transition = [
            self.valid_transitions[i] for i in self.disabled_transition_indices
        ]
        self.enabled_transition_indices = [
            i
            for i in range(len(self.valid_transitions))
            if i not in self.disabled_transition_indices
        ]
        self.enabled_transition = [
            self.valid_transitions[i] for i in self.enabled_transition_indices
        ]
        print("Number of disabled transitions: ", len(self.disabled_transition_indices))

    def _get_disabled_indices(self, pct_transitions_disabled=0.1):
        # Sample a random subset of transitions to disable

        num_transitions = len(self.valid_transitions)
        num_transitions_disabled = int(pct_transitions_disabled * num_transitions)
        disabled_indices = random.sample(
            range(num_transitions), num_transitions_disabled
        )
        return disabled_indices

    def _get_all_states(self):
        all_states = []
        for x in range(self.size):
            for y in range(self.size):
                for direction in range(4):
                    all_states.append([x, y, direction])
        return all_states

    def _get_all_valid_transitions(self):
        valid_transitions = []
        for x in range(self.size):
            for y in range(self.size):
                # Assume that the agent state is [x, y, direction]

                # Attempt to turn right
                for direction in range(4):
                    valid_transitions.append(
                        ([x, y, direction], 0, [x, y, (direction + 1) % 4])
                    )

                # Attempt to move forward
                for direction in range(4):
                    new_x, new_y = (
                        x + self.move_actions[direction][0],
                        y + self.move_actions[direction][1],
                    )
                    if 0 <= new_x < self.size and 0 <= new_y < self.size:
                        valid_transitions.append(
                            ([x, y, direction], 1, [new_x, new_y, direction])
                        )

        return valid_transitions

    def sample_transition(self, k=1):
        # Randomly sample k transitions with replacement from the non-disabled transitions
        # return random.choices(self.valid_transitions, k=k)
        return [
            self.valid_transitions[i]
            for i in random.choices(self.enabled_transition_indices, k=k)
        ]

    def _onehot_encode(self, state):
        # One-hot encode the state with vector size = 2 * size + 4
        onehot_state = torch.zeros(self.size * 2 + 4)
        onehot_state[state[0]] = 1
        onehot_state[self.size + state[1]] = 1
        onehot_state[self.size * 2 + state[2]] = 1
        return onehot_state.float()

    def _convert_transitions_to_tensors(self, transitions):
        # Convert the transitions to tensors
        state = torch.stack(
            [self._onehot_encode(transition[0]) for transition in transitions]
        )
        action = torch.tensor(
            [transition[1] for transition in transitions], dtype=torch.long
        )
        next_state = torch.stack(
            [self._onehot_encode(transition[2]) for transition in transitions]
        )
        action = action.view(-1, 1)

        return state, action, next_state

    def sample_batch_transition(self, batch_size=8):
        # Randomly sample a batch of transitions
        sampled_transitions = self.sample_transition(k=batch_size)
        return self._convert_transitions_to_tensors(sampled_transitions)

    @torch.inference_mode()
    def draw_learned_latent(self, world_model, cur_step):
        colors = [
            "tab:blue",
            "tab:orange",
            "tab:green",
            "tab:red",
            "tab:purple",
            "tab:brown",
            "tab:pink",
            "tab:gray",
            "tab:olive",
            "tab:cyan",
        ]

        fig = plt.figure(figsize=(8, 8), dpi=200)
        ax = fig.add_subplot(111, projection="3d")
        connectionstyle = "arc3, rad=-.3"
        connectionstyle2 = "arc3, rad=0"
        arrowstyle = "Simple, tail_width=0.5, head_width=3, head_length=5"
        unseen_transtion_linestyle = (0, (5, 10))  # loosely dashed

        annotated_states = []
        for state in self.all_states:
            x, y, direction = state
            color_id = (x + y) % len(colors)
            abs_state = world_model.encoder(
                self._onehot_encode(state).unsqueeze(0)
            ).detach()
            ax.plot(
                abs_state[0, 0],
                abs_state[0, 1],
                abs_state[0, 2],
                "o",
                color=colors[color_id],
                markersize=5,
            )

            if (x, y) not in annotated_states:
                ax.text(
                    abs_state[0, 0],
                    abs_state[0, 1],
                    abs_state[0, 2],
                    f"({x}, {y})",
                    color="black",
                )
                annotated_states.append((x, y))

            for a in range(self.num_actions):
                next_state = state.copy()

                if a == 0:
                    next_state[2] = (next_state[2] + 1) % 4
                else:
                    next_state[0] += self.move_actions[next_state[2]][0]
                    next_state[1] += self.move_actions[next_state[2]][1]
                    if (
                        next_state[0] < 0
                        or next_state[0] >= self.size
                        or next_state[1] < 0
                        or next_state[1] >= self.size
                    ):
                        continue

                is_unseen = (state, a, next_state) in self.disabled_transition
                action_onehot = torch.zeros(1, self.num_actions)
                action_onehot[0, a] = 1
                d = world_model.transition_model(abs_state, action_onehot).detach()
                pred_next_state = abs_state + d

                # Draw the arrow
                a = Arrow3D(
                    [abs_state[0, 0], pred_next_state[0, 0]],
                    [abs_state[0, 1], pred_next_state[0, 1]],
                    [abs_state[0, 2], pred_next_state[0, 2]],
                    connectionstyle=connectionstyle2,
                    linestyle=unseen_transtion_linestyle if is_unseen else "solid",
                    color="red" if is_unseen else "black",
                    mutation_scale=5,
                )
                ax.add_artist(a)

        plt.tight_layout()
        plt.savefig(
            os.path.join(world_model.log_dir, f"latent_{cur_step}.pdf"), dpi=200
        )


def main(
    seed=42,
    latent_dim=3,
    hidden_dim=32,
    batch_size=64,
    max_steps=50000,
    encoder_lr=1e-4,
    transition_lr=1e-4,
    momentum=0.9,
    weight_decay=0,
    pct_transitions_disabled=0.1,
    log_dir="logs/minigrid_v2_generalization",
):
    env = MiniGridEnvV2(
        size=5, seed=42, pct_transitions_disabled=pct_transitions_disabled
    )
    disabled_indicies = env.disabled_transition_indices
    print(
        f"Number of disabled transitions: {len(disabled_indicies)}/{len(env.valid_transitions)}"
    )
    for i in disabled_indicies:
        state, action, next_state = env.valid_transitions[i]
        print(state, action, next_state)

    set_seed(seed)
    world_model = WorldModel(
        env=env,
        latent_dim=latent_dim,
        hidden_dim=hidden_dim,
        encoder_lr=encoder_lr,
        transition_lr=transition_lr,
        momentum=momentum,
        weight_decay=weight_decay,
    )
    world_model.train(
        max_steps=max_steps,
        eval_every=1000,
        batch_size=batch_size,
        log_dir=log_dir,
        log_every=100,
    )


if __name__ == "__main__":
    from fire import Fire

    Fire(main)
