import torch
import torch.nn as nn
from torch.nn import functional as F
import random

import math
import numpy as np

''' 
The general idea of this is the following:
1. Have a causual decision transformer for outputting actions at each step.
2. Feed the top layers to a decision transformer
'''

class Attention(nn.Module):
    """
    A vanilla multi-head masked self-attention layer with a projection at the end.
    It is possible to use torch.nn.MultiheadAttention here but I am including an
    explicit implementation here to show that there is nothing too scary here.
    """

    def __init__(self, n_head, n_embd, attn_pdrop, resid_pdrop, block_size, causal=False):
        super().__init__()
        assert n_embd % n_head == 0
        # key, query, value projections for all heads
        self.key = nn.Linear(n_embd, n_embd)
        self.query = nn.Linear(n_embd, n_embd)
        self.value = nn.Linear(n_embd, n_embd)
        # regularization
        self.attn_drop = nn.Dropout(attn_pdrop)
        self.resid_drop = nn.Dropout(resid_pdrop)
        # output projection
        self.proj = nn.Linear(n_embd, n_embd)
        # causal mask to ensure that attention is only applied to the left in the input sequence
        self.causal = causal
        if self.causal:
            self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size))
                                     .view(1, 1, block_size, block_size))
        else:
            self.mask = None
        self.n_head = n_head

    def forward(self, q, k, v, mask=None):
        B, T, C = q.size()
        _, S, C = k.size()

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q = self.query(q).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        k = self.key(k).view(B, S, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, S, hs)
        v = self.value(v).view(B, S, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, S, hs)

        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, S) -> (B, nh, T, S)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        if not self.mask is None:
            att = att.masked_fill(self.mask[:,:,:T,:S] == 0, float('-inf'))
        if not mask is None: # Additional mask if provided in user input
            att = att.masked_fill(mask.unsqueeze(1) == 0, float('-inf')) # TODO: complete
        att = F.softmax(att, dim=-1)
        att = self.attn_drop(att)
        y = att @ v # (B, nh, T, S) x (B, nh, S, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

        # output projection
        y = self.resid_drop(self.proj(y))
        return y

class EncoderBlock(nn.Module):
    """ an unassuming Transformer block """

    def __init__(self, n_head, n_embd, attn_pdrop, resid_pdrop, block_size, mlp_ratio):
        super().__init__()
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)
        self.attn = Attention(n_head, n_embd, attn_pdrop, resid_pdrop, block_size, causal=True)
        self.mlp = nn.Sequential(
            nn.Linear(n_embd, mlp_ratio * n_embd),
            nn.GELU(),
            nn.Linear(mlp_ratio * n_embd, n_embd),
            nn.Dropout(resid_pdrop),
        )

    def forward(self, x, mask=None):
        ln_x = self.ln1(x)
        x = x + self.attn(ln_x, ln_x, ln_x, mask=mask)
        x = x + self.mlp(self.ln2(x))
        return x

class DecoderBlock(nn.Module):

    def __init__(self, n_head, n_embd, attn_pdrop, resid_pdrop, block_size, mlp_ratio):
        super().__init__()
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)
        self.ln3 = nn.LayerNorm(n_embd)
        self.self_attn = Attention(n_head, n_embd, attn_pdrop, resid_pdrop, block_size, causal=True)
        self.cross_attn = Attention(n_head, n_embd, attn_pdrop, resid_pdrop, block_size, causal=False)
        self.mlp = nn.Sequential(
            nn.Linear(n_embd, mlp_ratio * n_embd),
            nn.GELU(),
            nn.Linear(mlp_ratio * n_embd, n_embd),
            nn.Dropout(resid_pdrop),
        )

    def forward(self, trg, src, cross_attn_mask=None):
        ln_trg = self.ln1(trg)
        x = trg + self.self_attn(ln_trg, ln_trg, ln_trg, mask=None) # No mask, but it is causal attention
        x = x + self.cross_attn(self.ln2(x), src, src, mask=cross_attn_mask)
        x = x + self.mlp(self.ln3(x))
        return x

