import torch
import einops
import random
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from .attention import Block
from .vision_transformer import restart_from_checkpoint
from .dconv import DeformableConv2d
from .swin_transformer_v2 import _create_swin_transformer_v2, _create_swin_transformer_v2_up
#from timm.models.swin_transformer_v2 import _create_swin_transformer_v2

def build_grid(resolution):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    ranges = [np.linspace(0., 1., num=res) for res in resolution]
    grid = np.meshgrid(*ranges, sparse=False, indexing="ij")
    grid = np.stack(grid, axis=-1)
    grid = np.reshape(grid, [res for res in resolution] + [-1])
    grid = np.expand_dims(grid, axis=0)
    grid = grid.astype(np.float32)
    return torch.tensor(np.concatenate([grid, 1.0 - grid], axis=-1)).to(device)


class SoftPositionEmbed(nn.Module):
  """Adds soft (include spatio-temporal) positional embedding with learnable projection."""
  def __init__(self, hidden_size, resolution):
        """Builds the soft position embedding layer.

        Args:
          hidden_size: Size of input feature dimension.
          resolution: Tuple of integers specifying width and height of grid.
        """
        super(SoftPositionEmbed, self).__init__()
        self.proj = nn.Linear(len(resolution)*2, hidden_size)
        self.grid = build_grid(resolution)
  def forward(self, inputs):
        return inputs + self.proj(self.grid)


def spatial_broadcast(slots, resolution):
    """Broadcast slot features to a 2D grid and collapse slot dimension."""
    # `slots` has shape: [batch_size, num_slots, slot_size].
    slots = torch.reshape(slots, [-1, slots.shape[-1]])[:, None, None, :]
    grid = einops.repeat(slots, 'b_n i j d -> b_n (tilei i) (tilej j) d', tilei=resolution[0], tilej=resolution[1])
    # `grid` has shape: [batch_size*num_slots, height, width, slot_size].
    return grid


# def spatial_flatten(x):
#     return torch.reshape(x, [-1, x.shape[1] * x.shape[2], x.shape[-1]])


def unstack_and_split(x, batch_size, num_channels=3):
    """Unstack batch dimension and split into channels and alpha mask."""
    unstacked = einops.rearrange(x, '(b s) c h w -> b s c h w', b=batch_size)
    channels, masks = torch.split(unstacked, [num_channels, 1], dim=2)
    return channels, masks


class SlotAttention(nn.Module):
    """Slot Attention module."""

    def __init__(self, num_slots, encoder_dims, iters=3, hidden_dim=128, eps=1e-8):
        """Builds the Slot Attention module.
        Args:
            iters: Number of iterations.
            num_slots: Number of slots.
            encoder_dims: Dimensionality of slot feature vectors.
            hidden_dim: Hidden layer size of MLP.
            eps: Offset for attention coefficients before normalization.
        """
        super(SlotAttention, self).__init__()
        
        self.eps = eps
        self.iters = iters
        self.num_slots = num_slots
        self.scale = encoder_dims ** -0.5
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

        self.norm_input = nn.LayerNorm(encoder_dims)
        self.norm_slots = nn.LayerNorm(encoder_dims)
        self.norm_pre_ff = nn.LayerNorm(encoder_dims)

        # Parameters for Gaussian init (shared by all slots).
        # self.slots_mu = nn.Parameter(torch.randn(1, 1, encoder_dims))
        # self.slots_sigma = nn.Parameter(torch.randn(1, 1, encoder_dims))

        self.slots_embedding = nn.Embedding(num_slots, encoder_dims)
#         self.order_embedding = nn.Parameter(torch.randn(1, 2*16*28, encoder_dims) * .02)
        # Linear maps for the attention module.
        self.project_q = nn.Linear(encoder_dims, encoder_dims)
        self.project_k = nn.Linear(encoder_dims, encoder_dims)
        self.project_v = nn.Linear(encoder_dims, encoder_dims)

        # Slot update functions.
        self.gru = nn.GRUCell(encoder_dims, encoder_dims)

        hidden_dim = max(encoder_dims, hidden_dim)
        self.mlp = nn.Sequential(
            nn.Linear(encoder_dims, hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, encoder_dims)
        )

    def forward(self, inputs, init_slots=None, num_slots=None):
        # inputs has shape [batch_size, num_inputs, inputs_size].
        # inputs = inputs + self.order_embedding
        inputs = self.norm_input(inputs)  # Apply layer norm to the input.
        k = self.project_k(inputs)  # Shape: [batch_size, num_inputs, slot_size].
        v = self.project_v(inputs)  # Shape: [batch_size, num_inputs, slot_size].

        # Initialize the slots. Shape: [batch_size, num_slots, slot_size].
        b, n, d = inputs.shape
        n_s = num_slots if num_slots is not None else self.num_slots

        # random slots initialization,
        # mu = self.slots_mu.expand(b, n_s, -1)
        # sigma = self.slots_sigma.expand(b, n_s, -1)
        # slots = torch.normal(mu, sigma)

        # learnable slots initialization
        if init_slots == None:
            slots = self.slots_embedding(torch.arange(0, n_s).expand(b, n_s).to(self.device))
        else:
            slots = init_slots
        # Multiple rounds of attention.
        for _ in range(self.iters):
            slots_prev = slots
            slots = self.norm_slots(slots)

            # Attention.
            q = self.project_q(slots)  # Shape: [batch_size, num_slots, slot_size].
            dots = torch.einsum('bid,bjd->bij', q, k) * self.scale
            attn = dots.softmax(dim=1) + self.eps
            attn = attn / attn.sum(dim=-1, keepdim=True)  # weighted mean.

            updates = torch.einsum('bjd,bij->bid', v, attn)
            # `updates` has shape: [batch_size, num_slots, slot_size].

            # Slot update.
            slots = self.gru(
                updates.reshape(-1, d),
                slots_prev.reshape(-1, d)
            )
            slots = slots.reshape(b, -1, d)
            slots = slots + self.mlp(self.norm_pre_ff(slots))

        return slots, dots.softmax(dim=1) + self.eps


