import torch
import torch.nn as nn
import torch.nn.init as init
import numpy as np
import torch.nn.functional as F

import math
from einops import rearrange

__USE_DEFAULT_INIT__ = False

class JaxLinear(nn.Linear):
    """ Linear layers with initialization matching the Jax defaults """
    def reset_parameters(self):
        if __USE_DEFAULT_INIT__:
            super().reset_parameters()
        else:
            input_size = self.weight.shape[-1]
            std = math.sqrt(1/input_size)
            init.trunc_normal_(self.weight, std=std, a=-2.*std, b=2.*std)
            if self.bias is not None:
                init.zeros_(self.bias)

class ViTLinear(nn.Linear):
    """ Initialization for linear layers used by ViT """
    def reset_parameters(self):
        if __USE_DEFAULT_INIT__:
            super().reset_parameters()
        else:
            init.xavier_uniform_(self.weight)
            if self.bias is not None:
                init.normal_(self.bias, std=1e-6)

class SRTLinear(nn.Linear):
    """ Initialization for linear layers used in the SRT decoder """
    def reset_parameters(self):
        if __USE_DEFAULT_INIT__:
            super().reset_parameters()
        else:
            init.xavier_uniform_(self.weight)
            if self.bias is not None:
                init.zeros_(self.bias)

class PositionalEncoding(nn.Module):
    def __init__(self, num_octaves=8, start_octave=0):
        super().__init__()
        self.num_octaves = num_octaves
        self.start_octave = start_octave

    def forward(self, coords, rays=None):
        embed_fns = []
        batch_size, num_points, dim = coords.shape

        octaves = torch.arange(self.start_octave, self.start_octave + self.num_octaves)
        octaves = octaves.float().to(coords)
        multipliers = 2**octaves * math.pi
        coords = coords.unsqueeze(-1)
        while len(multipliers.shape) < len(coords.shape):
            multipliers = multipliers.unsqueeze(0)

        scaled_coords = coords * multipliers

        sines = torch.sin(scaled_coords).reshape(batch_size, num_points, dim * self.num_octaves)
        cosines = torch.cos(scaled_coords).reshape(batch_size, num_points, dim * self.num_octaves)

        result = torch.cat((sines, cosines), -1)
        return result

class RayEncoder(nn.Module):
    def __init__(self, pos_octaves=8, pos_start_octave=0, ray_octaves=4, ray_start_octave=0):
        super().__init__()
        self.pos_encoding = PositionalEncoding(num_octaves=pos_octaves, start_octave=pos_start_octave)
        self.ray_encoding = PositionalEncoding(num_octaves=ray_octaves, start_octave=ray_start_octave)

    def forward(self, pos, rays):
        if len(rays.shape) == 4:
            batchsize, height, width, dims = rays.shape
            pos_enc = self.pos_encoding(pos.unsqueeze(1))
            pos_enc = pos_enc.view(batchsize, pos_enc.shape[-1], 1, 1)
            pos_enc = pos_enc.repeat(1, 1, height, width)
            rays = rays.flatten(1, 2)

            ray_enc = self.ray_encoding(rays)
            ray_enc = ray_enc.view(batchsize, height, width, ray_enc.shape[-1])
            ray_enc = ray_enc.permute((0, 3, 1, 2))
            x = torch.cat((pos_enc, ray_enc), 1)
        else:
            pos_enc = self.pos_encoding(pos)
            ray_enc = self.ray_encoding(rays)
            x = torch.cat((pos_enc, ray_enc), -1)

        return x

# Transformer implementation based on ViT
# https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            ViTLinear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            ViTLinear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64, dropout=0., selfatt=True, kv_dim=None):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim=-1)
        if selfatt:
            self.to_qkv = JaxLinear(dim, inner_dim * 3, bias=False)
        else:
            self.to_q = JaxLinear(dim, inner_dim, bias=False)
            self.to_kv = JaxLinear(kv_dim, inner_dim * 2, bias=False)

        self.to_out = nn.Sequential(
            JaxLinear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x, z=None):
        if z is None:
            qkv = self.to_qkv(x).chunk(3, dim=-1)
        else:
            q = self.to_q(x)
            k, v = self.to_kv(z).chunk(2, dim=-1)
            qkv = (q, k, v)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0., selfatt=True, kv_dim=None):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head,
                                       dropout=dropout, selfatt=selfatt, kv_dim=kv_dim)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
            ]))

    def forward(self, x, z=None):
        for attn, ff in self.layers:
            x = attn(x, z=z) + x
            x = ff(x) + x
        return x

