import torch.nn as nn
import torch as th
import numpy as np
from utils.utils import Gaus2D, SharedObjectsToBatch, BatchToSharedObjects, Prioritize, LambdaModule, ForcedAlpha, Binarize
from nn.eprop_lstm import EpropLSTM
from nn.residual import SkipConnection
from nn.eprop_gate_l0rd import EpropGateL0rd
from torch.autograd import Function
from nn.vae import VariationalFunction
from einops import rearrange, repeat, reduce

from typing import Tuple, Union, List
import utils
import cv2


class ConvNeXtBlock(nn.Module):
    def __init__(
            self,
            in_channels: int,
            out_channels: int = None,
        ):
        super(ConvNeXtBlock, self).__init__()

        if out_channels is None:
            out_channels = in_channels

        self.layers = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=7, padding=3, groups=in_channels),
            LambdaModule(lambda x: th.permute(x, [0, 2, 3, 1])),
            nn.Linear(in_channels, out_channels*4),
            nn.SiLU(),
            nn.Linear(out_channels*4, out_channels),
            LambdaModule(lambda x: th.permute(x, [0, 3, 1, 2])),
        )

        self.skip  = SkipConnection(in_channels, out_channels)

    def forward(self, input: th.Tensor) -> th.Tensor:
        return self.skip(input) + self.layers(input)


class NeighbourChannels(nn.Module):
    def __init__(self, channels):
        super(NeighbourChannels, self).__init__()

        self.register_buffer("weights", th.ones(channels, channels, 1, 1), persistent=False)

        for i in range(channels):
            self.weights[i,i,0,0] = 0

    def forward(self, input: th.Tensor):
        return nn.functional.conv2d(input, self.weights)

class ObjectTracker(nn.Module):
    def __init__(self, num_objects: int, size: Union[int, Tuple[int, int]]): 
        super(ObjectTracker, self).__init__()
        self.num_objects = num_objects
        self.neighbours  = NeighbourChannels(num_objects)
        self.prioritize  = Prioritize(num_objects)
        self.gaus2d      = Gaus2D(size, position_limit=2)
        self.to_batch    = LambdaModule(lambda x: rearrange(x, 'b (o c) -> (b o) c', o = num_objects))
        self.to_shared   = BatchToSharedObjects(num_objects)

    def forward(
        self, 
        input_rgb: th.Tensor, 
        input_depth: th.Tensor,
        error_last: th.Tensor, 
        mask: th.Tensor,
        object_rgb: th.Tensor,
        object_depth: th.Tensor,
        object_flow: th.Tensor,
        position: th.Tensor,
        priority: th.Tensor
    ):
        mask       = mask.detach()
        position   = position.detach()
        priority   = priority.detach() if priority is not None else None
        error_last = error_last.detach()
        
        bg_mask     = repeat(mask[:,-1:], 'b 1 h w -> b c h w', c = self.num_objects)
        mask        = mask[:,:-1]
        mask_others = self.neighbours(mask)

        own_gaus2d    = self.to_shared(self.gaus2d(self.to_batch(position)))
        others_gaus2d = self.neighbours(self.prioritize(own_gaus2d, priority))
        
        input_rgb     = repeat(input_rgb,        'b c h w -> b o c h w', o = self.num_objects)
        input_depth   = repeat(input_depth,      'b c h w -> b o c h w', o = self.num_objects)
        error_last   = repeat(error_last,      'b 1 h w -> b o 1 h w', o = self.num_objects)
        bg_mask       = rearrange(bg_mask,       'b o h w -> b o 1 h w')
        mask_others   = rearrange(mask_others,   'b o h w -> b o 1 h w')
        mask          = rearrange(mask,          'b o h w -> b o 1 h w')
        object_rgb    = rearrange(object_rgb,    'b (o c) h w -> b o c h w', o = self.num_objects)
        object_depth  = rearrange(object_depth,  'b (o c) h w -> b o c h w', o = self.num_objects)
        object_flow   = rearrange(object_flow,   'b (o c) h w -> b o c h w', o = self.num_objects)
        own_gaus2d    = rearrange(own_gaus2d,    'b o h w -> b o 1 h w')
        others_gaus2d = rearrange(others_gaus2d, 'b o h w -> b o 1 h w')
        
        output = th.cat((
            input_rgb, 
            input_depth, 
            error_last, 
            mask, 
            mask_others, 
            bg_mask, 
            object_rgb, 
            object_depth, 
            object_flow, 
            own_gaus2d,
            others_gaus2d
        ), dim=2) 
        output = rearrange(output, 'b o c h w -> (b o) c h w')

        return output

class PatchDownConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size = 4):
        super(PatchDownConv, self).__init__()
        assert out_channels % in_channels == 0
        
        self.layers = nn.Linear(in_channels * kernel_size**2, out_channels)

        self.kernel_size     = kernel_size
        self.channels_factor = out_channels // in_channels

    def forward(self, input: th.Tensor):
        H, W = input.shape[2:]
        K    = self.kernel_size
        C    = self.channels_factor

        skip = reduce(input, 'b c (h h2) (w w2) -> b c h w', 'mean', h2=K, w2=K)
        skip = repeat(skip, 'b c h w -> b (c n) h w', n=C)

        input    = rearrange(input, 'b c (h h2) (w w2) -> (b h w) (c h2 w2)', h2=K, w2=K)
        residual = self.layers(input)
        residual = rearrange(residual, '(b h w) c -> b c h w', h = H // K, w = W // K)

        return skip + residual

class AggressiveConvTo1x1(nn.Module):
    def __init__(self, in_channels, out_channels, size: Union[int, Tuple[int, int]]):
        super(AggressiveConvTo1x1, self).__init__()

        assert out_channels % in_channels == 0 or in_channels % out_channels == 0
        
        self.layers = nn.Sequential(
            nn.Conv2d(
                in_channels  = in_channels, 
                out_channels = out_channels, 
                kernel_size  = 5,
                stride       = 3,
                padding      = 3
            ),
            nn.ReLU(),
            nn.Conv2d(
                in_channels  = out_channels, 
                out_channels = out_channels, 
                kernel_size  = ((size[0] + 1)//3 + 1, (size[1] + 1)//3 + 1)
            ),
            LambdaModule(lambda x: rearrange(x, 'b c 1 1 -> b c')),
        )

        if out_channels > in_channels:
            self.skip = nn.Sequential(
                LambdaModule(lambda x: reduce(x, 'b c h w -> b c', 'mean')),
                LambdaModule(lambda x: repeat(x, 'b c -> b (c n)', n = out_channels // in_channels))
            )
        else:
            self.skip = LambdaModule(
                lambda x: reduce(x, 'b (c n) h w -> b c', 'mean', n = in_channels // out_channels)
            )

    def forward(self, input: th.Tensor):
        return self.skip(input) + self.layers(input)

class PixelToPosition(nn.Module):
    def __init__(self, size):
        super(PixelToPosition, self).__init__()

        self.register_buffer("grid_y", th.arange(size[0]), persistent=False)
        self.register_buffer("grid_x", th.arange(size[1]), persistent=False)

        self.grid_y = (self.grid_y / (size[0]-1)) * 2 - 1
        self.grid_x = (self.grid_x / (size[1]-1)) * 2 - 1

        self.grid_y = self.grid_y.view(1, 1, -1, 1).expand(1, 1, *size).clone()
        self.grid_x = self.grid_x.view(1, 1, 1, -1).expand(1, 1, *size).clone()

        self.size = size

    def forward(self, input: th.Tensor):
        assert input.shape[1] == 1

        input = rearrange(input, 'b c h w -> b c (h w)')
        input = th.softmax(input, dim=2)
        input = rearrange(input, 'b c (h w) -> b c h w', h = self.size[0], w = self.size[1])

        x = th.sum(input * self.grid_x, dim=(2,3))
        y = th.sum(input * self.grid_y, dim=(2,3))

        return th.cat((x, y), dim=1)

class PixelToSTD(nn.Module):
    def __init__(self):
        super(PixelToSTD, self).__init__()

    def forward(self, input: th.Tensor):
        assert input.shape[1] == 1
        return reduce(input, 'b c h w -> b c', 'mean')

class PixelToDepth(nn.Module):
    def __init__(self):
        super(PixelToDepth, self).__init__()

    def forward(self, input):
        assert input.shape[1] == 1
        return reduce(th.tanh(input), 'b c h w -> b c', 'mean')

class PixelToPriority(nn.Module):
    def __init__(self, num_objects):
        super(PixelToPriority, self).__init__()

        self.num_objects = num_objects
        self.register_buffer("indices", rearrange(th.arange(num_objects), 'a -> 1 a'), persistent=False)
        self.indices = (self.indices / (num_objects - 1)) * 2 - 1

        self.index_factor    = nn.Parameter(th.ones(1))
        self.priority_factor = nn.Parameter(th.zeros(1)+1e-16)
        self.depth_factor    = nn.Parameter(th.zeros(1)+1e-16)

    def forward(self, input: th.Tensor, depth: th.Tensor):
        assert input.shape[1] == 1
        #print(f"index: {self.index_factor.item():.2e}, priority: {self.priority_factor.item():.2e}, depth: {self.depth_factor.item():.2e}")
        priority = th.tanh(reduce(input, '(b o) 1 h w -> b o', 'mean', o = self.num_objects))
        priority = priority * self.priority_factor + self.index_factor * self.indices
        priority = rearrange(priority, 'b o -> (b o) 1') + depth * th.abs(self.depth_factor)
        return priority

class LociEncoder(nn.Module):
    def __init__(
        self,
        input_size: Union[int, Tuple[int, int]], 
        latent_size: Union[int, Tuple[int, int]],
        num_objects: int, 
        base_channels: int,
        blocks,
        gestalt_size: int,
        vae_factor: float,
        batch_size: int,
        bottleneck: str,
    ):
        super(LociEncoder, self).__init__()
        self.num_objects      = num_objects
        self.latent_size      = latent_size

        img_channels = 16

        self.to_shared = LambdaModule(lambda x: rearrange(x, '(b o) c -> b (o c)', o = self.num_objects))

        self.tracker = ObjectTracker(num_objects, input_size)

        self.stem = nn.Sequential(
            PatchDownConv(img_channels, base_channels),
            *[ConvNeXtBlock(base_channels) for _ in range(blocks[0])],
            PatchDownConv(base_channels, base_channels * 2, kernel_size=2),
            *[ConvNeXtBlock(base_channels * 2) for _ in range(blocks[1])],
            PatchDownConv(base_channels * 2, base_channels * 4, kernel_size=2),
        )

        hidden_channels = base_channels * 4

        self.layers0 = nn.Sequential(
            *[ConvNeXtBlock(hidden_channels) for _ in range(blocks[2])]
        )

        self.position_encoder = nn.Sequential(
            *[ConvNeXtBlock(hidden_channels) for _ in range(blocks[3])],
            ConvNeXtBlock(hidden_channels, 4),
        )

        self.xy_encoder       = PixelToPosition(latent_size)
        self.std_encoder      = PixelToSTD()
        self.depth_encoder    = PixelToDepth()
        self.priority_encoder = PixelToPriority(num_objects)

        gestalt_encoder = [ConvNeXtBlock(hidden_channels) for _ in range(blocks[3])]
        gestalt_encoder.append( 
            AggressiveConvTo1x1(
                in_channels  = hidden_channels, 
                out_channels = gestalt_size * 2 if bottleneck == "vae" else gestalt_size, 
                size         = latent_size
            )
        )
        if bottleneck == "binar":
            gestalt_encoder.append(Binarize())
            print("Encoder: Binary Bottleneck")
        elif bottleneck == "vae":
            gestalt_encoder.append(VariationalFunction(factor = vae_factor)),
            print("Encoder: VAE Bottleneck")
        elif bottleneck == "sigmoid":
            gestalt_encoder.append(nn.Sigmoid())
            print("Encoder: Sigmoid Bottleneck")
        else:
            print("Encoder: Unrestricted Bottleneck")

        self.gestalt_encoder = nn.Sequential(*gestalt_encoder)

    def set_level(self, level):
        self.level = level

    def forward(
        self, 
        input_rgb: th.Tensor,
        input_depth: th.Tensor,
        error_last: th.Tensor,
        mask: th.Tensor,
        object_rgb: th.Tensor,
        object_depth: th.Tensor,
        object_flow: th.Tensor,
        position: th.Tensor,
        gestalt: th.Tensor,
        priority: th.Tensor
    ):
        
        latent = self.tracker(
            input_rgb, 
            input_depth, 
            error_last, 
            mask, 
            object_rgb, 
            object_depth, 
            object_flow, 
            position,
            priority
        )
        latent = self.stem(latent)

        latent  = self.layers0(latent)
        gestalt = self.gestalt_encoder(latent)

        latent   = self.position_encoder(latent)
        xy       = self.xy_encoder(latent[:,0:1])
        z        = self.depth_encoder(latent[:,1:2])
        std      = self.std_encoder(latent[:,2:3])
        priority = self.priority_encoder(latent[:,3:4], z)
        position = th.cat((xy, z, std), dim=1)

        position = self.to_shared(position)
        gestalt  = self.to_shared(gestalt)
        priority = self.to_shared(priority)

        return position, gestalt, priority

