import torch
import torch.nn as nn

from .layers import RayEncoder, Transformer, SRTLinear

class RayPredictor(nn.Module):
    def __init__(self, num_att_blocks=2, pos_start_octave=0, out_dims=3, input_mlp=None, output_mlp=None,
                 z_dim=1536):
        super().__init__()
        if input_mlp is not None:  # Input MLP added with OSRT
            self.input_mlp = nn.Sequential(
                SRTLinear(180, 360),
                nn.ReLU(),
                SRTLinear(360, 180))
        else:
            self.input_mlp = None

        self.query_encoder = RayEncoder(pos_octaves=15, pos_start_octave=pos_start_octave,
                                        ray_octaves=15)

        self.transformer = Transformer(180, depth=num_att_blocks, heads=12, dim_head=z_dim // 12,
                                       mlp_dim=z_dim * 2, selfatt=False, kv_dim=z_dim)
        if output_mlp is not None:
            self.output_mlp = nn.Sequential(
                SRTLinear(180, 128),
                nn.ReLU(),
                SRTLinear(128, out_dims))
        else:
            self.output_mlp = None

    def forward(self, z, x, rays):
        """
        Args:
            z: scene encoding [batch_size, num_patches, patch_dim]
            x: query camera positions [batch_size, num_rays, 3]
            rays: query ray directions [batch_size, num_rays, 3]
        """
        queries = self.query_encoder(x, rays)
        if self.input_mlp is not None:
            queries = self.input_mlp(queries)

        output = self.transformer(queries, z)
        if self.output_mlp is not None:
            output = self.output_mlp(output)
        return output, queries
    
class MixingBlock(nn.Module):
    def __init__(self, input_dim=180, slot_dim=1536, att_dim=1536, layer_norm=False):
        super().__init__()
        self.to_q = SRTLinear(input_dim, att_dim, bias=False)
        self.to_k = SRTLinear(slot_dim, att_dim, bias=False)
        if layer_norm:
            self.norm1 = nn.LayerNorm(input_dim)
            self.norm2 = nn.LayerNorm(slot_dim)

        self.scale = att_dim ** -0.5
        self.layer_norm = layer_norm

    def forward(self, x, slot_latents):
        """
        Args:
            x: query ray features [batch_size, num_rays, input_dim]
            slot_latents: slot scene representation [batch_size, num_slots, slot_dim]
        """
        if self.layer_norm:
            x = self.norm1(x)
        q = self.to_q(x)
        k = self.to_k(slot_latents)

        dots = torch.einsum('bid,bsd->bis', q, k) * self.scale
        w = dots.softmax(dim=2)  # [batch_size, num_rays, num_slots]
        s = (w.unsqueeze(-1) * slot_latents.unsqueeze(1)).sum(2)

        if self.layer_norm:
            s = self.norm2(s)

        return s, w

class RenderMLP(nn.Module):
    def __init__(self, input_dim=1536+180, hidden_dim=1536):
        super().__init__()
        # According to Mehdi, this uses Leaky ReLUs, and a Sigmoid at the end
        self.net = nn.Sequential(
            SRTLinear(input_dim, hidden_dim),
            nn.LeakyReLU(),
            SRTLinear(hidden_dim, hidden_dim),
            nn.LeakyReLU(),
            SRTLinear(hidden_dim, hidden_dim),
            nn.LeakyReLU(),
            SRTLinear(hidden_dim, hidden_dim),
            nn.LeakyReLU(),
            SRTLinear(hidden_dim, 3),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.net(x)
        
class SlotMixerDecoder(nn.Module):
    """ The Slot Mixer Decoder proposed in the OSRT paper. """
    def __init__(self, num_att_blocks=2, pos_start_octave=0, slot_dim=1536, layer_norm=False, att_dim=1536, hidden_dim=1536):
        super().__init__()
        self.allocation_transformer = RayPredictor(num_att_blocks=num_att_blocks,
                                                   pos_start_octave=pos_start_octave,
                                                   input_mlp=None, z_dim=slot_dim)
                                                   #input_mlp=False, z_dim=slot_dim)

        self.mixing_block = MixingBlock(input_dim=180, att_dim=att_dim, layer_norm=layer_norm, slot_dim=slot_dim)
        self.render_mlp = RenderMLP(input_dim=slot_dim+180, hidden_dim=hidden_dim)

    def forward(self, slot_latents, camera_pos, rays, **kwargs):
        x, query_rays = self.allocation_transformer(slot_latents, camera_pos, rays)
        slot_mix, slot_weights = self.mixing_block(x, slot_latents)
        pixels = self.render_mlp(torch.cat((slot_mix, query_rays), -1))

        pixels_k = []
        slots = []

        for i in range(slot_latents.size(1)):
            pixel_k = self.render_mlp(torch.cat((slot_mix, query_rays), -1))
            pixels_k.append(pixel_k)

            slot_k = torch.matmul(slot_weights[..., i][..., None], slot_latents[:, i][:, None]) # [B, Npts, D]
            slot_k = slot_k
            slots.append(slot_k)

        return pixels, torch.stack(pixels_k, dim=1), slot_weights, torch.stack(slots, dim=1)