import copy
import math
import warnings
from contextlib import nullcontext

import numpy as np
import torch
from minihack.agent.common.models.embed import GlyphEmbedding
from minihack.agent.common.models.transformer import TransformerEncoder
from minihack.agent.polybeast.models.base import NUM_FEATURES, Crop, NetHackNet
from minihack.agent.polybeast.models.intrinsic import IntrinsicRewardNet
from tokenizers import BertWordPieceTokenizer
from torch import nn, optim
from torch.distributions import Categorical

from . import message as message_models


def init_tokenizer(vocab_file):
    return BertWordPieceTokenizer(vocab_file, lowercase=True)


class Teacher(NetHackNet):
    def __init__(self, observation_shape, flags, device):
        super().__init__()

        self.flags = flags

        self.observation_shape = observation_shape

        self.H = observation_shape[0]
        self.W = observation_shape[1]

        self.k_dim = flags.minihack.embedding_dim
        self.h_dim = flags.minihack.hidden_dim

        self.num_features = NUM_FEATURES

        self.glyph_type = flags.minihack.glyph_type
        self.glyph_embedding = GlyphEmbedding(
            flags.minihack.glyph_type,
            flags.minihack.embedding_dim,
            device,
            flags.minihack.use_index_select,
        )

        # MESSAGING MODEL
        self.msg_model = flags.minihack.msg.model
        self.msg_hdim = flags.minihack.msg.hidden_dim
        self.msg_edim = flags.minihack.msg.embedding_dim
        if self.msg_model == "word_gru":
            self.msg_net = message_models.WordGRU(
                self.msg_hdim, self.msg_edim, vocab_size=30522
            )
        elif self.msg_model != "none":
            raise NotImplementedError

        self.embed_features = nn.Sequential(
            nn.Linear(self.num_features, self.k_dim),
            nn.ReLU(),
            nn.Linear(self.k_dim, self.k_dim),
            nn.ReLU(),
        )

        FDIM = 3  # filter dimensions
        S = 1  # stride
        P = 1  # padding
        M = 16  # number of intermediate filters
        self.Y = 1  # number of output filters
        L = flags.minihack.layers  # number of convnet layers

        # Number of inputs to convnet:
        # 1 k_dim for features, 1 k_dim for glyph embeddings
        core_state_dim = 2 * self.k_dim
        # messaging model
        if self.msg_model != "none":
            core_state_dim += self.msg_hdim
        self.core_state_dim = core_state_dim

        in_channels = [core_state_dim] + [M] * (L - 1)
        out_channels = [M] * (L - 1) + [self.Y]

        def interleave(xs, ys):
            return [val for pair in zip(xs, ys) for val in pair]

        conv_extract = [
            nn.Conv2d(
                in_channels=in_channels[i],
                out_channels=out_channels[i],
                kernel_size=(FDIM, FDIM),
                stride=S,
                padding=P,
            )
            for i in range(L)
        ]

        self.extract_representation = nn.Sequential(
            *interleave(conv_extract, [nn.ELU()] * len(conv_extract))
        )

        self.num_actions = self.H * self.W

        self.baseline = nn.Linear(self.core_state_dim + self.num_actions, 1)
        self.logits_size = self.num_actions
        self.raw_goal_size = 1  # Message size

    def initial_state(self, batch_size=1):
        return tuple()

    def prepare_input(self, inputs):
        # -- [T x B x H x W]
        T, B, H, W = inputs["glyphs"].shape

        # take our chosen glyphs and merge the time and batch

        glyphs = self.glyph_embedding.prepare_input(inputs)

        # -- [T x B x F]
        features = inputs["blstats"]
        # -- [B' x F]
        features = features.view(T * B, -1).float()

        return glyphs, features

    def embed_environment(self, inputs):
        T, B, *_ = inputs["glyphs"].shape

        glyphs, features = self.prepare_input(inputs)

        features = features.view(T * B, -1).float()
        # -- [B x K]
        features_emb = self.embed_features(features)

        assert features_emb.shape[0] == T * B

        extra_reps = [features_emb]

        # MESSAGING MODEL
        if self.msg_model != "none":
            # [T x B x 256] -> [T * B x 256]
            char_rep = self.embed_split_messages(
                inputs["split_messages"], inputs["split_messages_len"]
            )
            extra_reps.append(char_rep)

        # -- [B x H x W x K]
        glyphs_emb = self.glyph_embedding(glyphs)

        # Cat extra rep with each cell
        extra_reps_unsq = []
        for er in extra_reps:
            extra_reps_unsq.append(
                er.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, self.H, self.W)
            )

        # -- [B x K x H x W]
        glyphs_emb = glyphs_emb.permute((0, 3, 1, 2))
        # -- [B x K' x H x W]
        glyphs_aug = torch.cat([glyphs_emb, *extra_reps_unsq], 1)

        # -- [B x K']
        core_state = glyphs_aug.mean(-1).mean(-1)

        # -- [B x 1 x H x W]
        logits = self.extract_representation(glyphs_aug)

        assert logits.shape[0] == T * B
        # -- [B x (H * W)]
        logits = logits.view(T * B, -1)

        return core_state, logits

    def forward(self, inputs, prev_goals=None):
        T, B, *_ = inputs["glyphs"].shape
        core_output, logits = self.embed_environment(inputs)

        # -- [B x A]
        baseline_input = torch.cat([core_output, logits], 1)
        baseline = self.baseline(baseline_input)

        goal_dist = Categorical(logits=logits)
        action = goal_dist.sample()

        logits = logits.view(T, B, self.num_actions)
        baseline = baseline.view(T, B)
        action = action.view(T, B)

        return {
            "generator_logits": logits,
            "generator_baseline": baseline,
            "goal": action,
            "raw_goal": action,
        }

    def embed_split_messages(self, split_messages, split_messages_len=None):
        T, B, N, *_ = split_messages.shape
        split_messages = split_messages.view(T * B * N, -1)
        if split_messages_len is not None:
            split_messages_len = split_messages_len.view(T * B * N)
        rep = self.embed_message(split_messages, split_messages_len)
        rep = rep.view(T * B, N, -1).mean(-2)  # Average across the N messages
        return rep

    def embed_message(self, messages, message_len=None):
        if self.msg_model == "none":
            raise RuntimeError("No message model")

        if message_len is None:
            # Create it yourself
            message_len = (messages != 0).long().sum(-1, keepdim=True)
            # Min length of 1
            message_len = torch.clamp(message_len, min=1)

        if messages.ndim > 2:
            T, B, *_ = messages.shape
            # [T x B x 256] -> [T * B x 256]
            messages = messages.long().view(T * B, -1)
            message_len = message_len.long().view(T * B)

        return self.msg_net(messages, message_len)


