import numpy as np
import torch
import torch.nn as nn
from .layers import RayEncoder, Transformer, SlotAttention, BlockSlotAttention

class SRTConvBlock(nn.Module):
    def __init__(self, idim, hdim=None, odim=None):
        super().__init__()
        if hdim is None:
            hdim = idim

        if odim is None:
            odim = 2 * hdim

        conv_kwargs = {'bias': False, 'kernel_size': 3, 'padding': 1}
        self.layers = nn.Sequential(
            nn.Conv2d(idim, hdim, stride=1, **conv_kwargs),
            nn.ReLU(),
            nn.Conv2d(hdim, odim, stride=2, **conv_kwargs),
            nn.ReLU())

    def forward(self, x):
        return self.layers(x)


class ImprovedSRTEncoder(nn.Module):
    """
    Scene Representation Transformer Encoder with the improvements from Appendix A.4 in the OSRT paper.
    """
    def __init__(self, num_conv_blocks=3, num_att_blocks=5, pos_start_octave=0, idim=183, hdim=96, cur_hdim=192):
        super().__init__()
        self.ray_encoder = RayEncoder(pos_octaves=15, pos_start_octave=pos_start_octave,
                                      ray_octaves=15)

        conv_blocks = [SRTConvBlock(idim=idim, hdim=hdim)]
        for i in range(1, num_conv_blocks):
            conv_blocks.append(SRTConvBlock(idim=cur_hdim, odim=None))
            cur_hdim *= 2

        self.conv_blocks = nn.Sequential(*conv_blocks)

        self.per_patch_linear = nn.Conv2d(cur_hdim, cur_hdim , kernel_size=1) #nn.Conv2d(cur_hdim, 768, kernel_size=1)

        self.transformer = Transformer(cur_hdim , depth=num_att_blocks, heads=12, dim_head=64, #Transformer(768, depth=num_att_blocks, heads=12, dim_head=64, mlp_dim=1536, selfatt=True)
                                       mlp_dim=1536, selfatt=True)

    def forward(self, images, camera_pos, rays):
        """
        Args:
            images: [batch_size, num_images, 3, height, width]. Assume the first image is canonical.
            camera_pos: [batch_size, num_images, 3]
            rays: [batch_size, num_images, height, width, 3]
        Returns:
            scene representation: [batch_size, num_patches, channels_per_patch]
        """

        batch_size, num_images = images.shape[:2]

        x = images.flatten(0, 1)                        #[B * T, 3, H, W]
        camera_pos = camera_pos.flatten(0, 1)           #[B * T, 30]
        rays = rays.flatten(0, 1)                       #[B * T, H, W, 3]

        ray_enc = self.ray_encoder(camera_pos, rays)    #[B * T, 180, H, W]
        x = torch.cat((x, ray_enc), 1)
        x = self.conv_blocks(x)                         #[B * T, 768, H, W]
        x = self.per_patch_linear(x)                    #[B * T, 768, 16, 16]
        x = x.flatten(2, 3).permute(0, 2, 1)            #[B * T, 256, 768]

        patches_per_image, channels_per_patch = x.shape[1:]
        x = x.reshape(batch_size, num_images * patches_per_image, channels_per_patch)   #[B, 3 * 256, 768]

        x = self.transformer(x)

        return x


class OSRTEncoder(nn.Module):
    def __init__(self, pos_start_octave=0, num_slots=6, slot_dim=64, slot_iters=1, mlp_hidden_size=512, randomize_initial_slots=False):
        super().__init__()
        self.srt_encoder = ImprovedSRTEncoder(num_conv_blocks=3, num_att_blocks=5,
                                             pos_start_octave=pos_start_octave)

        self.slot_attention = SlotAttention(num_slots, slot_dim=slot_dim, iters=slot_iters, hidden_dim=mlp_hidden_size,
                                            randomize_initial_slots=randomize_initial_slots)

    def forward(self, images, camera_pos, rays):
        set_latents = self.srt_encoder(images, camera_pos, rays)
        slot_latents = self.slot_attention(set_latents)
        return slot_latents

class BSAEncoder(nn.Module):
    def __init__(self, pos_start_octave=0, num_slots=6, num_bg=1, slot_dim=64, slot_iters=1, num_blocks=8, mlp_hidden_size=512, num_prototypes=16):
        super().__init__()
        self.srt_encoder = ImprovedSRTEncoder(num_conv_blocks=3, num_att_blocks=5,
                                            pos_start_octave=pos_start_octave)

        self.slot_attention = BlockSlotAttention(num_iterations=slot_iters, num_slots=num_slots, num_bg=num_bg, num_prototypes=num_prototypes,
                                            slot_size=slot_dim, num_blocks=num_blocks, mlp_hidden_size=mlp_hidden_size)

    def forward(self, images, camera_pos, rays):
        set_latents = self.srt_encoder(images, camera_pos, rays)
        slot_latents = self.slot_attention(set_latents)
        return slot_latents