class DT(nn.Module):

    def __init__(self, num_actions, n_head=2, n_embd=128, attn_pdrop=0.1, resid_pdrop=0.1, 
                       block_size=384, mlp_ratio=2, n_layer=4,
                       vocab_size=40, embd_pdrop=0.1):
        super().__init__()
        # Encode each trajectory into
        self.cnn = nn.Sequential(
                                nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
                                nn.BatchNorm2d(64),
                                nn.ReLU(),
                                nn.Conv2d(64, n_embd, kernel_size=3, stride=1, padding=1),
                                nn.BatchNorm2d(n_embd),
                                nn.ReLU(),
                                )
        self.n_embd = n_embd

        self.tok_emb = nn.Embedding(vocab_size, n_embd)
        self.pos_emb = nn.Parameter(torch.zeros(1, block_size, n_embd))
        self.drop = nn.Dropout(embd_pdrop)
        # transformer
        self.blocks = nn.Sequential(*[EncoderBlock(n_head, n_embd, attn_pdrop, resid_pdrop, block_size, mlp_ratio) 
                                        for _ in range(n_layer)])
                                        
        # self.action_head = nn.Sequential(nn.Linear(n_embd, 128), nn.ReLU(), nn.Linear(128, num_actions))
        self.action_head = nn.Linear(n_embd, num_actions)

        self.block_size = block_size
        self.apply(self._init_weights)

    def get_block_size(self):
        return self.block_size

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def forward(self, obs, labels=None, skip_size=-1):
        if skip_size > 0:
            assert not 'mission' in obs, "Using a skip and the mission was provided!"
        
        B, S, c, h, w = obs['image'].shape
        img = obs['image'].view(B*S, c, h, w) # Reshape it
        img = self.cnn(img.float())
        img = torch.max(img.view(B*S, self.n_embd, h*w), dim=-1)[0]
        img = img.view(B, S, self.n_embd) # Now we have the image sequence
        
        if skip_size > 0:
            position_embeddings = self.pos_emb[:, skip_size:skip_size+img.shape[1], :]
            x = self.drop(img + position_embeddings)
            x = self.blocks(x)
            # cannot return actions if we are skipping the mission
            actions = None
        else:
            mission = self.tok_emb(obs['mission'])
            # Now concatenate everything together
            x = torch.cat((mission, img), dim=1)
            # Need to be careful with blocksize
            position_embeddings = self.pos_emb[:, :x.shape[1], :]
            x = self.drop(x + position_embeddings)
            x = self.blocks(x)
            actions = self.action_head(x[:, obs['mission'].shape[1]:, :]) # This gives the length
        
        return actions, x # This gives the latents.

    def predict(self, obs, deterministic=True, history=None):
        assert not history is None
        assert history['image'].shape[0] == 1, "Only a batch size of 1 is currently supported"
        # Run the model on the observation
        # We only want to send in the mission once. Can use the current timestep one.
        combined_obs = {'image': history['image'], 'mission': obs['mission']}
        action_logits, _ = self(combined_obs, skip_size=-1)
        # We only care about the last timestep action logits
        action_logits = action_logits[0, -1, :]
        action = torch.argmax(action_logits).item()
        return action # return the predicted action.

