import torch.nn as nn
import torch as th
import numpy as np
from utils.utils import Gaus2D, SharedObjectsToBatch, BatchToSharedObjects, Prioritize, ForcedAlpha, Binarize, MultiArgSequential, ProbabilisticBinarize, CapacityModulatedBinarize
from nn.eprop_lstm import EpropLSTM
from nn.embedding import PositionPooling
from nn.residual import SkipConnection, LinearSkip
from nn.manifold import HyperSequential, HyperConvNextBlock, NonHyperWrapper
from nn.eprop_gate_l0rd import EpropGateL0rd
from torch.autograd import Function
from nn.vae import VariationalFunction
from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange, Reduce
from nn.downscale import MemoryEfficientPatchDownScale

from typing import Tuple, Union, List
import utils
import cv2


class NeighbourChannels(nn.Module):
    def __init__(self, channels):
        super(NeighbourChannels, self).__init__()

        self.register_buffer("weights", th.ones(channels, channels, 1, 1), persistent=False)

        for i in range(channels):
            self.weights[i,i,0,0] = 0

    def forward(self, input: th.Tensor):
        return nn.functional.conv2d(input, self.weights)

class ObjectTracker(nn.Module):
    def __init__(self, num_objects: int, size: Union[int, Tuple[int, int]]): 
        super(ObjectTracker, self).__init__()
        self.num_objects = num_objects
        self.neighbours  = NeighbourChannels(num_objects)
        self.gaus2d      = Gaus2D(size)
        self.to_batch    = Rearrange('b (o c) -> (b o) c', o = num_objects)
        self.to_shared   = BatchToSharedObjects(num_objects)

    def forward(
        self, 
        input_rgb: th.Tensor, 
        input_depth: th.Tensor,
        error_last: th.Tensor, 
        mask: th.Tensor,
        mask_raw: th.Tensor,
        object_rgb: th.Tensor,
        object_depth: th.Tensor,
        object_flow: th.Tensor,
        position: th.Tensor,
        slot_reset: th.Tensor,
    ):

        bg_mask     = repeat(mask[:,-1:], 'b 1 h w -> b c h w', c = self.num_objects)
        mask        = mask[:,:-1] * slot_reset.unsqueeze(-1).unsqueeze(-1)
        mask_others = self.neighbours(mask)

        own_gaus2d    = self.to_shared(self.gaus2d(self.to_batch(position)))

        input_rgb     = repeat(input_rgb,        'b c h w -> b o c h w', o = self.num_objects)
        input_depth   = repeat(input_depth,      'b c h w -> b o c h w', o = self.num_objects)
        error_last    = repeat(error_last,       'b 1 h w -> b o 1 h w', o = self.num_objects)
        bg_mask       = rearrange(bg_mask,       'b o h w -> b o 1 h w')
        mask_others   = rearrange(mask_others,   'b o h w -> b o 1 h w')
        mask          = rearrange(mask,          'b o h w -> b o 1 h w')
        object_rgb    = rearrange(object_rgb,    'b (o c) h w -> b o c h w', o = self.num_objects+1 if self.num_objects > 1 else 1)
        object_depth  = rearrange(object_depth,  'b (o c) h w -> b o c h w', o = self.num_objects)
        object_flow   = rearrange(object_flow,   'b (o c) h w -> b o c h w', o = self.num_objects)
        own_gaus2d    = rearrange(own_gaus2d,    'b o h w -> b o 1 h w')
        mask_raw      = rearrange(mask_raw,      'b o h w -> b o 1 h w')
        slot_reset    = rearrange(slot_reset,    'b o -> b o 1 1 1')

        output = th.cat((
            input_rgb, 
            input_depth, 
            error_last,
            mask, 
            mask_others,
            bg_mask,
            object_rgb[:,:-1] * mask_raw * slot_reset if self.num_objects > 1 else object_rgb * mask_raw * slot_reset,
            object_depth * mask_raw * slot_reset,
            object_flow * mask_raw * slot_reset,
            own_gaus2d,
            mask_raw * slot_reset,
        ), dim=2) 
        output = rearrange(output, 'b o c h w -> (b o) c h w')

        return output

class PatchDownConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size = 4):
        super(PatchDownConv, self).__init__()
        assert out_channels % in_channels == 0
        
        self.layers = nn.Linear(in_channels * kernel_size**2, out_channels)

        self.kernel_size     = kernel_size
        self.channels_factor = out_channels // in_channels

    def forward(self, input: th.Tensor):
        H, W = input.shape[2:]
        K    = self.kernel_size
        C    = self.channels_factor

        skip = reduce(input, 'b c (h h2) (w w2) -> b c h w', 'mean', h2=K, w2=K)
        skip = repeat(skip, 'b c h w -> b (c n) h w', n=C)

        input    = rearrange(input, 'b c (h h2) (w w2) -> (b h w) (c h2 w2)', h2=K, w2=K)
        residual = self.layers(input)
        residual = rearrange(residual, '(b h w) c -> b c h w', h = H // K, w = W // K)

        return skip + residual

class PixelToPosition(nn.Module):
    def __init__(self, size): # FIXME add update grid !!!
        super(PixelToPosition, self).__init__()

        self.register_buffer("grid_y", th.arange(size[0]), persistent=False)
        self.register_buffer("grid_x", th.arange(size[1]), persistent=False)

        self.grid_y = (self.grid_y / (size[0]-1)) * 2 - 1
        self.grid_x = (self.grid_x / (size[1]-1)) * 2 - 1

        self.grid_y = self.grid_y.view(1, 1, -1, 1).expand(1, 1, *size).clone()
        self.grid_x = self.grid_x.view(1, 1, 1, -1).expand(1, 1, *size).clone()

        self.size = size

    def forward(self, input: th.Tensor):
        assert input.shape[1] == 1

        input = rearrange(input, 'b c h w -> b c (h w)')
        input = th.softmax(input, dim=2)
        input = rearrange(input, 'b c (h w) -> b c h w', h = self.size[0], w = self.size[1])

        x = th.sum(input * self.grid_x, dim=(2,3))
        y = th.sum(input * self.grid_y, dim=(2,3))

        return th.cat((x, y), dim=1)

class PixelToSTD(nn.Module):
    def __init__(self):
        super(PixelToSTD, self).__init__()

    def forward(self, input: th.Tensor):
        assert input.shape[1] == 1
        return reduce(input, 'b c h w -> b c', 'mean')

class PixelToDepth(nn.Module):
    def __init__(self):
        super(PixelToDepth, self).__init__()

    def forward(self, input):
        assert input.shape[1] == 1
        return th.sigmoid(reduce(input, 'b c h w -> b c', 'mean'))

class PixelToPriority(nn.Module):
    def __init__(self, num_objects):
        super(PixelToPriority, self).__init__()

        self.num_objects = num_objects
        self.register_buffer("indices", rearrange(th.arange(num_objects), 'a -> 1 a'), persistent=False)
        self.indices = (self.indices / (num_objects - 1)) * 2 - 1

        self.index_factor    = nn.Parameter(th.ones(1))
        self.priority_factor = nn.Parameter(th.zeros(1)+1e-16)
        self.depth_factor    = nn.Parameter(th.zeros(1)+1e-16)

    def forward(self, input: th.Tensor, depth: th.Tensor):
        assert input.shape[1] == 1
        #print(f"index: {self.index_factor.item():.2e}, priority: {self.priority_factor.item():.2e}, depth: {self.depth_factor.item():.2e}")
        priority = th.tanh(reduce(input, '(b o) 1 h w -> b o', 'mean', o = self.num_objects))
        priority = priority * self.priority_factor + self.index_factor * self.indices
        priority = rearrange(priority, 'b o -> (b o) 1') + depth * th.abs(self.depth_factor)
        return priority

class LociEncoder(nn.Module):
    def __init__(
        self,
        input_size: Union[int, Tuple[int, int]], 
        latent_size: Union[int, Tuple[int, int]],
        num_objects: int, 
        base_channels: int,
        blocks,
        hyper_channels: int,
        gestalt_size: int,
        vae_factor: float,
        batch_size: int,
        bottleneck: str,
    ):
        super(LociEncoder, self).__init__()
        self.num_objects = num_objects
        self.latent_size = latent_size

        img_channels = 16

        self.to_shared = Rearrange('(b o) c -> b (o c)', o = self.num_objects)
        self.to_batch  = Rearrange('b (o c) -> (b o) c', o = self.num_objects)

        self.tracker = ObjectTracker(num_objects, input_size)

        self.stem = HyperSequential(
            NonHyperWrapper(PatchDownConv(img_channels, base_channels)),
            *[HyperConvNextBlock(base_channels) for _ in range(blocks[0])],
            NonHyperWrapper(PatchDownConv(base_channels, base_channels * 2, kernel_size=2)),
            *[HyperConvNextBlock(base_channels * 2) for _ in range(blocks[1])],
            NonHyperWrapper(PatchDownConv(base_channels * 2, base_channels * 4, kernel_size=2)),
            *[HyperConvNextBlock(base_channels * 4) for _ in range(blocks[2])]
        )

        self.position_encoder = HyperSequential(
            *[HyperConvNextBlock(base_channels * 4) for _ in range(blocks[3])],
            HyperConvNextBlock(base_channels * 4, 4)
        )

        self.xy_encoder       = PixelToPosition(latent_size)
        self.std_encoder      = PixelToSTD()
        self.depth_encoder    = PixelToDepth()
        self.priority_encoder = PixelToPriority(num_objects)

        self.gestalt_base_encoder = HyperSequential(
            *[HyperConvNextBlock(base_channels * 4) for _ in range(blocks[3])],
        )

        self.mask_gestalt_encoder = HyperSequential(
            NonHyperWrapper(PatchDownConv(base_channels * 4, base_channels * 8, kernel_size=2)),
            *[HyperConvNextBlock(base_channels * 8) for _ in range(blocks[3] * 2)]
        )

        self.depth_gestalt_encoder = HyperSequential(
            NonHyperWrapper(PatchDownConv(base_channels * 4, base_channels * 8, kernel_size=2)),
            *[HyperConvNextBlock(base_channels * 8) for _ in range(blocks[3] * 2)]
        )

        self.object_gestalt_encoder = HyperSequential(
            NonHyperWrapper(PatchDownConv(base_channels * 4, base_channels * 8, kernel_size=2)),
            *[HyperConvNextBlock(base_channels * 8) for _ in range(blocks[3] * 2)]
        )

        mask_gestalt_pooling = [
            PositionPooling(
                in_channels  = base_channels * 8, 
                out_channels = gestalt_size * 2 if bottleneck == "vae" else gestalt_size, 
                size         = [latent_size[0] // 2, latent_size[1] // 2]
            )
        ]
        depth_gestalt_pooling = [
            PositionPooling(
                in_channels  = base_channels * 8, 
                out_channels = gestalt_size * 2 if bottleneck == "vae" else gestalt_size, 
                size         = [latent_size[0] // 2, latent_size[1] // 2]
            )
        ]
        depth_scale_pooling = [
            PositionPooling(
                in_channels  = base_channels * 8, 
                out_channels = 2 if bottleneck == "vae" else 1, 
                size         = [latent_size[0] // 2, latent_size[1] // 2]
            ),
            nn.Sigmoid()
        ]
        object_gestalt_pooling = [
            PositionPooling(
                in_channels  = base_channels * 8, 
                out_channels = gestalt_size * 2 if bottleneck == "vae" else gestalt_size, 
                size         = [latent_size[0] // 2, latent_size[1] // 2]
            )
        ]
        if bottleneck == "binar":
            mask_gestalt_pooling.append(Binarize())
            depth_gestalt_pooling.append(Binarize())
            object_gestalt_pooling.append(Binarize())
            print("Encoder: Binary Bottleneck")
        elif bottleneck == "prob-bin":
            mask_gestalt_pooling.append(ProbabilisticBinarize(gestalt_size))
            depth_gestalt_pooling.append(ProbabilisticBinarize(gestalt_size))
            object_gestalt_pooling.append(ProbabilisticBinarize(gestalt_size))
            print("Encoder: Probabilistic Binary Bottleneck")
        elif bottleneck == "mod-bin":
            mask_gestalt_pooling.append(CapacityModulatedBinarize(gestalt_size))
            depth_gestalt_pooling.append(CapacityModulatedBinarize(gestalt_size))
            object_gestalt_pooling.append(CapacityModulatedBinarize(gestalt_size))
            print("Encoder: Capacity Modulated Binary Bottleneck")
        elif bottleneck == "vae":
            mask_gestalt_pooling.append(VariationalFunction(factor = vae_factor)),
            depth_gestalt_pooling.append(VariationalFunction(factor = vae_factor)),
            object_gestalt_pooling.append(VariationalFunction(factor = vae_factor)),
            print("Encoder: VAE Bottleneck")
        elif bottleneck == "sigmoid":
            mask_gestalt_pooling.append(nn.Sigmoid())
            depth_gestalt_pooling.append(nn.Sigmoid())
            object_gestalt_pooling.append(nn.Sigmoid())
            print("Encoder: Sigmoid Bottleneck")
        else:
            print("Encoder: Unrestricted Bottleneck")

        self.mask_gestalt_pooling = MultiArgSequential(*mask_gestalt_pooling)
        self.depth_gestalt_pooling = MultiArgSequential(*depth_gestalt_pooling)
        self.depth_scale_pooling = MultiArgSequential(*depth_scale_pooling)
        self.object_gestalt_pooling = MultiArgSequential(*object_gestalt_pooling)

        self.hyper_weights = nn.Sequential(
            nn.Linear(gestalt_size * 3, hyper_channels),
            nn.SiLU(),
            *[nn.Sequential(
                nn.Linear(hyper_channels, hyper_channels),
                nn.SiLU()
            ) for _ in range(blocks[2]-2)],
            nn.Linear(hyper_channels, self.num_hyper_weights())
        )

    def num_hyper_weights(self):
        return (
            self.stem.num_weights()                  + 
            self.position_encoder.num_weights()      + 
            self.gestalt_base_encoder.num_weights()  +
            self.mask_gestalt_encoder.num_weights()  +
            self.depth_gestalt_encoder.num_weights() +
            self.object_gestalt_encoder.num_weights()
        )

    def forward(
        self, 
        input_rgb: th.Tensor,
        input_depth: th.Tensor,
        error_last: th.Tensor,
        mask: th.Tensor,
        mask_raw: th.Tensor,
        object_rgb: th.Tensor,
        object_depth: th.Tensor,
        object_flow: th.Tensor,
        position: th.Tensor,
        gestalt: th.Tensor,
        slot_reset: th.Tensor = None,
        use_hyper_weights: bool = True,
    ):
        if slot_reset is None and use_hyper_weights:
            slot_reset = th.ones_like(position[:,:1])
        
        if slot_reset is None:
            slot_reset = th.zeros_like(position[:,:1])
        
        latent = self.tracker(
            input_rgb, 
            input_depth, 
            error_last, 
            mask, 
            mask_raw,
            object_rgb, 
            object_depth, 
            object_flow, 
            position,
            slot_reset
        )

        gestalt = self.to_batch(gestalt)[:,:-1] # remove z-scale
        slot_reset = self.to_batch(slot_reset)

        if use_hyper_weights:
            hyper_weights = self.hyper_weights(gestalt) * slot_reset

            stem_hyper_weights = hyper_weights[:,:self.stem.num_weights()]; 
            offset = self.stem.num_weights()

            position_hyper_weights = hyper_weights[:,offset:offset+self.position_encoder.num_weights()]; 
            offset += self.position_encoder.num_weights()

            base_gestalt_hyper_weights = hyper_weights[:,offset:offset+self.gestalt_base_encoder.num_weights()];
            offset += self.gestalt_base_encoder.num_weights()

            mask_gestalt_hyper_weights  = hyper_weights[:,offset:offset+self.mask_gestalt_encoder.num_weights()];
            offset += self.mask_gestalt_encoder.num_weights()

            depth_gestalt_hyper_weights = hyper_weights[:,offset:offset+self.depth_gestalt_encoder.num_weights()];
            offset += self.depth_gestalt_encoder.num_weights()

            object_gestalt_hyper_weights = hyper_weights[:,offset:offset+self.object_gestalt_encoder.num_weights()];
            offset += self.object_gestalt_encoder.num_weights()

        latent         = self.stem(latent,                           stem_hyper_weights           if use_hyper_weights else None)
        latent_gestalt = self.gestalt_base_encoder(latent,           base_gestalt_hyper_weights   if use_hyper_weights else None)
        mask_gestalt   = self.mask_gestalt_encoder(latent_gestalt,   mask_gestalt_hyper_weights   if use_hyper_weights else None)
        depth_gestalt  = self.depth_gestalt_encoder(latent_gestalt,  depth_gestalt_hyper_weights  if use_hyper_weights else None)
        object_gestalt = self.object_gestalt_encoder(latent_gestalt, object_gestalt_hyper_weights if use_hyper_weights else None)

        latent   = self.position_encoder(latent, position_hyper_weights if use_hyper_weights else None)
        xy       = self.xy_encoder(latent[:,0:1])
        z        = self.depth_encoder(latent[:,1:2])
        std      = self.std_encoder(latent[:,2:3])
        priority = self.priority_encoder(latent[:,3:4], z)
        position = th.cat((xy, z, std), dim=1)

        mask_gestalt   = self.mask_gestalt_pooling(mask_gestalt, position)
        depth_scale    = self.depth_scale_pooling(depth_gestalt, position)
        depth_gestalt  = self.depth_gestalt_pooling(depth_gestalt, position)
        object_gestalt = self.object_gestalt_pooling(object_gestalt, position)
        gestalt        = th.cat((mask_gestalt, depth_gestalt, object_gestalt, depth_scale), dim=1)


        position = self.to_shared(position)
        gestalt  = self.to_shared(gestalt)
        priority = self.to_shared(priority)

        return position, gestalt, priority