class SlotAttentionAutoEncoder(nn.Module):
    """Slot Attention-based auto-encoder for object discovery."""
    def __init__(self, resolution, 
                       num_slots, 
                       num_o=3, 
                       num_t=3,
                       in_channels=3, 
                       out_channels=3, 
                       hid_dim=32,
                       iters=5, 
                       attn_drop=0.1,
                       replicate=False,
                       num_frames=7
                ):
        """Builds the Slot Attention-based Auto-encoder.
        Args:
            resolution: Tuple of integers specifying width and height of input image
            num_slots: Number of slots in Slot Attention.
            iters: Number of iterations in Slot Attention.
        """
        super(SlotAttentionAutoEncoder, self).__init__()

        self.iters = iters
        self.num_slots = num_slots
        self.resolution = resolution
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.replicate = replicate
        
        down_times = 4
        print(down_times)
        self.encoder_dims = 384
        
        model_kwargs = dict(window_size=12, embed_dim=96, depths=(2, 2, 6), num_heads=(3, 6, 12), input_size=(3, resolution[0], resolution[1]))
        self.encoder = _create_swin_transformer_v2('swinv2_tiny_window16_256', **model_kwargs)

        self.end_size = (resolution[0] // 2**down_times, resolution[1] // 2**down_times)
#         self.transformer = Block(dim=self.encoder_dims, num_heads=16)
        self.decoder_pos = nn.Parameter(torch.randn(1, self.end_size[0], self.end_size[1], self.encoder_dims) * .02)
        self.num_t = num_t
        self.T = num_frames
        self.temporal_transformer = nn.ModuleList([Block(
                                                        dim=self.encoder_dims, 
                                                        num_heads=8, 
                                                        n_token=num_frames*np.prod(self.end_size), 
                                                        attn_drop=attn_drop) 
                                                        for j in range(self.num_t)])
        self.num_o = num_o
        self.flow_transformer = nn.ModuleList([Block(
                                                    dim=self.encoder_dims, 
                                                    num_heads=8, 
                                                    n_token=np.prod(self.end_size), 
                                                    attn_drop=attn_drop) 
                                                    for j in range(self.num_o)])
        self.decoder_initial_size = self.end_size
        self.norm = nn.LayerNorm(2*self.encoder_dims)
        self.mlp = nn.Sequential(
            DeformableConv2d(2*self.encoder_dims, 2*self.encoder_dims),
            nn.ReLU(inplace=True),
            DeformableConv2d(2*self.encoder_dims, self.encoder_dims))

        self.slot_attention = SlotAttention(
            iters=self.iters,
            num_slots=self.num_slots,
            encoder_dims=self.encoder_dims,
            hidden_dim=self.encoder_dims)
        self.decoder_dims = self.encoder_dims // 4

        model_kwargs = dict(window_size=12, embed_dim=96, depths=(2, 2, 2), num_heads=(12, 6, 3), input_size=(3, resolution[0], resolution[1]))
        self.decoder_block = _create_swin_transformer_v2_up('swinv2_up_tiny_window16_256', **model_kwargs)
            
        self.decoder_conv = nn.Conv2d(self.decoder_dims, self.out_channels + 1, kernel_size=5, padding=2, stride=1)


    def make_encoder(self, in_channels, encoder_arch):
        layers = []
        down_factor = 0
        for v in encoder_arch:
            if v == 'MP':
                layers += [nn.MaxPool2d(2, stride=2, ceil_mode=True)]
                down_factor += 1
            else:
                conv1 = nn.Conv2d(in_channels, v, kernel_size=5, padding=2)
                conv2 = nn.Conv2d(v, v, kernel_size=5, padding=2)

                layers += [conv1, nn.InstanceNorm2d(v, affine=True), nn.ReLU(inplace=True),
                           conv2, nn.InstanceNorm2d(v, affine=True), nn.ReLU(inplace=True)]
                in_channels = v
        return nn.Sequential(*layers), 2 ** down_factor

    def kl_loss(self, src, target):
        loss = F.kl_div(src.log(), target) + F.kl_div(target.log(), src)
        return loss

    def forward(self, image, p_s=None, flow=None):
        ## input: 'image' has shape B, 5(T), C, H, W  
        ##        'p_s' has shape B, 4 (random sample)
        ##        'flow' has shape B, 6, C, H, W 
        ## output:
        ###### 'recon_flow' has shape B, 3, 2, C, H, W 
        ###### 'recons' has shape B, 3, 2, 2(num_slot), C, H, W 
        ###### 'masks' has shape B, 3, 2, 2(num_slot), 1, H, W 
        # Convolutional encoder with position embedding.
        bs = image.shape[0]
        image_t = einops.rearrange(image, 'b t c h w -> (b t) c h w')
        x = self.encoder(image_t)  # CNN Backbone/ DINO backbone

        x = einops.rearrange(x, '(b t) c h w -> b (t h w) c', t=self.T) ##spatial_flatten
        #####transformer spatio-temporal fusion 
        if self.num_t > 0:
            for block in self.temporal_transformer:
                x = block(x)
        x = einops.rearrange(x, 'b (t hw) c -> b t hw c', t=self.T) ##spatial-temporal_map
#         [vid['gap_-1_3'][ind-1], vid['gap_-2_2'][ind], vid['gap_-3_1'][ind+1]]
        if p_s == None: ##evalutaion 
            p_s = [[0,0],[1,1],[2,2],[3,3],[4,4],[5,5],[6,6]]
            x_regroup = [torch.cat([x[:, p[0]], x[:, p[1]]], dim=-1) for p in p_s]
            x_regroup = torch.stack(x_regroup, dim=1) # b, 14, N, 2C
            recon_combined, recons, masks, slots = self.decode(x_regroup, ts=self.T)
        else:
            idx = torch.arange(0, bs)
            x_regroup = [torch.cat([x[idx, p//2], x[idx, p_s[:, p]]], dim=-1) for p 
                         in range(2*self.T)]
            x_regroup = torch.stack(x_regroup, dim=1) # b, 10, N, 2C
            recon_combined, recons, masks, slots = self.decode(x_regroup, ts=self.T)
        
        if self.replicate:
            ##GT of flow is 1.
            x_static = torch.cat([x, x], dim=-1) # b, 7, N(HxW), 2C
            recon_combined_s, recons_s, masks_s, slots_s = self.decode(x_static, ts=self.T)
            out_static = (recon_combined_s, recons_s, masks_s, slots_s)

        return recon_combined, recons, masks, slots, out_static if self.replicate else None

    def decode(self, x, ts=7): 
        bs = x.shape[0]
        
        x = einops.rearrange(x, 'b s n c -> (b s) n c')
        x_static = x.chunk(2, dim=-1)[0]
        x = self.norm(x) # Feedforward network on set and half the channels.
        x = self.mlp(einops.rearrange(x, 'bs (h w) c -> bs c h w', h=self.end_size[0]))
        x = einops.rearrange(x, 'bs c h w -> bs (h w) c')
        if self.num_o > 0:
            for block in self.flow_transformer:
                x = block(x)
        x = x + x_static ## highlight RGB cues

        # `x_regroup` has shape: [batch_size*6, width*height, input_size].
        # Slot Attention module. serve as decoder to reconstruct the flow (RGB)
        # `slots` has shape: [batch_size*6, num_slots, slot_size].
        # attn has shape: [batch_size*6, 2, N(16*28)]
        slots, attn = self.slot_attention(x) 
        # attn = einops.rearrange(attn, 'b s n -> b n s')
        # Spatial broadcast decoder.
        x = spatial_broadcast(slots, self.decoder_initial_size)
        # `x` has shape: [batch_size*6*num_slots, height_init, width_init, slot_size].
        x = x + self.decoder_pos
        #********************************************#
        x = self.decoder_block(x)
        x = self.decoder_conv(x)
        # `x` has shape: [batch_size*6(12)*num_slots, num_channels+1, height, width].
        
        # Undo combination of slot and batch dimension; split alpha masks.
        recons, masks = unstack_and_split(x, batch_size=int(x.shape[0]//2), num_channels=self.out_channels)
        # `recons` has shape: [batch_size*6, num_slots, num_channels, height, width].
        # `masks` has shape: [batch_size*6, num_slots, 1, height, width].
        recons = einops.rearrange(recons, '(b t p) s c h w -> b t p s c h w', b=bs, t=ts)
        masks = einops.rearrange(masks, '(b t p) s c h w -> b t p s c h w', b=bs, t=ts)
        # Normalize alpha masks over slots.
        masks = torch.softmax(masks, axis=3)
        recon_combined = torch.sum(recons * masks, axis=3)  # Recombine image.
        # `recon_combined` has shape: [batch_size, temporal, 2(t), num_channels, height, width].
        return recon_combined, recons, masks, slots