class Seq2SeqDT(nn.Module):

    def __init__(self, num_actions, n_head=2, n_embd=128, attn_pdrop=0.1, resid_pdrop=0.1, 
                       block_size=384, mlp_ratio=2, n_enc_layer=4, n_dec_layer=1, embd_pdrop=0.1, vocab_size=40, 
                       use_mask=True, unsup_dim=128, unsup_proj='mat'):
        super().__init__()
        self.dt = DT(num_actions, n_head=n_head, n_embd=n_embd, attn_pdrop=attn_pdrop, resid_pdrop=resid_pdrop,
                    block_size=block_size, mlp_ratio=mlp_ratio, n_layer=n_enc_layer, 
                    vocab_size=vocab_size, embd_pdrop=embd_pdrop)
        
        self.pos_emb_trg = nn.Parameter(torch.zeros(1, block_size, n_embd))
        self.drop = nn.Dropout(embd_pdrop)
        self.decoder = nn.ModuleList([DecoderBlock(n_head, n_embd, attn_pdrop, resid_pdrop, block_size, mlp_ratio) 
                                        for _ in range(n_dec_layer)])
        self.use_mask = use_mask
        self.vocab_head = nn.Linear(n_embd, vocab_size, bias=False)

        self.decoder.apply(self.dt._init_weights) # apply the proper initialization for the transformer.

        # For unsupervised
        self.unsup_head = nn.Linear(n_embd, unsup_dim)
        if unsup_proj == 'mat':
            self._unsup_proj = nn.Parameter(torch.rand(unsup_dim, unsup_dim))
        else:
            self._unsup_proj = None
        self.unsup_mlp = nn.Sequential(nn.Linear(unsup_dim, 2*unsup_dim), nn.ReLU(), nn.Linear(2*unsup_dim, unsup_dim))

    def forward(self, obs, labels=None, is_target=False):
        if is_target:
            obs = {k: v for k, v in obs.items() if k != 'image'}
            obs['image'] = obs['next_image'] # Change the image to the next image.
        actions, latents = self.dt(obs)

        # Now pass the latents into the decoder
        if not labels is None:
            tgt = self.dt.tok_emb(labels[:, :-1]) # Ignore the last dim
            position_embeddings = self.pos_emb_trg[:, :tgt.shape[1], :]
            tgt = self.drop(tgt + position_embeddings)
            # Modify the mask to support the size of the img sequence
            B, T, _ = tgt.shape
            if self.use_mask:
                mission_mask = torch.ones(B, T, obs['mission'].shape[1], device=obs['mission'].device)
                mask = torch.cat((mission_mask, obs['mask'][:, :-1]), dim=2)
            else:
                mask = None
            # Everyone can attend to all the mission positions. Padding oopsies
            for layer in self.decoder:
                tgt = layer(tgt, latents, cross_attn_mask=mask)
            lang_logits = self.vocab_head(tgt)
            lang_aux = F.cross_entropy(lang_logits.reshape(-1, lang_logits.size(-1)), labels[:, 1:].reshape(-1), ignore_index=0)
        else:
            lang_aux = None
        # Now run the unsupervised prediction
        unsup_logits = self.unsup_head(latents[:, obs['mission'].shape[1]:, :]) # Must remove the mission latents
        if not is_target:
            unsup_logits = self.unsup_mlp(unsup_logits) # Forward through the projection MLP
        return actions, lang_aux, unsup_logits

    @property
    def unsup_proj(self):
        return self._unsup_proj

    def predict(self, obs, deterministic=True, history=None, **kwargs):
        return self.dt.predict(obs, deterministic=deterministic, history=history)

    # def generate_instr(self, obs, history=None):
    #     device = torch.device("cpu")
    #     if isinstance(obs, dict):
    #         obs = {k: torch.from_numpy(v).to(device).unsqueeze(0) if isinstance(v, np.ndarray) else v for k,v in obs.items()}
    #         if not history is None:
    #             history = {k: torch.from_numpy(v).to(device).unsqueeze(0) if isinstance(v, np.ndarray) else v for k,v in history.items()}
    #     else:
    #         obs = torch.from_numpy(obs).to(device).unsqueeze(0)
    #         if not history is None:
    #             history = torch.from_numpy(history).to(device).unsqueeze(0)

    #     combined_obs = {'image': history['image'], 'mission': obs['mission']}

    #     max_plan_tokens = 100
    #     from lang_hrl.envs.babyai_wrappers import WORD_TO_IDX
    #     plan = WORD_TO_IDX['END_MISSION'] * torch.ones(1, 1, dtype=torch.long, device=obs['mission'].device)
    #     _, language_latents = self.dt(combined_obs, skip_size=-1)

    #     for i in range(1, max_plan_tokens):
    #         tgt = self.dt.tok_emb(plan)
    #         position_embeddings = self.pos_emb_trg[:, :tgt.shape[1], :]
    #         tgt = self.drop(tgt + position_embeddings)
    #         for layer in self.decoder:
    #             tgt = layer(tgt, language_latents, cross_attn_mask=None) # can see everything during inference
    #         lang_logits = self.vocab_head(tgt)
    #         lang_logits = lang_logits[:, -1, :] # Get the logits at the last timestep
    #         _, ix = torch.topk(lang_logits, k=1, dim=-1)
    #         plan = torch.cat((plan, ix), dim=1)
    #         # Check to see if the index was something we should break on.
    #         if ix.item() == WORD_TO_IDX['END']:
    #             break # we can break

    #     plan = plan.detach().cpu().numpy()[0]
    #     print(plan)
    #     IDX_TO_WORD = {v:k for k,v in WORD_TO_IDX.items()}
    #     print([IDX_TO_WORD[idx] for idx in plan])


