from typing import Optional
from torch import Tensor
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
import numpy as np
import os

########################################### From Mask2Former Repo #############################################
class SelfAttentionLayer(nn.Module):

    def __init__(self, d_model, nhead, dropout=0.0,
                 activation="relu", normalize_before=False):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)

        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

        self.activation = F.relu
        self.normalize_before = normalize_before

        self._reset_parameters()

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
        return tensor if pos is None else tensor + pos

    def forward_post(self, tgt,
                     tgt_mask: Optional[Tensor] = None,
                     tgt_key_padding_mask: Optional[Tensor] = None,
                     query_pos: Optional[Tensor] = None):
        q = k = self.with_pos_embed(tgt, query_pos)
        tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
                              key_padding_mask=tgt_key_padding_mask)[0]
        tgt = tgt + self.dropout(tgt2)
        tgt = self.norm(tgt)

        return tgt

    def forward_pre(self, tgt,
                    tgt_mask: Optional[Tensor] = None,
                    tgt_key_padding_mask: Optional[Tensor] = None,
                    query_pos: Optional[Tensor] = None):
        tgt2 = self.norm(tgt)
        q = k = self.with_pos_embed(tgt2, query_pos)
        tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
                              key_padding_mask=tgt_key_padding_mask)[0]
        tgt = tgt + self.dropout(tgt2)

        return tgt

    def forward(self, tgt,
                tgt_mask: Optional[Tensor] = None,
                tgt_key_padding_mask: Optional[Tensor] = None,
                query_pos: Optional[Tensor] = None):
        if self.normalize_before:
            return self.forward_pre(tgt, tgt_mask,
                                    tgt_key_padding_mask, query_pos)
        return self.forward_post(tgt, tgt_mask,
                                 tgt_key_padding_mask, query_pos)

class PositionEmbeddingSine3D(nn.Module):
    """
    This is a more standard version of the position embedding, very similar to the one
    used by the Attention is all you need paper, generalized to work on videos.
    """

    def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
        super().__init__()
        self.num_pos_feats = num_pos_feats
        self.temperature = temperature
        self.normalize = normalize
        if scale is not None and normalize is False:
            raise ValueError("normalize should be True if scale is passed")
        if scale is None:
            scale = 2 * math.pi
        self.scale = scale

    def forward(self, x, mask=None):
        # b, t, c, h, w
        assert x.dim() == 5, f"{x.shape} should be a 5-dimensional Tensor, got {x.dim()}-dimensional Tensor instead"
        if mask is None:
            mask = torch.zeros((x.size(0), x.size(1), x.size(3), x.size(4)), device=x.device, dtype=torch.bool)
        not_mask = ~mask
        z_embed = not_mask.cumsum(1, dtype=torch.float32)
        y_embed = not_mask.cumsum(2, dtype=torch.float32)
        x_embed = not_mask.cumsum(3, dtype=torch.float32)
        if self.normalize:
            eps = 1e-6
            z_embed = z_embed / (z_embed[:, -1:, :, :] + eps) * self.scale
            y_embed = y_embed / (y_embed[:, :, -1:, :] + eps) * self.scale
            x_embed = x_embed / (x_embed[:, :, :, -1:] + eps) * self.scale

        dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)

        dim_t_z = torch.arange((self.num_pos_feats * 2), dtype=torch.float32, device=x.device)
        dim_t_z = self.temperature ** (2 * (dim_t_z // 2) / (self.num_pos_feats * 2))

        pos_x = x_embed[:, :, :, :, None] / dim_t
        pos_y = y_embed[:, :, :, :, None] / dim_t
        pos_z = z_embed[:, :, :, :, None] / dim_t_z
        pos_x = torch.stack((pos_x[:, :, :, :, 0::2].sin(), pos_x[:, :, :, :, 1::2].cos()), dim=5).flatten(4)
        pos_y = torch.stack((pos_y[:, :, :, :, 0::2].sin(), pos_y[:, :, :, :, 1::2].cos()), dim=5).flatten(4)
        pos_z = torch.stack((pos_z[:, :, :, :, 0::2].sin(), pos_z[:, :, :, :, 1::2].cos()), dim=5).flatten(4)
        pos = (torch.cat((pos_y, pos_x), dim=4) + pos_z).permute(0, 1, 4, 2, 3)  # b, t, c, h, w
        return pos

class CrossAttentionLayer(nn.Module):

    def __init__(self, d_model, nhead, dropout=0.0, normalize_before=False):
        super().__init__()
        self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.nhead = nhead

        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.normalize_before = normalize_before

        self._reset_parameters()

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
        return tensor if pos is None else tensor + pos

    def forward_post(self, tgt, memory,
                     memory_mask: Optional[Tensor] = None,
                     memory_key_padding_mask: Optional[Tensor] = None,
                     pos: Optional[Tensor] = None,
                     query_pos: Optional[Tensor] = None):

        tgt2, tgt2_attention = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
                                                   key=self.with_pos_embed(memory, pos),
                                                   value=memory, attn_mask=memory_mask,
                                                   key_padding_mask=memory_key_padding_mask)

        tgt = tgt + self.dropout(tgt2)
        tgt = self.norm(tgt)

        return tgt, tgt2_attention

    def forward_pre(self, tgt, memory,
                    memory_mask: Optional[Tensor] = None,
                    memory_key_padding_mask: Optional[Tensor] = None,
                    pos: Optional[Tensor] = None,
                    query_pos: Optional[Tensor] = None):
        tgt2 = self.norm(tgt)
        tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
                                   key=self.with_pos_embed(memory, pos),
                                   value=memory, attn_mask=memory_mask,
                                   key_padding_mask=memory_key_padding_mask)[0]
        tgt = tgt + self.dropout(tgt2)

        return tgt

    def forward(self, tgt, memory,
                memory_mask: Optional[Tensor] = None,
                memory_key_padding_mask: Optional[Tensor] = None,
                pos: Optional[Tensor] = None,
                query_pos: Optional[Tensor] = None):

        if memory_mask is not None:
            memory_mask = memory_mask.unsqueeze(1).repeat(1, self.nhead, 1, 1)
            memory_mask = memory_mask.view(-1, *memory_mask.shape[-2:])

        if self.normalize_before:
            return self.forward_pre(tgt, memory, memory_mask,
                                    memory_key_padding_mask, pos, query_pos)
        return self.forward_post(tgt, memory, memory_mask,
                                 memory_key_padding_mask, pos, query_pos)

