import torch
import torch.nn as nn
import torch.nn.functional as F

from nerv.training import BaseModel

from slotformer.base_slots.models import StoSAVi

import copy
import math

from nerv.models import deconv_out_shape, conv_norm_act, deconv_norm_act

from .utils import assert_shape, SoftPositionEmbed, torch_cat
from .predictor import ResidualMLPPredictor, TransformerPredictor, \
    RNNPredictorWrapper


def get_sin_pos_enc(seq_len, d_model):
    """Sinusoid absolute positional encoding."""
    inv_freq = 1. / (10000**(torch.arange(0.0, d_model, 2.0) / d_model))
    pos_seq = torch.arange(seq_len - 1, -1, -1).type_as(inv_freq)
    sinusoid_inp = torch.outer(pos_seq, inv_freq)
    pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)
    return pos_emb.unsqueeze(0)  # [1, L, C]


def build_pos_enc(pos_enc, input_len, d_model):
    """Positional Encoding of shape [1, L, D]."""
    if not pos_enc:
        return None
    # ViT, BEiT etc. all use zero-init learnable pos enc
    if pos_enc == 'learnable':
        pos_embedding = nn.Parameter(torch.zeros(1, input_len, d_model))
    # in SlotFormer, we find out that sine P.E. is already good enough
    elif 'sin' in pos_enc:  # 'sin', 'sine'
        pos_embedding = nn.Parameter(
            get_sin_pos_enc(input_len, d_model), requires_grad=False)
    else:
        raise NotImplementedError(f'unsupported pos enc {pos_enc}')
    return pos_embedding


class Rollouter(nn.Module):
    """Base class for a predictor based on slot_embs."""

    def forward(self, x):
        raise NotImplementedError

    def burnin(self, x):
        pass

    def reset(self):
        pass


class SlotRollouter(Rollouter):
    """Transformer encoder only."""

    def __init__(
        self,
        num_slots,
        slot_size,
        history_len,  # burn-in steps
        t_pe='sin',  # temporal P.E.
        slots_pe='',  # slots P.E., None in SlotFormer
        inst_size=768,
        # Transformer-related configs
        d_model=128,
        num_layers=4,
        num_heads=8,
        ffn_dim=512,
        norm_first=True,
    ):
        super().__init__()

        self.num_slots = num_slots
        self.history_len = history_len

        self.in_proj = nn.Linear(slot_size, d_model)
        self.inst_proj = nn.Linear(inst_size, d_model)

        # TODO: cross attention with decoder layer
        dec_layer = nn.TransformerDecoderLayer(
            d_model=d_model,
            nhead=num_heads,
            dim_feedforward=ffn_dim,
            norm_first=norm_first,
            batch_first=True,
        )

        self.transformer_decoder = nn.TransformerDecoder(
            decoder_layer=dec_layer, num_layers=num_layers)

        self.enc_t_pe = build_pos_enc(t_pe, history_len, d_model)
        self.enc_slots_pe = build_pos_enc(slots_pe, num_slots, d_model)
        self.out_proj = nn.Linear(d_model, slot_size)

    def forward(self, x, inst, pred_len):
        """Forward function.

        Args:
            x: [B, history_len, num_slots, slot_size]
            pred_len: int

        Returns:
            [B, pred_len, num_slots, slot_size]
        """
        assert x.shape[1] == self.history_len, 'wrong burn-in steps'

        B = x.shape[0]
        x = x.flatten(1, 2)  # [B, T * N, slot_size]
        in_x = x

        inst = self.inst_proj(inst).unsqueeze(1)

        # temporal_pe repeat for each slot, shouldn't be None
        # [1, T, D] --> [B, T, N, D] --> [B, T * N, D]
        enc_pe = self.enc_t_pe.unsqueeze(2).\
            repeat(B, 1, self.num_slots, 1).flatten(1, 2)
        # slots_pe repeat for each timestep
        if self.enc_slots_pe is not None:
            slots_pe = self.enc_slots_pe.unsqueeze(1).\
                repeat(B, self.history_len, 1, 1).flatten(1, 2)
            enc_pe = slots_pe + enc_pe

        # generate future slots autoregressively
        pred_out = []
        for _ in range(pred_len):
            # project to latent space
            x = self.in_proj(in_x)
            # encoder positional encoding
            x = x + enc_pe
            # spatio-temporal interaction via transformer, TODO: instruction to cross attention
            x = self.transformer_decoder(x, inst)
            # take the last N output tokens to predict slots
            pred_slots = self.out_proj(x[:, -self.num_slots:])
            pred_out.append(pred_slots)
            # feed the predicted slots autoregressively
            in_x = torch.cat([in_x[:, self.num_slots:], pred_out[-1]], dim=1)

        return torch.stack(pred_out, dim=1)

    @property
    def dtype(self):
        return self.in_proj.weight.dtype

    @property
    def device(self):
        return self.in_proj.weight.device