class LanguageTeacher(Teacher):
    """
    This is the BaseNet minihack model, with modifications
    """

    def __init__(self, observation_shape, flags, device):
        super(LanguageTeacher, self).__init__(observation_shape, flags, device)

        self.crop_model = flags.minihack.crop_model
        self.crop_dim = flags.minihack.crop_dim

        self.crop = Crop(self.H, self.W, self.crop_dim, self.crop_dim, device)
        K = flags.minihack.embedding_dim  # number of input filters
        FDIM = 3  # filter dimensions
        S = 1  # stride
        P = 1  # padding
        M = 16  # number of intermediate filters
        self.Y = 8  # number of output filters
        L = flags.minihack.layers  # number of convnet layers

        in_channels = [K] + [M] * (L - 1)
        out_channels = [M] * (L - 1) + [self.Y]

        def interleave(xs, ys):
            return [val for pair in zip(xs, ys) for val in pair]

        conv_extract = [
            nn.Conv2d(
                in_channels=in_channels[i],
                out_channels=out_channels[i],
                kernel_size=(FDIM, FDIM),
                stride=S,
                padding=P,
            )
            for i in range(L)
        ]

        self.extract_representation = nn.Sequential(
            *interleave(conv_extract, [nn.ELU()] * len(conv_extract))
        )

        if self.crop_model == "transformer":
            self.extract_crop_representation = TransformerEncoder(
                K,
                N=L,
                heads=8,
                height=self.crop_dim,
                width=self.crop_dim,
                device=device,
            )
        elif self.crop_model == "cnn":
            conv_extract_crop = [
                nn.Conv2d(
                    in_channels=in_channels[i],
                    out_channels=out_channels[i],
                    kernel_size=(FDIM, FDIM),
                    stride=S,
                    padding=P,
                )
                for i in range(L)
            ]

            self.extract_crop_representation = nn.Sequential(
                *interleave(conv_extract_crop, [nn.ELU()] * len(conv_extract))
            )

        if flags.language_goals in {"online_naive", "online_grounding"}:
            self.register_buffer(
                "goals",
                torch.zeros(
                    (
                        flags.max_online_goals + 1,
                        flags.minihack.msg.word.max_message_len,
                    ),
                    dtype=torch.int64,
                ),
            )
            self.register_buffer(
                "goals_mask", torch.zeros(flags.max_online_goals + 1, dtype=torch.bool)
            )
            self.goals_str = {}
            self.lang_templates = []
            self.lang_hashes = {}

            # Null goal. chr(254) = some weird character which is hopefully
            # Null goal: random token
            self.goals[0, 0] = 254
            self.goals_mask[0] = True  # Null goal
            self.goals_str[b"<NULL>"] = 0
            self.lang_templates.append("<NULL>")  # Lang templates use strings.
            self.lang_hashes["b<NULL>"] = 0

            if flags.language_goals == "online_grounding":
                self.grounding_head = nn.Linear(self.msg_hdim, self.h_dim)
        else:
            raise NotImplementedError

        # just added up the output dimensions of the input featurizers
        # feature / status dim
        out_dim = self.k_dim
        # CNN over full glyph map
        out_dim += self.H * self.W * self.Y
        if self.crop_model == "transformer":
            out_dim += self.crop_dim ** 2 * K
        elif self.crop_model == "cnn":
            out_dim += self.crop_dim ** 2 * self.Y
        # messaging model
        if self.msg_model != "none":
            out_dim += self.msg_hdim

        self.fc = nn.Sequential(
            nn.Linear(out_dim, self.h_dim),
            nn.ReLU(),
            nn.Linear(self.h_dim, self.h_dim),
            nn.ReLU(),
        )

        # Map message state to hidden state
        self.msg2st = nn.Linear(self.msg_hdim, self.h_dim)

        if flags.language_goals == "online_naive":
            self.num_actions = flags.max_online_goals + 1
        else:
            self.num_actions = self.goals.shape[0]
        self.baseline = nn.Linear(self.h_dim + self.num_actions, 1)
        self.logits_size = self.num_actions
        self.raw_goal_size = flags.minihack.msg.word.max_message_len

        # Message tokenizer
        self.tokenizer = None

    def embed_environment(self, inputs):
        T, B, *_ = inputs["glyphs"].shape

        glyphs, features = self.prepare_input(inputs)

        # -- [B x 2] x,y coordinates
        coordinates = features[:, :2]

        features = features.view(T * B, -1).float()
        # -- [B x K]
        features_emb = self.embed_features(features)

        assert features_emb.shape[0] == T * B

        reps = [features_emb]

        # -- [B x H' x W']
        crop = self.glyph_embedding.GlyphTuple(
            *[self.crop(g, coordinates) for g in glyphs]
        )
        # -- [B x H' x W' x K]
        crop_emb = self.glyph_embedding(crop)

        if self.crop_model == "transformer":
            # -- [B x W' x H' x K]
            crop_rep = self.extract_crop_representation(crop_emb, mask=None)
        elif self.crop_model == "cnn":
            # -- [B x K x W' x H']
            crop_emb = crop_emb.transpose(1, 3)
            # -- [B x W' x H' x K]
            crop_rep = self.extract_crop_representation(crop_emb)
        # -- [B x K']

        crop_rep = crop_rep.view(T * B, -1)
        assert crop_rep.shape[0] == T * B

        reps.append(crop_rep)

        # -- [B x H x W x K]
        glyphs_emb = self.glyph_embedding(glyphs)
        # glyphs_emb = self.embed(glyphs)
        # -- [B x K x W x H]
        glyphs_emb = glyphs_emb.transpose(1, 3)
        # -- [B x W x H x K]
        glyphs_rep = self.extract_representation(glyphs_emb)

        # -- [B x K']
        glyphs_rep = glyphs_rep.view(T * B, -1)

        assert glyphs_rep.shape[0] == T * B

        # -- [B x K'']
        reps.append(glyphs_rep)

        # MESSAGING MODEL
        if self.msg_model != "none":
            # [T x B x 256] -> [T * B x 256]
            char_rep = self.embed_message(inputs["message"])
            reps.append(char_rep)

        st = torch.cat(reps, dim=1)

        # -- [B x K]
        st = self.fc(st)

        core_output = st
        return core_output

    def forward_grounder(self, inputs):
        T, B, *_ = inputs["glyphs"].shape
        core_output = self.embed_environment(inputs)
        goal_emb = self.embed_message(self.goals.unsqueeze(0))
        grounding_output = self.predict_achievable(
            env_emb=core_output, goal_emb=goal_emb
        )
        return grounding_output

    def forward(self, inputs, prev_goals=None):
        T, B, *_ = inputs["glyphs"].shape
        core_output = self.embed_environment(inputs)

        # Select from goals
        goal_emb = self.embed_message(self.goals.unsqueeze(0))
        goal_emb_proj = self.msg2st(goal_emb).unsqueeze(0)
        goal_emb_proj = goal_emb_proj.expand(core_output.shape[0], -1, -1)

        policy_logits = torch.bmm(goal_emb_proj, core_output.unsqueeze(-1)).squeeze(-1)
        if np.random.random() < self.flags.generator_eps:
            # Sample uniformly.
            policy_logits_masked = torch.zeros_like(policy_logits)
        else:
            policy_logits_masked = policy_logits.detach().clone()

        # -- [B x A]
        baseline_input = torch.cat([core_output, policy_logits], 1)
        baseline = self.baseline(baseline_input)

        grounding_output = None
        subgoal_mask = None
        if self.flags.language_goals == "online_naive":
            subgoal_mask = self.goals_mask
        elif self.flags.language_goals == "online_grounding":
            # To be available to sample, goal must be predicted AND seen (as
            # revealed by goals mask)
            with torch.no_grad():
                grounding_output = self.predict_achievable(
                    env_emb=core_output, goal_emb=goal_emb
                )
            if self.flags.hard_grounding:
                subgoal_mask = grounding_output["preds"] & self.goals_mask
                if self.flags.force_new_goals and prev_goals is not None:
                    subgoal_mask[prev_goals] = False
                subgoal_mask[:, 0] = True  # Subgoal mask always 1 for null goal.
            else:
                subgoal_mask = (
                    torch.sigmoid(grounding_output["logits"]) * self.goals_mask
                )
                if self.flags.force_new_goals and prev_goals is not None:
                    subgoal_mask[prev_goals] = 0.0
                subgoal_mask[:, 0] = 0.5  # Null goal always possible

        if subgoal_mask is not None:
            if self.flags.hard_grounding:
                if subgoal_mask.ndim == 3:
                    subgoal_mask = subgoal_mask[0]
                    policy_logits_masked[~subgoal_mask] = -np.inf
                elif subgoal_mask.ndim == 2:
                    policy_logits_masked[~subgoal_mask] = -np.inf
                else:
                    policy_logits_masked[:, ~subgoal_mask] = -np.inf
                goal_dist = Categorical(logits=policy_logits_masked)
            else:
                # Product of experts - add probabilities in log space = mult
                if subgoal_mask is not None:
                    if subgoal_mask.ndim == 3:
                        subgoal_mask = subgoal_mask[0]
                        policy_logits_masked += subgoal_mask.float().log()
                    elif subgoal_mask.ndim == 2:
                        policy_logits_masked += subgoal_mask.float().log()
                    else:
                        policy_logits_masked += subgoal_mask.unsqueeze(0).float().log()
                goal_dist = Categorical(logits=policy_logits_masked)

        action = goal_dist.sample()

        policy_logits = policy_logits.view(T, B, self.num_actions)
        baseline = baseline.view(T, B)
        action = action.view(T, B)

        if self.flags.language_goals == "online_grounding":
            grounder_logits = subgoal_mask.view(T, B, -1)
        else:
            grounder_logits = torch.ones_like(policy_logits)

        return {
            "generator_logits": policy_logits,
            "generator_baseline": baseline,
            "goal": action,
            "raw_goal": self.goals[action],
            "grounder_logits": grounder_logits,
        }

    def decode_messages(self, messages):
        if self.tokenizer is None:
            self.tokenizer = init_tokenizer(self.flags.minihack.msg.word.vocab_file)
        decoded = self.tokenizer.decode_batch(messages)
        return [x.encode("utf-8") for x in decoded]

    def update_goals(self, batch):
        # Goal buffer is full
        if len(self.goals_str) >= self.goals_mask.shape[0] + 1:
            return
        message_flat = batch["split_messages"].view(
            -1, batch["split_messages"].shape[-1]
        )
        message_flat = message_flat.unique(dim=0)
        message_np = message_flat.cpu().numpy()
        message_strs = self.decode_messages(message_np)
        for message, mnp, mstr in zip(message_flat, message_np, message_strs):
            if len(mstr) == 0:  # No message
                continue
            # Check if goal has been seen
            if mstr not in self.goals_str:
                # Add this goal to the generator goals
                msg_index = len(self.goals_str)
                if msg_index >= len(self.goals):
                    warnings.warn(f"Overflowed {len(self.goals)} goals")
                    return  # Too many goals; stop
                self.goals[msg_index] = message
                # Update goals mask
                self.goals_mask[msg_index] = True
                # Update hashable values
                self.goals_str[mstr] = msg_index
                self.lang_templates.append(mstr.decode("utf-8"))
                # Quick tobytes for hashing later on
                self.lang_hashes[mnp.tobytes().rstrip(b"\x00")] = msg_index

    def predict_achievable(self, inputs=None, env_emb=None, goal_emb=None):
        if inputs is not None:
            raise NotImplementedError
        goal_emb_proj = self.grounding_head(goal_emb)
        goal_emb_proj = goal_emb_proj.unsqueeze(0).expand(env_emb.shape[0], -1, -1)
        goal_logits = torch.bmm(goal_emb_proj, env_emb.unsqueeze(-1)).squeeze(-1)
        return {
            "logits": goal_logits,
            "preds": (goal_logits > 0),
        }