######################################################################################################

class MultiscaleTransformerDecoder(nn.Module):
    def __init__(self,  in_channels,
                 hidden_dim: int,
                 num_queries: int,
                 nheads: int,
                 dim_feedforward: int,
                 dec_layers: int,
                 pre_norm: bool,
                 use_sa: bool = False):
        super().__init__()

        # positional encoding
        N_steps = hidden_dim // 2
        self.pe_layer = PositionEmbeddingSine3D(N_steps, normalize=True)

        # define Transformer decoder here
        self.num_heads = nheads
        self.num_layers = dec_layers
        self.transformer_cross_attention_layers = nn.ModuleList()
        self.transformer_self_attention_layers = nn.ModuleList()
        self.use_sa = use_sa

        for _ in range(self.num_layers):
            self.transformer_cross_attention_layers.append(
                CrossAttentionLayer(
                    d_model=hidden_dim,
                    nhead=nheads,
                    dropout=0.0,
                    normalize_before=pre_norm,
                )
            )

        if self.use_sa:
            for _ in range(self.num_layers):
                self.transformer_self_attention_layers.append(
                    SelfAttentionLayer(
                        d_model=hidden_dim,
                        nhead=nheads,
                        dropout=0.0,
                        normalize_before=pre_norm,
                    )
                )

        self.num_queries = num_queries
        # learnable key features
        self.key_feat = nn.Embedding(num_queries, hidden_dim)

        # learnable key p.e.
        self.key_embed = nn.Embedding(num_queries, hidden_dim)

        # level embedding (we always use 3 scales)
        self.num_feature_levels = 3
        self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim)
        self.input_proj = nn.ModuleList()
        for _ in range(self.num_feature_levels):
