import torch.nn as nn
import torch as th
import numpy as np
from utils.utils import Gaus2D, SharedObjectsToBatch, BatchToSharedObjects, Prioritize, LambdaModule, ForcedAlpha, Binarize, PartialBinarize
from nn.eprop_lstm import EpropLSTM
from nn.residual import ResidualBlock, SkipConnection
from nn.eprop_gate_l0rd import EpropGateL0rd
from torch.autograd import Function
from nn.vae import VariationalFunction
from einops import rearrange, repeat, reduce

from typing import Tuple, Union, List
import utils
import cv2



class NeighbourChannels(nn.Module):
    def __init__(self, channels):
        super(NeighbourChannels, self).__init__()

        self.register_buffer("weights", th.ones(channels, channels, 1, 1), persistent=False)

        for i in range(channels):
            self.weights[i,i,0,0] = 0

    def forward(self, input: th.Tensor):
        return nn.functional.conv2d(input, self.weights)

class InputPreprocessing(nn.Module):
    def __init__(self, num_objects: int, size: Union[int, Tuple[int, int]]): 
        super(InputPreprocessing, self).__init__()
        self.num_objects = num_objects
        self.neighbours  = NeighbourChannels(num_objects)
        self.gaus2d      = Gaus2D(size)
        self.to_batch    = LambdaModule(lambda x: rearrange(x, 'b (o c) -> (b o) c', o = num_objects))
        self.to_shared   = BatchToSharedObjects(num_objects)

    def forward(
        self, 
        input: th.Tensor, 
        error: th.Tensor, 
        mask: th.Tensor,
        object: th.Tensor,
        position: th.Tensor,
        maskraw: th.Tensor
    ):
        batch_size = input.shape[0]
        size       = input.shape[2:]
        
        bg_mask     = repeat(mask[:,-1:], 'b 1 h w -> b c h w', c = self.num_objects)
        mask        = mask[:,:-1]
        mask_others = self.neighbours(mask)
        maskraw     = maskraw[:,:-1]

        own_gaus2d    = self.to_shared(self.gaus2d(self.to_batch(position)))
        
        input         = repeat(input,            'b c h w -> b o c h w', o = self.num_objects)
        error         = repeat(error,            'b 1 h w -> b o 1 h w', o = self.num_objects)
        bg_mask       = rearrange(bg_mask,       'b o h w -> b o 1 h w')
        mask_others   = rearrange(mask_others,   'b o h w -> b o 1 h w')
        mask          = rearrange(mask,          'b o h w -> b o 1 h w')
        object        = rearrange(object,        'b (o c) h w -> b o c h w', o = self.num_objects)
        own_gaus2d    = rearrange(own_gaus2d,    'b o h w -> b o 1 h w')
        maskraw       = rearrange(maskraw,       'b o h w -> b o 1 h w')
        
        output = th.cat((input, error, mask, mask_others, bg_mask, object, own_gaus2d, maskraw), dim=2) 
        output = rearrange(output, 'b o c h w -> (b o) c h w')

        return output

class PatchDownConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size = 4, alpha = 1):
        super(PatchDownConv, self).__init__()
        assert out_channels % in_channels == 0
        
        self.layers = nn.Conv2d(
            in_channels  = in_channels, 
            out_channels = out_channels, 
            kernel_size  = kernel_size,
            stride       = kernel_size,
        )

        self.alpha = nn.Parameter(th.zeros(1) + alpha)
        self.kernel_size = 4
        self.channels_factor = out_channels // in_channels

    def forward(self, input: th.Tensor):
        k = self.kernel_size
        c = self.channels_factor
        skip = reduce(input, 'b c (h h2) (w w2) -> b c h w', 'mean', h2=k, w2=k)
        skip = repeat(skip, 'b c h w -> b (c n) h w', n=c)
        return skip + self.alpha * self.layers(input)

