import torch.nn as nn
import torch as th
import numpy as np
from torch.autograd import Function
from einops import rearrange, repeat, reduce
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.transforms import Resize, InterpolationMode


from typing import Tuple, Union, List
import utils
import cv2


class Permute(nn.Module):
    def __init__(self, *args):
        super(Permute, self).__init__()
        self.args = args

    def forward(self, input: th.Tensor) -> th.Tensor:
        return th.permute(input, *self.args)

class PrintShape(nn.Module):
    def __init__(self, msg = ""):
        super(PrintShape, self).__init__()
        self.msg = msg

    def forward(self, input: th.Tensor):
        if self.msg != "":
            print(self.msg, input.shape)
        else:
            print(input.shape)
        return input

class PrintStats(nn.Module):
    def __init__(self):
        super(PrintStats, self).__init__()

    def forward(self, input: th.Tensor):
        print(
            "min: ", th.min(input).detach().cpu().numpy(),
            ", mean: ", th.mean(input).detach().cpu().numpy(),
            ", max: ", th.max(input).detach().cpu().numpy()
        )
        return input

class PushToInfFunction(Function):
    @staticmethod
    def forward(ctx, tensor):
        ctx.save_for_backward(tensor)
        return tensor.clone()

    @staticmethod
    def backward(ctx, grad_output):
        tensor = ctx.saved_tensors[0]
        grad_input = -th.ones_like(grad_output)
        return grad_input

class PushToInf(nn.Module):
    def __init__(self):
        super(PushToInf, self).__init__()
        
        self.fcn = PushToInfFunction.apply

    def forward(self, input: th.Tensor):
        return self.fcn(input)

class ForcedAlpha(nn.Module):
    def __init__(self, speed = 1):
        super(ForcedAlpha, self).__init__()

        self.init   = nn.Parameter(th.zeros(1))
        self.speed  = speed
        self.to_inf = PushToInf()

    def item(self):
        return th.tanh(self.to_inf(self.init * self.speed)).item()

    def forward(self, input: th.Tensor):
        return input * th.tanh(self.to_inf(self.init * self.speed))

class AlphaThreshold(nn.Module):
    def __init__(self, max_value = 1):
        super(AlphaThreshold, self).__init__()

        self.init      = nn.Parameter(th.zeros(1))
        self.to_inf    = PushToInf()
        self.max_value = max_value

    def forward(self):
        return th.tanh(self.to_inf(self.init)) * self.max_value

class TanhAlpha(nn.Module):
    def __init__(self, start = 0, stepsize = 1e-4, max_value = 1):
        super(TanhAlpha, self).__init__()

        self.register_buffer('init', th.zeros(1) + start)
        self.stepsize  = stepsize
        self.max_value = max_value

    def get(self):
        return (th.tanh(self.init) * self.max_value).item()

    def forward(self):
        self.init = self.init.detach() + self.stepsize
        return self.get()

class MultiArgSequential(nn.Sequential):
    def __init__(self, *args, **kwargs):
        super(MultiArgSequential, self).__init__(*args, **kwargs)

    def forward(self, *tensor):

        for n in range(len(self)):
            if isinstance(tensor, th.Tensor) or tensor == None:
                tensor = self[n](tensor)
            else:
                tensor = self[n](*tensor)

        return tensor