class SlotAttention(nn.Module):
    """
    Slot Attention as introduced by Locatello et al.
    """
    def __init__(self, num_slots, input_dim=768, slot_dim=64, hidden_dim=512, iters=3, eps=1e-8,
                 randomize_initial_slots=False):
        super().__init__()

        self.num_slots = num_slots
        self.iters = iters
        self.scale = slot_dim ** -0.5
        self.slot_dim = slot_dim

        self.randomize_initial_slots = randomize_initial_slots
        self.initial_slots = nn.Parameter(torch.randn(num_slots, slot_dim))

        self.eps = eps
        self.to_q = JaxLinear(slot_dim, slot_dim, bias=False)
        self.to_k = JaxLinear(input_dim, slot_dim, bias=False)
        self.to_v = JaxLinear(input_dim, slot_dim, bias=False)

        self.gru = nn.GRUCell(slot_dim, slot_dim)

        self.mlp = nn.Sequential(
            JaxLinear(slot_dim, hidden_dim),
            nn.ReLU(inplace=True),
            JaxLinear(hidden_dim, slot_dim)
        )

        self.norm_input   = nn.LayerNorm(input_dim)
        self.norm_slots   = nn.LayerNorm(slot_dim)
        self.norm_pre_mlp = nn.LayerNorm(slot_dim)

    def forward(self, inputs):
        """
        Args:
            inputs: set-latent representation [batch_size, num_inputs, dim]
        """
        batch_size, num_inputs, dim = inputs.shape

        inputs = self.norm_input(inputs)
        if self.randomize_initial_slots:
            slot_means = self.initial_slots.unsqueeze(0).expand(batch_size, -1, -1)
            slots = torch.distributions.Normal(slot_means, self.embedding_stdev).rsample()
        else:
            slots = self.initial_slots.unsqueeze(0).expand(batch_size, -1, -1)

        k, v = self.to_k(inputs), self.to_v(inputs)

        for _ in range(self.iters):
            slots_prev = slots
            norm_slots = self.norm_slots(slots)

            q = self.to_q(norm_slots)

            dots = torch.einsum('bid,bjd->bij', q, k) * self.scale
            # shape: [batch_size, num_slots, num_inputs]
            attn = dots.softmax(dim=1) + self.eps
            attn = attn / attn.sum(dim=-1, keepdim=True)
            updates = torch.einsum('bjd,bij->bid', v, attn)

            slots = self.gru(updates.flatten(0, 1), slots_prev.flatten(0, 1))
            slots = slots.reshape(batch_size, self.num_slots, self.slot_dim)
            slots = slots + self.mlp(self.norm_pre_mlp(slots))

        return slots

class BlockPrototypeMemory(nn.Module):
    def __init__(self, num_prototypes, num_blocks, d_model):
        super().__init__()

        self.num_prototypes = num_prototypes
        self.num_blocks = num_blocks
        self.d_model = d_model
        self.d_block = self.d_model // self.num_blocks

        # block prototype memory
        self.mem_params = nn.Parameter(torch.zeros(1, num_prototypes, num_blocks, self.d_block), requires_grad=True)
        nn.init.trunc_normal_(self.mem_params)
        self.mem_proj = nn.Sequential(
            linear(self.d_block, 4 * self.d_block),
            nn.ReLU(),
            linear(4 * self.d_block, 4 * self.d_block),
            nn.ReLU(),
            linear(4 * self.d_block, 4 * self.d_block),
            nn.ReLU(),
            linear(4 * self.d_block, self.d_block)
        )

        # norms
        self.norm_mem = BlockLayerNorm(d_model, num_blocks)
        self.norm_query = BlockLayerNorm(d_model, num_blocks)

        # block attention
        self.attn = BlockAttention(d_model, num_blocks)

    def forward(self, queries):
        '''
        queries: B, N, d_model
        return: B, N, d_model
        '''

        B, N, _ = queries.shape

        # get memories
        mem = self.mem_proj(self.mem_params)  # 1, num_prototypes, num_blocks, d_block
        mem = mem.reshape(1, self.num_prototypes, -1)  # 1, num_prototypes, d_model

        # norms
        mem = self.norm_mem(mem)  # 1, num_prototypes, d_model
        queries = self.norm_query(queries)  # B, N, d_model

        # broadcast
        mem = mem.expand(B, -1, -1)  # B, num_prototypes, d_model

        # read
        return self.attn(queries, mem, mem)  # B, N, d_mod

def linear(in_features, out_features, bias=True, weight_init='xavier', gain=1.):
    
    m = nn.Linear(in_features, out_features, bias)
    
    if weight_init == 'kaiming':
        nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')
    else:
        nn.init.xavier_uniform_(m.weight, gain)
    
    if bias:
        nn.init.zeros_(m.bias)
    
    return m

