import torch.nn as nn
import torch as th
import numpy as np
from nn.residual import SkipConnection, ResidualBlock, LinearResidual
from utils.utils import Gaus2D, SharedObjectsToBatch, BatchToSharedObjects, Prioritize, LambdaModule
from torch.autograd import Function
from einops import rearrange, repeat, reduce

from typing import Tuple, Union, List
import utils



class PriorityEncoder(nn.Module):
    def __init__(self, num_objects, batch_size):
        super(PriorityEncoder, self).__init__()

        self.num_objects = num_objects
        self.register_buffer("indices", repeat(th.arange(num_objects), 'a -> b a', b=batch_size), persistent=False)

        self.index_factor    = nn.Parameter(th.ones(1))
        self.priority_factor = nn.Parameter(th.ones(1))

    def forward(self, priority: th.Tensor) -> th.Tensor:

        if priority is None:
            return None
        
        priority = priority * self.num_objects + th.randn_like(priority) * 0.1
        priority = priority * self.priority_factor 
        priority = priority + self.indices * self.index_factor

        return priority * 25


class GestaltPositionMerge(nn.Module):
    def __init__(
        self, 
        latent_size: Union[int, Tuple[int, int]],
        num_objects: int,
        batch_size: int
    ):

        super(GestaltPositionMerge, self).__init__()
        self.num_objects = num_objects

        self.gaus2d = Gaus2D(size=latent_size)

        self.to_batch  = SharedObjectsToBatch(num_objects)
        self.to_shared = BatchToSharedObjects(num_objects)

        self.prioritize = Prioritize(num_objects)

        self.priority_encoder = PriorityEncoder(num_objects, batch_size)

    def forward(self, position, gestalt, priority):
        
        position   = rearrange(position, 'b (o c) -> (b o) c', o = self.num_objects)
        gestalt    = rearrange(gestalt, 'b (o c) -> (b o) c 1 1', o = self.num_objects)
        priority   = self.priority_encoder(priority)

        position = self.gaus2d(position)
        position = self.to_batch(self.prioritize(self.to_shared(position), priority))

        return position * gestalt

class PatchUpscale(nn.Module):
    def __init__(self, in_channels, out_channels, scale_factor = 4, alpha = 1):
        super(PatchUpscale, self).__init__()
        assert in_channels % out_channels == 0
        
        self.skip = SkipConnection(in_channels, out_channels, scale_factor=scale_factor)

        self.residual = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(
                in_channels  = in_channels, 
                out_channels = in_channels, 
                kernel_size  = 3,
                padding      = 1
            ),
            nn.ReLU(),
            nn.ConvTranspose2d(
                in_channels  = in_channels, 
                out_channels = out_channels, 
                kernel_size  = scale_factor,
                stride       = scale_factor,
            ),
        )

        self.alpha = nn.Parameter(th.zeros(1) + alpha)

    def forward(self, input):
        return self.skip(input) + self.alpha * self.residual(input)


class LociDecoder(nn.Module):
    def __init__(
        self, 
        latent_size: Union[int, Tuple[int, int]],
        gestalt_size: int,
        num_objects: int, 
        img_channels: int,
        hidden_channels: int,
        level1_channels: int,
        num_layers: int,
        batch_size: int
    ): 

        super(LociDecoder, self).__init__()
        self.to_batch  = SharedObjectsToBatch(num_objects)
        self.to_shared = BatchToSharedObjects(num_objects)
        self.level     = 1

        assert(level1_channels % img_channels == 0)
        level1_factor   = level1_channels // img_channels
        print(f"Level1 channels: {level1_channels}")

        self.merge = GestaltPositionMerge(
            latent_size = latent_size,
            num_objects = num_objects,
            batch_size  = batch_size
        )

        self.layer0 = nn.Sequential(
            ResidualBlock(gestalt_size, hidden_channels, input_nonlinearity = False),
            *[ResidualBlock(hidden_channels, hidden_channels) for _ in range(num_layers-1)],
        )

        self.to_mask_level0 = ResidualBlock(hidden_channels, hidden_channels)
        self.to_mask_level1 = PatchUpscale(hidden_channels, 1)

        self.to_mask_level2 = nn.Sequential(
            ResidualBlock(hidden_channels, hidden_channels),
            ResidualBlock(hidden_channels, hidden_channels),
            PatchUpscale(hidden_channels, level1_factor, alpha = 1),
            PatchUpscale(level1_factor, 1, alpha = 1)
        )

        self.to_object_level0 = ResidualBlock(hidden_channels, hidden_channels)
        self.to_object_level1 = PatchUpscale(hidden_channels, img_channels)

        self.to_object_level2 = nn.Sequential(
            ResidualBlock(hidden_channels, hidden_channels),
            ResidualBlock(hidden_channels, hidden_channels),
            PatchUpscale(hidden_channels, level1_channels, alpha = 1),
            PatchUpscale(level1_channels, img_channels, alpha = 1)
        )

        self.mask_alpha   = nn.Parameter(th.zeros(1)+1e-16)
        self.object_alpha = nn.Parameter(th.zeros(1)+1e-16)


    def set_level(self, level):
        self.level = level

    def forward(self, position, gestalt, priority = None):

        maps    = self.layer0(self.merge(position, gestalt, priority))
        mask0   = self.to_mask_level0(maps)
        object0 = self.to_object_level0(maps)

        mask   = self.to_mask_level1(mask0)
        object = self.to_object_level1(object0)

        if self.level > 1:
            mask   = repeat(mask,   'b c h w -> b c (h h2) (w w2)', h2 = 4, w2 = 4)
            object = repeat(object, 'b c h w -> b c (h h2) (w w2)', h2 = 4, w2 = 4)

            mask   = mask   + self.to_mask_level2(mask0) * self.mask_alpha
            object = object + self.to_object_level2(object0) * self.object_alpha

        return self.to_shared(mask), self.to_shared(object)