class RNDNet(IntrinsicRewardNet):
    def __init__(
        self, observation_shape, num_actions, flags, device, message_novelty=False
    ):
        super(RNDNet, self).__init__(observation_shape, num_actions, flags, device)

        Y = 8  # number of output filters

        # IMPLEMENTED HERE: RND net using the default feature extractor
        self.rndtgt_embed = GlyphEmbedding(
            flags.glyph_type,
            flags.embedding_dim,
            device,
            flags.use_index_select,
        ).requires_grad_(False)
        self.rndprd_embed = GlyphEmbedding(
            flags.glyph_type,
            flags.embedding_dim,
            device,
            flags.use_index_select,
        )

        if self.intrinsic_input not in (
            "crop_only",
            "glyph_only",
            "full",
            "glyph_msg",
            "crop_msg",
        ):
            raise NotImplementedError("RND input type %s" % self.intrinsic_input)

        rnd_out_dim = 0
        if self.intrinsic_input in ("full", "crop_only", "crop_msg"):
            self.rndtgt_extract_crop_representation = copy.deepcopy(
                self.extract_crop_representation
            ).requires_grad_(False)
            self.rndprd_extract_crop_representation = copy.deepcopy(
                self.extract_crop_representation
            )

            rnd_out_dim += self.crop_dim ** 2 * Y  # crop dim

        if self.intrinsic_input in ("full", "glyph_only", "glyph_msg"):
            self.rndtgt_extract_representation = copy.deepcopy(
                self.extract_representation
            ).requires_grad_(False)
            self.rndprd_extract_representation = copy.deepcopy(
                self.extract_representation
            )
            rnd_out_dim += self.H * self.W * Y  # glyph dim

            if self.intrinsic_input == "full":
                self.rndtgt_embed_features = nn.Sequential(
                    nn.Linear(self.num_features, self.k_dim),
                    nn.ELU(),
                    nn.Linear(self.k_dim, self.k_dim),
                    nn.ELU(),
                ).requires_grad_(False)
                self.rndprd_embed_features = nn.Sequential(
                    nn.Linear(self.num_features, self.k_dim),
                    nn.ELU(),
                    nn.Linear(self.k_dim, self.k_dim),
                    nn.ELU(),
                )
                rnd_out_dim += self.k_dim  # feature dim

        if (
            self.intrinsic_input in ("full", "glyph_msg", "crop_msg")
            and self.msg_model != "none"
        ):
            # we only implement the lt_cnn msg model for RND for simplicity & speed
            self.rndtgt_msg_net = copy.deepcopy(self.msg_net)
            for param in self.rndtgt_msg_net.parameters():
                param.requires_grad = False

            self.rndprd_msg_net = copy.deepcopy(self.msg_net)
            rnd_out_dim += self.msg_hdim

        self.rndtgt_fc = nn.Sequential(  # matching RND paper making this smaller
            nn.Linear(rnd_out_dim, self.h_dim)
        ).requires_grad_(False)
        self.rndprd_fc = nn.Sequential(  # matching RND paper making this bigger
            nn.Linear(rnd_out_dim, self.h_dim),
            nn.ELU(),
            nn.Linear(self.h_dim, self.h_dim),
            nn.ELU(),
            nn.Linear(self.h_dim, self.h_dim),
        )

        modules_to_init = [
            self.rndtgt_embed,
            self.rndprd_embed,
            self.rndtgt_fc,
            self.rndprd_fc,
        ]

        SQRT_2 = math.sqrt(2)

        def init(p):
            if isinstance(p, nn.Conv2d) or isinstance(p, nn.Linear):
                # init method used in paper
                nn.init.orthogonal_(p.weight, SQRT_2)
                p.bias.data.zero_()
            if isinstance(p, nn.Embedding):
                nn.init.orthogonal_(p.weight, SQRT_2)

        # manually init all to orthogonal dist

        if self.intrinsic_input in ("full", "crop_only", "crop_msg"):
            modules_to_init.append(self.rndtgt_extract_crop_representation)
            modules_to_init.append(self.rndprd_extract_crop_representation)
        if self.intrinsic_input in ("full", "glyph_only", "glyph_msg"):
            modules_to_init.append(self.rndtgt_extract_representation)
            modules_to_init.append(self.rndprd_extract_representation)
        if self.intrinsic_input in ("full",):
            modules_to_init.append(self.rndtgt_embed_features)
            modules_to_init.append(self.rndprd_embed_features)
        if self.intrinsic_input in ("full", "glyph_msg", "crop_msg"):
            if self.msg_model != "none":
                modules_to_init.append(self.rndtgt_msg_net)
                modules_to_init.append(self.rndprd_msg_net)

        self.optimizer = optim.RMSprop(
            self.parameters(),
            lr=0.0002,
            momentum=0,
            eps=0.000001,
            alpha=0.99,
        )

        self.message_novelty = message_novelty
        if self.message_novelty:
            assert self.msg_model != "none"
            # Separate message predictor networks
            self.rndtgt_msg_net_m = copy.deepcopy(self.msg_net)
            for param in self.rndtgt_msg_net_m.parameters():
                param.requires_grad = False
            self.rndprd_msg_net_m = copy.deepcopy(self.msg_net)
            self.rndtgt_fc_m = nn.Sequential(  # matching RND paper making this smaller
                nn.Linear(self.msg_hdim, self.h_dim)
            ).requires_grad_(False)
            self.rndprd_fc_m = nn.Sequential(  # matching RND paper making this bigger
                nn.Linear(self.msg_hdim, self.h_dim),
                nn.ELU(),
                nn.Linear(self.h_dim, self.h_dim),
                nn.ELU(),
                nn.Linear(self.h_dim, self.h_dim),
            )

            self.message_optimizer = optim.RMSprop(
                list(self.rndprd_msg_net_m.parameters())
                + list(self.rndprd_fc_m.parameters()),
                lr=0.0002,
                momentum=0,
                eps=0.000001,
                alpha=0.99,
            )

            modules_to_init.append(self.rndtgt_msg_net_m)
            modules_to_init.append(self.rndprd_msg_net_m)
            modules_to_init.append(self.rndtgt_fc_m)
            modules_to_init.append(self.rndprd_fc_m)

        for m in modules_to_init:
            for p in m.modules():
                init(p)

    def forward(self, inputs, optimize=False):
        if optimize:
            ctx = nullcontext()
        else:
            ctx = torch.no_grad()

        with ctx:
            T, B, *_ = inputs["glyphs"].shape

            glyphs, features = self.prepare_input(inputs)

            # -- [B x 2] x,y coordinates
            coordinates = features[:, :2]

            features = features.view(T * B, -1).float()
            # -- [B x K]
            features_emb = self.embed_features(features)

            assert features_emb.shape[0] == T * B

            reps = [features_emb]

            # -- [B x H' x W']
            crop = self.glyph_embedding.GlyphTuple(
                *[self.crop(g, coordinates) for g in glyphs]
            )
            # -- [B x H' x W' x K]
            crop_emb = self.glyph_embedding(crop)

            if self.crop_model == "transformer":
                # -- [B x W' x H' x K]
                crop_rep = self.extract_crop_representation(crop_emb, mask=None)
            elif self.crop_model == "cnn":
                # -- [B x K x W' x H']
                crop_emb = crop_emb.transpose(1, 3)
                # -- [B x W' x H' x K]
                crop_rep = self.extract_crop_representation(crop_emb)
            # -- [B x K']

            crop_rep = crop_rep.view(T * B, -1)
            assert crop_rep.shape[0] == T * B

            reps.append(crop_rep)

            # -- [B x H x W x K]
            glyphs_emb = self.glyph_embedding(glyphs)
            # -- [B x K x W x H]
            glyphs_emb = glyphs_emb.transpose(1, 3)
            # -- [B x W x H x K]
            glyphs_rep = self.extract_representation(glyphs_emb)

            # -- [B x K']
            glyphs_rep = glyphs_rep.view(T * B, -1)

            assert glyphs_rep.shape[0] == T * B

            # -- [B x K'']
            reps.append(glyphs_rep)

            # MESSAGING MODEL
            if self.msg_model != "none":
                # [T x B x 256] -> [T * B x 256]
                messages = inputs["message"].long().view(T * B, -1)
                if "message_len" in inputs:
                    # Squeeze last dim.
                    messages_len = inputs["message_len"].long().view(T * B)
                else:
                    messages_len = None
                msg_rep = self.msg_net(messages, messages_len)
                reps.append(msg_rep)

            st = torch.cat(reps, dim=1)

            # -- [B x K]
            st = self.fc(st)

            # TARGET NETWORK
            with torch.no_grad():
                if self.intrinsic_input in ("crop_only", "crop_msg"):
                    tgt_reps = []
                    tgt_crop_emb = self.rndtgt_embed(crop).transpose(1, 3)
                    tgt_crop_rep = self.rndtgt_extract_crop_representation(tgt_crop_emb)
                    tgt_reps.append(tgt_crop_rep.view(T * B, -1))

                    if self.intrinsic_input == "crop_msg":
                        tgt_char_rep = self.rndtgt_msg_net(messages, messages_len)
                        tgt_reps.append(tgt_char_rep)

                    tgt_st = self.rndtgt_fc(torch.cat(tgt_reps, dim=1))

                elif self.intrinsic_input in ("glyph_only", "glyph_msg"):
                    tgt_reps = []
                    tgt_glyphs_emb = self.rndtgt_embed(glyphs).transpose(1, 3)
                    tgt_glyphs_rep = self.rndtgt_extract_representation(tgt_glyphs_emb)
                    tgt_reps.append(tgt_glyphs_rep.view(T * B, -1))

                    if self.intrinsic_input == "glyph_msg":
                        tgt_char_rep = self.rndtgt_msg_net(messages, messages_len)
                        tgt_reps.append(tgt_char_rep)

                    tgt_st = self.rndtgt_fc(torch.cat(tgt_reps, dim=1))

                else:  # full
                    tgt_reps = []
                    tgt_feats = self.rndtgt_embed_features(features)
                    tgt_reps.append(tgt_feats)

                    tgt_crop_emb = self.rndtgt_embed(crop).transpose(1, 3)
                    tgt_crop_rep = self.rndtgt_extract_crop_representation(tgt_crop_emb)
                    tgt_reps.append(tgt_crop_rep.view(T * B, -1))

                    tgt_glyphs_emb = self.rndtgt_embed(glyphs).transpose(1, 3)
                    tgt_glyphs_rep = self.rndtgt_extract_representation(tgt_glyphs_emb)
                    tgt_reps.append(tgt_glyphs_rep.view(T * B, -1))

                    if self.msg_model != "none":
                        tgt_char_rep = self.rndtgt_msg_net(messages, messages_len)
                        tgt_reps.append(tgt_char_rep)

                    tgt_st = self.rndtgt_fc(torch.cat(tgt_reps, dim=1))

            # PREDICTOR NETWORK
            if self.intrinsic_input in ("crop_only", "crop_msg"):
                prd_reps = []
                prd_crop_emb = self.rndprd_embed(crop).transpose(1, 3)
                prd_crop_rep = self.rndprd_extract_crop_representation(prd_crop_emb)
                prd_reps.append(prd_crop_rep.view(T * B, -1))

                if self.intrinsic_input == "crop_msg":
                    prd_char_rep = self.rndprd_msg_net(messages, messages_len)
                    prd_reps.append(prd_char_rep)

                prd_st = self.rndprd_fc(torch.cat(prd_reps, dim=1))

            elif self.intrinsic_input in ("glyph_only", "glyph_msg"):
                prd_reps = []
                prd_glyphs_emb = self.rndprd_embed(glyphs).transpose(1, 3)
                prd_glyphs_rep = self.rndprd_extract_representation(prd_glyphs_emb)
                prd_reps.append(prd_glyphs_rep.view(T * B, -1))

                if self.intrinsic_input == "glyph_msg":
                    prd_char_rep = self.rndprd_msg_net(messages, messages_len)
                    prd_reps.append(prd_char_rep)

                prd_st = self.rndprd_fc(torch.cat(prd_reps, dim=1))

            else:  # full
                prd_reps = []
                prd_feats = self.rndprd_embed_features(features)
                prd_reps.append(prd_feats)

                prd_crop_emb = self.rndprd_embed(crop).transpose(1, 3)
                prd_crop_rep = self.rndprd_extract_crop_representation(prd_crop_emb)
                prd_reps.append(prd_crop_rep.view(T * B, -1))

                prd_glyphs_emb = self.rndprd_embed(glyphs).transpose(1, 3)
                prd_glyphs_rep = self.rndprd_extract_representation(prd_glyphs_emb)
                prd_reps.append(prd_glyphs_rep.view(T * B, -1))

                if self.msg_model != "none":
                    prd_char_rep = self.rndprd_msg_net(messages, messages_len)
                    prd_reps.append(prd_char_rep)

                prd_st = self.rndprd_fc(torch.cat(prd_reps, dim=1))

            assert tgt_st.size() == prd_st.size()

            tgt_st = tgt_st.view(T, B, -1)
            prd_st = prd_st.view(T, B, -1)

            novelty = torch.norm(tgt_st - prd_st, dim=-1, p=2)
            loss = 0.01 * novelty.mean()  # RND forward cost

            if self.message_novelty:
                with torch.no_grad():
                    tgt_st_m = self.rndtgt_msg_net_m(messages, messages_len)
                    tgt_st_m = self.rndtgt_fc_m(tgt_st_m)
                prd_st_m = self.rndprd_msg_net_m(messages, messages_len)
                prd_st_m = self.rndprd_fc_m(prd_st_m)

                assert tgt_st_m.size() == prd_st_m.size()
                tgt_st_m = tgt_st_m.view(T, B, -1)
                prd_st_m = prd_st_m.view(T, B, -1)

                message_novelty = torch.norm(tgt_st_m - prd_st_m, dim=-1, p=2)
                message_loss = 0.01 * message_novelty.mean()
            else:
                message_novelty = torch.tensor(0.0)
                message_loss = torch.tensor(0.0)

        if optimize:
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            if self.message_novelty:
                self.message_optimizer.zero_grad()
                message_loss.backward()
                self.message_optimizer.step()

        novelty = novelty.detach()
        message_novelty = message_novelty.detach()

        return novelty, loss, message_novelty, message_loss