class SlotPromptRollouter(Rollouter):
    """Transformer encoder only."""

    def __init__(
        self,
        num_slots,
        slot_size,
        history_len,  # burn-in steps
        t_pe='sin',  # temporal P.E.
        slots_pe='',  # slots P.E., None in SlotFormer
        inst_size=768,
        # Transformer-related configs
        d_model=128,
        num_layers=4,
        num_heads=8,
        ffn_dim=512,
        norm_first=True,
    ):
        super().__init__()

        self.num_slots = num_slots
        self.history_len = history_len

        self.in_proj = nn.Linear(slot_size, d_model)
        self.inst_proj = nn.Linear(inst_size, d_model)

        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=num_heads,
            dim_feedforward=ffn_dim,
            norm_first=norm_first,
            batch_first=True,
        )

        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer=enc_layer, num_layers=num_layers)

        self.enc_t_pe = build_pos_enc(t_pe, history_len, d_model)
        self.enc_slots_pe = build_pos_enc(slots_pe, num_slots, d_model)
        self.out_proj = nn.Linear(d_model, slot_size)

    def forward(self, x, inst, pred_len):
        """Forward function.

        Args:
            x: [B, history_len, num_slots, slot_size]
            pred_len: int

        Returns:
            [B, pred_len, num_slots, slot_size]
        """
        assert x.shape[1] == self.history_len, 'wrong burn-in steps'

        B = x.shape[0]
        x = x.flatten(1, 2)  # [B, T * N, slot_size]
        in_x = x

        inst = self.inst_proj(inst).unsqueeze(1) # []

        # temporal_pe repeat for each slot, shouldn't be None
        # [1, T, D] --> [B, T, N, D] --> [B, T * N, D]
        enc_pe = self.enc_t_pe.unsqueeze(2).\
            repeat(B, 1, self.num_slots, 1).flatten(1, 2)
        # slots_pe repeat for each timestep
        if self.enc_slots_pe is not None:
            slots_pe = self.enc_slots_pe.unsqueeze(1).\
                repeat(B, self.history_len, 1, 1).flatten(1, 2)
            enc_pe = slots_pe + enc_pe

        # generate future slots autoregressively
        pred_out = []
        for _ in range(pred_len):
            # project to latent space
            x = self.in_proj(in_x)
            # encoder positional encoding
            x = x + enc_pe
            # spatio-temporal interaction via transformer, TODO: instruction to cross attention
            x = x # TODO: concat with inst
            x = self.transformer_encoder(x)
            # take the last N output tokens to predict slots
            pred_slots = self.out_proj(x[:, -self.num_slots:])
            pred_out.append(pred_slots)
            # feed the predicted slots autoregressively
            in_x = torch.cat([in_x[:, self.num_slots:], pred_out[-1]], dim=1)

        return torch.stack(pred_out, dim=1)

    @property
    def dtype(self):
        return self.in_proj.weight.dtype

    @property
    def device(self):
        return self.in_proj.weight.device

