import torch as th
import torch.nn as nn
import numpy as np
from utils.utils import TanhAlpha, LambdaModule, Gaus2D
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange
from nn.downscale import MemoryEfficientPatchDownScale
from nn.vit_v2 import AttentionLayer, AttentionSum


class ObjectDiscovery(nn.Module):
    def __init__(
            self, 
            num_objects: int, 
            gestalt_size: int,
            object_permanence_strength: int,
            entity_pretraining_steps: int
        ):
        super(ObjectDiscovery, self).__init__()
        self.object_permanence_strength = object_permanence_strength
        self.gestalt_size = gestalt_size

        if object_permanence_strength < 0 or object_permanence_strength > 1:
            raise ValueError("object_permanence_strength must be in (0, 1)")
        
        if entity_pretraining_steps < 0:
            raise ValueError("entity_pretraining_steps must be > 0")

        if entity_pretraining_steps > 1e4:
            raise ValueError("entity_pretraining_steps must be < 1e4")

        self.num_objects  = num_objects
        self.std          = nn.Parameter(th.zeros(1)-5)
        self.depth        = nn.Parameter(th.zeros(1)+5)

        self.init = TanhAlpha(start = -1e-4 * entity_pretraining_steps)
        self.register_buffer('priority', th.zeros(num_objects)-5, persistent=False)
        self.register_buffer('threshold', th.ones(1) * 0.8)
        self.last_mask = None

        self.to_batch  = LambdaModule(lambda x: rearrange(x, 'b (o c) -> (b o) c', o = num_objects))
        self.to_shared = LambdaModule(lambda x: rearrange(x, '(b o) c -> b (o c)', o = num_objects))

    def load_pretrained(self, state):
        if "std" in state:
            self.std.data = state['std']

        if "depth" in state:
            self.depth.data = state['depth']

    def reset_state(self):
        self.last_mask = None

    def forward(
        self, 
        error: th.Tensor, 
        mask: th.Tensor = None, 
        position: th.Tensor = None,
        gestalt: th.Tensor = None,
        priority: th.Tensor = None
    ):

        batch_size = error.shape[0]
        device     = error.device

        object_permanence_strength = self.object_permanence_strength
        if self.init.get() < 1 and self.init.get() < object_permanence_strength:
            object_permanence_strength = self.init()

        if self.last_mask is None:
            self.last_mask = th.zeros((batch_size * self.num_objects, 1), device = device)

        if mask is not None:
            mask = reduce(mask[:,:-1], 'b c h w -> (b c) 1' , 'max').detach()

            if object_permanence_strength <= 0:
                self.last_mask = mask
            elif object_permanence_strength < 1:
                self.last_mask = th.maximum(self.last_mask, mask)
                self.last_mask = self.last_mask - th.relu(-1 * (mask - self.threshold) * (1 - object_permanence_strength))
            else:
                self.last_mask = th.maximum(self.last_mask, mask)

        mask = (self.last_mask > self.threshold).float().detach()

        gestalt_new  = th.zeros((batch_size * self.num_objects, self.gestalt_size), device = device)

        if gestalt is None:
            gestalt = gestalt_new
        else:
            gestalt = self.to_batch(gestalt) * mask + gestalt_new * (1 - mask)

        if priority is None:
            priority = repeat(self.priority, 'o -> (b o) 1', b = batch_size)
        else:
            priority = self.to_batch(priority) * mask + repeat(self.priority, 'o -> (b o) 1', b = batch_size) * (1 - mask)

        # FIXME rename error to uncertainty
        # if we have no sigificant error, we can sample positions everywhere
        error_mask = (reduce(error, 'b c h w -> b c 1 1', 'max') > 0.1).float()
        error = error * error_mask + th.rand_like(error) * (1 - error_mask) # TODO replace with actual prediction error
        
        # Normalize error map to form a probability distribution and flatten it. 
        # Sample 'num_objects' number of indices from this distribution with replacement, 
        # and convert these indices into image x and y positions.
        error_map_normalized = error / th.sum(error, dim=(1,2,3), keepdim=True)
        error_map_flat = error_map_normalized.view(batch_size, -1)
        sampled_indices = th.multinomial(error_map_flat, num_samples=self.num_objects, replacement=True)
        y_positions = sampled_indices // error.shape[3]
        x_positions = sampled_indices % error.shape[3]

        # Convert positions from range [0, error.shape] to range [-1, 1]
        x_positions = x_positions.float() / (error.shape[3] / 2.0) - 1
        y_positions = y_positions.float() / (error.shape[2] / 2.0) - 1

        x_positions = self.to_batch(x_positions)
        y_positions = self.to_batch(y_positions)

        std = repeat(self.std, '1 -> (b o) 1', b = batch_size, o = self.num_objects)

        if position is None:
            z = repeat(self.depth, '1 -> (b o) 1', b = batch_size, o = self.num_objects)
            position = th.cat((x_positions, y_positions, z, std), dim=-1)
        else:
            z = repeat(self.depth, '1 -> (b o) 1', b = batch_size, o = self.num_objects)
            position = self.to_batch(position) * mask + th.cat((x_positions, y_positions, z, std), dim=1) * (1 - mask)

        return self.to_shared(position), self.to_shared(gestalt), self.to_shared(priority), self.to_shared(mask)