#            if in_channels != hidden_dim or enforce_input_project:
#                self.input_proj.append(Conv2d(in_channels, hidden_dim, kernel_size=1))
#                weight_init.c2_xavier_fill(self.input_proj[-1])
#            else:
            self.input_proj.append(nn.Sequential())

        # output FFNs
        outch2 = hidden_dim
        self.decoder_norm = nn.LayerNorm(hidden_dim)
        self.decoder2 = nn.Sequential(nn.Conv2d(outch2, outch2, (3, 3), padding=(1, 1), bias=True),
                                      nn.ReLU(),
                                      nn.Conv2d(outch2, 2, (3, 3), padding=(1, 1), bias=True))

        #self.class_embed = nn.Linear(hidden_dim, 2)
        #self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3)

    def _mix_levels(self, low_res_f, high_res_f, size_low, size_high, num_frames):
        low_res_f = low_res_f.view(num_frames, *size_low, *low_res_f.shape[-2:])
        low_res_f = low_res_f.permute(3, 0, 4, 1, 2)

        bs, t, c, h, w = low_res_f.shape
        low_res_f = F.interpolate(low_res_f.flatten(0, 1), size_high, mode='bilinear')

        low_res_f = low_res_f.view(bs, t, c, *size_high).permute(1, 3, 4, 0, 2)
        mixed_feats = high_res_f + low_res_f.flatten(0,2)
        return mixed_feats

    def _extract_feats_and_pos(self, x, bs, t):
        size_list = []
        pos = []
        src = []
        for i in range(self.num_feature_levels):
            size_list.append(x[i].shape[-2:])
            pos.append(self.pe_layer(x[i].view(bs, t, -1, size_list[-1][0], size_list[-1][1]), None).flatten(3))
            src.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None])

            # NTxCxHW => NxTxCxHW => (TxHW)xNxC
            _, c, hw = src[-1].shape
            pos[-1] = pos[-1].view(bs, t, c, hw).permute(1, 3, 0, 2).flatten(0, 1)
            src[-1] = src[-1].view(bs, t, c, hw).permute(1, 3, 0, 2).flatten(0, 1)
        return src, pos, size_list

    def _do_st_cross_attention(self, src, pos, size_list, bs, t, invalid_pixels, slot_feat, slot_pos):
        output = None
        for i in range(self.num_layers):
            level_index = i % self.num_feature_levels
            if output is None:
                output = src[level_index]
            else:
                # Mix between attended features w.r.t slots + higher level feats
                output = self._mix_levels(output, src[level_index], size_list[level_index-1],
                                          size_list[level_index], t)

            # Resize and reshape mask on attention based in invalid pixels to
            if invalid_pixels is not None:
                memory_mask = F.interpolate(invalid_pixels.float().view(-1, 1, *invalid_pixels.shape[-2:]),
                                            size_list[level_index],
                                            mode='bilinear').flatten(2, 3)
                memory_mask = memory_mask.view(bs, t, 1, -1).repeat(1, 1, slot_feat.shape[0], 1).permute(0, 1, 3, 2)
                memory_mask = memory_mask.flatten(1,2)
            else:
                memory_mask = None

            if self.use_sa and level_index in [0, 1]:
                output = self.transformer_self_attention_layers[i](
                    output, tgt_mask=None,
                    tgt_key_padding_mask=None,
                    query_pos=pos[level_index]
                )

            output, _ = self.transformer_cross_attention_layers[i](
                output, slot_feat,
                memory_mask=memory_mask,
                memory_key_padding_mask=None,  # here we do not apply masking on padded region
                pos=slot_pos, query_pos=pos[level_index]
            )


        return output, None

    def forward(self, x, invalid_pixels, num_frames=-1, seq_name=''):
        t = num_frames
        bs = x[0].shape[0] // t

        # Project features and get Positional embeddings
        src, pos, size_list = self._extract_feats_and_pos(x, bs, t)

        # QxNxC
        slot_feat = self.key_embed.weight.unsqueeze(1).repeat(1, bs, 1)
        slot_pos = self.key_feat.weight.unsqueeze(1).repeat(1, bs, 1)

        # Perform Spatiotemporal Cross Attention
        output, output_attention_maps = self._do_st_cross_attention(
                src, pos, size_list, bs, t, invalid_pixels, slot_feat, slot_pos
        )

        #attn_fname = 'output_attention_maps/%s.npy'%seq_name[0]
        #if os.path.exists(attn_fname):
        #    attn_fname = attn_fname.replace('.npy', '_1.npy')
        #np.save(attn_fname, [o.cpu() for o in output_attention_maps])

        output = self.decoder_norm(output)
        output = output.view(t, *size_list[-1], *output.shape[-2:]).permute(3, 0, 4, 1, 2)
        output = output.flatten(0,1)
        logits = self.decoder2(output)

        interm_feats = self.decoder2[0](output)
        interm_feats = self.decoder2[1](interm_feats)

        return logits, interm_feats

