import torch
import torch.nn as nn
import torch.nn.functional as F

from nerv.training import BaseModel
from nerv.models import deconv_out_shape, conv_norm_act, deconv_norm_act

from transformers import AutoImageProcessor

from ...base_slots.models.utils import assert_shape, SoftPositionEmbed, torch_cat

from slotformer.base_slots.models import StoSAVi
from .robotics_transformer import RoboticsTransformer
from slotformer.e2e.models import E2E
from .robotics_transformer import build_pos_enc

from einops import pack, unpack, repeat, reduce, rearrange
from einops.layers.torch import Rearrange, Reduce

import os

import copy

def normalize_act(action, min_val=-0.1, max_val=0.1):
    action = 2 * (action - min_val) / (max_val - min_val) - 1.
    return action

def denormalize_act(action, min_val=-0.1, max_val=0.1):
    action = (action + 1) / 2 * (max_val - min_val) + min_val
    return action


# 커스텀 초기화 함수
def weights_init_normal(m):
    if isinstance(m, nn.Linear):
        nn.init.normal_(m.weight, mean=0.0, std=0.05)
        if m.bias is not None:
            nn.init.normal_(m.bias, mean=0.0, std=0.05)


# Residual 블록 정의
class ResidualBlock(nn.Module):
    def __init__(self, input_size, hidden_size1, hidden_size2):
        super(ResidualBlock, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size1)
        self.fc2 = nn.Linear(hidden_size1, hidden_size1)
        self.fc3 = nn.Linear(hidden_size1, hidden_size2)
        self.relu = nn.ReLU()

        # 커널과 바이어스를 정규 분포로 초기화
        self.apply(weights_init_normal)
    
    def forward(self, x):
        identity = x  # 입력을 그대로 더하기 위해 저장
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        out = self.relu(out)
        out = self.fc3(out)
        out += identity  # Residual 연결
        out = self.relu(out)
        return out

# Residual MLP 모델 정의
class ResidualMLP(nn.Module):
    def __init__(self, input_size, hidden_size1, hidden_size2, num_blocks, output_size):
        super(ResidualMLP, self).__init__()
        self.blocks = nn.ModuleList([ResidualBlock(input_size, hidden_size1, hidden_size2) for _ in range(num_blocks)])
        self.final_fc = nn.Linear(hidden_size2, output_size)

        # 커널과 바이어스를 정규 분포로 초기화
        self.apply(weights_init_normal)
    
    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        x = self.final_fc(x)
        return x

