from functools import partial
import math
import numpy as np

from itertools import repeat
import collections.abc
from collections import OrderedDict
import warnings

import torch
from torch import nn
import torch.nn.functional as F

from torch.nn.init import _calculate_fan_in_and_fan_out

from typing import Any, Callable, Optional, Tuple

def _no_grad_trunc_normal_(tensor, mean, std, a, b):
    # Cut & paste from PyTorch official master until it's in a few official releases - RW
    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    def norm_cdf(x):
        # Computes standard normal cumulative distribution function
        return (1. + math.erf(x / math.sqrt(2.))) / 2.

    if (mean < a - 2 * std) or (mean > b + 2 * std):
        warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
                      "The distribution of values may be incorrect.",
                      stacklevel=2)

    with torch.no_grad():
        # Values are generated by using a truncated uniform distribution and
        # then using the inverse CDF for the normal distribution.
        # Get upper and lower cdf values
        l = norm_cdf((a - mean) / std)
        u = norm_cdf((b - mean) / std)

        # Uniformly fill tensor with values from [l, u], then translate to
        # [2l-1, 2u-1].
        tensor.uniform_(2 * l - 1, 2 * u - 1)

        # Use inverse cdf transform for normal distribution to get truncated
        # standard normal
        tensor.erfinv_()

        # Transform to proper mean, std
        tensor.mul_(std * math.sqrt(2.))
        tensor.add_(mean)

        # Clamp to ensure it's in the proper range
        tensor.clamp_(min=a, max=b)
        return tensor

def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
    # type: (Tensor, float, float, float, float) -> Tensor
    r"""Fills the input Tensor with values drawn from a truncated
    normal distribution. The values are effectively drawn from the
    normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
    with values outside :math:`[a, b]` redrawn until they are within
    the bounds. The method used for generating the random values works
    best when :math:`a \leq \text{mean} \leq b`.
    Args:
        tensor: an n-dimensional `torch.Tensor`
        mean: the mean of the normal distribution
        std: the standard deviation of the normal distribution
        a: the minimum cutoff value
        b: the maximum cutoff value
    Examples:
        >>> w = torch.empty(3, 5)
        >>> nn.init.trunc_normal_(w)
    """
    return _no_grad_trunc_normal_(tensor, mean, std, a, b)


def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
    fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
    if mode == 'fan_in':
        denom = fan_in
    elif mode == 'fan_out':
        denom = fan_out
    elif mode == 'fan_avg':
        denom = (fan_in + fan_out) / 2

    variance = scale / denom

    if distribution == "truncated_normal":
        # constant is stddev of standard normal truncated to (-2, 2)
        trunc_normal_(tensor, std=math.sqrt(variance) / .87962566103423978)
    elif distribution == "normal":
        tensor.normal_(std=math.sqrt(variance))
    elif distribution == "uniform":
        bound = math.sqrt(3 * variance)
        tensor.uniform_(-bound, bound)
    else:
        raise ValueError(f"invalid distribution {distribution}")

def lecun_normal_(tensor):
    variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal')

class SelfAttention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., causal=False):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.causal = causal
        if self.causal:
            self.register_buffer("mask", torch.tril(torch.ones(256, 256).view(1, 1, 256, 256)))

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        if self.causal:
            attn = attn.masked_fill(self.mask[:,:,:N, :N] == 0, float('-inf'))
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class CrossAttention(nn.Module):

    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        self.q = nn.Linear(dim, dim, bias=qkv_bias)
        self.k = nn.Linear(dim, dim, bias=qkv_bias)
        self.v = nn.Linear(dim, dim, bias=qkv_bias)
        
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, src, trg, mask=None):
        B, S, C = src.shape
        _, T, _ = trg.shape
        q = self.q(trg).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2)
        k = self.k(src).view(B, S, self.num_heads, C // self.num_heads).transpose(1, 2)
        v = self.v(src).view(B, S, self.num_heads, C // self.num_heads).transpose(1, 2)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        if not mask is None:
            attn = attn.masked_fill(mask.unsqueeze(1) == 0, float('-inf'))
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, T, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class Mlp(nn.Module):
    """ MLP as used in Vision Transformer, MLP-Mixer and related networks
    """
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        if self.drop_prob == 0. or not self.training:
            return x
        keep_prob = 1 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
        random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
        random_tensor.floor_()  # binarize
        output = x.div(keep_prob) * random_tensor
        return output

class Block(nn.Module):

    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = SelfAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

class DecoderBlock(nn.Module):

    def __init__(self, dim, num_heads, mlp_ratio, qkv_bias=False, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.self_attn = SelfAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, 
                                       attn_drop=attn_drop, proj_drop=drop, causal=True)
        self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity()
        self.norm2 = norm_layer(dim)
        self.cross_attn = CrossAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, 
                                       attn_drop=attn_drop, proj_drop=drop)
        self.norm3 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, trg, src, mask=None):
        x = trg + self.drop_path(self.self_attn(self.norm1(trg)))
        x = x + self.drop_path(self.cross_attn(src, self.norm2(x), mask=mask))
        x = x + self.drop_path(self.mlp(self.norm3(x)))
        return x