class InitialLatentStates(nn.Module):
    def __init__(
            self, 
            gestalt_size: int, 
            num_objects: int, 
            size: Tuple[int, int],
            object_permanence_strength: int,
            entity_pretraining_steps: int
        ):
        super(InitialLatentStates, self).__init__()
        self.object_permanence_strength = object_permanence_strength

        self.num_objects  = num_objects
        self.gestalt_size = gestalt_size
        self.gestalt_mean = nn.Parameter(th.zeros(1, gestalt_size))
        self.gestalt_std  = nn.Parameter(th.ones(1, gestalt_size))
        self.std          = nn.Parameter(th.zeros(1)-5)
        self.depth        = nn.Parameter(th.zeros(1)+5)

        self.init = TanhAlpha(start = -1e-4 * entity_pretraining_steps)
        self.register_buffer('priority', th.zeros(num_objects)-3, persistent=False)
        #self.register_buffer('priority', th.arange(num_objects).float() * -1 - 3, persistent=False) #TODO try
        self.register_buffer('threshold', th.ones(1) * 0.75)
        self.last_mask = None

        self.gaus2d = nn.Sequential(
            Gaus2D((size[0] // 16, size[1] // 16)),
            Gaus2D((size[0] //  4, size[1] //  4)),
            Gaus2D(size)
        )

        self.level = 2

        self.to_batch  = LambdaModule(lambda x: rearrange(x, 'b (o c) -> (b o) c', o = num_objects))
        self.to_shared = LambdaModule(lambda x: rearrange(x, '(b o) c -> b (o c)', o = num_objects))

        self.i = 0

    def load_pretrained(self, state):
        if "gestalt_mean" in state:
            self.gestalt_mean.data = state['gestalt_mean']

        if "gestalt_std" in state:
            self.gestalt_std.data = state['gestalt_std']

        if "std" in state:
            self.std.data = state['std']

        if "depth" in state:
            self.depth.data = state['depth']

    def reset_state(self):
        self.last_mask = None

    def set_level(self, level):
        self.level = level

    def forward(
        self, 
        error: th.Tensor, 
        mask: th.Tensor = None, 
        position: th.Tensor = None,
        gestalt: th.Tensor = None,
        priority: th.Tensor = None
    ):

        batch_size = error.shape[0]
        device     = error.device

        object_permanence_strength = self.object_permanence_strength
        if self.init.get() < 1 and self.init.get() < object_permanence_strength:
            object_permanence_strength = self.init()

        if self.last_mask is None:
            self.last_mask = th.zeros((batch_size * self.num_objects, 1), device = device)

        if mask is not None:
            mask = reduce(mask[:,:-1], 'b c h w -> (b c) 1' , 'max').detach()

            if object_permanence_strength <= 0:
                self.last_mask = mask
            elif object_permanence_strength < 1:
                self.last_mask = th.maximum(self.last_mask, mask)
                self.last_mask = self.last_mask - th.relu(-1 * (mask - self.threshold) * (1 - object_permanence_strength))
            else:
                self.last_mask = th.maximum(self.last_mask, mask)

        mask = (self.last_mask > self.threshold).float().detach()

        gestalt_rand = th.randn((batch_size * self.num_objects, self.gestalt_size), device = device)
        gestalt_new  = th.sigmoid(gestalt_rand * self.gestalt_std + self.gestalt_mean)

        if gestalt is None:
            gestalt = gestalt_new
        else:
            gestalt = self.to_batch(gestalt) * mask + gestalt_new * (1 - mask)

        if priority is None:
            priority = repeat(self.priority, 'o -> (b o) 1', b = batch_size)
        else:
            priority = self.to_batch(priority) * mask + repeat(self.priority, 'o -> (b o) 1', b = batch_size) * (1 - mask)

        # FIXME rename error to uncertainty
        # if we have no sigificant error, we can sample positions everywhere
        error_mask = (reduce(error, 'b c h w -> b c 1 1', 'max') > 0.1).float()
        error = error * error_mask + th.rand_like(error) * (1 - error_mask) # TODO replace with actual prediction error
        
        # Normalize error map to form a probability distribution and flatten it. 
        # Sample 'num_objects' number of indices from this distribution with replacement, 
        # and convert these indices into image x and y positions.
        error_map_normalized = error / th.sum(error, dim=(1,2,3), keepdim=True)
        error_map_flat = error_map_normalized.view(batch_size, -1)
        sampled_indices = th.multinomial(error_map_flat, num_samples=self.num_objects, replacement=True)
        y_positions = sampled_indices // error.shape[3]
        x_positions = sampled_indices % error.shape[3]

        # Convert positions from range [0, error.shape] to range [-1, 1]
        x_positions = x_positions.float() / (error.shape[3] / 2.0) - 1
        y_positions = y_positions.float() / (error.shape[2] / 2.0) - 1

        x_positions = self.to_batch(x_positions)
        y_positions = self.to_batch(y_positions)

        std = repeat(self.std, '1 -> (b o) 1', b = batch_size, o = self.num_objects)

        if position is None:
            z = repeat(self.depth, '1 -> (b o) 1', b = batch_size, o = self.num_objects)
            position = th.cat((x_positions, y_positions, z, std), dim=-1)
        else:
            z = repeat(self.depth, '1 -> (b o) 1', b = batch_size, o = self.num_objects)
            position = self.to_batch(position) * mask + th.cat((x_positions, y_positions, z, std), dim=1) * (1 - mask)

        """
        print(mask.cpu().numpy()[:,0], priority.cpu().numpy()[:,0])
        # save debuggin info
        position2d = self.gaus2d[self.level](th.cat((position[:,:-1], th.maximum(std+3, position[:,-1:])), dim=-1))
        position2d = th.sum(position2d, dim=0, keepdim=True)

        img = th.zeros((1, 3, error.shape[2], error.shape[3]), device = device) + error
        img[0,0] = position2d[0,0]

        # save with cv2
        img = img.cpu().numpy()
        img = np.transpose(img, (0, 2, 3, 1))
        img = cv2.cvtColor(img[0], cv2.COLOR_RGB2BGR)
        cv2.imwrite(f"position-debug-{self.i:03d}.png", img * 255)
        self.i += 1
        """

        return self.to_shared(position), self.to_shared(gestalt), self.to_shared(priority)

class Gaus2D(nn.Module):
    def __init__(self, size = None, position_limit = 1):
        super(Gaus2D, self).__init__()
        self.size = size
        self.position_limit = position_limit
        self.min_std = 0.1
        self.max_std = 0.5

        self.register_buffer("grid_x", th.zeros(1,1,1,1), persistent=False)
        self.register_buffer("grid_y", th.zeros(1,1,1,1), persistent=False)

        if size is not None:
            self.min_std = 1.0 / min(size)
            self.update_grid(size)

        print(f"Gaus2D: min std: {self.min_std}")

    def update_grid(self, size):

        if size != self.grid_x.shape[2:]:
            self.size    = size
            self.min_std = 1.0 / min(size)
            H, W = size

            self.grid_x = th.arange(W, device=self.grid_x.device)
            self.grid_y = th.arange(H, device=self.grid_x.device)

            self.grid_x = (self.grid_x / (W-1)) * 2 - 1
            self.grid_y = (self.grid_y / (H-1)) * 2 - 1

            self.grid_x = self.grid_x.view(1, 1, 1, -1).expand(1, 1, H, W).clone()
            self.grid_y = self.grid_y.view(1, 1, -1, 1).expand(1, 1, H, W).clone()

    def forward(self, input: th.Tensor, compute_std = True):
        assert input.shape[1] >= 2 and input.shape[1] <= 4
        H, W = self.size

        x   = rearrange(input[:,0:1], 'b c -> b c 1 1')
        y   = rearrange(input[:,1:2], 'b c -> b c 1 1')
        std = th.zeros_like(x)

        if input.shape[1] == 3:
            std = rearrange(input[:,2:3], 'b c -> b c 1 1')

        if input.shape[1] == 4:
            std = rearrange(input[:,3:4], 'b c -> b c 1 1')

        x   = th.clip(x, -self.position_limit, self.position_limit)
        y   = th.clip(y, -self.position_limit, self.position_limit)

        if compute_std:
            std = th.sigmoid(std) * (self.max_std - self.min_std) + self.min_std
        else:
            std = th.clip(std, self.min_std, self.max_std)
            
        std_y = std.clone()
        std_x = std * (H / W)

        return th.exp(-1 * ((self.grid_x - x)**2/(2 * std_x**2) + (self.grid_y - y)**2/(2 * std_y**2)))

class SharedObjectsToBatch(nn.Module):
    def __init__(self, num_objects):
        super(SharedObjectsToBatch, self).__init__()

        self.num_objects = num_objects

    def forward(self, input: th.Tensor):
        return rearrange(input, 'b (o c) h w -> (b o) c h w', o=self.num_objects)

class BatchToSharedObjects(nn.Module):
    def __init__(self, num_objects):
        super(BatchToSharedObjects, self).__init__()

        self.num_objects = num_objects

    def forward(self, input: th.Tensor):
        return rearrange(input, '(b o) c h w -> b (o c) h w', o=self.num_objects)

class LambdaModule(nn.Module):
    def __init__(self, lambd):
        super().__init__()
        import types
        assert type(lambd) is types.LambdaType
        self.lambd = lambd

    def forward(self, *x):
        return self.lambd(*x)

class PrintGradientFunction(Function):
    @staticmethod
    def forward(ctx, tensor, msg):
        ctx.msg = msg
        print(f"Forward: {msg}: {th.min(tensor).item():.2e}, {th.mean(tensor).item():.2e} +- {th.std(tensor).item():.2e}, {th.max(tensor).item():.2e}")
        return tensor

    @staticmethod
    def backward(ctx, grad_output):
        grad_input = grad_output.clone()
        print(f"{ctx.msg}: {th.min(grad_input).item():.2e}, {th.mean(grad_input).item():.2e} +- {th.std(grad_input).item():.2e}, {th.max(grad_input).item():.2e}")
        return grad_input, None

class PrintGradient(nn.Module):
    def __init__(self, msg = "PrintGradient"):
        super(PrintGradient, self).__init__()

        self.fcn = PrintGradientFunction.apply
        self.msg = msg

    def forward(self, input: th.Tensor):
        return self.fcn(input, self.msg)

class Prioritize(nn.Module):
    def __init__(self, num_objects):
        super(Prioritize, self).__init__()

        self.num_objects = num_objects

    def forward(self, input: th.Tensor, priority: th.Tensor):
        
        if priority is None:
            return input

        priority = priority * 250 + th.randn_like(priority) * 2.5

        batch_size = input.shape[0]
        weights    = th.zeros((batch_size, self.num_objects, self.num_objects, 1, 1), device=input.device)

        for o in range(self.num_objects):
            weights[:,o,:,0,0] = th.sigmoid(priority[:,:] - priority[:,o:o+1])
            weights[:,o,o,0,0] = weights[:,o,o,0,0] * 0

        input  = rearrange(input, 'b c h w -> 1 (b c) h w')
        weights = rearrange(weights, 'b o i 1 1 -> (b o) i 1 1')

        output = th.relu(input - nn.functional.conv2d(input, weights, groups=batch_size))
        output = rearrange(output, '1 (b c) h w -> b c h w ', b=batch_size)

        return output

class Binarize(nn.Module):
    def __init__(self):
        super(Binarize, self).__init__()

    def forward(self, input: th.Tensor):
        input = th.sigmoid(input)
        if not self.training:
            return th.round(input)

        return input + input * (1 - input) * th.randn_like(input)

class ProbabilisticBinarize(nn.Module):
    def __init__(self, channels):
        super(ProbabilisticBinarize, self).__init__()

        self.register_buffer('mean', th.zeros((1,channels)))
        self.register_buffer('bin_mean', th.zeros((1,channels)))

        self.register_buffer('numel', th.zeros((1,channels)))

        self.register_buffer('intervall', th.linspace(0, 1, channels))

    def forward(self, input: th.Tensor):
        input = th.sigmoid(input)
        if not self.training:
            return th.round(input)

        self.mean     = self.mean * 0.99 + th.sum(input, dim=0, keepdim=True).detach()
        self.bin_mean = self.bin_mean * 0.99 + th.sum(th.minimum(input, 1-input), dim=0, keepdim=True).detach()
        self.numel    = self.numel * 0.99 + input.shape[0]

        mean     = self.mean  / self.numel
        bin_mean = self.bin_mean  / self.numel
        
        bin_noise = th.bernoulli(mean.expand(input.shape))
        noise     = bin_noise - bin_mean * bin_noise + bin_mean * (1 - bin_noise)

        batch_size, channels = input.shape
        noise_prop = th.rand((batch_size,1), device=input.device)

        noise_mask = (self.intervall < noise_prop).float() 

        input = input * noise_mask + noise * (1 - noise_mask)
        return input + input * (1 - input) * th.randn_like(input)

class CapacityModulatedBinarize(nn.Module):
    def __init__(self, channels, min_capacity=4):
        super(CapacityModulatedBinarize, self).__init__()

        self.register_buffer('intervall', th.linspace(0, 1, channels).view(1, channels, 1, 1))
        self.capacity = 1.0
        self.min_capacity = min_capacity / channels

    def set_capacity(self, capacity):
        self.capacity = max(min(capacity, 1.0), self.min_capacity)
        return self.capacity

    def forward(self, input: th.Tensor):
        input = th.sigmoid(input)
        if not self.training:
            return th.round(input)

        noise = th.sigmoid(th.rand_like(input))

        noise_mask = (self.intervall < self.capacity).float() 

        input = input * noise_mask + noise * (1 - noise_mask)
        return input + input * (1 - input) * th.randn_like(input)


class MaskCenter(nn.Module):
    def __init__(self, size, normalize=True, combine=False):
        super(MaskCenter, self).__init__()
        self.combine = combine

        # Get the mask dimensions
        height, width = size

        # Create meshgrid of coordinates
        if normalize:
            x_range = th.linspace(-1, 1, width)
            y_range = th.linspace(-1, 1, height)
        else:
            x_range = th.linspace(0, width, width)
            y_range = th.linspace(0, height, height)

        y_coords, x_coords = th.meshgrid(y_range, x_range)

        # Broadcast the coordinates to match the mask shape
        self.register_buffer('x_coords', x_coords[None, None, :, :], persistent=False)
        self.register_buffer('y_coords', y_coords[None, None, :, :], persistent=False)

    def forward(self, mask):

        # Compute the center of the mask for each instance in the batch
        center_x = th.sum(self.x_coords * mask, dim=(2, 3)) / (th.sum(mask, dim=(2, 3)) + 1e-8)
        center_y = th.sum(self.y_coords * mask, dim=(2, 3)) / (th.sum(mask, dim=(2, 3)) + 1e-8)
        std      = (th.sum(mask, dim=(2, 3)) / th.sum(th.ones_like(mask), dim=(2, 3)))**0.5

        if self.combine:
            return th.cat((center_x, center_y, std), dim=-1)

        return th.cat((center_x, center_y), dim=-1), std

class PositionInMask(nn.Module):
    """
    Computes a random position that lies inside the mask
    """
    def __init__(self, size):
        super(PositionInMask, self).__init__()

        # Get the mask dimensions
        height, width = size

        # Create meshgrid of coordinates
        x_range = th.linspace(-1, 1, width)
        y_range = th.linspace(-1, 1, height)

        y_coords, x_coords = th.meshgrid(y_range, x_range)

        # Broadcast the coordinates to match the mask shape
        self.register_buffer('x_coords', x_coords[None, None, :, :], persistent=False)
        self.register_buffer('y_coords', y_coords[None, None, :, :], persistent=False)

    def forward(self, mask):

        B, C, H, W = mask.shape

        with th.no_grad():
            bin_mask     = (mask > 0.75).float()
            erroded_mask = 1 - th.nn.functional.max_pool2d(1 - bin_mask, kernel_size=5, stride=1, padding=2)

            use_center = (th.sum(erroded_mask, dim=(2, 3)) < 0.1).float()

            rand_mask = th.randn_like(erroded_mask) * erroded_mask * 1000
            rand_pixel = th.softmax(rand_mask.view(B, C, -1), dim=-1).view(B, C, H, W) * erroded_mask

            # Compute the center of the mask for each instance in the batch
            center_x = th.sum(self.x_coords * mask, dim=(2, 3)) / (th.sum(mask, dim=(2, 3)) + 1e-6)
            center_y = th.sum(self.y_coords * mask, dim=(2, 3)) / (th.sum(mask, dim=(2, 3)) + 1e-6)
            std      = (th.sum(mask, dim=(2, 3)) / th.sum(th.ones_like(mask), dim=(2, 3)))**0.5

            # compute the random position inside the mask for each instance in the batch
            rand_x = th.sum(self.x_coords * rand_pixel, dim=(2, 3))
            rand_y = th.sum(self.y_coords * rand_pixel, dim=(2, 3))
            
            center_pos = th.cat((center_x, center_y), dim=-1)
            rand_pos   = th.cat((rand_x, rand_y), dim=-1)

            return use_center * center_pos + (1 - use_center) * rand_pos, center_pos, std

        assert False, "This should never happen"
        return None


class RandomCropCentered(nn.Module):
    def __init__(self, crop_size, img_size):
        super(RandomCropCentered, self).__init__()
        self.crop_height, self.crop_width = crop_size
        self.mask_center_module = MaskCenter(img_size, normalize=False)

    def get_random_crop_coords(self, centers, img_height, img_width):
        crop_y_min = th.clamp(centers[..., 1] - self.crop_height, min=0, max=img_height - self.crop_height)
        crop_y_max = th.clamp(centers[..., 1],                    min=0, max=img_height - self.crop_height)
        crop_x_min = th.clamp(centers[..., 0] - self.crop_width,  min=0, max=img_width - self.crop_width)
        crop_x_max = th.clamp(centers[..., 0],                    min=0, max=img_width - self.crop_width)

        rand_y = th.rand_like(crop_y_min)
        rand_x = th.rand_like(crop_x_min)

        y1 = (rand_y * (crop_y_max - crop_y_min) + crop_y_min).long().view(-1, 1)
        x1 = (rand_x * (crop_x_max - crop_x_min) + crop_x_min).long().view(-1, 1)

        return y1, x1

    def crop(self, image, y1, x1):
        cropped_image = []
        for i in range(image.size(0)):
            cropped_image.append(image[i, :, y1[i]:y1[i]+self.crop_height, x1[i]:x1[i]+self.crop_width])

        return th.stack(cropped_image)

    def forward(self, rgb_image, depth_image, mask, crop_coords=None):
        if crop_coords is None:
            centers = self.mask_center_module(mask)
            crop_coords = self.get_random_crop_coords(centers, rgb_image.size(2), rgb_image.size(3))

        cropped_rgb_image = self.crop(rgb_image, *crop_coords)
        cropped_depth_image = self.crop(depth_image, *crop_coords)
        cropped_mask = self.crop(mask, *crop_coords)

        return cropped_rgb_image, cropped_depth_image, cropped_mask, crop_coords




class RandomUpscaleFlip(nn.Module):
    def __init__(self, max_scale, size):
        super(RandomUpscaleFlip, self).__init__()
        self.max_scale = max_scale
        self.crop = transforms.CenterCrop(size)

    def get_random_scale_factor(self):
        return th.rand(1).item() * (self.max_scale - 1) + 1

    def scale_image(self, image, scale_factor):
        height, width = image.shape[2:]
        new_height, new_width = int(height * scale_factor), int(width * scale_factor)

        resize = transforms.Resize((new_height, new_width))
        return resize(image)

    def random_flip(self, images):
        flip_prob = th.rand(1).item()
        if flip_prob < 0.5:
            return images
        else:
            flipped_images = [image.flip(-1) for image in images]
            return flipped_images

    def forward(self, rgb_image, depth_image, mask):
        scale_factor = self.get_random_scale_factor()

        scaled_rgb_image   = self.scale_image(rgb_image, scale_factor)
        scaled_depth_image = self.scale_image(depth_image, scale_factor)
        scaled_mask        = self.scale_image(mask, scale_factor)

        flipped_images = self.random_flip([scaled_rgb_image, scaled_depth_image, scaled_mask])

        cropped_rgb_image   = self.crop(flipped_images[0])
        cropped_depth_image = self.crop(flipped_images[1])
        cropped_mask        = self.crop(flipped_images[2])

        return cropped_rgb_image, cropped_depth_image, cropped_mask

class RandomHorizontalFlip(nn.Module):
    def __init__(self, flip_probability=0.5):
        super(RandomHorizontalFlip, self).__init__()
        self.flip_probability = flip_probability

    def __getstate__(self):
        return {"flip_probability": self.flip_probability}

    def __setstate__(self, state):
        self.flip_probability = state["flip_probability"]

    def flip(self, image):
        assert len(image.shape) == 3
        return th.flip(image, dims=(2,))

    def forward(self, rgb_image, depth_image, mask, seed:int=-1):

        if seed > 0:
            th.manual_seed(seed)
            print(f"Setting seed to {seed}")

        do_flip = th.rand(1).item() < self.flip_probability

        if do_flip:
            flipped_rgb_image = self.flip(rgb_image)
            flipped_depth_image = self.flip(depth_image)
            flipped_mask = self.flip(mask)
            return flipped_rgb_image, flipped_depth_image, flipped_mask

        return rgb_image, depth_image, mask

class RandomScale(nn.Module):
    def __init__(self, crop_size, max_scale_factor=1.15):
        super(RandomScale, self).__init__()
        self.max_scale_factor = max_scale_factor
        self.crop_size = crop_size

    def __getstate__(self):
        return {"max_scale_factor": self.max_scale_factor, "crop_size": self.crop_size}
    
    def __setstate__(self, state):
        self.max_scale_factor = state["max_scale_factor"]
        self.crop_size = state["crop_size"]

    def scale(self, image, scale_factor):
        new_height = max(int(image.shape[1] * scale_factor), self.crop_size[0])
        new_width = max(int(image.shape[2] * scale_factor), self.crop_size[1])
        return F.interpolate(image.unsqueeze(0), (new_height, new_width), mode='bilinear', align_corners=True)[0] # TODO use antialias ???

    def forward(self, rgb_image, depth_image, mask, bbox=None, seed:int=-1):
        assert len(rgb_image.shape) == 3

        if seed > 0:
            th.manual_seed(seed)
            print(f"Setting seed to {seed}")

        scale_factor_h = self.crop_size[0] / rgb_image.shape[1]
        scale_factor_w = self.crop_size[1] / rgb_image.shape[2]
        min_scale_factor = max(scale_factor_h, scale_factor_w)

        scale_factor = (self.max_scale_factor - min_scale_factor) * th.rand(1, device=rgb_image.device) + min_scale_factor

        scaled_rgb_image   = self.scale(rgb_image, scale_factor)
        scaled_depth_image = self.scale(depth_image, scale_factor)
        scaled_mask        = self.scale(mask, scale_factor)

        if bbox is not None:
            bbox = bbox * scale_factor

        return scaled_rgb_image, scaled_depth_image, scaled_mask, bbox

class BBoxCropTorch(nn.Module):
    def __init__(self, crop_size):
        super(BBoxCropTorch, self).__init__()
        self.crop_height, self.crop_width = crop_size

    def __getstate__(self):
        return {"crop_height": self.crop_height, "crop_width": self.crop_width}

    def __setstate__(self, state):
        self.crop_height = state["crop_height"]
        self.crop_width = state["crop_width"]

    def get_random_crop_coords(self, bbox, img_height: int, img_width: int):
        bbox_x_min, bbox_y_min, bbox_x_max, bbox_y_max = [int(x) for x in bbox]

        crop_y_min = min(max(bbox_y_max - self.crop_height + 1, 0), img_height - self.crop_height)
        crop_y_max = max(min(bbox_y_min, img_height - self.crop_height), crop_y_min)
        crop_x_min = min(max(bbox_x_max - self.crop_width + 1, 0), img_width - self.crop_width)
        crop_x_max = max(min(bbox_x_min, img_width - self.crop_width), crop_x_min)

        if bbox_x_max - bbox_x_min > self.crop_width:
            crop_x_min = bbox_x_min
            crop_x_max = bbox_x_max - self.crop_width + 1

        if bbox_y_max - bbox_y_min > self.crop_height:
            crop_y_min = bbox_y_min
            crop_y_max = bbox_y_max - self.crop_height + 1

        rand_y = th.randint(0, crop_y_max - crop_y_min, (1,)).item() if crop_y_max - crop_y_min > 0 else 0
        rand_x = th.randint(0, crop_x_max - crop_x_min, (1,)).item() if crop_x_max - crop_x_min > 0 else 0

        y1 = int(rand_y + crop_y_min)
        x1 = int(rand_x + crop_x_min)

        return y1, x1

    def crop(self, image, y1: int, x1: int):
        cropped_image = image[:, y1:y1 + self.crop_height, x1:x1 + self.crop_width]
        return cropped_image

    def forward(self, rgb_image, depth_image, mask, bbox):
        assert len(rgb_image.shape) == 3
        if rgb_image.shape[1] == self.crop_height and rgb_image.shape[2] == self.crop_width:
            return rgb_image, depth_image, mask

        if rgb_image.shape[1] < self.crop_height or rgb_image.shape[2] < self.crop_width:
            raise ValueError('Crop size is larger than image size')

        crop_coords = self.get_random_crop_coords(bbox, rgb_image.shape[1], rgb_image.shape[2])

        cropped_rgb_image = self.crop(rgb_image, *crop_coords)
        cropped_depth_image = self.crop(depth_image, *crop_coords)
        cropped_mask = self.crop(mask, *crop_coords)

        return cropped_rgb_image, cropped_depth_image, cropped_mask

# RGB to YCbCr
class RGB2YCbCr(nn.Module):
    def __init__(self):
        super(RGB2YCbCr, self).__init__()

        kr = 0.299
        kg = 0.587
        kb = 0.114

        # The transformation matrix from RGB to YCbCr (ITU-R BT.601 conversion)
        self.register_buffer("matrix", th.tensor([
            [                  kr,                  kg,                    kb],
            [-0.5 * kr / (1 - kb), -0.5 * kg / (1 - kb),                  0.5],
            [                 0.5, -0.5 * kg / (1 - kr), -0.5 * kb / (1 - kr)]
        ]).t(), persistent=False)

        # Adjustments for each channel
        self.register_buffer("shift", th.tensor([0., 0.5, 0.5]), persistent=False)

    def forward(self, img):
        if len(img.shape) != 4 or img.shape[1] != 3:
            raise ValueError('Input image must be 4D tensor with a size of 3 in the second dimension.')

        return th.tensordot(img.permute(0, 2, 3, 1), self.matrix, dims=1).permute(0, 3, 1, 2) + self.shift[None, :, None, None]

# RGBD to YCbCr
class RGBD2YCbCr(nn.Module):
    def __init__(self):
        super(RGBD2YCbCr, self).__init__()

        kr = 0.299
        kg = 0.587
        kb = 0.114

        # The transformation matrix from RGB to YCbCr (ITU-R BT.601 conversion)
        self.register_buffer("matrix", th.tensor([
            [                  kr,                  kg,                    kb],
            [-0.5 * kr / (1 - kb), -0.5 * kg / (1 - kb),                  0.5],
            [                 0.5, -0.5 * kg / (1 - kr), -0.5 * kb / (1 - kr)]
        ]).t(), persistent=False)

        # Adjustments for each channel
        self.register_buffer("shift", th.tensor([0., 0.5, 0.5]), persistent=False)

    def forward(self, img):
        if len(img.shape) != 4 or img.shape[1] != 4:
            raise ValueError('Input image must be 4D tensor with a size of 4 in the second dimension.')

        ycbcr = th.tensordot(img[:,:3].permute(0, 2, 3, 1), self.matrix, dims=1).permute(0, 3, 1, 2) + self.shift[None, :, None, None]
        return th.cat((ycbcr, img[:,3:]), dim=1)


# YCbCr to RGB
class YCbCr2RGB(nn.Module):
    def __init__(self):
        super(YCbCr2RGB, self).__init__()

        kr = 0.299
        kg = 0.587
        kb = 0.114

        # The transformation matrix from YCbCr to RGB (ITU-R BT.601 conversion)
        self.register_buffer("matrix", th.tensor([
            [1,                       0,              2 - 2 * kr],
            [1, -kb / kg * (2 - 2 * kb), -kr / kg * (2 - 2 * kr)],
            [1,              2 - 2 * kb,                       0]
        ]).t(), persistent=False)

        # Adjustments for each channel
        self.register_buffer("shift", th.tensor([0., 0.5, 0.5]), persistent=False)

    def forward(self, img):
        if len(img.shape) != 4 or img.shape[1] != 3:
            raise ValueError('Input image must be 4D tensor with a size of 3 in the second dimension.')

        result = th.tensordot((img - self.shift[None, :, None, None]).permute(0, 2, 3, 1), self.matrix, dims=1).permute(0, 3, 1, 2)

        # Clamp the results to the valid range for RGB [0, 1]
        return th.clamp(result, 0, 1)

class Conv1x1(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Conv1x1, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.linear = nn.Linear(in_channels, out_channels)

    def forward(self, x):
        batch_size, channels, height, width = x.size()
        x = x.permute(0, 2, 3, 1).contiguous()
        x = x.view(-1, self.in_channels)
        x = self.linear(x)
        x = x.view(batch_size, height, width, self.out_channels)
        x = x.permute(0, 3, 1, 2).contiguous()
        return x

class GradientScaler(nn.Module):
    class ScaleGrad(Function):
        @staticmethod
        def forward(ctx, input_tensor, scale):
            ctx.scale = scale
            return input_tensor.clone()

        @staticmethod
        def backward(ctx, grad_output):
            return grad_output * ctx.scale, None

    def __init__(self, scale):
        super().__init__()
        self.scale = scale

    def forward(self, input):
        return self.ScaleGrad.apply(input, self.scale)

class RGB2Structure(nn.Module):
    def __init__(self):
        super(RGB2Structure, self).__init__()

        # Define Sobel kernels
        sobel_kernel_x = th.tensor([[-2., -1., 0., 1., 2.],
                                    [-3., -2., 0., 2., 3.],
                                    [-4., -3., 0., 3., 4.],
                                    [-3., -2., 0., 2., 3.],
                                    [-2., -1., 0., 1., 2.]], dtype=th.float32).unsqueeze(0).unsqueeze(0)

        sobel_kernel_y = th.tensor([[-2., -3., -4., -3., -2.],
                                    [-1., -2., -3., -2., -1.],
                                    [ 0.,  0.,  0.,  0.,  0.],
                                    [ 1.,  2.,  3.,  2.,  1.],
                                    [ 2.,  3.,  4.,  3.,  2.]], dtype=th.float32).unsqueeze(0).unsqueeze(0)

        # normalize kernels
        self.register_buffer('sobel_kernel_x', sobel_kernel_x / 16.)
        self.register_buffer('sobel_kernel_y', sobel_kernel_y / 16.)

        self.to_ycbcr = RGB2YCbCr()

    def forward(self, x):
        
        y = self.to_ycbcr(x[:,:3])[:,0:1]

        # Apply Sobel filters
        sobel_x = F.conv2d(y, self.sobel_kernel_x, padding=2)
        sobel_y = F.conv2d(y, self.sobel_kernel_y, padding=2)

        # Stack the results in the channel dimension
        return th.cat([y, sobel_x, sobel_y, x[:,3:]], dim=1)

class RadomSimilarityBasedMaskDrop(nn.Module):
    def __init__(self, sigma_scale = 25):
        super(RadomSimilarityBasedMaskDrop, self).__init__()
        self.sigma_scale = sigma_scale

        
    def distance_weights(self, positions):
        sigma = positions[:,:,-1]
        positions = positions[:,:,:-1]

        # give more weight to z distance
        positions = th.cat((positions, positions[:,:,2:3]), dim=2)
        
        # Expand dims to compute pairwise differences
        p1 = positions[:, :, None, :]
        p2 = positions[:, None, :, :]

        # expand sigma
        sigma1 = sigma[:,:,None]
        sigma2 = sigma[:,None,:]
        
        # Compute pairwise differences and squared Euclidean distance
        diff = p1 - p2
        squared_diff = diff ** 2
        squared_distances = th.sum(squared_diff, dim=-1)

        var = sigma1 * sigma2 * self.sigma_scale
        
        # Compute the actual distances
        distances = th.sqrt(squared_distances)
        weights = th.exp(-distances / (2 * var + 1e-5))
        
        return weights

    def batch_covariance(self, slots):
        mean_slots = th.mean(slots, dim=1, keepdim=True)
        centered_slots = slots - mean_slots
        cov_matrix = th.bmm(centered_slots.transpose(1, 2), centered_slots) / (slots.size(1) - 1)
        return cov_matrix

    def batch_correlation(self, slots):
        cov_matrix = self.batch_covariance(slots)
        variances = th.diagonal(cov_matrix, dim1=-2, dim2=-1)
        std_matrix = th.sqrt(variances[:, :, None] * variances[:, None, :])
        corr_matrix = cov_matrix / std_matrix
        return corr_matrix

    def get_drop_mask(self, similarity_matrix):
        similarity_matrix = th.relu(similarity_matrix)
        mean_similarity   = th.mean(similarity_matrix)
        similarity_matrix = th.relu(similarity_matrix - mean_similarity) / (1 - mean_similarity)
        similarity_matrix = th.triu(similarity_matrix) * (1 - th.eye(similarity_matrix.shape[-1], device=similarity_matrix.device))
        drop_propability  = reduce(similarity_matrix, 'b n m -> b n', 'max')
        #return (drop_propability < th.rand_like(drop_propability)).float()
        return (drop_propability < 0.00001).float()

    def forward(self, position, gestalt, mask):
        num_objects = mask.shape[1]

        gestalt  = rearrange(gestalt,  'b (o c) -> b c o', o = num_objects)
        position = rearrange(position, 'b (o c) -> b c o', o = num_objects)

        visible  = th.softmax(th.cat((mask, th.ones_like(mask[:,:1])), dim=1), dim=1) 
        visible  = (reduce(visible[:,:-1], 'b c h w -> b 1 c', 'max') > 0.75).float().detach()
        gestalt  = gestalt.detach()  * visible  + (1 - visible) * (0.5 + th.randn_like(gestalt)*0.01)
        position = position.detach() * visible

        weights      = self.distance_weights(rearrange(position, 'b c o -> b o c'))
        slot_corr    = self.batch_correlation(gestalt) * weights 
        #slot_corr    = self.batch_correlation(th.cat((gestalt, position[:,-1:]), dim=1)) * weights 
        drop_mask    = self.get_drop_mask(slot_corr).unsqueeze(-1).unsqueeze(-1).detach()

        return mask * drop_mask - 10 * (1 - drop_mask)
