import torch.nn as nn
import torch as th
import numpy as np
import nn as nn_modules
import torchvision.transforms as transforms
from torch.autograd import Function
import torch.nn.functional as F
from einops import rearrange, repeat, reduce

from typing import Tuple, Union, List
import utils



class Buffer:
    def __init__(self, size: int, channels: int, num_objects: int, batch_size: int, device):
  
        if size == 0:
            raise ValueError("Buffer size must be greater than 0")
            
        self.channels = channels
        self.size = size
        self.num_objects = num_objects
        self.batch_size = batch_size
        self.device = device

        self.to_batch  = LambdaModule(lambda x: rearrange(x, 'b o s c -> b (o s c)'))
        self.to_shared = LambdaModule(lambda x: rearrange(x, '(b o) (s c) -> b o s c', s=self.size, o=self.num_objects))
        self.to_buffer = LambdaModule(lambda x: rearrange(x, 'b (o c) -> b o 1 c', o=self.num_objects))

        self.reset()
    
    def reset(self):
        self.buffer = th.zeros(self.batch_size, self.num_objects, self.size, self.channels).float().to(self.device)

    def update(self, input):
        input = self.to_buffer(input)
        self.buffer = th.cat((input, self.buffer), dim=2)[:, :, :self.size]

    def get_mean(self):
        m = th.mean(self.buffer, dim=2, keepdim=True)
        return self.to_batch(m)

    def get_buffer(self):
        return self.to_batch(self.buffer)

    def detach(self):
        self.buffer = self.buffer.detach()

class PrintShape(nn.Module):
    def __init__(self, msg = ""):
        super(PrintShape, self).__init__()
        self.msg = msg

    def forward(self, input: th.Tensor):
        if self.msg != "":
            print(self.msg, input.shape)
        else:
            print(input.shape)
        return input

class PrintStats(nn.Module):
    def __init__(self):
        super(PrintStats, self).__init__()

    def forward(self, input: th.Tensor):
        print(
            "min: ", th.min(input).detach().cpu().numpy(),
            ", mean: ", th.mean(input).detach().cpu().numpy(),
            ", max: ", th.max(input).detach().cpu().numpy()
        )
        return input

class PushToInfFunction(Function):
    @staticmethod
    def forward(ctx, tensor):
        ctx.save_for_backward(tensor)
        return tensor.clone()

    @staticmethod
    def backward(ctx, grad_output):
        tensor = ctx.saved_tensors[0]
        grad_input = -th.ones_like(grad_output)
        return grad_input

class PushToInf(nn.Module):
    def __init__(self):
        super(PushToInf, self).__init__()
        
        self.fcn = PushToInfFunction.apply

    def forward(self, input: th.Tensor):
        return self.fcn(input)

class ForcedAlpha(nn.Module):
    def __init__(self, speed = 1):
        super(ForcedAlpha, self).__init__()

        self.init   = nn.Parameter(th.zeros(1))
        self.speed  = speed
        self.to_inf = PushToInf()

    def item(self):
        return th.tanh(self.to_inf(self.init * self.speed)).item()

    def forward(self, input: th.Tensor):
        return input * th.tanh(self.to_inf(self.init * self.speed))

class AlphaThreshold(nn.Module):
    def __init__(self, max_value = 1):
        super(AlphaThreshold, self).__init__()

        self.init      = nn.Parameter(th.zeros(1))
        self.to_inf    = PushToInf()
        self.max_value = max_value

    def forward(self):
        return th.tanh(self.to_inf(self.init)) * self.max_value

class TanhAlpha(nn.Module):
    def __init__(self, start = 0, stepsize = 1e-4, max_value = 1):
        super(TanhAlpha, self).__init__()

        self.register_buffer('init', th.zeros(1) + start)
        self.stepsize  = stepsize
        self.max_value = max_value

    def get(self):
        return (th.tanh(self.init) * self.max_value).item()

    def forward(self):
        self.init = self.init.detach() + self.stepsize
        return self.get()