class BlockLayerNorm(nn.Module):
    def __init__(self, size, k):
        super(BlockLayerNorm, self).__init__()

        assert size % k == 0
        self.size = size
        self.k = k
        self.g = size // k
        self.norm = nn.LayerNorm(self.g, elementwise_affine=False)

    def forward(self, x):
        *OTHER, D = x.shape
        x = x.reshape(np.prod(OTHER), self.k, -1)
        x = self.norm(x)
        x = x.reshape(*OTHER, -1)
        return x

class BSAGRU(nn.Module):
    """
        A GRU where the weight matrices have a block structure so that information flow is constrained
        Data is assumed to come in [block1, block2, ..., block_n].
    """

    def __init__(self, num_input, num_hidden, k):
        super(BSAGRU, self).__init__()

        assert num_input % k == 0
        assert num_hidden % k == 0

        self.k = k
        self.gru = nn.GRUCell(num_input, num_hidden)

        self.num_hidden = num_hidden
        self.block_hidden = self.num_hidden // k

        self.num_input = num_input
        self.block_input = self.num_input // k

        self.mask_hidden_input = nn.Parameter(
            torch.eye(self.k, self.k)
                .repeat_interleave(self.block_hidden, dim=0)
                .repeat_interleave(self.block_input, dim=1)
                .repeat(3, 1),
            requires_grad=False
        )

        self.mask_hidden_hidden = nn.Parameter(
            torch.eye(self.k, self.k)
                .repeat_interleave(self.block_hidden, dim=0)
                .repeat_interleave(self.block_hidden, dim=1)
                .repeat(3, 1),
            requires_grad=False
        )

    def blockify_params(self):
        for p in self.gru.parameters():
            p = p.data
            if p.shape == torch.Size([self.num_hidden * 3]):
                pass
            if p.shape == torch.Size([self.num_hidden * 3, self.num_hidden]):
                p.mul_(self.mask_hidden_hidden)
            if p.shape == torch.Size([self.num_hidden * 3, self.num_input]):
                p.mul_(self.mask_hidden_input)

    def forward(self, input, h):

        self.blockify_params()

        return self.gru(input, h)