class GoalPretrainDT(nn.Module):

    def __init__(self, num_actions, n_head=2, n_embd=128, attn_pdrop=0.1, resid_pdrop=0.1, 
                       block_size=384, mlp_ratio=2, n_enc_layer=4, n_dec_layer=1, embd_pdrop=0.1, vocab_size=40):
        super().__init__()
        self.dt = DT(num_actions, n_head=n_head, n_embd=n_embd, attn_pdrop=attn_pdrop, resid_pdrop=resid_pdrop,
                    block_size=block_size, mlp_ratio=mlp_ratio, n_layer=n_enc_layer, 
                    vocab_size=vocab_size, embd_pdrop=embd_pdrop)
        
        self.drop = nn.Dropout(embd_pdrop)
        self.decoder = nn.ModuleList([DecoderBlock(n_head, n_embd, attn_pdrop, resid_pdrop, block_size, mlp_ratio) 
                                        for _ in range(n_dec_layer)])
        self.vocab_head = nn.Linear(n_embd, vocab_size, bias=False)
        self.start_token = nn.Parameter(torch.zeros(1, 1, n_embd))

    def forward(self, obs, labels=None):
        if labels is None:
            # if we are not given the labels, then we are in BC mode, so just pass through the mission
            actions, _ = self.dt(obs) # This is all we need
            return actions, None
        else:
            # We are in language training mode. Now, use obs['mission'] as the label.
            start_token = self.start_token.expand(obs['mission'].shape[0], -1, -1) # Expand to B, T, C
            tgt = self.dt.tok_emb(obs['mission'][:, :-1])
            tgt = torch.cat((start_token, tgt), dim=1)
            skip_size = tgt.shape[1]
            position_embeddings = self.dt.pos_emb[:, :skip_size, :] # Get the first position embeddings that are used for text.
            tgt = self.drop(tgt + position_embeddings)
            # Now grab the latents and remove the 'mission' from the obs
            obs_without_mission = {k:v for k,v in obs.items() if k != 'mission'}
            _, latents = self.dt(obs_without_mission, skip_size=skip_size)
            for layer in self.decoder:
                tgt = layer(tgt, latents, cross_attn_mask=None)
            lang_logits = self.vocab_head(tgt)
            aux = F.cross_entropy(lang_logits.reshape(-1, lang_logits.size(-1)), obs['mission'].reshape(-1).long(), ignore_index=0)
            return None, aux

    def predict(self, obs, deterministic=True, history=None):
        return self.dt.predict(obs, deterministic=deterministic, history=history)

class SeqGPTMask(nn.Module):

    def __init__(self, num_actions, n_head=2, n_embd=128, attn_pdrop=0.1, resid_pdrop=0.1, 
                       block_size=768, mlp_ratio=2, n_layer=4, embd_pdrop=0.1, vocab_size=40):
        super().__init__()
        self.cnn = nn.Sequential(
                                nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
                                nn.BatchNorm2d(64),
                                nn.ReLU(),
                                nn.Conv2d(64, n_embd, kernel_size=3, stride=1, padding=1),
                                nn.BatchNorm2d(n_embd),
                                nn.ReLU(),
                                )
        self.n_embd = n_embd

        self.tok_emb = nn.Embedding(vocab_size, n_embd)
        self.pos_emb = nn.Parameter(torch.zeros(1, block_size, n_embd))
        self.drop = nn.Dropout(embd_pdrop)
        # transformer
        self.blocks = nn.ModuleList([EncoderBlock(n_head, n_embd, attn_pdrop, resid_pdrop, block_size, mlp_ratio) 
                                        for _ in range(n_layer)])

        # self.action_head = nn.Sequential(nn.Linear(n_embd, 128), nn.ReLU(), nn.Linear(128, num_actions))
        self.action_head = nn.Linear(n_embd, num_actions)
        self.block_size = block_size
        self.vocab_head = nn.Linear(n_embd, vocab_size, bias=False)
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def forward(self, obs, labels=None):
        B, S, c, h, w = obs['image'].shape
        img = obs['image'].view(B*S, c, h, w) # Reshape it
        img = self.cnn(img.float())
        img = torch.max(img.view(B*S, self.n_embd, h*w), dim=-1)[0]
        img = img.view(B, S, self.n_embd) # Now we have the image sequence
        mission = self.tok_emb(obs['mission'])
        # Now concatenate everything together
        x = torch.cat((mission, img), dim=1)
        # Need to be careful with blocksize
        if labels is None:
            position_embeddings = self.pos_emb[:, :x.shape[1], :]
            x = self.drop(x + position_embeddings)
            for layer in self.blocks:
                x = layer(x)
            actions = self.action_head(x[:, obs['mission'].shape[1]:, :]) # This gives the length
            return actions, None
        
        # Now we encode the subgoals and concatenate somehow.
        # obs['mask'] gives a mask that is shape T, S
        # Once we concatenate, our sequence is size M+S+T and we want a mask that is (B, M+T+S, M+T+S)
        tgt = self.tok_emb(labels[:, :-1]) # Ignore the last dim
        T, M = tgt.shape[1], mission.shape[1]
        x = torch.cat((x, tgt), dim=1)
        mask = torch.ones(B, T+M+S, T+M+S, device=obs['image'].device, dtype=torch.bool)
        mask[:, M+S:M+S+T, M:M+S] = obs['mask'][:, :-1, :] # Set the correct part of the mask
        position_embeddings = self.pos_emb[:, :x.shape[1], :]
        x = self.drop(x + position_embeddings)
        for layer in self.blocks:
            x = layer(x, mask=mask)
        actions = self.action_head(x[:, M:M+S, :])
        lang_logits = self.vocab_head(x[:, M+S:, :])
        aux = F.cross_entropy(lang_logits.reshape(-1, lang_logits.size(-1)), labels[:, 1:].reshape(-1), ignore_index=0)
        return actions, aux # This gives the latents.

    def predict(self, obs, deterministic=True, history=None):
        assert not history is None
        assert history['image'].shape[0] == 1, "Only a batch size of 1 is currently supported"
        # Run the model on the observation
        # We only want to send in the mission once. Can use the current timestep one.
        combined_obs = {'image': history['image'], 'mission': obs['mission']}
        action_logits, _ = self(combined_obs)
        # We only care about the last timestep action logits
        action_logits = action_logits[0, -1, :]
        action = torch.argmax(action_logits).item()
        return action # return the predicted action.