class InitialLatentStates(nn.Module):
    def __init__(
            self, 
            gestalt_size: int, 
            num_objects: int, 
            bottleneck: str,
            size: Tuple[int, int],
            object_permanence_strength: int,
            teacher_forcing: int
        ):
        super(InitialLatentStates, self).__init__()
        self.bottleneck = bottleneck

        self.num_objects                = num_objects
        self.gestalt_mean               = nn.Parameter(th.zeros(1, gestalt_size))
        self.gestalt_std                = nn.Parameter(th.ones(1, gestalt_size))
        self.std                        = nn.Parameter(th.zeros(1))
        self.object_permanence_strength = object_permanence_strength
        self.teacher_forcing            = teacher_forcing

        self.init = TanhAlpha(start = -1)
        self.register_buffer('priority', th.arange(num_objects).float() * 25, persistent=False)
        self.register_buffer('threshold', th.ones(1) * 0.8)
        self.last_mask = None
        self.binarize_first = round(gestalt_size * 0.8)

        self.gaus2d = nn.Sequential(
            Gaus2D((size[0] // 16, size[1] // 16)),
            Gaus2D((size[0] //  4, size[1] //  4)),
            Gaus2D(size)
        )

        self.level = 1
        self.t     = 0

        self.to_batch  = LambdaModule(lambda x: rearrange(x, 'b (o c) -> (b o) c', o = num_objects))
        self.to_shared = LambdaModule(lambda x: rearrange(x, '(b o) c -> b (o c)', o = num_objects))

        self.blur = transforms.GaussianBlur(13)
        self.size = size

    def reset_state(self):
        self.last_mask = None
        self.t = 0
        self.to_next_spawn = 0

    def set_level(self, level):
        self.level = level
        factor = int(4 / (level ** 2))
        self.to_position = ErrorToPosition((self.size[0] //  factor, self.size[1] //  factor))

    def forward(
        self, 
        error: th.Tensor, 
        mask: th.Tensor = None, 
        position: th.Tensor = None,
        gestalt: th.Tensor = None,
        priority: th.Tensor = None,
        shuffleslots: bool = True, 
        slots_bounded_last: th.Tensor = None,
        slots_occlusionfactor_last: th.Tensor = None,
        allow_spawn: bool = True,
        clean_slots: bool = False
    ):

        batch_size = error.shape[0]
        device     = error.device

        if self.init.get() < 1:
            self.object_permanence_strength = self.init()

        if self.last_mask is None:
            self.last_mask = th.zeros((batch_size * self.num_objects, 1), device = device)
            if shuffleslots:
                self.slots_assigned = th.ones((batch_size * self.num_objects, 1), device = device)
            else:
                self.slots_assigned = th.zeros((batch_size * self.num_objects, 1), device = device)

        if not allow_spawn:
            unnassigned = self.slots_assigned - slots_bounded_last
            self.slots_assigned = self.slots_assigned - unnassigned

        if clean_slots and (slots_occlusionfactor_last is not None):
            occluded = self.slots_assigned * (self.to_batch(slots_occlusionfactor_last) > 0.1).float()
            self.slots_assigned = self.slots_assigned - occluded

        if (slots_bounded_last is None) or (self.object_permanence_strength < 1):

            if mask is not None:
                # maximum berechnung --> slot gebunden c=o
                mask2 = reduce(mask[:,:-1], 'b c h w -> (b c) 1' , 'max').detach()

                if self.object_permanence_strength <= 0:
                    self.last_mask = mask2
                elif self.object_permanence_strength < 1:
                    self.last_mask = th.maximum(self.last_mask, mask2)
                    self.last_mask = self.last_mask - th.relu(-1 * (mask2 - self.threshold) * (1 - self.object_permanence_strength))
                else:
                    self.last_mask = th.maximum(self.last_mask, mask2)
    
            slots_bounded = (self.last_mask > self.threshold).float().detach() * self.slots_assigned
        else:
            slots_bounded = slots_bounded_last * self.slots_assigned

        if self.bottleneck == "binar":
            gestalt_new = repeat(th.sigmoid(self.gestalt_mean), '1 c -> b c', b = batch_size * self.num_objects)
            gestalt_new = gestalt_new + gestalt_new * (1 - gestalt_new) * th.randn_like(gestalt_new)
        elif self.bottleneck == "partial_binar":
            bin_part = self.gestalt_mean[:, :self.binarize_first]
            gestalt_new_bin = repeat(th.sigmoid(bin_part), '1 c -> b c', b = batch_size * self.num_objects)
            gestalt_new_bin = gestalt_new_bin + gestalt_new_bin * (1 - gestalt_new_bin) * th.randn_like(gestalt_new_bin)

            gestalt_mean = repeat(self.gestalt_mean[:, self.binarize_first:], '1 c -> b c', b = batch_size * self.num_objects)
            gestalt_std  = repeat(self.gestalt_mean[:, self.binarize_first:], '1 c -> b c', b = batch_size * self.num_objects)
            gestalt_new_flex  = th.sigmoid(gestalt_mean + gestalt_std * th.randn_like(gestalt_std))

            gestalt_new = th.cat([gestalt_new_bin, gestalt_new_flex], dim = 1)
        else:
            gestalt_mean = repeat(self.gestalt_mean, '1 c -> b c', b = batch_size * self.num_objects)
            gestalt_std  = repeat(self.gestalt_std,  '1 c -> b c', b = batch_size * self.num_objects)
            gestalt_new  = th.sigmoid(gestalt_mean + gestalt_std * th.randn_like(gestalt_std))

        if gestalt is None:
            gestalt = gestalt_new
        else:
            gestalt = self.to_batch(gestalt) * slots_bounded + gestalt_new * (1 - slots_bounded)

        if priority is None:
            priority = repeat(self.priority, 'o -> (b o) 1', b = batch_size)
        else:
            priority = self.to_batch(priority) * slots_bounded + repeat(self.priority, 'o -> (b o) 1', b = batch_size) * (1 - slots_bounded)


        if shuffleslots:
            self.slots_assigned = th.ones_like(self.slots_assigned)

            xy_rand_new  = th.rand((batch_size * self.num_objects * 10, 2), device = device) * 2 - 1 
            std_new      = th.zeros((batch_size * self.num_objects * 10, 1), device = device)
            position_new = th.cat((xy_rand_new, std_new), dim=1) 

            position2d = self.gaus2d[self.level](position_new)
            position2d = rearrange(position2d, '(b o) 1 h w -> b o h w', b = batch_size)

            rand_error = reduce(position2d * error, 'b o h w -> (b o) 1', 'sum')

            xy_rand_new = rearrange(xy_rand_new, '(b r) c -> r b c', r = 10)
            rand_error  = rearrange(rand_error,  '(b r) c -> r b c', r = 10)

            max_error = th.argmax(rand_error, dim=0, keepdim=True)
            x, y = th.chunk(xy_rand_new, 2, dim=2)
            x = th.gather(x, dim=0, index=max_error).detach().squeeze(dim=0)
            y = th.gather(y, dim=0, index=max_error).detach().squeeze(dim=0)
            std  = repeat(self.std, '1 -> (b o) 1', b = batch_size, o=self.num_objects)

            if position is None:
                position = th.cat((x, y, std), dim=1) 
            else:
                position = self.to_batch(position) * slots_bounded + th.cat((x, y, std), dim=1) * (1 - slots_bounded)

        else:

            # set unassigned slots to empty position
            empty_position = th.tensor([-1,-1,0]).to(device)
            empty_position = repeat(empty_position, 'c -> (b o) c', b = batch_size, o=self.num_objects).detach()

            if position is None:
                position = empty_position
            else:
                position = self.to_batch(position) * self.slots_assigned + empty_position * (1 - self.slots_assigned)


            # blur errror, and set masked areas to zero
            error = self.blur(error)
            if mask is not None:
                mask2 = mask[:,:-1] * rearrange(slots_bounded, '(b o) 1 -> b o 1 1', b = batch_size)
                mask2 = th.sum(mask2, dim=1, keepdim=True)
                error = error * (1-mask2)
            max_error = reduce(error, 'b o h w -> (b o) 1', 'max')

            if self.to_next_spawn <= 0 and allow_spawn:

                self.to_next_spawn = 2

                # calculate the position with the highest error
                new_pos = self.to_position(error)
                std  = repeat(self.std, '1 -> b 1', b = batch_size)
                new_pos = repeat(th.cat((new_pos, std), dim=1), 'b c -> (b o) c', o = self.num_objects)
                
                #  calculate if an assigned slot is unbound (-->free)
                n_slots_assigned = self.to_shared(self.slots_assigned).sum(dim=1, keepdim=True)
                n_slots_bounded = self.to_shared(slots_bounded).sum(dim=1, keepdim=True)
                free_slot_given = th.clip(n_slots_assigned - n_slots_bounded, 0, 1)

                # either spawn a new slot or use the one that is free
                slots_new_index = n_slots_assigned * (1-free_slot_given) + n_slots_bounded * free_slot_given # reset the free slot each timespawn

                # new slot index
                free_slot_required = (max_error > 0).float()
                slots_new_index = F.one_hot(slots_new_index.long(), num_classes=self.num_objects+1).float().squeeze(dim=1)[:,:-1]
                slots_new_index = self.to_batch(slots_new_index * free_slot_required)

                # place new free slot
                position = new_pos * slots_new_index + position * (1 - slots_new_index)
                self.slots_assigned = th.clip(self.slots_assigned + slots_new_index, 0, 1)

        self.to_next_spawn -= 1
        return self.to_shared(position), self.to_shared(gestalt), self.to_shared(priority), error

    def get_slots_unassigned(self):
        return self.to_shared(1-self.slots_assigned)
    
    def get_slots_assigned(self):
        return self.to_shared(self.slots_assigned)
    
class ErrorToPosition(nn.Module):
    def __init__(self, size: Union[int, Tuple[int, int]]):
        super(ErrorToPosition, self).__init__()

        self.register_buffer("grid_x", th.arange(size[0]), persistent=False)
        self.register_buffer("grid_y", th.arange(size[1]), persistent=False)

        self.grid_x = (self.grid_x / (size[0]-1)) * 2 - 1
        self.grid_y = (self.grid_y / (size[1]-1)) * 2 - 1

        self.grid_x = self.grid_x.view(1, 1, -1, 1).expand(1, 1, *size).clone()
        self.grid_y = self.grid_y.view(1, 1, 1, -1).expand(1, 1, *size).clone()

        self.grid_x = self.grid_x.view(1, 1, -1)
        self.grid_y = self.grid_y.view(1, 1, -1)

        self.size = size

    def forward(self, input: th.Tensor):
        assert input.shape[1] == 1

        input = rearrange(input, 'b c h w -> b c (h w)')
        argmax = th.argmax(input, dim=2, keepdim=True)

        x = self.grid_x[0,0,argmax].squeeze(dim=2)
        y = self.grid_y[0,0,argmax].squeeze(dim=2)

        return th.cat((x,y),dim=1)

class OcclusionTracker(nn.Module):
     def __init__(self, batch_size, num_objects, device):
         super(OcclusionTracker, self).__init__()
         self.batch_size = batch_size
         self.num_objects = num_objects
         self.slots_bounded_all = th.zeros((batch_size * num_objects, 1)).to(device)
         self.threshold = 0.8
         self.device = device
         self.to_shared = LambdaModule(lambda x: rearrange(x, '(b o) c -> b (o c)', o = num_objects))
         self.slots_bounded_next_last = None

     def forward(
         self, 
         mask: th.Tensor = None, 
         maskraw: th.Tensor = None,
         reset_mask: bool = False,
         update: bool = True
     ):

        if mask is not None:

            # compute bounding mask
            slots_bounded_smooth_cur = reduce(mask[:,:-1], 'b o h w -> (b o) 1' , 'max').detach()
            slots_bounded_cur = (slots_bounded_smooth_cur > self.threshold).float().detach()
            if reset_mask:
                self.slots_bounded_next_last = slots_bounded_cur # allow immediate spawn
        
            if update:
                slots_bounded_cur = slots_bounded_cur * th.clip(self.slots_bounded_next_last + self.slots_bounded_all, 0, 1)
            else:
                self.slots_bounded_next_last = slots_bounded_cur
        
            if reset_mask:
                self.slots_bounded_smooth_all = slots_bounded_smooth_cur
                self.slots_bounded_all = slots_bounded_cur
            elif update:
                self.slots_bounded_all = th.maximum(self.slots_bounded_all, slots_bounded_cur)
                self.slots_bounded_smooth_all = th.maximum(self.slots_bounded_smooth_all, slots_bounded_smooth_cur)

            # compute occlusion mask
            slots_occluded_cur = self.slots_bounded_all - slots_bounded_cur

            # compute partially occluded mask
            mask = (mask[:,:-1] > self.threshold).float().detach()
            maskraw = (maskraw[:,:-1] > self.threshold).float().detach()
            masked = maskraw - mask

            masked = reduce(masked, 'b o h w -> (b o) 1' , 'sum')
            maskraw = reduce(maskraw, 'b o h w -> (b o) 1' , 'sum')

            slots_occlusionfactor_cur = (masked / (maskraw + 1)) * (1-slots_occluded_cur) + slots_occluded_cur
            slots_partially_occluded = (slots_occlusionfactor_cur > 0.1).float() #* slots_bounded_cur
            slots_fully_visible = (slots_occlusionfactor_cur <= 0.1).float() * slots_bounded_cur

            if reset_mask:
                self.slots_fully_visible_all = slots_fully_visible
            elif update:
                self.slots_fully_visible_all = th.maximum(self.slots_fully_visible_all, slots_fully_visible)

        return self.to_shared(self.slots_bounded_all), self.to_shared(self.slots_bounded_smooth_all), self.to_shared(slots_occluded_cur), self.to_shared(slots_partially_occluded), self.to_shared(slots_fully_visible), self.to_shared(slots_occlusionfactor_cur)

     def get_slots_fully_visible_all(self):
         return self.to_shared(self.slots_fully_visible_all)


class UpdateGate(nn.Module):
    def __init__(self, num_objects):
        super(UpdateGate, self).__init__()
        self.to_batch  = LambdaModule(lambda x: rearrange(x, 'b (o c) -> (b o) c', o = num_objects))
        self.to_shared = LambdaModule(lambda x: rearrange(x, '(b o) c -> b (o c)', o = num_objects))

    def forward(
        self, 
        tensor_cur: th.Tensor = None,
        tensor_last: th.Tensor = None,
        slot_mask: th.Tensor = None
    ):

        slot_mask = rearrange(slot_mask, 'b o -> (b o) 1')
        tensor_cur = slot_mask * self.to_batch(tensor_last) + (1 - slot_mask) * self.to_batch(tensor_cur)

        return self.to_shared(tensor_cur)

class Gaus2D(nn.Module):
    def __init__(self, size: Tuple[int, int]):
        super(Gaus2D, self).__init__()

        self.size = size

        self.register_buffer("grid_x", th.arange(size[0]), persistent=False)
        self.register_buffer("grid_y", th.arange(size[1]), persistent=False)

        self.grid_x = (self.grid_x / (size[0]-1)) * 2 - 1
        self.grid_y = (self.grid_y / (size[1]-1)) * 2 - 1

        self.grid_x = self.grid_x.view(1, 1, -1, 1).expand(1, 1, *size).clone()
        self.grid_y = self.grid_y.view(1, 1, 1, -1).expand(1, 1, *size).clone()

    def forward(self, input: th.Tensor):

        x   = rearrange(input[:,0:1], 'b c -> b c 1 1')
        y   = rearrange(input[:,1:2], 'b c -> b c 1 1')
        std = rearrange(input[:,2:3], 'b c -> b c 1 1')

        x   = th.clip(x, -1, 1)
        y   = th.clip(y, -1, 1)
        std = th.clip(std, 0, 1)
            
        max_size = max(self.size)
        std_x = (1 + max_size * std) / self.size[0]
        std_y = (1 + max_size * std) / self.size[1]

        return th.exp(-1 * ((self.grid_x - x)**2/(2 * std_x**2) + (self.grid_y - y)**2/(2 * std_y**2)))

class Vector2D(nn.Module):
    def __init__(self, size: Tuple[int, int]):
        super(Vector2D, self).__init__()

        self.size = size

        self.register_buffer("grid_x", th.arange(size[0]), persistent=False)
        self.register_buffer("grid_y", th.arange(size[1]), persistent=False)

        self.grid_x = (self.grid_x / (size[0]-1)) * 2 - 1
        self.grid_y = (self.grid_y / (size[1]-1)) * 2 - 1

        self.grid_x = self.grid_x.view(1, 1, -1, 1).expand(1, 3, *size).clone()
        self.grid_y = self.grid_y.view(1, 1, 1, -1).expand(1, 3, *size).clone()

    def forward(self, input: th.Tensor, vector: th.Tensor = None):

        x   = rearrange(input[:,0:1], 'b c -> b c 1 1')
        y   = rearrange(input[:,1:2], 'b c -> b c 1 1')
        if vector is not None:
            x_vec = rearrange(vector[:,0:1], 'b c -> b c 1 1')
            y_vec = rearrange(vector[:,1:2], 'b c -> b c 1 1')

        x   = th.clip(x, -1, 1)
        y   = th.clip(y, -1, 1)
        std = 0.01
            
        max_size = max(self.size)
        std_x = (1 + max_size * std) / self.size[0]
        std_y = (1 + max_size * std) / self.size[1]
        grid = th.exp(-1 * ((self.grid_x - x)**2/(2 * std_x**2) + (self.grid_y - y)**2/(2 * std_y**2)))

        # interpolating between start and end point
        if vector is not None:
            for length in np.linspace(0, 1, 11):
                x_end = th.clip(x + x_vec * length, -1, 1)
                y_end = th.clip(y + y_vec * length, -1, 1)

                grid_point = th.exp(-1 * ((self.grid_x - x_end)**2/(0.5 * std_x**2) + (self.grid_y - y_end)**2/(0.5 * std_y**2)))
                grid_point[:, 0:2, :, :] = 0
                grid = th.max(grid, grid_point)

        return grid

class SharedObjectsToBatch(nn.Module):
    def __init__(self, num_objects):
        super(SharedObjectsToBatch, self).__init__()

        self.num_objects = num_objects

    def forward(self, input: th.Tensor):
        return rearrange(input, 'b (o c) h w -> (b o) c h w', o=self.num_objects)

class BatchToSharedObjects(nn.Module):
    def __init__(self, num_objects):
        super(BatchToSharedObjects, self).__init__()

        self.num_objects = num_objects

    def forward(self, input: th.Tensor):
        return rearrange(input, '(b o) c h w -> b (o c) h w', o=self.num_objects)

class LambdaModule(nn.Module):
    def __init__(self, lambd):
        super().__init__()
        import types
        assert type(lambd) is types.LambdaType
        self.lambd = lambd

    def forward(self, *x):
        return self.lambd(*x)

class PrintGradientFunction(Function):
    @staticmethod
    def forward(ctx, tensor, msg):
        ctx.msg = msg
        return tensor

    @staticmethod
    def backward(ctx, grad_output):
        grad_input = grad_output.clone()
        print(f"{ctx.msg}: {th.mean(grad_output).item()} +- {th.std(grad_output).item()}")
        return grad_input, None

class PrintGradient(nn.Module):
    def __init__(self, msg = "PrintGradient"):
        super(PrintGradient, self).__init__()

        self.fcn = PrintGradientFunction.apply
        self.msg = msg

    def forward(self, input: th.Tensor):
        return self.fcn(input, self.msg)

class Prioritize(nn.Module):
    def __init__(self, num_objects):
        super(Prioritize, self).__init__()

        self.num_objects = num_objects
        self.to_batch    = SharedObjectsToBatch(num_objects)

    def forward(self, input: th.Tensor, priority: th.Tensor):
        
        if priority is None:
            return input

        batch_size = input.shape[0]
        weights    = th.zeros((batch_size, self.num_objects, self.num_objects, 1, 1), device=input.device)

        for o in range(self.num_objects):
            weights[:,o,:,0,0] = th.sigmoid(priority[:,:] - priority[:,o:o+1])
            weights[:,o,o,0,0] = weights[:,o,o,0,0] * 0

        input  = rearrange(input, 'b c h w -> 1 (b c) h w')
        weights = rearrange(weights, 'b o i 1 1 -> (b o) i 1 1')

        output = th.relu(input - nn.functional.conv2d(input, weights, groups=batch_size))
        output = rearrange(output, '1 (b c) h w -> b c h w ', b=batch_size)

        return output


class MultiArgSequential(nn.Sequential):
    def __init__(self, *args, **kwargs):
        super(MultiArgSequential, self).__init__(*args, **kwargs)

    def forward(self, *tensor):

        for n in range(len(self)):
            if isinstance(tensor, th.Tensor) or tensor == None:
                tensor = self[n](tensor)
            else:
                tensor = self[n](*tensor)

        return tensor

class SobelSecondOrder(nn.Module):
    def __init__(self):
        super(SobelSecondOrder, self).__init__()
        
        self.register_buffer('kernel_x', th.tensor([[[
            [1, 0,  -2, 0, 1],
            [4, 0,  -8, 0, 4],
            [6, 0, -12, 0, 6],
            [1, 0,  -2, 0, 1],
            [4, 0,  -8, 0, 4]
        ]]]).float())

        self.register_buffer('kernel_y', th.tensor([[[
            [ 1,  4,   6,  4,  1],
            [ 0,  0,   0,  0,  0],
            [-2, -8, -12, -8, -2],
            [ 0,  0,   0,  0,  0],
            [ 1,  4,   6,  4,  1]
        ]]]).float())

        self.register_buffer('kernel_xy', th.tensor([[[
            [ 1,  2,   0, -2, -1],
            [ 2,  4,   0, -4, -2],
            [ 0,  0,   0,  0,  0],
            [-2, -4,   0,  4,  2],
            [-1, -2,   0,  2,  1]
        ]]]).float())


    def forward(self, input: th.Tensor):
        c = input.shape[1]
        input = rearrange(input, 'b c h w -> (b c) 1 h w')
        input = th.nn.functional.pad(input, (2, 2, 2, 2), mode='replicate')
        
        x  = th.nn.functional.conv2d(input, self.kernel_x)
        y  = th.nn.functional.conv2d(input, self.kernel_y)
        xy = th.nn.functional.conv2d(input, self.kernel_xy)

        x  = rearrange(x,  '(b c) 1 h w -> b c h w', c = c)
        y  = rearrange(y,  '(b c) 1 h w -> b c h w', c = c)
        xy = rearrange(xy, '(b c) 1 h w -> b c h w', c = c)

        return x, y, xy

class GausBlur(nn.Module):
    def __init__(self, size=5):
        super(GausBlur, self).__init__()

        self.size = size

        grid_x = ((th.arange(end=size) / (size - 1)) * 2 - 1) * 3
        grid_y = ((th.arange(end=size) / (size - 1)) * 2 - 1) * 3

        grid_x = grid_x.view(1, 1, -1, 1).expand(1, 1, size, size)
        grid_y = grid_y.view(1, 1, 1, -1).expand(1, 1, size, size)

        self.register_buffer("kernel", th.exp(grid_x**2/-2 + grid_y**2/-2))
        self.kernel = self.kernel / th.sum(self.kernel)

    def forward(self, input: th.Tensor):
        c = input.shape[1]
        input = rearrange(input, 'b c h w -> (b c) 1 h w')
        input = th.nn.functional.pad(input, (self.size // 2,)*4, mode='replicate')
        return rearrange(nn.functional.conv2d(input, self.kernel), '(b c) 1 h w -> b c h w', c = c)


def create_grid(size):
    grid_x = th.arange(size[0])
    grid_y = th.arange(size[1])

    grid_x = (grid_x / (size[0]-1)) * 2 - 1
    grid_y = (grid_y / (size[1]-1)) * 2 - 1

    grid_x = grid_x.view(1, 1, -1, 1).expand(1, 1, *size).clone()
    grid_y = grid_y.view(1, 1, 1, -1).expand(1, 1, *size).clone()

    return th.cat((grid_y, grid_x), dim=1)

class Warp(nn.Module):
    def __init__(self, size, padding = 0.1):
        super(Warp, self).__init__()

        padding = int(max(size) * padding)
        padded_size = (size[0] + 2 * padding, size[1] + 2 * padding)

        self.register_buffer('grid', create_grid(size))
        self.register_buffer('padded_grid', create_grid(padded_size))

        self.replication_pad = nn.ReplicationPad2d(padding)
        self.interpolate = nn.Sequential(
            LambdaModule(lambda x:
                th.nn.functional.interpolate(x, size=size, mode='bicubic', align_corners = True)
            ),
            LambdaModule(lambda x: x - self.grid),
            nn.ConstantPad2d(padding, 0),
            LambdaModule(lambda x: x + self.padded_grid),
            LambdaModule(lambda x: rearrange(x, 'b c h w -> b h w c'))
        )

        self.warp = LambdaModule(lambda input, flow:
            th.nn.functional.grid_sample(input, flow, mode='bicubic', align_corners=True)
        )

        self.un_pad = LambdaModule(lambda x: x[:,:,padding:-padding,padding:-padding])
    
    def get_raw_flow(self, flow):
        return flow - self.grid

    def forward(self, input, flow):
        input = self.replication_pad(input)
        flow  = self.interpolate(flow)
        return self.un_pad(self.warp(input, flow))

class Binarize(nn.Module):
    def __init__(self):
        super(Binarize, self).__init__()

    def forward(self, input: th.Tensor):
        input = th.sigmoid(input)
        if not self.training:
            return th.round(input)

        return input + input * (1 - input) * th.randn_like(input)

class PartialBinarize(nn.Module):
    def __init__(self, gestalt_size: int):
        self.binarize_first = round(gestalt_size * 0.8)
        print(f'Partial Binarize: binarize first {self.binarize_first} channels of {gestalt_size} channels')
        super(PartialBinarize, self).__init__()

    def forward(self, input: th.Tensor):
        input = th.sigmoid(input)
        bin_part = input[:, :self.binarize_first]
        flex_part = input[:, self.binarize_first:]

        if not self.training:
            bin_part = th.round(bin_part)
        else:
            bin_part = bin_part + bin_part * (1 - bin_part) * th.randn_like(bin_part)
            
        return th.cat((bin_part, flex_part), dim=1)

class Shape2D(nn.Module):
    def __init__(self, gestalt_size, size: Tuple[int, int]):
        super(Shape2D, self).__init__()

        self.size = size

        self.register_buffer("grid_x", th.arange(size[0]), persistent=False)
        self.register_buffer("grid_y", th.arange(size[1]), persistent=False)

        self.grid_x = (self.grid_x / (size[0]-1)) * 2 - 1
        self.grid_y = (self.grid_y / (size[1]-1)) * 2 - 1

        self.grid_x = self.grid_x.view(1, 1, -1, 1).expand(1, 1, *size).clone()
        self.grid_y = self.grid_y.view(1, 1, 1, -1).expand(1, 1, *size).clone()
        
        self.shape = nn.Sequential(
            nn.Conv2d(2, 32, kernel_size = 1),
            nn.Tanh(),
            nn.Conv2d(32, 32, kernel_size = 1),
            nn.Tanh(),
            nn.Conv2d(32, 64, kernel_size = 1),
            nn.Tanh(),
            LambdaModule(lambda x: rearrange(x, 'b (c n) h w -> b c n h w', n = 2)),
        )

        self.selection = nn.Sequential(
            nn.Conv2d(gestalt_size, 64, kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=1),
            LambdaModule(lambda x: rearrange(x, 'b (c n) 1 1 -> b c n 1 1', n = 2)),
            nn.Softmax(dim=1)
        )

    def forward(self, position: th.Tensor, gestalt: th.Tensor):

        x   = rearrange(position[:,0:1], 'b c -> b c 1 1')
        y   = rearrange(position[:,1:2], 'b c -> b c 1 1')
        std = rearrange(position[:,2:3], 'b c -> b c 1 1')

        x   = th.clip(x, -1, 1)
        y   = th.clip(y, -1, 1)
        std = th.clip(std, 0, 1)
            
        max_size = max(self.size)
        std_x = (1 + max_size * std) / self.size[0]
        std_y = (1 + max_size * std) / self.size[1]

        grid  = th.cat((self.grid_x - x, self.grid_y - y), dim=1)
        shape = self.shape(grid) * self.selection(gestalt)
        shape = reduce(shape, 'b c n h w -> b n h w', 'sum')
        print(f'Shape: {th.mean(shape).item():.2e} +- {th.std(shape).item():.2e}')
        grid  = grid + shape
        grid_x, grid_y = th.chunk(grid, 2, dim=1)

        return th.exp(-1 * (grid_x**2/(2 * std_x**2) + grid_y**2/(2 * std_y**2)))