class SlotAttention(nn.Module):
    """Slot attention module that iteratively performs cross-attention."""

    def __init__(
        self,
        in_features,
        num_iterations,
        num_slots,
        slot_size,
        mlp_hidden_size,
        eps=1e-6,
    ):
        super().__init__()
        self.in_features = in_features
        self.num_iterations = num_iterations
        self.num_slots = num_slots
        self.slot_size = slot_size
        self.mlp_hidden_size = mlp_hidden_size
        self.eps = eps
        self.attn_scale = self.slot_size**-0.5

        self.norm_inputs = nn.LayerNorm(self.in_features)

        # Linear maps for the attention module.
        self.project_q = nn.Sequential(
            nn.LayerNorm(self.slot_size),
            nn.Linear(self.slot_size, self.slot_size, bias=False),
        )
        self.project_k = nn.Linear(in_features, self.slot_size, bias=False)
        self.project_v = nn.Linear(in_features, self.slot_size, bias=False)

        # Slot update functions.
        self.gru = nn.GRUCell(self.slot_size, self.slot_size)
        self.mlp = nn.Sequential(
            nn.LayerNorm(self.slot_size),
            nn.Linear(self.slot_size, self.mlp_hidden_size),
            nn.ReLU(),
            nn.Linear(self.mlp_hidden_size, self.slot_size),
        )

    def forward(self, inputs, slots):
        """Forward function.

        Args:
            inputs (torch.Tensor): [B, N, C], flattened per-pixel features.
            slots (torch.Tensor): [B, num_slots, C] slot inits.
        """
        # `inputs` has shape [B, num_inputs, inputs_size].
        # `num_inputs` is actually the spatial dim of feature map (H*W)
        bs, num_inputs, inputs_size = inputs.shape
        inputs = self.norm_inputs(inputs)  # Apply layer norm to the input.
        # Shape: [B, num_inputs, slot_size].
        k = self.project_k(inputs)
        # Shape: [B, num_inputs, slot_size].
        v = self.project_v(inputs)

        # Initialize the slots. Shape: [B, num_slots, slot_size].
        assert len(slots.shape) == 3

        # Multiple rounds of attention.
        for _ in range(self.num_iterations):
            slots_prev = slots

            # Attention. Shape: [B, num_slots, slot_size].
            q = self.project_q(slots)

            attn_logits = self.attn_scale * torch.einsum('bnc,bmc->bnm', k, q)
            attn = F.softmax(attn_logits, dim=-1)
            # `attn` has shape: [B, num_inputs, num_slots].

            # Normalize along spatial dim and do weighted mean.
            attn = attn + self.eps
            attn = attn / torch.sum(attn, dim=1, keepdim=True)
            updates = torch.einsum('bnm,bnc->bmc', attn, v)
            # `updates` has shape: [B, num_slots, slot_size].

            # Slot update.
            # GRU is expecting inputs of size (N, L)
            # so flatten batch and slots dimension
            slots = self.gru(
                updates.view(bs * self.num_slots, self.slot_size),
                slots_prev.view(bs * self.num_slots, self.slot_size),
            )
            slots = slots.view(bs, self.num_slots, self.slot_size)
            slots = slots + self.mlp(slots)

        return slots

    @property
    def dtype(self):
        return self.project_k.weight.dtype

    @property
    def device(self):
        return self.project_k.weight.device