class SeqGPTRand(nn.Module):

    def __init__(self, num_actions, n_head=2, n_embd=128, attn_pdrop=0.1, resid_pdrop=0.1, 
                       block_size=768, mlp_ratio=2, n_layer=4, embd_pdrop=0.1, vocab_size=40):
        super().__init__()
        self.cnn = nn.Sequential(
                                nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
                                nn.BatchNorm2d(64),
                                nn.ReLU(),
                                nn.Conv2d(64, n_embd, kernel_size=3, stride=1, padding=1),
                                nn.BatchNorm2d(n_embd),
                                nn.ReLU(),
                                )
        self.n_embd = n_embd

        self.tok_emb = nn.Embedding(vocab_size, n_embd)
        self.pos_emb = nn.Parameter(torch.zeros(1, block_size, n_embd))
        self.drop = nn.Dropout(embd_pdrop)
        # transformer
        self.blocks = nn.ModuleList([EncoderBlock(n_head, n_embd, attn_pdrop, resid_pdrop, block_size, mlp_ratio) 
                                        for _ in range(n_layer)])

        # self.action_head = nn.Sequential(nn.Linear(n_embd, 128), nn.ReLU(), nn.Linear(128, num_actions))
        self.action_head = nn.Linear(n_embd, num_actions)
        self.block_size = block_size
        self.vocab_head = nn.Linear(n_embd, vocab_size, bias=False)
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def forward(self, obs, labels=None):
        B, S, c, h, w = obs['image'].shape
        img = obs['image'].view(B*S, c, h, w) # Reshape it
        img = self.cnn(img.float())
        img = torch.max(img.view(B*S, self.n_embd, h*w), dim=-1)[0]
        img = img.view(B, S, self.n_embd) # Now we have the image sequence
        mission = self.tok_emb(obs['mission'])
        # Now concatenate everything together
        x = torch.cat((mission, img), dim=1)
        # Need to be careful with blocksize
        if labels is None:
            position_embeddings = self.pos_emb[:, :x.shape[1], :]
            x = self.drop(x + position_embeddings)
            for layer in self.blocks:
                x = layer(x)
            actions = self.action_head(x[:, obs['mission'].shape[1]:, :]) # This gives the length
            return actions, None

        # Now we encode the subgoals and concatenate somehow.
        # obs['mask'] gives a mask that is shape T, S
        # Once we concatenate, our sequence is size M+S+T and we want a mask that is (B, M+T+S, M+T+S)
        # Randomly take a crop of the input that is at least the size of the input.
        M = mission.shape[1]
        crop_size = M + int(S* random.random())
        x = torch.cat((x[:, :crop_size], self.tok_emb(labels[:, :-1])), dim=1)
        position_embeddings = self.pos_emb[:, :x.shape[1], :]
        x = self.drop(x + position_embeddings)
        for layer in self.blocks:
            x = layer(x)
        lang_logits = self.vocab_head(x[:, crop_size:])
        aux = F.cross_entropy(lang_logits.reshape(-1, lang_logits.size(-1)), labels[:, 1:].reshape(-1), ignore_index=0)
        return None, aux # This gives the latents.

    def predict(self, obs, deterministic=True, history=None):
        assert not history is None
        assert history['image'].shape[0] == 1, "Only a batch size of 1 is currently supported"
        # Run the model on the observation
        # We only want to send in the mission once. Can use the current timestep one.
        combined_obs = {'image': history['image'], 'mission': obs['mission']}
        action_logits, _ = self(combined_obs)
        # We only care about the last timestep action logits
        action_logits = action_logits[0, -1, :]
        action = torch.argmax(action_logits).item()
        return action # return the predicted action.