def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False):
    """ ViT weight initialization
    * When called without n, head_bias, jax_impl args it will behave exactly the same
      as my original init for compatibility with prev hparam / downstream use cases (ie DeiT).
    * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl
    """
    if isinstance(module, nn.Linear):
        if name.startswith('head'):
            nn.init.zeros_(module.weight)
            nn.init.constant_(module.bias, head_bias)
        elif name.startswith('pre_logits'):
            lecun_normal_(module.weight)
            nn.init.zeros_(module.bias)
        else:
            if jax_impl:
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    if 'mlp' in name:
                        nn.init.normal_(module.bias, std=1e-6)
                    else:
                        nn.init.zeros_(module.bias)
            else:
                trunc_normal_(module.weight, std=.02)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
    elif jax_impl and isinstance(module, nn.Conv2d):
        # NOTE conv was left to pytorch default in my original init
        lecun_normal_(module.weight)
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
        nn.init.zeros_(module.bias)
        nn.init.ones_(module.weight)

class StateEncoderFree(nn.Module):

    def __init__(self, embed_dim):
        super().__init__()
        self.embed_dim = embed_dim
        self.onehot_fc = nn.Linear(7, embed_dim)

    def forward(self, grid_embedding, grid_onehot, inventory, goal):
        onehot = self.onehot_fc(grid_onehot).view(-1, 25, self.embed_dim)
        grid_embedding = grid_embedding.view(-1, 25, self.embed_dim)
        inventory = inventory.view(-1, 1, self.embed_dim)
        goal = goal.view(-1, 1, self.embed_dim)
        encoder_out = torch.cat((grid_embedding, onehot, inventory, goal), dim=1)
        return encoder_out