class E2E(BaseModel):
    """StoSAVi + SlotFormer to enable end-to-end training or fine-tuning from slot extraction to prediction.
    """
    def __init__(
        self,
        resolution,
        clip_len,
        slot_dict=dict(
            num_slots=7,
            slot_size=128,
            slot_mlp_size=256,
            num_iterations=2,
            kernel_mlp=True,
        ),
        enc_dict=dict(
            enc_channels=(3, 64, 64, 64, 64),
            enc_ks=5,
            enc_out_channels=128,
            enc_norm='',
        ),
        dec_dict=dict(
            dec_channels=(128, 64, 64, 64, 64),
            dec_resolution=(8, 8),
            dec_ks=5,
            dec_norm='',
            dec_ckp_path='',
        ),
        rollout_dict=dict(
            num_slots=7,
            slot_size=128,
            history_len=6,
            t_pe='sin',
            slots_pe='',
            d_model=128,
            num_layers=4,
            num_heads=8,
            ffn_dim=512,
            norm_first=True,
        ),
        pred_dict=dict(
            pred_type='transformer',
            pred_rnn=True,
            pred_norm_first=True,
            pred_num_layers=2,
            pred_num_heads=4,
            pred_ffn_dim=512,
            pred_sg_every=None,
        ),
        loss_dict=dict(
            use_post_recon_loss=True,
            kld_method='var-0.01',
            rollout_len=6,
            use_img_recon_loss=False,
        ),
        eps=1e-6,
    ):
        super().__init__()
        
        self.resolution = resolution
        self.clip_len = clip_len
        self.eps = eps

        self.slot_dict = slot_dict
        self.enc_dict = enc_dict
        self.dec_dict = dec_dict
        self.pred_dict = pred_dict
        self.rollout_dict = rollout_dict
        self.loss_dict = loss_dict

        self._build_slot_attention()
        self._build_encoder()
        self._build_decoder()
        self._build_predictor()
        self._build_rollouter()
        self._build_loss()

        self.testing = False  # for compatibility
        self.loss_decay_factor = 1.  # temporal loss weighting

    def _build_slot_attention(self):
        # Build SlotAttention module
        # kernels x img_feats --> posterior_slots
        self.enc_out_channels = self.enc_dict['enc_out_channels']
        self.num_slots = self.slot_dict['num_slots']
        self.slot_size = self.slot_dict['slot_size']
        self.slot_mlp_size = self.slot_dict['slot_mlp_size']
        self.num_iterations = self.slot_dict['num_iterations']

        # directly use learnable embeddings for each slot
        self.init_latents = nn.Parameter(
            nn.init.normal_(torch.empty(1, self.num_slots, self.slot_size)))

        # predict the (\mu, \sigma) to sample the `kernels` input to SA
        if self.slot_dict.get('kernel_mlp', True):
            self.kernel_dist_layer = nn.Sequential(
                nn.Linear(self.slot_size, self.slot_size * 2),
                nn.LayerNorm(self.slot_size * 2),
                nn.ReLU(),
                nn.Linear(self.slot_size * 2, self.slot_size * 2),
            )
        else:
            self.kernel_dist_layer = nn.Sequential(
                nn.Linear(self.slot_size, self.slot_size * 2), )

        # predict the `prior_slots`
        # useless, just for compatibility to load pre-trained weights
        self.prior_slot_layer = nn.Sequential(
            nn.Linear(self.slot_size, self.slot_size),
            nn.LayerNorm(self.slot_size),
            nn.ReLU(),
            nn.Linear(self.slot_size, self.slot_size),
        )

        self.slot_attention = SlotAttention(
            in_features=self.enc_out_channels,
            num_iterations=self.num_iterations,
            num_slots=self.num_slots,
            slot_size=self.slot_size,
            mlp_hidden_size=self.slot_mlp_size,
            eps=self.eps,
        )

    def _build_encoder(self):
        # Build Encoder
        # Conv CNN --> PosEnc --> MLP
        self.enc_channels = list(self.enc_dict['enc_channels'])  # CNN channels
        self.enc_ks = self.enc_dict['enc_ks']  # kernel size in CNN
        self.enc_norm = self.enc_dict['enc_norm']  # norm in CNN
        self.visual_resolution = (64, 64)  # CNN out visual resolution
        self.visual_channels = self.enc_channels[-1]  # CNN out visual channels

        enc_layers = len(self.enc_channels) - 1
        self.encoder = nn.Sequential(*[
            conv_norm_act(
                self.enc_channels[i],
                self.enc_channels[i + 1],
                kernel_size=self.enc_ks,
                # 2x downsampling for 128x128 image
                stride=2 if (i == 0 and self.resolution[0] == 128) else 1,
                norm=self.enc_norm,
                act='relu' if i != (enc_layers - 1) else '')
            for i in range(enc_layers)
        ])  # relu except for the last layer

        # Build Encoder related modules
        self.encoder_pos_embedding = SoftPositionEmbed(self.visual_channels,
                                                       self.visual_resolution)
        self.encoder_out_layer = nn.Sequential(
            nn.LayerNorm(self.visual_channels),
            nn.Linear(self.visual_channels, self.enc_out_channels),
            nn.ReLU(),
            nn.Linear(self.enc_out_channels, self.enc_out_channels),
        )

    def _build_decoder(self):
        # Build Decoder
        # Spatial broadcast --> PosEnc --> DeConv CNN
        self.dec_channels = self.dec_dict['dec_channels']  # CNN channels
        self.dec_resolution = self.dec_dict['dec_resolution']  # broadcast size
        self.dec_ks = self.dec_dict['dec_ks']  # kernel size
        self.dec_norm = self.dec_dict['dec_norm']  # norm in CNN
        assert self.dec_channels[0] == self.slot_size, \
            'wrong in_channels for Decoder'
        modules = []
        in_size = self.dec_resolution[0]
        out_size = in_size
        stride = 2
        for i in range(len(self.dec_channels) - 1):
            if out_size == self.resolution[0]:
                stride = 1
            modules.append(
                deconv_norm_act(
                    self.dec_channels[i],
                    self.dec_channels[i + 1],
                    kernel_size=self.dec_ks,
                    stride=stride,
                    norm=self.dec_norm,
                    act='relu'))
            out_size = deconv_out_shape(out_size, stride, self.dec_ks // 2,
                                        self.dec_ks, stride - 1)

        assert_shape(
            self.resolution,
            (out_size, out_size),
            message="Output shape of decoder did not match input resolution. "
            "Try changing `decoder_resolution`.",
        )

        # out Conv for RGB and seg mask
        modules.append(
            nn.Conv2d(
                self.dec_channels[-1], 4, kernel_size=1, stride=1, padding=0))

        self.decoder = nn.Sequential(*modules)
        self.decoder_pos_embedding = SoftPositionEmbed(self.slot_size,
                                                       self.dec_resolution)

        # # load pretrained weight
        # ckp_path = self.dec_dict['dec_ckp_path']
        # assert ckp_path, 'Please provide pretrained decoder weight'
        # w = torch.load(ckp_path, map_location='cpu')['state_dict']
        
        # dec_w = {k[8:]: v for k, v in w.items() if k.startswith('decoder.')}
        # dec_pe_w = {
        #     k[22:]: v
        #     for k, v in w.items() if k.startswith('decoder_pos_embedding.')
        # }
        # self.decoder.load_state_dict(dec_w)
        # self.decoder_pos_embedding.load_state_dict(dec_pe_w)

        # print("INFO: pretrained decoder weight loaded from", ckp_path)

        # freeze decoder
        # for p in self.decoder.parameters():
        #     p.requires_grad = False
        # for p in self.decoder_pos_embedding.parameters():
        #     p.requires_grad = False
        # self.decoder.eval()
        # self.decoder_pos_embedding.eval()

    def _build_rollouter(self):
        """Predictor as in SAVi to transition slot from time t to t+1."""
        # Build Rollouter
        self.history_len = self.rollout_dict['history_len']  # burn-in steps
        self.rollouter = SlotRollouter(**self.rollout_dict)

    def _build_predictor(self):
        """Predictor as in SAVi to transition slot from time t to t+1."""
        # Build Predictor
        pred_type = self.pred_dict.get('pred_type', 'transformer')
        # Transformer (object interaction) --> LSTM (scene dynamic)
        if pred_type == 'mlp':
            self.predictor = ResidualMLPPredictor(
                [self.slot_size, self.slot_size * 2, self.slot_size],
                norm_first=self.pred_dict['pred_norm_first'],
            )
        else:
            self.predictor = TransformerPredictor(
                self.slot_size,
                self.pred_dict['pred_num_layers'],
                self.pred_dict['pred_num_heads'],
                self.pred_dict['pred_ffn_dim'],
                norm_first=self.pred_dict['pred_norm_first'],
            )
        # wrap LSTM
        if self.pred_dict['pred_rnn']:
            self.predictor = RNNPredictorWrapper(
                self.predictor,
                self.slot_size,
                self.slot_mlp_size,
                num_layers=1,
                rnn_cell='LSTM',
                sg_every=self.pred_dict['pred_sg_every'],
            )

    def _build_loss(self):
        """Loss calculation settings."""
        
        self.use_post_recon_loss = self.loss_dict['use_post_recon_loss']
        assert self.use_post_recon_loss
        # stochastic SAVi by sampling the kernels
        kld_method = self.loss_dict['kld_method']
        # a smaller sigma for the prior distribution
        if '-' in kld_method:
            kld_method, kld_var = kld_method.split('-')
            self.kld_log_var = math.log(float(kld_var))
        else:
            self.kld_log_var = math.log(1.)
        self.kld_method = kld_method
        assert self.kld_method in ['var', 'none']

        self.rollout_len = self.loss_dict['rollout_len']  # rollout steps
        self.use_img_recon_loss = self.loss_dict['use_img_recon_loss']

    def _kld_loss(self, prior_dist, post_slots):
        """KLD between (mu, sigma) and (0 or mu, 1)."""
        if self.kld_method == 'none':
            return torch.tensor(0.).type_as(prior_dist)
        assert prior_dist.shape[-1] == self.slot_size * 2
        mu1 = prior_dist[..., :self.slot_size]
        log_var1 = prior_dist[..., self.slot_size:]
        mu2 = mu1.detach().clone()  # no penalty for mu
        log_var2 = torch.ones_like(log_var1).detach() * self.kld_log_var
        sigma1 = torch.exp(log_var1 * 0.5)
        sigma2 = torch.exp(log_var2 * 0.5)
        kld = torch.log(sigma2 / sigma1) + \
            (torch.exp(log_var1) + (mu1 - mu2)**2) / \
            (2. * torch.exp(log_var2)) - 0.5
        return kld.sum(-1).mean()

    def _sample_dist(self, dist):
        """Sample values from Gaussian distribution."""
        assert dist.shape[-1] == self.slot_size * 2
        mu = dist[..., :self.slot_size]
        # not doing any stochastic
        if self.kld_method == 'none':
            return mu
        log_var = dist[..., self.slot_size:]
        eps = torch.randn_like(mu).detach()
        sigma = torch.exp(log_var * 0.5)
        return mu + eps * sigma

    def _get_encoder_out(self, img):
        """Encode image, potentially add pos enc, apply MLP."""
        encoder_out = self.encoder(img).type(self.dtype)
        encoder_out = self.encoder_pos_embedding(encoder_out)
        # `encoder_out` has shape: [B, C, H, W]
        encoder_out = torch.flatten(encoder_out, start_dim=2, end_dim=3)
        # `encoder_out` has shape: [B, C, H*W]
        encoder_out = encoder_out.permute(0, 2, 1).contiguous()
        encoder_out = self.encoder_out_layer(encoder_out)
        # `encoder_out` has shape: [B, H*W, enc_out_channels]
        return encoder_out

    def encode(self, img, prev_slots=None):
        """Encode from img to slots."""
        B, T, C, H, W = img.shape
        img = img.flatten(0, 1)

        encoder_out = self._get_encoder_out(img)
        encoder_out = encoder_out.unflatten(0, (B, T))
        # `encoder_out` has shape: [B, T, H*W, out_features]

        # init slots
        init_latents = self.init_latents.repeat(B, 1, 1)  # [B, N, C]

        # apply SlotAttn on video frames via reusing slots
        all_kernel_dist, all_post_slots = [], []
        for idx in range(T):
            # init
            if prev_slots is None:
                latents = init_latents  # [B, N, C]
            else:
                latents = self.predictor(prev_slots)  # [B, N, C]

            # stochastic `kernels` as SA input
            kernel_dist = self.kernel_dist_layer(latents)
            kernels = self._sample_dist(kernel_dist)
            all_kernel_dist.append(kernel_dist)

            # perform SA to get `post_slots`
            post_slots = self.slot_attention(encoder_out[:, idx], kernels)
            all_post_slots.append(post_slots)

            # next timestep
            prev_slots = post_slots

        # (B, T, self.num_slots, self.slot_size)
        kernel_dist = torch.stack(all_kernel_dist, dim=1)
        post_slots = torch.stack(all_post_slots, dim=1)

        return kernel_dist, post_slots, encoder_out

    def _reset_rnn(self):
        self.predictor.reset()

    def extract_slot(self, data_dict):
        """A wrapper for extract slot.

        If the input video is too long in testing, we manually cut it.
        """
        img = data_dict['img']
        
        T = img.shape[1]
        if T <= self.clip_len or self.training:

            return self._forward(img, None)

        # try to find the max len to input each time
        clip_len = T
        while True:
            try:
                _ = self._forward(img[:, :clip_len], None)
                del _
                torch.cuda.empty_cache()
                break
            except RuntimeError:  # CUDA out of memory
                clip_len = clip_len // 2 + 1
        # update `clip_len`
        self.clip_len = max(self.clip_len, clip_len)
        # no need to split
        if clip_len == T:
            return self._forward(img, None)

        # split along temporal dim
        cat_dict = None
        prev_slots = None
        for clip_idx in range(0, T, clip_len):
            out_dict = self._forward(img[:, clip_idx:clip_idx + clip_len],
                                     prev_slots)
            # because this should be in test mode, we detach the outputs
            if cat_dict is None:
                cat_dict = {k: [v.detach()] for k, v in out_dict.items()}
            else:
                for k, v in out_dict.items():
                    cat_dict[k].append(v.detach())
            prev_slots = cat_dict['post_slots'][-1][:, -1].detach().clone()
            del out_dict
            torch.cuda.empty_cache()
        cat_dict = {k: torch_cat(v, dim=1) for k, v in cat_dict.items()}
        # print("INFO: slots have gradients!!! ", cat_dict['post_slots'].requires_grad)
        return cat_dict


    def forward(self, data_dict):
        """A wrapper for model forward.

        If the input video is too long in testing, we manually cut it.
        """
        cat_dict = self.extract_slot(data_dict)

        slots = cat_dict['post_slots']
        kernel_dist = cat_dict['kernel_dist']

        inst = data_dict['instruction']

        assert self.rollout_len + self.history_len == slots.shape[1], \
            f'wrong SlotFormer training length {slots.shape[1]}'    

        past_slots = slots[:, :self.history_len]
        gt_slots = slots[:, self.history_len:]
        if self.use_img_recon_loss:
            out_dict = self.rollout(
                past_slots, inst, self.rollout_len, decode=True, with_gt=False)
            out_dict['pred_slots'] = out_dict.pop('slots')
            out_dict['gt_slots'] = gt_slots  # both slots [B, pred_len, N, C]
            out_dict['past_slots'] = past_slots
        else:
            pred_slots = self.rollout(
                past_slots, inst, self.rollout_len, decode=False)
            out_dict = {
                'past_slots': past_slots,
                'gt_slots': gt_slots,  # both slots [B, pred_len, N, C]
                'pred_slots': pred_slots,
            }
        
    
        return out_dict

    def _forward(self, img, prev_slots=None):
        """Forward function.

        Args:
            img: [B, T, C, H, W]
            prev_slots: [B, num_slots, slot_size] or None,
                the `post_slots` from last timestep.
        """
        # reset RNN states if this is the first frame
        if prev_slots is None:
            self._reset_rnn()

        B, T = img.shape[:2]

        # history slots 
        kernel_dist, post_slots, encoder_out = \
            self.encode(img[:,:self.history_len], prev_slots=prev_slots)

        # gt slots w/o gradient
        with torch.no_grad():
            prev_slots_ = post_slots[:,-1]
            kernel_dist_gt, post_slots_gt, encoder_out_gt = \
                self.encode(img[:,self.history_len:], prev_slots=prev_slots_)

        kernel_dist = torch.cat([kernel_dist, kernel_dist_gt], dim=1)
        post_slots = torch.cat([post_slots, post_slots_gt], dim=1)
        encoder_out = torch.cat([encoder_out, encoder_out_gt], dim=1)

        # `slots` has shape: [B, T, self.num_slots, self.slot_size]

        out_dict = {
            'post_slots': post_slots,  # [B, T, num_slots, C]
            'kernel_dist': kernel_dist,  # [B, T, num_slots, 2C]
            'img': img,  # [B, T, 3, H, W]
        }
        if self.testing:
            return out_dict

        if self.use_post_recon_loss:
            with torch.no_grad():
                slots_dec = post_slots
            post_recon_img, post_recons, post_masks, _ = \
                self.decode(slots_dec.flatten(0, 1))
            post_dict = {
                'post_recon_combined': post_recon_img,  # [B*T, 3, H, W]
                'post_recons': post_recons,  # [B*T, num_slots, 3, H, W]
                'post_masks': post_masks,  # [B*T, num_slots, 1, H, W]
            }
            out_dict.update(
                {k: v.unflatten(0, (B, T))
                 for k, v in post_dict.items()})

        return out_dict

    def decode(self, slots):
        """Decode from slots to reconstructed images and masks."""
        # `slots` has shape: [B, self.num_slots, self.slot_size].
        bs, num_slots, slot_size = slots.shape
        height, width = self.resolution
        num_channels = 3

        # spatial broadcast
        decoder_in = slots.view(bs * num_slots, slot_size, 1, 1)
        decoder_in = decoder_in.repeat(1, 1, self.dec_resolution[0],
                                       self.dec_resolution[1])

        out = self.decoder_pos_embedding(decoder_in)
        out = self.decoder(out)
        # `out` has shape: [B*num_slots, 4, H, W].

        out = out.view(bs, num_slots, num_channels + 1, height, width)
        recons = out[:, :, :num_channels, :, :]  # [B, num_slots, 3, H, W]
        masks = out[:, :, -1:, :, :]
        masks = F.softmax(masks, dim=1)  # [B, num_slots, 1, H, W]
        recon_combined = torch.sum(recons * masks, dim=1)  # [B, 3, H, W]
        return recon_combined, recons, masks, slots

    def rollout(self, past_slots, inst, pred_len, decode=False, with_gt=True):
        """Unroll slots for `pred_len` steps, potentially decode images."""
        B = past_slots.shape[0]  # [B, T, N, C]
        pred_slots = self.rollouter(past_slots[:, -self.history_len:], inst,
                                    pred_len)
        # `decode` is usually called from outside
        # used to visualize an entire video (burn-in + rollout)
        # i.e. `with_gt` is True
        if decode:
            if with_gt:
                T = pred_len + past_slots.shape[1]
                slots = torch.cat([past_slots, pred_slots], dim=1)
            else:
                T = pred_len
                slots = pred_slots
            with torch.no_grad():
                slots_dec = slots
            recon_img, recons, masks, _ = self.decode(slots_dec.flatten(0, 1))
            out_dict = {
                'recon_combined': recon_img,  # [B*T, 3, H, W]
                'recons': recons,  # [B*T, num_slots, 3, H, W]
                'masks': masks,  # [B*T, num_slots, 1, H, W]
            }
            out_dict = {k: v.unflatten(0, (B, T)) for k, v in out_dict.items()}
            out_dict['slots'] = slots
            return out_dict
        # [B, pred_len, N, C]
        return pred_slots

    def calc_train_loss(self, data_dict, out_dict):
        """Compute training loss."""
        loss_dict = {}
        gt_slots = out_dict['gt_slots']  # [B, rollout_T, N, C]
        pred_slots = out_dict['pred_slots']

        if self.loss_dict['contrastive'] == 'none':
            slots_loss = F.mse_loss(pred_slots, gt_slots, reduction='none')
        elif self.loss_dict['contrastive']  == 'inst':
            # contrastive loss computed on batch (pos/neg according to the instruction)
            B, rollout_T, N, C = gt_slots.shape
            # print(f"INFO: gt_slot shape [{B}, {rollout_T}, {N}, {C}]")
            label = torch.eye(B).to(self.device)
            margin = 200.
            euclidean_dist = F.pairwise_distance(gt_slots.view(B, -1), pred_slots.view(B, -1))
            # print(f"INFO: slot_dist max{euclidean_dist.max().item()}, min{euclidean_dist.min().item()}")
            pos_slots_loss = torch.sum(label * torch.pow(euclidean_dist, 2)) / B / rollout_T / N / C
            neg_slots_loss = torch.sum((1-label) * torch.pow(torch.clamp(margin - euclidean_dist, min=0.0), 2)) / B / (B - 1) / rollout_T / N / C
            # print(f"INFO: pos_slots_loss {pos_slots_loss.item()}")
            # print(f"INFO: neg_slots_loss {neg_slots_loss.item()}")

            slots_loss = pos_slots_loss + neg_slots_loss
        
        if self.loss_dict['contrastive'] == 'none':
            # compute per-step slot loss in eval time
            if not self.training:
                for step in range(min(6, gt_slots.shape[1])):
                    loss_dict[f'slot_recon_loss_{step+1}'] = \
                        slots_loss[:, step].mean()
            # apply temporal loss weighting as done in RPIN
            # penalize more for early steps, less for later steps
            if self.loss_decay_factor < 1.:
                w = self.loss_decay_factor**torch.arange(gt_slots.shape[1])
                w = w.type_as(slots_loss)
                # w should sum up to rollout_T
                w = w / w.sum() * gt_slots.shape[1]
                slots_loss = slots_loss * w[None, :, None, None]

            # only compute loss on valid slots/imgs
            # e.g. in PHYRE, some videos are short, so we pad zero slots
            vid_len = data_dict.get('vid_len', None)
            trunc_loss = False
            if (vid_len is not None) and \
                    (vid_len < (self.history_len + self.rollout_len)).any():
                trunc_loss = True
                valid_mask = torch.arange(gt_slots.shape[1]).to(
                    gt_slots.device) + self.history_len
                valid_mask = valid_mask[None] < vid_len[:, None]  # [B, rollout_T]
                valid_mask = valid_mask.flatten(0, 1)
                slots_loss = slots_loss.flatten(0, 1)[valid_mask]
            loss_dict['slot_recon_loss'] = slots_loss.mean()


        elif self.loss_dict['contrastive'] == 'inst':
            loss_dict['slot_recon_loss'] = slots_loss
            loss_dict['positive_slot_loss'] = pos_slots_loss
            loss_dict['negative_slot_loss'] = neg_slots_loss

        if self.use_img_recon_loss:
            recon_combined = out_dict['recon_combined']
            gt_img = data_dict['img'][:, self.history_len:]
            if self.loss_dict['contrastive'] == 'none':
                imgs_loss = F.mse_loss(recon_combined, gt_img, reduction='none')
                # in case of truncated loss, we need to mask out invalid imgs
                if trunc_loss:
                    imgs_loss = imgs_loss.flatten(0, 1)[valid_mask]
                loss_dict['img_recon_loss'] = imgs_loss.mean()
            elif self.loss_dict['contrastive'] == 'inst':
                B = gt_img.shape[0]
                # print(f"INFO: gt_img shape ", gt_img.shape)
                gt_img_rshp = gt_img.view(B, -1)
                recon_combined_rshp = recon_combined.view(B, -1)
                label = torch.eye(B).to(self.device)
                margin = 800.
                euclidean_dist = F.pairwise_distance(gt_img_rshp, recon_combined_rshp)
                # print(f"INFO: img_dist max{euclidean_dist.max().item()}, min{euclidean_dist.min().item()}")
                pos_img_loss = torch.sum(label * torch.pow(euclidean_dist, 2)) / B / gt_img_rshp.shape[1]
                neg_img_loss = torch.sum((1-label) * torch.pow(torch.clamp(margin - euclidean_dist, min=0.0), 2)) / B / (B - 1) / gt_img_rshp.shape[1]
                # print(f"INFO: pos_img_loss {pos_img_loss.item()}")
                # print(f"INFO: neg_img_loss {neg_img_loss.item()}")
                imgs_loss = pos_img_loss + neg_img_loss
                loss_dict['positive_img_loss'] = pos_img_loss
                loss_dict['negative_img_loss'] = neg_img_loss
                loss_dict['img_recon_loss'] = imgs_loss


        

            
            
        return loss_dict

    @property
    def dtype(self):
        return self.slot_attention.dtype

    @property
    def device(self):
        return self.slot_attention.device

    def train(self, mode=True):
        super().train(mode)
        # keep decoder part in eval mode
        # self.decoder.eval()
        # self.decoder_pos_embedding.eval()
        return self