class InstructionDT(nn.Module):

    def __init__(self, num_actions, n_head=2, n_embd=128, attn_pdrop=0.1, resid_pdrop=0.1, 
                       block_size=512, mlp_ratio=2, n_layer=4,
                       vocab_size=40, embd_pdrop=0.1):
        super().__init__()
        # Encode each trajectory into
        self.cnn = nn.Sequential(
                                nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
                                nn.BatchNorm2d(64),
                                nn.ReLU(),
                                nn.Conv2d(64, n_embd, kernel_size=3, stride=1, padding=1),
                                nn.BatchNorm2d(n_embd),
                                nn.ReLU(),
                                )
        self.n_embd = n_embd

        self.tok_emb = nn.Embedding(vocab_size, n_embd)
        self.pos_emb = nn.Parameter(torch.zeros(1, block_size, n_embd))
        self.drop = nn.Dropout(embd_pdrop)
        # transformer
        self.blocks = nn.Sequential(*[EncoderBlock(n_head, n_embd, attn_pdrop, resid_pdrop, block_size, mlp_ratio) 
                                        for _ in range(n_layer)])
                                        
        # self.action_head = nn.Sequential(nn.Linear(n_embd, 128), nn.ReLU(), nn.Linear(128, num_actions))
        self.action_head = nn.Linear(n_embd, num_actions)

        self.block_size = block_size
        self.apply(self._init_weights)

    def get_block_size(self):
        return self.block_size

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def forward(self, obs, labels):
        B, S, c, h, w = obs['image'].shape
        img = obs['image'].view(B*S, c, h, w) # Reshape it
        img = self.cnn(img.float())
        img = torch.max(img.view(B*S, self.n_embd, h*w), dim=-1)[0]
        img = img.view(B, S, self.n_embd) # Now we have the image sequence
        
        mission = self.tok_emb(obs['mission'])
        subgoals = self.tok_emb(labels)
        
        # Now concatenate everything together
        x = torch.cat((mission, subgoals, img), dim=1)
        # Need to be careful with blocksize
        position_embeddings = self.pos_emb[:, :x.shape[1], :]
        x = self.drop(x + position_embeddings)
        x = self.blocks(x)
        actions = self.action_head(x[:, obs['mission'].shape[1] + labels.shape[1]:, :]) # This gives the length
        
        return actions, x # This gives the latents.

    def predict(self, obs, labels, deterministic=True, history=None):
        assert not history is None
        assert history['image'].shape[0] == 1, "Only a batch size of 1 is currently supported"
        # Run the model on the observation
        # We only want to send in the mission once. Can use the current timestep one.
        combined_obs = {'image': history['image'], 'mission': obs['mission']}
        action_logits, _ = self(combined_obs, labels)
        # We only care about the last timestep action logits
        action_logits = action_logits[0, -1, :]
        action = torch.argmax(action_logits).item()
        return action # return the predicted action.