class AYHVisionTransformer(nn.Module):

    def __init__(self, num_actions, vocab, embed_weights, embed_dim=128, depth=6,
                 num_heads=2, mlp_ratio=2., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
                 decoder_layers=2, unsup_dim=64, unsup_proj='mat'):

        super().__init__()
        self.num_classes = num_actions
        self.glove_embed_dim = 300
        self.embed_dim = embed_dim
        self.vocab = vocab
        self.vocab_size = len(self.vocab)
        self.num_img_tokens = 25 + 25 + 1 + 1 # Grid Emb, Grid Onehot, inventory, goal

        # Setup the embeddings
        self.embedding = nn.Embedding(self.vocab_size, self.glove_embed_dim, self.vocab_size - 1)
        self.embedding.load_state_dict({'weight': torch.from_numpy(embed_weights)})
        self.embedding.weight.requires_grad = False
        self.state_encoder = StateEncoderFree(self.glove_embed_dim)

         # If we have a different dim n_embd, we need to down sample
        if self.embed_dim != self.glove_embed_dim:
            self.downsample = nn.Linear(self.glove_embed_dim, self.embed_dim)

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_img_tokens + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)
        
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
        self.blocks = nn.Sequential(*[
            Block(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
                attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU)
            for i in range(depth)])
        self.norm = nn.LayerNorm(embed_dim, eps=1e-6)

        # Classifier head(s)
        self.head = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()

        # Now initialize the decoder
        self.pos_emb_trg = nn.Parameter(torch.zeros(1, 256, embed_dim))
        self.decoder = nn.ModuleList([
            DecoderBlock(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
                attn_drop=attn_drop_rate, drop_path=drop_path_rate, norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU)
                for _ in range(decoder_layers)
                ])

        self.vocab_head = nn.Linear(embed_dim, self.vocab_size, bias=False)

        trunc_normal_(self.cls_token, std=.02)
        self.apply(_init_vit_weights)

        # For unsupervised
        self.unsup_head = nn.Linear(embed_dim, unsup_dim)
        if unsup_proj == 'mat':
            self._unsup_proj = nn.Parameter(torch.rand(unsup_dim, unsup_dim))
        else:
            self._unsup_proj = None
        self.unsup_mlp = nn.Sequential(nn.Linear(unsup_dim, 2*unsup_dim), nn.ReLU(), nn.Linear(2*unsup_dim, unsup_dim))

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'pos_embed', 'cls_token'}

    @property
    def unsup_proj(self):
        return self._unsup_proj

    def forward(self, obs, labels=None, is_target=False):
        if is_target:
            obs['grid_embedding'] = obs['next_grid_embedding']
            obs['grid_one_hot'] = obs['next_grid_onehot']
            obs['inventory'] = obs['next_inventory']
            obs['goal'] = obs['next_goal']
        
        grid_embedding, grid_onehot, inventory, goal = obs['grid_embedding'], obs['grid_onehot'], obs['inventory'], obs['goal']
        device = grid_embedding.device
        encoder_out = self.state_encoder(grid_embedding, grid_onehot, inventory, goal)
        assert encoder_out.shape[1] == self.num_img_tokens, "Num img tokens did not match encoder out"
        cls_token = self.cls_token.expand(encoder_out.shape[0], -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
        if self.embed_dim != self.glove_embed_dim:
            encoder_out = self.downsample(encoder_out) # Downsample the embeddings

        x = torch.cat((cls_token, encoder_out), dim=1)
        x = self.pos_drop(x + self.pos_embed)
        x = self.blocks(x)
        x = self.norm(x)
        action_pred = self.head(x[:, 0]) # Grab the first in the sequence
        
        if not labels is None:
            trg = self.embedding(obs['subgoal'][:, :-1]) # Remove the first in the sequence
            if self.embed_dim != self.glove_embed_dim:
                trg = self.downsample(trg) # Downsample the embeddings
            pos_emb_trg = self.pos_emb_trg[:, :trg.shape[1], :]
            trg = self.pos_drop(trg + pos_emb_trg)
            for layer in self.decoder:
                trg = layer(trg, x)
            lang_logits = self.vocab_head(trg)
            aux = F.cross_entropy(lang_logits.reshape(-1, lang_logits.size(-1)), labels.long().reshape(-1), ignore_index=self.vocab_size-1)
        else:
            aux = None
        
        # Now run the unsupervised prediction
        unsup_logits = self.unsup_head(x[:, 0]) # Grab the CLS Token
        if not is_target:
            unsup_logits = self.unsup_mlp(unsup_logits) # Forward through the projection MLP
        return action_pred, aux, unsup_logits # For unsupervised losses to be used later perhaps

class ViTEncoder(nn.Module):

    def __init__(self, num_tokens, embed_dim=128, depth=6, num_heads=2, mlp_ratio=2., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.):
        super().__init__()

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_tokens + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
        self.blocks = nn.Sequential(*[
            Block(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
                attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU)
            for i in range(depth)])
        self.norm = nn.LayerNorm(embed_dim, eps=1e-6)
        trunc_normal_(self.cls_token, std=.02)

    def forward(self, x):
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
        x = torch.cat((cls_token, x), dim=1)
        x = self.pos_drop(x + self.pos_embed)
        x = self.blocks(x)
        x = self.norm(x)
        return x

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'pos_embed', 'cls_token'}