######################################################################################################
class HierarchMultiscaleTransformerDecoder(MultiscaleTransformerDecoder):
    def __init__(self,  in_channels,
                 hidden_dim: int,
                 num_queries: int,
                 nheads: int,
                 dim_feedforward: int,
                 dec_layers: int,
                 pre_norm: bool):

        super().__init__(in_channels=in_channels, hidden_dim=hidden_dim, num_queries=num_queries,
                         nheads=nheads, dim_feedforward=dim_feedforward, dec_layers=dec_layers,
                         pre_norm=pre_norm)

        self.transformer_proto_cross_attention_layers = nn.ModuleList()
        for _ in range(self.num_layers):
            self.transformer_proto_cross_attention_layers.append(
                CrossAttentionLayer(
                    d_model=hidden_dim,
                    nhead=nheads,
                    dropout=0.0,
                    normalize_before=pre_norm,
                )
            )

    def _do_proto_cross_attention(self, src, pos, size_list, bs, t, invalid_pixels, slot_feat, slot_pos):
        output = slot_feat
        for i in range(self.num_layers):
            level_index = i % self.num_feature_levels
            output = self.transformer_proto_cross_attention_layers[i](
                output, src[level_index],
                memory_mask=None,
                memory_key_padding_mask=None,  # here we do not apply masking on padded region
                pos=pos[level_index], query_pos=slot_pos
            )
        return output


    def forward(self, x, invalid_pixels, num_frames=-1):
        t = num_frames
        bs = x[0].shape[0] // t

        # Project features and get Positional embeddings
        src, pos, size_list = self._extract_feats_and_pos(x, bs, t)

        # QxNxC
        slot_feat = self.key_embed.weight.unsqueeze(1).repeat(1, bs, 1)
        slot_pos = self.key_feat.weight.unsqueeze(1).repeat(1, bs, 1)

        # Perform first Prototype Cross Attention
        slot_feat = self._do_proto_cross_attention(src, pos, size_list, bs, t, invalid_pixels, slot_feat, slot_pos)

        # Perform Spatiotemporal Cross Attention
        output = self._do_st_cross_attention(src, pos, size_list, bs, t, invalid_pixels, slot_feat, slot_pos)

        output = self.decoder_norm(output)
        output = output.view(t, *size_list[-1], *output.shape[-2:]).permute(3, 0, 4, 1, 2)
        output = output.flatten(0,1)
        logits = self.decoder2(output)

        interm_feats = self.decoder2[0](output)
        interm_feats = self.decoder2[1](interm_feats)

        return logits, interm_feats

######################################################################################################
class BidirMultiscaleTransformerDecoder(MultiscaleTransformerDecoder):
    def __init__(self,  in_channels,
                 hidden_dim: int,
                 num_queries: int,
                 nheads: int,
                 dim_feedforward: int,
                 dec_layers: int,
                 pre_norm: bool,
                 use_sa: bool = False):

        super().__init__(in_channels=in_channels, hidden_dim=hidden_dim, num_queries=num_queries,
                         nheads=nheads, dim_feedforward=dim_feedforward, dec_layers=dec_layers,
                         pre_norm=pre_norm, use_sa=False)

    def _do_st_cross_attention(self, src, pos, size_list, bs, t, invalid_pixels, slot_feat, slot_pos):
        output = None
        output_attention_maps = []

        iterate_levels = [0, 1, 2, 1, 0, 1, 2]
        for it, level_index in enumerate(iterate_levels):
            #level_index = i % self.num_feature_levels
            if output is None:
                output = src[level_index]
            else:
                # Mix between attended features w.r.t slots + higher level feats
                level_index_bar = iterate_levels[it-1]
                output = self._mix_levels(output, src[level_index], size_list[level_index_bar],
                                          size_list[level_index], t)

            #print('Current level is ', level_index, ' layer used ', it)
            # Resize and reshape mask on attention based in invalid pixels to
            if invalid_pixels is not None:
                memory_mask = F.interpolate(invalid_pixels.float().view(-1, 1, *invalid_pixels.shape[-2:]),
                                            size_list[level_index],
                                            mode='bilinear').flatten(2, 3)
                memory_mask = memory_mask.view(bs, t, 1, -1).repeat(1, 1, slot_feat.shape[0], 1).permute(0, 1, 3, 2)
                memory_mask = memory_mask.flatten(1,2)
            else:
                memory_mask = None

            if self.use_sa and level_index in [0, 1]:
                output = self.transformer_self_attention_layers[i](
                    output, tgt_mask=None,
                    tgt_key_padding_mask=None,
                    query_pos=pos[level_index]
                )

            output, output_attn = self.transformer_cross_attention_layers[it](
                output, slot_feat,
                memory_mask=memory_mask,
                memory_key_padding_mask=None,  # here we do not apply masking on padded region
                pos=slot_pos, query_pos=pos[level_index]
            )
            output_attention_maps.append(output_attn.view(t, *size_list[level_index], output_attn.shape[-1]))

        return output, output_attention_maps