class HierarchicalSeq2SeqDT(nn.Module):

    def __init__(self, num_actions, n_head=2, n_embd=128, attn_pdrop=0.1, resid_pdrop=0.1, 
                       block_size=384, mlp_ratio=2, n_enc_layer=4, n_dec_layer=1, embd_pdrop=0.1, vocab_size=40, unsup_dim=128, unsup_proj='mat'):
        super().__init__()
        self.lang_enc = DT(num_actions, n_head=n_head, n_embd=n_embd, attn_pdrop=attn_pdrop, resid_pdrop=resid_pdrop,
                    block_size=block_size, mlp_ratio=mlp_ratio, n_layer=n_enc_layer, 
                    vocab_size=vocab_size, embd_pdrop=embd_pdrop)

        self.inst_dt = InstructionDT(num_actions, n_head=n_head, n_embd=n_embd, attn_pdrop=attn_pdrop, resid_pdrop=resid_pdrop,
                    block_size=block_size+128, mlp_ratio=mlp_ratio, n_layer=n_enc_layer, 
                    vocab_size=vocab_size, embd_pdrop=embd_pdrop)
        
        self.pos_emb_trg = nn.Parameter(torch.zeros(1, block_size, n_embd))
        self.drop = nn.Dropout(embd_pdrop)
        self.decoder = nn.ModuleList([DecoderBlock(n_head, n_embd, attn_pdrop, resid_pdrop, block_size, mlp_ratio) 
                                        for _ in range(n_dec_layer)])
        self.vocab_head = nn.Linear(n_embd, vocab_size, bias=False)
        self.decoder.apply(self.lang_enc._init_weights) # apply the proper initialization for the transformer.

    @property
    def unsup_proj(self):
        # Did not implement unsupervised losses for the hierarchical models.
        return None

    def forward(self, obs, labels=None, is_target=False):
        assert not is_target, "Lang losses not implemented"
        assert not labels is None, "Training HRL model requires language labels"
        # Both models can be trained simultaneously with the same optimizer.
        _, language_latents = self.lang_enc(obs)
        tgt = self.lang_enc.tok_emb(labels[:, :-1]) # Ignore the last dim
        position_embeddings = self.pos_emb_trg[:, :tgt.shape[1], :]
        tgt = self.drop(tgt + position_embeddings)
        # Modify the mask to support the size of the img sequence
        B, T, _ = tgt.shape
        mission_mask = torch.ones(B, T, obs['mission'].shape[1], device=obs['mission'].device)
        mask = torch.cat((mission_mask, obs['mask'][:, :-1]), dim=2)
        for layer in self.decoder:
            tgt = layer(tgt, language_latents, cross_attn_mask=mask)
        lang_logits = self.vocab_head(tgt)
        lang_aux = F.cross_entropy(lang_logits.reshape(-1, lang_logits.size(-1)), labels[:, 1:].reshape(-1), ignore_index=0)

        max_lang_plan_steps = 128
        if labels.shape[1] > max_lang_plan_steps:
            padded_labels = labels[:, :max_lang_plan_steps]
        else:
            padded_labels = torch.zeros(labels.shape[0], max_lang_plan_steps, dtype=labels.dtype, device=labels.device)
            padded_labels[:, :labels.shape[1]] = labels
        # Run inference for 96 time steps
        actions, _ = self.inst_dt(obs, padded_labels)
        return actions, lang_aux, None

    def predict(self, obs, deterministic=True, history=None):
        assert not history is None
        assert history['image'].shape[0] == 1, "Only a batch size of 1 is currently supported"
        # Run the model on the observation
        # We only want to send in the mission once. Can use the current timestep one.
        combined_obs = {'image': history['image'], 'mission': obs['mission']}
        # Get the language plan
        max_plan_tokens = 128

        # Create the initial plan
        from lang_hrl.envs.babyai_wrappers import WORD_TO_IDX
        plan = WORD_TO_IDX['END_MISSION'] * torch.ones(1, 1, dtype=torch.long, device=obs['mission'].device)
        _, language_latents = self.lang_enc(combined_obs)
        for i in range(1, max_plan_tokens):
            tgt = self.lang_enc.tok_emb(plan)
            position_embeddings = self.pos_emb_trg[:, :tgt.shape[1], :]
            tgt = self.drop(tgt + position_embeddings)
            for layer in self.decoder:
                tgt = layer(tgt, language_latents, cross_attn_mask=None) # can see everything during inference
            lang_logits = self.vocab_head(tgt)
            lang_logits = lang_logits[:, -1, :] # Get the logits at the last timestep
            _, ix = torch.topk(lang_logits, k=1, dim=-1)
            plan = torch.cat((plan, ix), dim=1)
            # Check to see if the index was something we should break on.
            if ix.item() == WORD_TO_IDX['END']:
                break # we can break

        # Now pad out the plan to the correct length.
        if plan.shape[1] > max_plan_tokens:
            padded_plan = plan[:, :max_plan_tokens]
        else:
            padded_plan = torch.zeros(plan.shape[0], max_plan_tokens, dtype=obs['mission'].dtype, device=obs['mission'].device)
            padded_plan[:, :plan.shape[1]] = plan

        # Now that we have the tokenized plan send it in to the transformer network
        action_logits, _ = self.inst_dt(combined_obs, padded_plan)
        # We only care about the last timestep action logits
        action_logits = action_logits[0, -1, :]
        action = torch.argmax(action_logits).item()
        return action # return the predicted action.