class PositionProposalVit(nn.Module):
    def __init__(self, input_channels, latent_channels, num_layers):
        super(PositionProposalVit, self).__init__()

        embedd_hidden = 2 * latent_channels
        self.embedding = nn.Sequential(
            nn.Linear(2, embedd_hidden),
            nn.SiLU(),
            nn.Linear(embedd_hidden, embedd_hidden),
            nn.SiLU(),
            nn.Linear(embedd_hidden, embedd_hidden),
            nn.SiLU(),
            nn.Linear(embedd_hidden, latent_channels),
        )

        self.to_patches = nn.Sequential(
            MemoryEfficientPatchDownScale(input_channels, latent_channels, scale_factor = 16, expand_ratio = 4),
            Rearrange('b c h w -> b (h w) c')
        )

        self.layers = nn.Sequential(
            *[AttentionLayer(latent_channels) for _ in range(num_layers)],
            AttentionSum(latent_channels),
            nn.Linear(latent_channels, 5),
        )

    def compute_embedding(self, B, H, W, device):

        grid_y, grid_x = th.meshgrid(
            th.linspace(-1, 1, H, device=device), 
            th.linspace(-1, 1, W, device=device),
            indexing='ij'
        )

        grid_x = grid_x.reshape(1, 1, H, W).clone()
        grid_y = grid_y.reshape(1, 1, H, W).clone()

        grid = rearrange(th.cat((grid_x, grid_y), dim=1), '1 c h w -> (h w) c')

        return repeat(self.embedding(grid), 'n c -> b n c', b=B)

    def forward(self, input, positions2d):
        H, W = input.shape[-2:]
        latent = self.to_patches(th.cat((input, positions2d), dim=1))
        embedding = self.compute_embedding(input.shape[0], H // 16, W // 16, input.device)

        out = self.layers(latent + embedding)
        position = out[:, :-1]
        valid    = out[:, -1:]

        return position, valid
"""
class SobelFilter(nn.Module):
    def __init__(self):
        super(SobelFilter, self).__init__()

        # Define Sobel kernels
        sobel_kernel_x = torch.tensor([[-2., -1., 0., 1., 2.],
                                       [-3., -2., 0., 2., 3.],
                                       [-4., -3., 0., 3., 4.],
                                       [-3., -2., 0., 2., 3.],
                                       [-2., -1., 0., 1., 2.]], dtype=torch.float32).unsqueeze(0).unsqueeze(0)

        sobel_kernel_y = torch.tensor([[-2., -3., -4., -3., -2.],
                                       [-1., -2., -3., -2., -1.],
                                       [ 0.,  0.,  0.,  0.,  0.],
                                       [ 1.,  2.,  3.,  2.,  1.],
                                       [ 2.,  3.,  4.,  3.,  2.]], dtype=torch.float32).unsqueeze(0).unsqueeze(0)

        sobel_kernel_diag1 = torch.tensor([[ 4.,  3.,  2.,  1.,  0.],
                                           [ 3.,  2.,  1.,  0., -1.],
                                           [ 2.,  1.,  0., -1., -2.],
                                           [ 1.,  0., -1., -2., -3.],
                                           [ 0., -1., -2., -3., -4.]], dtype=torch.float32).unsqueeze(0).unsqueeze(0)

        sobel_kernel_diag2 = torch.tensor([[ 0.,  1.,  2.,  3.,  4.],
                                           [-1.,  0.,  1.,  2.,  3.],
                                           [-2., -1.,  0.,  1.,  2.],
                                           [-3., -2., -1.,  0.,  1.],
                                           [-4., -3., -2., -1.,  0.]], dtype=torch.float32).unsqueeze(0).unsqueeze(0)

        # normalize kernels
        self.register_buffer('sobel_kernel_x', sobel_kernel_x / 16.)
        self.register_buffer('sobel_kernel_y', sobel_kernel_y / 16.)
        self.register_buffer('sobel_kernel_diag1', sobel_kernel_diag1 / 16.)
        self.register_buffer('sobel_kernel_diag2', sobel_kernel_diag2 / 16.)

    def forward(self, x):
        # Apply Sobel filters
        sobel_x = F.conv2d(x, self.sobel_kernel_x, padding=2)
        sobel_y = F.conv2d(x, self.sobel_kernel_y, padding=2)
        sobel_diag1 = F.conv2d(x, self.sobel_kernel_diag1, padding=2)
        sobel_diag2 = F.conv2d(x, self.sobel_kernel_diag2, padding=2)

        # Stack the results in the channel dimension
        return torch.cat([sobel_x, sobel_y, sobel_diag1, sobel_diag2], dim=1)


class SobelPooling(nn.Module):
    def __init__(self, pool_size):
        super(SobelPooling, self).__init__()
        self.pool_size = pool_size
        self.sobel_filter = SobelFilter()
        
        # Initialize learnable weights for each Sobel kernel
        self.kernel_weights = nn.Parameter(torch.ones(4) / 4.0)
        
    def forward(self, x):
        sobel_output = self.sobel_filter(x)
        
        sobel_weighted  = torch.sum(sobel_output * self.kernel_weights.view(1, 4, 1, 1), dim=1, keepdim=True)
        pooling_weights = torch.exp(-1 * sobel_weighted**2)

        # normalize pooling weights within each window
        sum_weights = reduce(pooling_weights, 'b c (h h2) (w w2) -> b c h w', 'sum', h2=self.pool_size, w2=self.pool_size)
        sum_weights = repeat(sum_weights, 'b c h w -> b c (h h2) (w w2)', h2=self.pool_size, w2=self.pool_size)
        
        pooling_weights_normalized = pooling_weights / th.maximum(sum_weights, 1e-8 * th.ones_like(sum_weights))
        
        # Apply weighted pooling
        pooled_output = reduce(x * pooling_weights_normalized, 'b c (h h2) (w w2) -> b c h w', 'sum', h2=self.pool_size, w2=self.pool_size)
        
        return pooled_output

class LearnedPooling(nn.Module):
    def __init__(self, pool_size, in_channels, upscale_factor = -1):
        super(SobelPooling, self).__init__()
        self.pool_size = pool_size

        upscale_factor = pool_size if upscale_factor < 1 else upscale_factor

        self.layers = MemoryEfficientUpscaling(in_channels, 1, upscale_factor)

    def forward(self, x, depth):
        weights = self.layers(x)
        weights = reduce(weights, 'b c (h h2) (w w2) -> b c h w', 'sum', h2=self.pool_size, w2=self.pool_size)
        weights = weights / th.maximum(weights, 1e-8 * th.ones_like(weights))
        weights = repeat(weights, 'b c h w -> b c (h h2) (w w2)', h2=self.pool_size, w2=self.pool_size)
        
        return reduce(depth * weights, 'b c (h h2) (w w2) -> b c h w', 'sum', h2=self.pool_size, w2=self.pool_size)

class Gaus3D(nn.Module):
    def __init__(self, size = None, position_limit = 1):
        super(Gaus3D, self).__init__()
        self.size = size
        self.position_limit = position_limit
        self.min_std = 0.1
        self.max_std = 0.5

        self.register_buffer("grid_x", th.zeros(1,1,1,1), persistent=False)
        self.register_buffer("grid_y", th.zeros(1,1,1,1), persistent=False)

        if size is not None:
            self.min_std = 1.0 / min(size)
            self.update_grid(size)

        print(f"Gaus2D: min std: {self.min_std}")

    def update_grid(self, size):

        if size != self.grid_x.shape[2:]:
            self.size    = size
            self.min_std = 1.0 / min(size)
            H, W = size

            self.grid_x = th.arange(W, device=self.grid_x.device)
            self.grid_y = th.arange(H, device=self.grid_x.device)

            self.grid_x = (self.grid_x / (W-1)) * 2 - 1
            self.grid_y = (self.grid_y / (H-1)) * 2 - 1

            self.grid_x = self.grid_x.view(1, 1, 1, -1).expand(1, 1, H, W).clone()
            self.grid_y = self.grid_y.view(1, 1, -1, 1).expand(1, 1, H, W).clone()

    def forward(self, input: th.Tensor, depth: th.Tensor):
        assert input.shape[1] == 5
        H, W = self.size

        x      = rearrange(input[:,0:1], 'b c -> b c 1 1')
        y      = rearrange(input[:,1:2], 'b c -> b c 1 1')
        z      = rearrange(input[:,2:3], 'b c -> b c 1 1')
        std_xy = rearrange(input[:,3:4], 'b c -> b c 1 1')
        std_z  = rearrange(input[:,4:5], 'b c -> b c 1 1')

        x      = th.clip(x, -self.position_limit, self.position_limit)
        y      = th.clip(y, -self.position_limit, self.position_limit)
        std_xy = th.clip(std, self.min_std, self.max_std)
        std_z  = th.clip(std, self.min_std, self.max_std)
            
        std_y = std.clone()
        std_x = std * (H / W)

        return th.exp(-1 * ((self.grid_x - x)**2/(2 * std_x**2) + (self.grid_y - y)**2/(2 * std_y**2) + (depth - z)**2/(2 * std_z**2)))

class PositionAttention(nn.Module):
    def __init__(self, size, in_channels, out_channels):
        super(PositionAttention, self).__init__()
        self.gaus3d = Gaus3D(size)

        # TODO learn std_z here ????
        #TODO yess !!! and initialize to 0.5 (use sigmoid)
        # TODO + squeeze and exite (hyper net like learning of std)

    def forward(self, feature_maps, depth_map, position):
        mask = self.gaus3d(position, depth_map)
        return mask * feature_maps  # unsqueeze to match the channel dimension

class PositionAttention(nn.Module):
    def __init__(self, size, in_channels, out_channels, compute_std=True):
        super(PositionAttention, self).__init__()
        self.compute_std = compute_std
        self.gaus2d = Gaus2D(size)
        
    def forward(self, feature_maps, position):
        mask = self.gaus2d(position, compute_std=self.compute_std)
        mask = mask / (reduce(mask, 'b c h w -> b 1 1 1', 'sum') + 1e-8)

        x = mask * feature_maps

class PositionSoftmax(nn.Module):
    def __init__(self):
        super(PositionSoftmax, self).__init__()
		
		self.scale = nn.Parameter(th.ones(1))

	def bhattacharyya_distance(mu1, mu2, Sigma1, Sigma2):
		det_Sigma1 = th.det(Sigma1)
		det_Sigma2 = th.det(Sigma2)
		Sigma_avg = (Sigma1 + Sigma2) / 2
		det_Sigma_avg = th.det(Sigma_avg)
		
		diff = mu1 - mu2
		inv_Sigma_avg = th.inverse(Sigma_avg)
		matmul_result = th.matmul(inv_Sigma_avg, diff.unsqueeze(-1)).squeeze(-1)
		
		term1 = 0.125 * th.einsum('bijk,bijk->bij', diff, matmul_result)
		term2 = 0.5 * th.log(det_Sigma_avg / th.sqrt(det_Sigma1 * det_Sigma2))
		
		distance = term1 + term2
		return distance

	def distance_weights(positions):
		xyz = positions[:, :, :3]
		sigma_xy = positions[:, :, 3:4]
		sigma_z = positions[:, :, 4:5]
		cov = th.diag_embed(th.cat((sigma_xy, sigma_xy, sigma_z), dim=2))

		# Expand dims to compute pairwise differences
		mu1 = xyz[:, :, None, :]
		mu2 = xyz[:, None, :, :]
		Sigma1 = cov[:, :, None, :, :]
		Sigma2 = cov[:, None, :, :, :]

		# Compute Bhattacharyya distance
		B_distances = bhattacharyya_distance(mu1, mu2, Sigma1, Sigma2)
		return B_distances

    def forward(self, positions):

        # Compute Bhattacharyya distance
        B_distances = distance_weights(positions)

        # Compute weights
        weights = 1 - trilu(th.exp(-B_distances * self.scale), diagonal=0)
        weights = reduce(weights, 'b i j -> b i', 'max')

        return positions * weights


class PositionProposal(nn.Module):
    def __init__(self, input_size, num_objects, latent_channels, num_layers):
        super(PositionProposalVit, self).__init__()

        latent_size = [input_size[0] // 16, input_size[1] // 16]

        embedd_hidden = 2 * latent_channels
        self.embedding = nn.Sequential(
            nn.Linear(2, embedd_hidden),
            nn.SiLU(),
            nn.Linear(embedd_hidden, embedd_hidden),
            nn.SiLU(),
            nn.Linear(embedd_hidden, embedd_hidden),
            nn.SiLU(),
            nn.Linear(embedd_hidden, latent_channels),
        )

        self.position_pooling = PositionPooling(latent_size, latent_channels, , compute_std=False)

        self.to_patches = nn.Sequential(
            MemoryEfficientPatchDownScale(1, latent_channels, scale_factor = 16, expand_ratio = 4),
            Rearrange('b c h w -> b (h w) c')
        )

        self.layers = nn.Sequential(
            *[AttentionLayer(latent_channels) for _ in range(num_layers)],
            AttentionSum(latent_channels),
            nn.Linear(latent_channels, 5),
        )

    def compute_embedding(self, B, H, W, device):

        grid_y, grid_x = th.meshgrid(
            th.linspace(-1, 1, H, device=device), 
            th.linspace(-1, 1, W, device=device),
            indexing='ij'
        )

        grid_x = grid_x.reshape(1, 1, H, W).clone()
        grid_y = grid_y.reshape(1, 1, H, W).clone()

        grid = rearrange(th.cat((grid_x, grid_y), dim=1), '1 c h w -> (h w) c')

        return repeat(self.embedding(grid), 'n c -> b n c', b=B)


    def forward(self, input, positions2d):
        H, W = input.shape[-2:]
        latent = self.to_patches(th.cat((input, positions2d), dim=1))
        embedding = self.compute_embedding(input.shape[0], H // 16, W // 16, input.device)

        out = self.layers(latent + embedding)
        position = out[:, :-1]
        valid    = out[:, -1:]

        return position, valid
"""