# input images + inst -> output action
# language conditional robotics transformer
class RoboticsLSlotFormer(RoboticsTransformer):
    """language conditional robotics transformer with end-to-end StoSAVi and SlotFormer as encoder for behavioral cloning
       Decoding action with MLP only using t and t+1 slots
    """

    def __init__(
            self,
            resolution,
            clip_len,
            slot_dict=dict(
                num_slots=7,
                slot_size=128,
                slot_mlp_size=256,
                num_iterations=2,
                kernel_mlp=True,
            ),
            enc_dict=dict(
                enc_channels=(3, 64, 64, 64, 64),
                enc_ks=5,
                enc_out_channels=128,
                enc_norm='',
            ),
            dec_dict=dict(
                dec_channels=(128, 64, 64, 64, 64),
                dec_resolution=(8, 8),
                dec_ks=5,
                dec_norm='',
                dec_ckp_path='',
            ),
            rollout_dict=dict(
                num_slots=7,
                slot_size=128,
                history_len=6,
                t_pe='sin',
                slots_pe='',
                d_model=128,
                num_layers=4,
                num_heads=8,
                ffn_dim=512,
                norm_first=True,
            ),
            pred_dict=dict(
                pred_type='transformer',
                pred_rnn=True,
                pred_norm_first=True,
                pred_num_layers=2,
                pred_num_heads=4,
                pred_ffn_dim=512,
                pred_sg_every=None,
            ),
            act_dec_dict=dict(
                history_len=6,
                t_pe='sin',
                d_model=128,
                num_layers=4,
                num_heads=8,
                ffn_dim=512,
                inst_size=768,
                act_size=2,
                norm_first=True,
            ),
            loss_dict=dict(
                use_post_recon_loss=True,
                kld_method='var-0.01',
                rollout_len=6,
                use_img_recon_loss=False,
            ),

    
            eps=1e-6
    ):
        self.resolution = resolution
        self.clip_len = clip_len
        self.eps = eps

        self.slot_dict = slot_dict
        self.enc_dict = enc_dict
        self.dec_dict = dec_dict
        self.pred_dict = pred_dict
        self.rollout_dict = rollout_dict
        self.loss_dict = loss_dict

        self.act_dec_dict = act_dec_dict
        try:
            self.rank = int(os.environ['RANK'])
        except:
            self.rank = None

        super().__init__(
            resolution, enc_dict, act_dec_dict, loss_dict, eps
        )

        self._build_encoder()
        self._build_dt()

        self.loss_decay_factor = 1.  # temporal loss weighting
        self.rollout_len = self.loss_dict['rollout_len']

    def _build_dt(self):
        

        num_slots = self.rollout_dict['num_slots']
        prev_len = self.act_dec_dict['prev_len']
        next_len = self.act_dec_dict['next_len']
        T = prev_len + next_len
        # print("INFO: T: ", T, "prev_len: ", prev_len, "next_len: ", next_len)
        d_slot = self.rollout_dict['slot_size']
        d_model = self.act_dec_dict['d_model']
        d_action = self.act_dec_dict['act_size']
        num_layers = self.act_dec_dict['num_layers']
        num_heads = self.act_dec_dict['num_heads']
        ffn_dim = self.act_dec_dict['ffn_dim']

        self.act_enc_t_pe = build_pos_enc(self.act_dec_dict['t_pe'], T, d_slot)
        self.t_pe = build_pos_enc(self.act_dec_dict['t_pe'], 2, d_model)
        self.time_pe = build_pos_enc(self.act_dec_dict['t_pe'], T, d_model) 

        
        if self.act_dec_dict['dec_type'] == 'res_mlp':
            input_size = T * num_slots * d_slot
            if self.act_dec_dict['inst']:
                input_size += d_model
            hidden_size1 = 256
            hidden_size2 = input_size
            num_blocks = 2
            output_size = d_action

            self.action_decoder = ResidualMLP(input_size, hidden_size1, hidden_size2, num_blocks, output_size)
            self.q_proj = nn.Linear(self.act_dec_dict['inst_size'], d_model) # inst
            

        elif self.act_dec_dict['dec_type'] == 'transformer':
            encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=num_heads, dim_feedforward=ffn_dim, batch_first=True, dropout=0.1)
            self.action_decoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
            self.inst_proj = nn.Linear(self.act_dec_dict['inst_size'], d_model)
            self.slot_proj = nn.Linear(d_slot, d_model)
            self.fc = nn.Linear(T*d_model, d_action)
            self.action_token = nn.Parameter(nn.init.normal_(torch.empty(1, d_model)))
            
        elif self.act_dec_dict['dec_type'] == 'transformer_slot':
            encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=num_heads, dim_feedforward=ffn_dim, batch_first=True, dropout=0.1)
            self.action_decoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
            self.inst_proj = nn.Linear(self.act_dec_dict['inst_size'], d_model)
            self.slot_proj = nn.Linear(d_slot, d_model)
            self.final_decoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
            self.action_token = nn.Parameter(nn.init.normal_(torch.empty(1, d_model)))
            self.final_action_token = nn.Parameter(nn.init.normal_(torch.empty(1, d_model)))
            self.fc = nn.Linear(d_model, d_action)

        elif self.act_dec_dict['dec_type'] == 'transformer_slot_pool':
            encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=num_heads, dim_feedforward=ffn_dim, batch_first=True, dropout=0.1)
            self.action_decoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
            self.inst_proj = nn.Linear(self.act_dec_dict['inst_size'], d_model)
            self.slot_proj = nn.Linear(d_slot, d_model)
            self.final_decoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
            self.action_token = nn.Parameter(nn.init.normal_(torch.empty(1, d_model)))
            self.final_action_token = nn.Parameter(nn.init.normal_(torch.empty(1, d_model)))
            self.fc = nn.Linear(num_slots * d_model, d_action)
            if self.act_dec_dict['pool_inst']:
                self.fc = nn.Linear(num_slots * d_model + d_model, d_action)
        
        elif self.act_dec_dict['dec_type'] == 'transformer_slot_fc':
            encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=num_heads, dim_feedforward=ffn_dim, batch_first=True, dropout=0.1)
            self.action_decoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
            self.inst_proj = nn.Linear(self.act_dec_dict['inst_size'], d_model)
            self.slot_proj = nn.Linear(d_slot, d_model)
            self.action_token = nn.Parameter(nn.init.normal_(torch.empty(1, d_model)))
            self.fc = nn.Linear(num_slots*d_model, d_action)

        elif self.act_dec_dict['dec_type'] == 'transformer_slot_fc_inst':
            encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=num_heads, dim_feedforward=ffn_dim, batch_first=True, dropout=0.1)
            self.action_decoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
            self.inst_proj = nn.Linear(self.act_dec_dict['inst_size'], d_model)
            self.slot_proj = nn.Linear(d_slot, d_model)
            self.action_token = nn.Parameter(nn.init.normal_(torch.empty(1, d_model)))
            self.fc = nn.Linear((num_slots+1)*d_model, d_action)
        
        elif self.act_dec_dict['dec_type'] in ['transformer_slot_fuse_inst', 'transformer_slot_fuse_inst_rev']:
            encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=num_heads, dim_feedforward=ffn_dim, batch_first=True, dropout=0.1)
            self.action_decoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
            self.inst_proj = nn.Linear(self.act_dec_dict['inst_size'], d_model)
            self.slot_proj = nn.Linear(d_slot, d_model)
            self.action_token = nn.Parameter(nn.init.normal_(torch.empty(1, d_model)))
            if self.act_dec_dict['dec_type'] == 'transformer_slot_fuse_inst':
                self.fc = nn.Linear(num_slots*d_model, d_action)
            elif self.act_dec_dict['dec_type'] == 'transformer_slot_fuse_inst_rev':
                self.fc = nn.Linear(d_model, d_action)
            
            decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=num_heads, dim_feedforward=ffn_dim, batch_first=True, dropout=0.1)
            self.fuse_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        
        elif self.act_dec_dict['dec_type'] in ['transformer_time_fuse_inst', 'transformer_time_fuse_inst_rev']:
            encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=num_heads, dim_feedforward=ffn_dim, batch_first=True, dropout=0.1)
            self.action_decoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
            self.inst_proj = nn.Linear(self.act_dec_dict['inst_size'], d_model)
            self.slot_proj = nn.Linear(d_slot, d_model)
            if self.act_dec_dict['dec_type'] == 'transformer_time_fuse_inst':
                self.fc = nn.Linear(T*d_model, d_action)
            elif self.act_dec_dict['dec_type'] == 'transformer_time_fuse_inst_rev':
                self.fc = nn.Linear(d_model, d_action)

            self.action_token = nn.Parameter(nn.init.normal_(torch.empty(1, d_model)))
            
            decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=num_heads, dim_feedforward=ffn_dim, batch_first=True, dropout=0.1)
            self.fuse_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)


    def _build_encoder(self):
        # e2e StoSAVi and Slotformer as encoder
        tmp = self.loss_dict['use_post_recon_loss']
        self.loss_dict['use_post_recon_loss'] = True # to avoid assertion
        self.encoder = E2E(
                resolution=self.resolution,
                clip_len=self.clip_len,
                slot_dict=self.slot_dict,
                enc_dict=self.enc_dict,
                dec_dict=self.dec_dict,
                pred_dict=self.pred_dict,
                rollout_dict=self.rollout_dict,
                loss_dict=self.loss_dict,
            )

        self.encoder.eval()
        for p in self.encoder.parameters():
            p.requires_grad = False
        print("INFO: Encoder is frozen!! ")
        self.loss_dict['use_post_recon_loss'] = tmp 

        

    def _get_encoder_out(self, data_dict):
        encoder_out = self.encoder(data_dict)
        slots = encoder_out['pred_slots']
        # slots = self.slot_proj(slots)
        return slots



    def forward(self, data_dict):
        """Forward pass."""
        frames = data_dict['img']  # [B, T, 3, H, W]
        inst = data_dict['instruction']  # [B, D_inst]
        B = frames.shape[0]
        pred_len = self.rollout_len
        num_slots = self.rollout_dict['num_slots']
        prev_len = self.act_dec_dict['prev_len']
        next_len = self.act_dec_dict['next_len']
        d_action = self.act_dec_dict['act_size']
        d_model = self.act_dec_dict['d_model']
        T = prev_len + next_len
        # print("INFO: T: ", T, "prev_len: ", prev_len, "next_len: ", next_len)

        slots = self._get_encoder_out(data_dict) # [B, pred_len, num_slots, d_slot]
        next_slot = slots[:, :next_len, :, :] # [B, next_len, num_slots, d_slot]
        prev_slots = self.encoder.extract_slot(data_dict) # [B, history + rollout, num_slots, d_slot]
        prev_slot = prev_slots['post_slots'][:, self.rollout_dict['history_len']-prev_len:self.rollout_dict['history_len'], :, :] # [B, prev_len, num_slots, d_slot]
        assert prev_slot.shape[1] == prev_len

        if next_len > 0:
            slots = torch.cat((prev_slot, next_slot), dim=1) # [B, T, num_slots, d_slot]
        else:
            slots = prev_slot

        assert slots.shape[1] == T

        even_T_indices = torch.arange(0, T, 2)
        for t in even_T_indices:
            slots[:, t, [1, -1], :] =  slots[:, t, [-1, 1], :]

        slots_out = copy.copy(slots)

        # temporal_pe repeat for each slot, shouldn't be None
        # [1, T, d_slot] --> [B, T, num_slots, d_slot]
        if self.act_dec_dict['dec_type'] in ['mlp', 'res_mlp']:
            act_enc_pe = self.act_enc_t_pe.unsqueeze(2).repeat(B, 1, num_slots, 1)
            slots = slots + act_enc_pe

        assert slots.shape[1] == prev_len + next_len

        if self.act_dec_dict['dec_type'] in ['transformer', 'transformer_time_fuse_inst', 'transformer_time_fuse_inst_rev']:
            is_mask = self.act_dec_dict['mask']
            is_inst = self.act_dec_dict['inst']
            slot_pe = self.t_pe[:, 0, :].unsqueeze(1).repeat(B, num_slots, 1) # [1, d_model] -> [B, num_slots, d_model] 
            lang_pe = self.t_pe[:, 1, :].unsqueeze(1).repeat(B, 1, 1) # [1, d_model] -> [B, 1, d_model]
    
            pe = torch.cat([slot_pe, lang_pe], dim=1) # [B, num_slots + 1, d_model]

            if is_inst:
                pe = torch.cat([pe, torch.zeros_like(lang_pe)], dim=1) # [B, num_slots + 2, d_model]
            else:
                pe = torch.zeros([B, num_slots + 1, d_model]).cuda() # [B, num_slots + 1, d_model]

            action_token = self.action_token.unsqueeze(0).repeat(B, 1, 1) # [1, d_model] -> [B, 1, d_model]


            slots = slots.permute([1, 0, 2, 3]) # [T, B, num_slots, d_slot]
            slots = self.slot_proj(slots) # [T, B, num_slots, d_model]
            inst = self.inst_proj(inst).unsqueeze(1) # [B, d_inst] -> [B, 1, d_model]
            token = action_token
            outs = []

            # print("INFO: slots.shape: ", slots.shape)
            
            
            for t, slot in enumerate(slots):
                if is_inst:
                    slot = torch.cat([slot, inst], dim=1) # [B, num_slots, d_slot] -> [B, num_slots + 1, d_slot]
                decoder_in = torch.cat([slot, token], dim=1) 
                decoder_in = decoder_in + pe
                # self.action_decoder[t] = self.action_decoder[t].cuda()
                # set mask so that the instruction can attend all slots and instruction, and slots only can attend themselves
                if is_mask:
                    seq_len = decoder_in.shape[1]
                    B = decoder_in.shape[0]
                    mask = ~torch.eye(seq_len).bool().cuda()
                    mask[-1, :] = False
                    out = self.action_decoder(decoder_in, mask=mask)[:, -1] # [B, d_slot]
                else:
                    out = self.action_decoder(decoder_in)[:, -1] # [B, d_slot]
                
                outs.append(out)
            
            if self.act_dec_dict['dec_type'] == 'transformer':
                outs = torch.cat(outs, dim=1) # [B, T * d_model]
            elif self.act_dec_dict['dec_type'] in ['transformer_time_fuse_inst', 'transformer_time_fuse_inst_rev']:
                outs = torch.stack(outs, dim=1) # [B, T, d_model]

            # print("INFO: outs.shape: ", outs.shape)
            if self.act_dec_dict['dec_type'] == 'transformer':
                actions = self.fc(outs) # [B, T * d_slot] -> [B, D_a]
            elif self.act_dec_dict['dec_type'] == 'transformer_time_fuse_inst': 
                time_pe = self.time_pe.repeat(B, 1, 1) # [1, T, d_model] -> [B, T, d_model]
                
                outs = outs + time_pe
                tmp = self.fuse_decoder(outs, inst) # [B, T, d_model]
                tmp = tmp.flatten(1, 2) # [B, T * d_model]
                actions = self.fc(tmp)
            elif self.act_dec_dict['dec_type'] == 'transformer_time_fuse_inst_rev':
                time_pe = self.time_pe.repeat(B, 1, 1) # [1, T, d_model] -> [B, T, d_model]
                outs = outs + time_pe
                tmp = self.fuse_decoder(inst, outs).squeeze(1) # [B, d_model]
                # tmp = tmp.flatten(1, 2) # [B, d_model]
                
                actions = self.fc(tmp)
                # print("INFO: tmp.shape: ", tmp.shape, 'actions.shape: ', actions.shape) 

        
        elif self.act_dec_dict['dec_type'] in ['transformer_slot', 'transformer_slot_pool', 'transformer_slot_fc', 'transformer_slot_fuse_inst', 'transformer_slot_fuse_inst_rev']:
            is_mask = self.act_dec_dict['mask']
            is_inst = self.act_dec_dict['inst']
            try:
                is_pool_inst = self.act_dec_dict['pool_inst']
            except:
                is_pool_inst = False
            
            
            time_pe = self.time_pe.repeat(B, 1, 1) # [1, T, d_model] -> [B, T, d_model]
            slots = slots.permute(2, 0, 1, 3) # [num_slots, B, T, d_slot]
            slots = self.slot_proj(slots) # [num_slots, B, T, d_model]
            inst = self.inst_proj(inst).unsqueeze(1) # [B, d_inst] -> [B, 1, d_model]
            action_token = self.action_token.unsqueeze(0).repeat(B, 1, 1) # [1, d_model] -> [B, 1, d_model]
            try:
                final_action_token = self.final_action_token.unsqueeze(0).repeat(B, 1, 1) # [1, d_model] -> [B, 1, d_model]
            except:
                pass
            slot_pe = self.t_pe[:, 0, :].unsqueeze(1).repeat(B, T, 1) # [1, d_model] -> [B, T, d_model] 
            lang_pe = self.t_pe[:, 1, :].unsqueeze(1).repeat(B, 1, 1) # [1, d_model] -> [B, 1, d_model]
            
            outs = []

            if not is_inst:
                time_pe = torch.cat([time_pe, torch.zeros(B, 1, d_model).cuda()], dim=1) # [B, T, d_model] -> [B, T + 1, d_model]])
                # slot_pe = torch.cat([slot_pe, torch.zeros(B, 1, d_model).cuda()], dim=1)
                slot_pe = torch.zeros([B, T + 1, d_model]).cuda() # [B, T + 1, d_model]
            else:
                time_pe = torch.cat([time_pe, torch.zeros(B, 2, d_model).cuda()], dim=1) # [B, T, d_model] -> [B, T + 2, d_model]])    
                slot_pe = torch.cat([slot_pe, lang_pe, torch.zeros(B, 1, d_model).cuda()], dim=1) # [B, T + 1, d_model] -> [B, T + 2, d_model]
            # slot_pe separates slots, inst, and token. time_pe indicates timestep, separating slots, inst
    
            for s, slot in enumerate(slots):
                if is_inst:
                    slot = torch.cat([slot, inst], dim=1) # [B, T, d_model] -> [B, T + 1, d_model]
                    # print("INFO: slot.shape: ", slot.shape)
                    
                decoder_in = torch.cat([slot, action_token], dim=1) # [B, T + 1, d_model]
                decoder_in = decoder_in + slot_pe + time_pe

                if is_mask:
                    seq_len = decoder_in.shape[1]
                    B = decoder_in.shape[0]
                    mask = ~torch.eye(seq_len).bool().cuda()
                    mask[-1, :] = False
                    out = self.action_decoder(decoder_in, mask=mask)[:, -1] # [B, d_slot]
                else:
                    out = self.action_decoder(decoder_in)[:, -1]
                outs.append(out)
            
            outs = torch.stack(outs, dim=1) # [B, num_slots, d_model]
            assert outs.shape[1] == num_slots
            if self.act_dec_dict['dec_type'] == 'transformer_slot':
                outs = torch.cat([outs, final_action_token], dim=1) # [B, num_slots + 1, d_model]
                outs = self.final_decoder(outs)[:, -1] # [B, d_model]
                actions = self.fc(outs) # [B, d_action]
            elif self.act_dec_dict['dec_type'] == 'transformer_slot_pool':
                outs = outs.flatten(1, 2) # [B, num_slots, d_model] -> [B, num_slots * d_model]
                if is_pool_inst:
                    inst = inst.squeeze(1)
                    outs = torch.cat([outs, inst], dim=1) # [B, num_slots * d_model + d_model]
                actions = self.fc(outs) # [B, num_slots * d_model] -> [B, d_action]
            elif self.act_dec_dict['dec_type'] == 'transformer_slot_fc':
                outs = outs.flatten(1, 2) # [B, num_slots, d_model] -> [B, num_slots * d_model]
                actions = self.fc(outs) # [B, num_slots * d_model] -> [B, d_action]
            elif self.act_dec_dict['dec_type'] == 'transformer_slot_fc_inst':
                outs = outs.flatten(1, 2) # [B, num_slots, d_model] -> [B, num_slots * d_model]
                inst = inst.squeeze(1) # [B, d_model]
                outs = torch.cat([outs, inst], dim=1) # [B, num_slots * d_model + d_model]
                actions = self.fc(outs) # [B, (num_slots + 1) * d_model] -> [B, d_action]
            elif self.act_dec_dict['dec_type'] == 'transformer_slot_fuse_inst':
                tmp = self.fuse_decoder(outs, inst) # [B, num_slots, d_model] -> [B, num_slots, d_model]
                tmp = tmp.flatten(1, 2) # [B, num_slots * d_model]  
                actions = self.fc(tmp) # [B, num_slots * d_model] -> [B, d_action]
            elif self.act_dec_dict['dec_type'] == 'transformer_slot_fuse_inst_rev':
                tmp = self.fuse_decoder(inst, outs).squeeze(1) # [B, num_slots * d_model] -> [B, d_model]
                actions = self.fc(tmp) # [B, d_model] -> [B, d_action]

        elif self.act_dec_dict['dec_type'] == 'res_mlp':
            

            slots = slots.flatten(2, 3)
            slots = slots.flatten(1, 2) # [B, T * num_slots * d_slot]
            decoder_in = slots
            if self.act_dec_dict['inst']:
                inst = self.q_proj(inst) # [B, D_inst] -> [B, d_model]
                decoder_in = torch.cat([slots, inst], dim=1) # [B, T * num_slots * d_slot + d_model]

            actions = self.action_decoder(decoder_in) # [B, D_a]
        


        assert actions.shape[0] == B and actions.shape[1] == self.act_dec_dict['act_size']



    
        out_dict = {
            'actions': actions,  # [B, D_a]
            'slots': slots_out,  # [B, T, num_slots, d_slot]
        }
        return out_dict

    def calc_train_loss(self, data_dict, out_dict):
        """Compute training loss."""
        loss_dict = {}
        gt_actions = normalize_act(data_dict['actions'])
        pred_actions = out_dict['actions']
        # print("INFO: normed gt_actions: ", gt_actions)
        # print("INFO: pred_actions: ", pred_actions)
        actions_loss = F.mse_loss(pred_actions, gt_actions, reduction='none')

        # compute per-step slot loss in eval time
        if not self.training:
            pass
            '''for step in range(min(6, gt_slots.shape[1])):
                loss_dict[f'slot_recon_loss_{step+1}'] = \
                slots_loss[:, step].mean()
            '''

        # apply temporal loss weighting as done in RPIN
        # penalize more for early steps, less for later steps
        if self.loss_decay_factor < 1.:
            w = self.loss_decay_factor**torch.arange(gt_actions.shape[1])
            w = w.type_as(actions_loss)
            # w should sum up to rollout_T
            w = w / w.sum() * gt_actions.shape[1]
            actions_loss = actions_loss * w[None, :, None, None]

        loss_dict['action_loss'] = actions_loss.mean()

        return loss_dict

    @property
    def dtype(self):
        return self.inst_proj.dtype

    @property
    def device(self):
        return self.inst_proj.device

    def train(self, mode=True):
        super().train(mode)
        self.encoder.eval()
        # print("INFO: Encoder is frozen!!")
        # keep decoder part in eval mode
        #self.decoder.eval()
        #self.decoder_pos_embedding.eval()
        return self