class AYHHierarchical(nn.Module):
    # Assume that all the instrucitons are going t obe length 16. The largest in the dataset is length 20.

    def __init__(self, num_actions, vocab, embed_weights, embed_dim=128, depth=6,
                 num_heads=2, mlp_ratio=2., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
                 decoder_layers=2, unsup_dim=64, unsup_proj='mat'):
        super().__init__()
        self.num_classes = num_actions
        self.glove_embed_dim = 300
        self.embed_dim = embed_dim
        self.vocab = vocab
        self.vocab_size = len(self.vocab)
        self.num_img_tokens = 25 + 25 + 1 + 1 # Grid Emb, Grid Onehot, inventory, goal

        # Setup the embeddings
        self.embedding = nn.Embedding(self.vocab_size, self.glove_embed_dim, self.vocab_size - 1)
        self.embedding.load_state_dict({'weight': torch.from_numpy(embed_weights)})
        self.embedding.weight.requires_grad = False

        self.lang_state_encoder = StateEncoderFree(self.glove_embed_dim)
        self.action_state_encoder = StateEncoderFree(self.glove_embed_dim)

         # If we have a different dim n_embd, we need to down sample
        if self.embed_dim != self.glove_embed_dim:
            self.lang_downsample = nn.Linear(self.glove_embed_dim, self.embed_dim)
            self.action_downsample = nn.Linear(self.glove_embed_dim, self.embed_dim)

        self.instr_length = 16
        self.lang_encoder = ViTEncoder(self.num_img_tokens, embed_dim=self.embed_dim, depth=depth, num_heads=num_heads, 
                                       mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate)

        self.action_encoder = ViTEncoder(self.num_img_tokens + self.instr_length, embed_dim=self.embed_dim, depth=depth, num_heads=num_heads, 
                                       mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate)

        # Now initialize the decoder
        self.pos_emb_trg = nn.Parameter(torch.zeros(1, 256, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)
        self.decoder = nn.ModuleList([
            DecoderBlock(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
                attn_drop=attn_drop_rate, drop_path=drop_path_rate, norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU)
                for _ in range(decoder_layers)
                ])
        
        # Classifier head(s)
        self.head = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
        self.vocab_head = nn.Linear(embed_dim, self.vocab_size, bias=False)

        self.apply(_init_vit_weights)

    @property
    def unsup_proj(self):
        return self._unsup_proj

    def forward(self, obs, labels=None, is_target=False):
        assert not is_target, "HRL model does not support unsupervised."
        assert not labels is None, "HRL model requires the subgoal labels"

        # Run the grid encoders
        grid_embedding, grid_onehot, inventory, goal = obs['grid_embedding'], obs['grid_onehot'], obs['inventory'], obs['goal']
        device = grid_embedding.device
        lang_state = self.lang_state_encoder(grid_embedding, grid_onehot, inventory, goal)
        action_state = self.action_state_encoder(grid_embedding, grid_onehot, inventory, goal)
        assert lang_state.shape[1] == self.num_img_tokens, "Num img tokens did not match encoder out"

        # Run the language encoder.
        if self.embed_dim != self.glove_embed_dim:
            lang_state = self.lang_downsample(lang_state) # Downsample the embeddings
        lang_latent = self.lang_encoder(lang_state)

        # Run the Decoder
        trg = self.embedding(obs['subgoal'][:, :-1]) # Remove the first in the sequence
        if self.embed_dim != self.glove_embed_dim:
            trg = self.lang_downsample(trg) # Downsample the embeddings
        pos_emb_trg = self.pos_emb_trg[:, :trg.shape[1], :]
        trg = self.pos_drop(trg + pos_emb_trg)
        for layer in self.decoder:
            trg = layer(trg, lang_latent)
        lang_logits = self.vocab_head(trg)
        aux = F.cross_entropy(lang_logits.reshape(-1, lang_logits.size(-1)), labels.long().reshape(-1), ignore_index=self.vocab_size-1)

        # Now get the actions
        # instruction embedding
        if obs['subgoal'].shape[1] > self.instr_length:
            instr = obs['subgoal'][:, :self.instr_length]
        elif obs['subgoal'].shape[1] < self.instr_length:
            instr = torch.ones(obs['subgoal'].shape[0], self.instr_length, device=obs['subgoal'].device, dtype=obs['subgoal'].dtype)
            # Multiply by the padding index
            instr *= (self.vocab_size - 1)
            instr[:, :obs['subgoal'].shape[1]] = obs['subgoal']
        else:
            instr = obs['subgoal']
        
        instr = self.embedding(instr)
        action_state_instr = torch.cat((action_state, instr), dim=1)
        if self.embed_dim != self.glove_embed_dim:
            action_state_instr = self.action_downsample(action_state_instr)
        action_latent = self.action_encoder(action_state_instr)
        action_pred = self.head(action_latent[:, 0]) # Grab the first in the sequence, corresponds to the CLS token.

        return action_pred, aux, None # For unsupervised losses to be used later perhaps

    def predict(self, obs, deterministic=True, history=None):
        assert obs['grid_embedding'].shape[0] == 1, "Only a batch size of 1 is currently supported"
       
       # Run the grid encoders
        grid_embedding, grid_onehot, inventory, goal = obs['grid_embedding'], obs['grid_onehot'], obs['inventory'], obs['goal']
        device = grid_embedding.device
        lang_state = self.lang_state_encoder(grid_embedding, grid_onehot, inventory, goal)
        action_state = self.action_state_encoder(grid_embedding, grid_onehot, inventory, goal)
        assert lang_state.shape[1] == self.num_img_tokens, "Num img tokens did not match encoder out"

        # Run the language encoder.
        if self.embed_dim != self.glove_embed_dim:
            lang_state = self.lang_downsample(lang_state) # Downsample the embeddings
        lang_latent = self.lang_encoder(lang_state)
        
        # Get the subgoal
        instr = torch.ones(obs['grid_embedding'].shape[0], self.instr_length, device=obs['grid_embedding'].device, dtype=torch.long)
        instr *= (self.vocab_size - 1)

        for i in range(1, self.instr_length):
            trg = self.embedding(instr) # Remove the first in the sequence
            if self.embed_dim != self.glove_embed_dim:
                trg = self.lang_downsample(trg) # Downsample the embeddings
            pos_emb_trg = self.pos_emb_trg[:, :trg.shape[1], :]
            trg = self.pos_drop(trg + pos_emb_trg)
            for layer in self.decoder:
                trg = layer(trg, lang_latent)
            lang_logits = self.vocab_head(trg)
            lang_logits = lang_logits[:, -1]
            _, ix = torch.topk(lang_logits, k=1, dim=-1)
            instr[:, i] = ix
            if ix.item() == self.vocab.word2idx['<end>']:
                break

        instr = self.embedding(instr)
        action_state_instr = torch.cat((action_state, instr), dim=1)
        if self.embed_dim != self.glove_embed_dim:
            action_state_instr = self.action_downsample(action_state_instr)
        action_latent = self.action_encoder(action_state_instr)
        action_logits = self.head(action_latent[:, 0]) # Grab the first in the sequence, corresponds to the CLS token.
        action = torch.argmax(action_logits).item()
        return action # return the predicted action.


class AYHVisionTransformerInverseModel(nn.Module):

    def __init__(self, num_actions, vocab, embed_weights,
                 embed_dim=128, depth=6, num_heads=2, mlp_ratio=2., qkv_bias=True, drop_rate=0., attn_drop_rate=0.,
                 drop_path_rate=0., decoder_layers=1):
        super().__init__()
        self.num_classes = num_actions
        self.glove_embed_dim = 300
        self.embed_dim = embed_dim
        self.vocab = vocab
        self.vocab_size = len(self.vocab)
        self.num_img_tokens = 25 + 25 + 1 + 1 # Grid Emb, Grid Onehot, inventory, goal

        # Setup the embeddings
        self.embedding = nn.Embedding(self.vocab_size, self.glove_embed_dim, self.vocab_size - 1)
        self.embedding.load_state_dict({'weight': torch.from_numpy(embed_weights)})
        self.embedding.weight.requires_grad = False
        self.state_encoder = StateEncoderFree(self.glove_embed_dim)

         # If we have a different dim n_embd, we need to down sample
        if self.embed_dim != self.glove_embed_dim:
            self.downsample = nn.Linear(self.glove_embed_dim, self.embed_dim)

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, 2*self.num_img_tokens + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)
        
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
        self.blocks = nn.Sequential(*[
            Block(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
                attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU)
            for i in range(depth)])
        self.norm = nn.LayerNorm(embed_dim, eps=1e-6)

        # Classifier head(s)
        self.head = nn.Linear(self.embed_dim, self.num_classes)

        trunc_normal_(self.cls_token, std=.02)
        self.apply(_init_vit_weights)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'pos_embed', 'cls_token'}

    def forward(self, obs, next_obs):
        grid_embedding, grid_onehot, inventory, goal = obs['grid_embedding'], obs['grid_onehot'], obs['inventory'], obs['goal']
        device = obs['grid_embedding'].device
        encoder_out_obs = self.state_encoder(obs['grid_embedding'], obs['grid_onehot'], obs['inventory'], obs['goal'])
        encoder_out_next_obs = self.state_encoder(next_obs['grid_embedding'], next_obs['grid_onehot'], next_obs['inventory'], next_obs['goal'])
        encoder_out = torch.cat((encoder_out_obs, encoder_out_next_obs), dim=1)
        assert encoder_out.shape[1] == 2*self.num_img_tokens, "Num img tokens did not match encoder out"
        cls_token = self.cls_token.expand(encoder_out.shape[0], -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
        if self.embed_dim != self.glove_embed_dim:
            encoder_out = self.downsample(encoder_out) # Downsample the embeddings
        
        x = torch.cat((cls_token, encoder_out), dim=1)
        x = self.pos_drop(x + self.pos_embed)
        x = self.blocks(x)
        x = self.norm(x)
        return self.head(x[:, 0]) # Grab the first in the sequence