class AggressiveConvToGestalt(nn.Module):
    def __init__(self, channels, gestalt_size, size: Union[int, Tuple[int, int]]):
        super(AggressiveConvToGestalt, self).__init__()

        assert gestalt_size % channels == 0 or channels % gestalt_size == 0
        
        self.layers = nn.Sequential(
            nn.Conv2d(
                in_channels  = channels, 
                out_channels = gestalt_size, 
                kernel_size  = 5,
                stride       = 3,
                padding      = 3
            ),
            nn.ReLU(),
            nn.Conv2d(
                in_channels  = gestalt_size, 
                out_channels = gestalt_size, 
                kernel_size  = ((size[0] + 1)//3 + 1, (size[1] + 1)//3 + 1)
            )
        )
        if gestalt_size > channels:
            self.skip = nn.Sequential(
                LambdaModule(lambda x: reduce(x, 'b c h w -> b c 1 1', 'mean')),
                LambdaModule(lambda x: repeat(x, 'b c 1 1 -> b (c n) 1 1', n = gestalt_size // channels))
            )
        else:
            self.skip = LambdaModule(lambda x: reduce(x, 'b (c n) h w -> b c 1 1', 'mean', n = channels // gestalt_size))


    def forward(self, input: th.Tensor):
        return self.skip(input) + self.layers(input)

class PixelToPosition(nn.Module):
    def __init__(self, size: Union[int, Tuple[int, int]]):
        super(PixelToPosition, self).__init__()

        self.register_buffer("grid_x", th.arange(size[0]), persistent=False)
        self.register_buffer("grid_y", th.arange(size[1]), persistent=False)

        self.grid_x = (self.grid_x / (size[0]-1)) * 2 - 1
        self.grid_y = (self.grid_y / (size[1]-1)) * 2 - 1

        self.grid_x = self.grid_x.view(1, 1, -1, 1).expand(1, 1, *size).clone()
        self.grid_y = self.grid_y.view(1, 1, 1, -1).expand(1, 1, *size).clone()

        self.size = size

    def forward(self, input: th.Tensor):
        assert input.shape[1] == 1

        input = rearrange(input, 'b c h w -> b c (h w)')
        input = th.softmax(input, dim=2)
        input = rearrange(input, 'b c (h w) -> b c h w', h = self.size[0], w = self.size[1])

        x = th.sum(input * self.grid_x, dim=(2,3))
        y = th.sum(input * self.grid_y, dim=(2,3))

        return th.cat((x,y),dim=1)

class PixelToSTD(nn.Module):
    def __init__(self):
        super(PixelToSTD, self).__init__()
        self.alpha = ForcedAlpha()

    def forward(self, input: th.Tensor):
        assert input.shape[1] == 1
        return self.alpha(reduce(th.sigmoid(input - 10), 'b c h w -> b c', 'mean'))

class PixelToPriority(nn.Module):
    def __init__(self):
        super(PixelToPriority, self).__init__()

    def forward(self, input: th.Tensor):
        assert input.shape[1] == 1
        return reduce(th.tanh(input), 'b c h w -> b c', 'mean')

class LociEncoder(nn.Module):
    def __init__(
        self,
        input_size: Union[int, Tuple[int, int]], 
        latent_size: Union[int, Tuple[int, int]],
        num_objects: int, 
        img_channels: int,
        hidden_channels: int,
        level1_channels: int,
        num_layers: int,
        gestalt_size: int,
        vae_factor: float,
        batch_size: int,
        bottleneck: str
    ):
        super(LociEncoder, self).__init__()

        self.num_objects  = num_objects
        self.latent_size  = latent_size
        self.level        = 1

        self.to_shared = LambdaModule(lambda x: rearrange(x, '(b o) c -> b (o c)', o = self.num_objects))

        print(f"Level1 channels: {level1_channels}")

        self.preprocess = nn.ModuleList([
            InputPreprocessing(num_objects, (input_size[0] // 16, input_size[1] // 16)),
            InputPreprocessing(num_objects, (input_size[0] //  4, input_size[1] //  4)),
            InputPreprocessing(num_objects, (input_size[0], input_size[1]))
        ])

        self.to_channels = nn.ModuleList([
            SkipConnection(img_channels, hidden_channels),
            SkipConnection(img_channels, level1_channels),
            SkipConnection(img_channels, img_channels)
        ])

        self.layers2 = nn.Sequential(
            PatchDownConv(img_channels, level1_channels, alpha = 1e-16),
            *[ResidualBlock(level1_channels, level1_channels, alpha_residual=True) for _ in range(num_layers)]
        )

        self.layers1 = PatchDownConv(level1_channels, hidden_channels)

        self.layers0 = nn.Sequential(
            *[ResidualBlock(hidden_channels, hidden_channels) for _ in range(num_layers)]
        )

        self.position_encoder = nn.Sequential(
            *[ResidualBlock(hidden_channels, hidden_channels) for _ in range(num_layers)],
            ResidualBlock(hidden_channels, 3),
        )

        self.xy_encoder = PixelToPosition(latent_size)
        self.std_encoder = PixelToSTD()
        self.priority_encoder = PixelToPriority()

        if bottleneck == "binar":
            print("Binary bottleneck")
            self.gestalt_encoder = nn.Sequential(
                *[ResidualBlock(hidden_channels, hidden_channels) for _ in range(num_layers)],
                AggressiveConvToGestalt(hidden_channels, gestalt_size, latent_size),
                LambdaModule(lambda x: rearrange(x, 'b c 1 1 -> b c')),
                Binarize(),
            )

        elif bottleneck == "partial_binar":
            print("Partial Binar bottleneck")
            self.gestalt_encoder = nn.Sequential(
                *[ResidualBlock(hidden_channels, hidden_channels) for _ in range(num_layers)],
                AggressiveConvToGestalt(hidden_channels, gestalt_size, latent_size),
                LambdaModule(lambda x: rearrange(x, 'b c 1 1 -> b c')),
                PartialBinarize(gestalt_size = gestalt_size),
            )

        elif bottleneck == "vae":
            print("VAE bottleneck")
            self.gestalt_encoder = nn.Sequential(
                *[ResidualBlock(hidden_channels, hidden_channels) for _ in range(num_layers)],
                AggressiveConvToGestalt(hidden_channels, gestalt_size * 2, latent_size),
                LambdaModule(lambda x: rearrange(x, 'b c 1 1 -> b c')),
                VariationalFunction(factor = vae_factor),
                nn.Sigmoid(),
            )
        else:
            print("unrestricted bottleneck")
            self.gestalt_encoder = nn.Sequential(
                *[ResidualBlock(hidden_channels, hidden_channels) for _ in range(num_layers)],
                AggressiveConvToGestalt(hidden_channels, gestalt_size, latent_size),
                LambdaModule(lambda x: rearrange(x, 'b c 1 1 -> b c')),
                nn.Sigmoid(),
            )

    def set_level(self, level):
        self.level = level

    def forward(
        self, 
        input: th.Tensor,
        error: th.Tensor,
        mask: th.Tensor,
        object: th.Tensor,
        position: th.Tensor,
        maskraw: th.Tensor
    ):
        
        latent = self.preprocess[self.level](input, error, mask, object, position, maskraw)
        latent = self.to_channels[self.level](latent)

        if self.level >= 2:
            latent = self.layers2(latent)

        if self.level >= 1:
            latent = self.layers1(latent)

        latent  = self.layers0(latent)
        gestalt = self.gestalt_encoder(latent)

        latent   = self.position_encoder(latent)
        std      = self.std_encoder(latent[:,0:1])
        xy       = self.xy_encoder(latent[:,1:2])
        priority = self.priority_encoder(latent[:,2:3])

        position = self.to_shared(th.cat((xy, std), dim=1))
        gestalt  = self.to_shared(gestalt)
        priority = self.to_shared(priority)

        return position, gestalt, priority