class HierarchicalSGSeq2SeqDT(nn.Module):

    def __init__(self, num_actions, n_head=2, n_embd=128, attn_pdrop=0.1, resid_pdrop=0.1, 
                       block_size=384, mlp_ratio=2, n_enc_layer=4, n_dec_layer=1, embd_pdrop=0.1, vocab_size=40, unsup_dim=128, unsup_proj='mat'):
        super().__init__()
        self.lang_enc = DT(num_actions, n_head=n_head, n_embd=n_embd, attn_pdrop=attn_pdrop, resid_pdrop=resid_pdrop,
                    block_size=block_size, mlp_ratio=mlp_ratio, n_layer=n_enc_layer, 
                    vocab_size=vocab_size, embd_pdrop=embd_pdrop)

        self.inst_dt = InstructionDT(num_actions, n_head=n_head, n_embd=n_embd, attn_pdrop=attn_pdrop, resid_pdrop=resid_pdrop,
                    block_size=block_size+128, mlp_ratio=mlp_ratio, n_layer=n_enc_layer, 
                    vocab_size=vocab_size, embd_pdrop=embd_pdrop)
        
        self.pos_emb_trg = nn.Parameter(torch.zeros(1, block_size, n_embd))
        self.drop = nn.Dropout(embd_pdrop)
        self.decoder = nn.ModuleList([DecoderBlock(n_head, n_embd, attn_pdrop, resid_pdrop, block_size, mlp_ratio) 
                                        for _ in range(n_dec_layer)])
        self.vocab_head = nn.Linear(n_embd, vocab_size, bias=False)
        self.decoder.apply(self.lang_enc._init_weights) # apply the proper initialization for the transformer.

    @property
    def unsup_proj(self):
        # Did not implement unsupervised losses for the hierarchical models.
        return None

    def forward(self, obs, labels=None, is_target=False):
        assert not is_target, "Lang losses not implemented"
        assert not labels is None, "Training HRL model requires language labels"
        # Both models can be trained simultaneously with the same optimizer.
        _, language_latents = self.lang_enc(obs)
        tgt = self.lang_enc.tok_emb(labels[:, :-1]) # Ignore the last dim
        position_embeddings = self.pos_emb_trg[:, :tgt.shape[1], :]
        tgt = self.drop(tgt + position_embeddings)
        
        for layer in self.decoder:
            tgt = layer(tgt, language_latents, cross_attn_mask=None)
        lang_logits = self.vocab_head(tgt)
        lang_aux = F.cross_entropy(lang_logits.reshape(-1, lang_logits.size(-1)), labels[:, 1:].reshape(-1), ignore_index=0)

        max_lang_plan_steps = 10
        if labels.shape[1] > max_lang_plan_steps:
            padded_labels = labels[:, :max_lang_plan_steps]
        elif labels.shape[1] < max_lang_plan_steps:
            padded_labels = torch.zeros(labels.shape[0], max_lang_plan_steps, dtype=labels.dtype, device=labels.device)
            padded_labels[:, :labels.shape[1]] = labels
        else:
            padded_labels = labels
        # Run inference
        actions, _ = self.inst_dt(obs, padded_labels)
        return actions, lang_aux, None

    def predict(self, obs, deterministic=True, history=None):
        assert not history is None
        assert history['image'].shape[0] == 1, "Only a batch size of 1 is currently supported"
        # Run the model on the observation
        # We only want to send in the mission once. Can use the current timestep one.
        combined_obs = {'image': history['image'], 'mission': obs['mission']}
        # Get the language plan
        max_plan_tokens = 10

        # Create the initial plan
        from lang_hrl.envs.babyai_wrappers import WORD_TO_IDX
        plan = WORD_TO_IDX['END_MISSION'] * torch.ones(1, 1, dtype=torch.long, device=obs['mission'].device)
        _, language_latents = self.lang_enc(combined_obs)
        for i in range(1, max_plan_tokens):
            tgt = self.lang_enc.tok_emb(plan)
            position_embeddings = self.pos_emb_trg[:, :tgt.shape[1], :]
            tgt = self.drop(tgt + position_embeddings)
            for layer in self.decoder:
                tgt = layer(tgt, language_latents, cross_attn_mask=None) # can see everything during inference
            lang_logits = self.vocab_head(tgt)
            lang_logits = lang_logits[:, -1, :] # Get the logits at the last timestep. BUG ths is incorrect.
            _, ix = torch.topk(lang_logits, k=1, dim=-1)
            plan = torch.cat((plan, ix), dim=1)            # Check to see if the index was something we should break on.
            if ix.item() == WORD_TO_IDX['END_SUBGOAL']:
                break # we can break
        
        # Now pad out the plan to the correct length.
        if plan.shape[1] > max_plan_tokens:
            padded_plan = plan[:, :max_plan_tokens]
        else:
            padded_plan = torch.zeros(plan.shape[0], max_plan_tokens, dtype=obs['mission'].dtype, device=obs['mission'].device)
            padded_plan[:, :plan.shape[1]] = plan

        # Now that we have the tokenized plan send it in to the transformer network
        action_logits, _ = self.inst_dt(combined_obs, padded_plan)

        # We only care about the last timestep action logits
        action_logits = action_logits[0, -1, :]
        action = torch.argmax(action_logits).item()
        return action # return the predicted action.