class BlockLinear(nn.Module):
    def __init__(self, ninp, nout, k, bias=True):
        super(BlockLinear, self).__init__()

        assert ninp % k == 0
        assert nout % k == 0
        self.k = k

        self.w = nn.Parameter(torch.Tensor(self.k, ninp // k, nout // k))
        self.b = nn.Parameter(torch.Tensor(1, nout), requires_grad=bias)

        nn.init.xavier_uniform_(self.w)
        nn.init.zeros_(self.b)

    def forward(self, x):
        """

        :param x: Tensor, (B, D)
        :return:
        """

        *OTHER, D = x.shape
        x = x.reshape(np.prod(OTHER), self.k, -1)
        x = x.permute(1, 0, 2)
        x = torch.bmm(x, self.w)
        x = x.permute(1, 0, 2).reshape(*OTHER, -1)
        x += self.b
        return x

class BlockAttention(nn.Module):

    def __init__(self, d_model, num_blocks):
        super().__init__()

        assert d_model % num_blocks == 0, "d_model must be divisible by num_blocks"
        self.d_model = d_model
        self.num_blocks = num_blocks

    def forward(self, q, k, v):
        """
        q: batch_size x target_len x d_model
        k: batch_size x source_len x d_model
        v: batch_size x source_len x d_model
        attn_mask: target_len x source_len
        return: batch_size x target_len x d_model
        """

        B, T, _ = q.shape
        _, S, _ = k.shape

        q = q.view(B, T, self.num_blocks, -1).transpose(1, 2)
        k = k.view(B, S, self.num_blocks, -1).transpose(1, 2)
        v = v.view(B, S, self.num_blocks, -1).transpose(1, 2)

        q = q * (q.shape[-1] ** (-0.5))
        attn = torch.matmul(q, k.transpose(-1, -2))
        attn = F.softmax(attn, dim=-1)

        output = torch.matmul(attn, v).transpose(1, 2).reshape(B, T, -1)

        return output

class BlockSlotAttention(nn.Module):
    def __init__(self, num_iterations, num_slots,
                 input_size=768, slot_size=64, mlp_hidden_size=512, num_prototypes=16, num_blocks=8, num_bg=1,
                 epsilon=1e-8):
        super().__init__()
        
        self.num_iterations = num_iterations
        self.num_slots = num_slots
        self.input_size = input_size
        self.slot_size = slot_size
        self.mlp_hidden_size = mlp_hidden_size
        self.num_blocks = num_blocks
        self.epsilon = epsilon
        self.scale = slot_size ** -0.5
        self.num_bg = num_bg

        self.initial_bg_slots = nn.Parameter(torch.randn(num_bg, slot_size))
        self.initial_ag_slots = nn.Parameter(torch.randn(1, slot_size))
        self.initial_fg_slots = nn.Parameter(torch.randn(num_slots - num_bg - 1, slot_size))

        # norms
        self.norm_inputs = nn.LayerNorm(input_size)
        self.norm_slots = nn.LayerNorm(slot_size)
        self.norm_mlp_fg = BlockLayerNorm(slot_size, num_blocks)
        self.norm_mlp_bgag = nn.LayerNorm(slot_size)

        # linear maps for the attention module.
        self.project_q = linear(slot_size, slot_size, bias=False)
        self.project_k = linear(input_size, slot_size, bias=False)
        self.project_v = linear(input_size, slot_size, bias=False)
        
        # slot update functions.
        self.gru_bg = nn.GRUCell(slot_size, slot_size)
        self.mlp_bg = nn.Sequential(
            JaxLinear(slot_size, mlp_hidden_size),
            nn.ReLU(inplace=True),
            JaxLinear(mlp_hidden_size, slot_size)
        )

        self.gru_ag = nn.GRUCell(slot_size, slot_size)
        self.mlp_ag = nn.Sequential(
            JaxLinear(slot_size, mlp_hidden_size),
            nn.ReLU(inplace=True),
            JaxLinear(mlp_hidden_size, slot_size)
        )

        self.gru_fg = BSAGRU(slot_size, slot_size, num_blocks)
        self.mlp_fg = nn.Sequential(
            BlockLinear(slot_size, mlp_hidden_size, num_blocks),
            nn.ReLU(),
            BlockLinear(mlp_hidden_size, slot_size, num_blocks))

        self.prototype_memory = BlockPrototypeMemory(num_prototypes, num_blocks, slot_size)

    def forward(self, inputs):
        B, _, _ = inputs.size()

        # initialize slots
        slot_bg = self.initial_bg_slots.unsqueeze(0).expand(B, -1, -1)
        slot_fg = self.initial_fg_slots.unsqueeze(0).expand(B, -1, -1)
        slot_ag = self.initial_ag_slots.unsqueeze(0).expand(B, -1, -1)

        slots = torch.concat((slot_bg, slot_fg, slot_ag), dim = 1)

        # setup key and value
        inputs = self.norm_inputs(inputs)
        k = self.project_k(inputs)                  # Shape: [batch_size, T, num_inputs, slot_size].
        v = self.project_v(inputs)                  # Shape: [batch_size, T, num_inputs, slot_size].
        
        for i in range(self.num_iterations):
            slots_prev = slots
            slots = self.norm_slots(slots)          # Shape: [batch_size, num_slots, slot_size].

            q = self.project_q(slots)               # Shape: [batch_size, num_slots, slot_size].

            # Attention.
            dots = torch.einsum('bid,bjd->bij', q, k) * self.scale  # Shape: [batch_size, num_slots, num_inputs].
            attn = dots.softmax(dim=1) + self.epsilon
            attn = attn / attn.sum(dim=-1, keepdim=True)
            updates = torch.einsum('bjd,bij->bid', v, attn)         # Shape: [batch_size, num_slots, slot_size]

            # Slot update.
            slots_bg = self.gru_bg(updates[:, :self.num_bg].flatten(0, 1), slots_prev[:, :self.num_bg].flatten(0, 1))
            slots_bg = slots_bg.reshape(B, self.num_bg, self.slot_size)
            slots_bg = slots_bg + self.mlp_bg(self.norm_mlp_bgag(slots_bg))

            slots_ag = self.gru_ag(updates[:, -1].unsqueeze(1).flatten(0, 1), slots_prev[:, -1].unsqueeze(1).flatten(0, 1))
            slots_ag = slots_ag.reshape(B, 1, self.slot_size)
            slots_ag = slots_ag + self.mlp_ag(self.norm_mlp_bgag(slots_ag))

            slots_fg = self.gru_fg(
                updates[:, self.num_bg:-1].flatten(0, 1),
                slots_prev[:, self.num_bg:-1].flatten(0, 1)
            )
            slots_fg = slots_fg.view(-1, self.num_slots - self.num_bg - 1, self.slot_size)
            slots_fg = slots_fg + self.mlp_fg(self.norm_mlp_fg(slots_fg))
            slots_fg = self.prototype_memory(slots_fg)

            slots = torch.cat((slots_bg, slots_fg, slots_ag), dim=1)

        return slots