import torch as th
import torchvision as tv
from torch import nn
from utils.utils import BatchToSharedObjects, SharedObjectsToBatch, LambdaModule, MultiArgSequential, Binarize, PrintShape, Prioritize
from einops import rearrange, repeat, reduce
from utils.optimizers import Ranger
from nn.convnext import MemoryEfficientBottleneck
import torch.nn.functional as F
from utils.loss import MaskedL1SSIMLoss
from nn.downscale import MemoryEfficientPatchDownScale
from nn.eprop_gate_l0rd import ReTanh
from nn.mask_decoder import MaskPretrainer, PositionPooling, Gaus2D, MaskCenter
from nn.residual import SkipConnection
from scipy.optimize import linear_sum_assignment
from nn.upscale import MemoryEfficientUpscaling

class ConvNeXtBlock(nn.Module):
    def __init__(
            self,
            in_channels: int,
            out_channels: int = None,
            channels_per_group = 32,
        ):
        super(ConvNeXtBlock, self).__init__()

        channels_per_group = min(channels_per_group, in_channels)

        if out_channels is None:
            out_channels = in_channels

        self.layers = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=7, padding=3, groups=in_channels),
            nn.GroupNorm(in_channels // channels_per_group, in_channels),
            MemoryEfficientBottleneck(in_channels, out_channels),
        )

        self.skip  = SkipConnection(in_channels, out_channels)

    def forward(self, input: th.Tensor) -> th.Tensor:
        return self.layers(input) + self.skip(input)

class PixelToPosition(nn.Module):
    def __init__(self, size): # FIXME add update grid !!!
        super(PixelToPosition, self).__init__()

        self.register_buffer("grid_y", th.arange(size[0]), persistent=False)
        self.register_buffer("grid_x", th.arange(size[1]), persistent=False)

        self.grid_y = (self.grid_y / (size[0]-1)) * 2 - 1
        self.grid_x = (self.grid_x / (size[1]-1)) * 2 - 1

        self.grid_y = self.grid_y.view(1, 1, -1, 1).expand(1, 1, *size).clone()
        self.grid_x = self.grid_x.view(1, 1, 1, -1).expand(1, 1, *size).clone()

        self.size = size

    def forward(self, input: th.Tensor):
        assert input.shape[1] == 1

        input = rearrange(input, 'b c h w -> b c (h w)')
        input = th.softmax(input, dim=2)
        input = rearrange(input, 'b c (h w) -> b c h w', h = self.size[0], w = self.size[1])

        x = th.sum(input * self.grid_x, dim=(2,3))
        y = th.sum(input * self.grid_y, dim=(2,3))

        return th.cat((x, y), dim=1)

class PixelToSTD(nn.Module):
    def __init__(self):
        super(PixelToSTD, self).__init__()

    def forward(self, input: th.Tensor):
        assert input.shape[1] == 1
        return reduce(th.relu(th.tanh(input)), 'b c h w -> b c', 'mean')

class PixelToPriority(nn.Module):
    def __init__(self, num_slots):
        super(PixelToPriority, self).__init__()

        self.num_slots = num_slots
        self.register_buffer("indices", rearrange(th.arange(num_slots), 'a -> 1 a'), persistent=False)
        self.indices = (self.indices / (num_slots - 1)) * 2 - 1

        self.index_factor    = nn.Parameter(th.ones(1))
        self.priority_factor = nn.Parameter(th.zeros(1)+1e-16)

    def forward(self, input: th.Tensor):
        assert input.shape[1] == 1
        priority = th.tanh(reduce(input, '(b n) 1 h w -> b n', 'mean', n = self.num_slots))
        priority = priority * self.priority_factor + self.index_factor * self.indices
        priority = rearrange(priority, 'b n -> b n 1')
        return priority

class PixelAttention(nn.Module):
    def __init__(
        self,
        channels,
        heads,
        dropout = 0.0
    ):
        super(PixelAttention, self).__init__()

        self.norm1     = nn.GroupNorm(heads, channels)
        self.alpha1    = nn.Parameter(th.zeros(1)+1e-12)
        self.attention = nn.MultiheadAttention(
            channels, 
            heads, 
            dropout = dropout, 
            batch_first = True
        )

        self.norm2     = nn.GroupNorm(heads, channels)
        self.alpha2    = nn.Parameter(th.zeros(1)+1e-12)
        self.layers    = nn.Sequential(
            nn.Linear(channels, channels),
            nn.SiLU(),
            nn.Linear(channels, channels),
        )

    def forward(self, x: th.Tensor):
        H, W = x.shape[2:]
        x = self.norm1(x)
        x = rearrange(x, 'b c h w -> b (h w) c')
        x = x + self.alpha1 * self.attention(x, x, x, need_weights=False)[0]
        x = rearrange(x, 'b (h w) c -> (b h w) c', h=H, w=W)
        x = self.norm2(x)
        x = x + self.alpha2 * self.layers(x)
        x = rearrange(x, '(b h w) c -> b c h w', h=H, w=W)
        return x
        

class SlotAttentionLayer(nn.Module):
    def __init__(self, channels, num_slots, channels_per_group = 32):
        super(SlotAttentionLayer, self).__init__()
        channels_per_group = min(channels_per_group, channels)

        self.gate = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=7, padding=3, groups=channels),
            nn.GroupNorm(channels // channels_per_group, channels),
            MemoryEfficientBottleneck(channels, num_slots + 1),
            nn.Softmax(dim=1),
            LambdaModule(lambda x: x[:, :-1]), # TODO regularize the sum from this to the summ of gaus2d positions!!!
            LambdaModule(lambda x: repeat(x, 'b n h w -> b (n c) h w', n=num_slots, c=channels//num_slots)),
        )
        self.layers = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=7, padding=3, groups=channels),
            nn.GroupNorm(channels // channels_per_group, channels),
            MemoryEfficientBottleneck(channels, channels),
        )

        self.alpha = nn.Parameter(th.ones(1, channels, 1, 1))

    def forward(self, input):
        return (input + self.layers(input)) * self.gate(input) * self.alpha

class UnceretaintyEstimator(nn.Module):
    def __init__(self, in_channels, num_slots, size, base_channels = 16, blocks=[1,2,3,1], gestalt_size = 256, reg_lambda=0):
        super(UnceretaintyEstimator, self).__init__()
        assert (base_channels * 4) % num_slots == 0, f"base_channels * 4 ({base_channels * 4}) must be divisible by num_slots ({num_slots})"

        latent_size = [size[1] // 32, size[0] // 32]
        self.num_slots = num_slots

        self.base_encoder = nn.Sequential(
            MemoryEfficientPatchDownScale(in_channels, base_channels, expand_ratio=16, scale_factor=4),
            *[ConvNeXtBlock(base_channels) for _ in range(blocks[0])],
            MemoryEfficientPatchDownScale(base_channels, base_channels * 2, expand_ratio=16, scale_factor=2),
            *[ConvNeXtBlock(base_channels * 2) for i in range(blocks[1])],
            MemoryEfficientPatchDownScale(base_channels * 2, base_channels * 4, expand_ratio=16, scale_factor=2),
            *[nn.Sequential(
                PixelAttention(base_channels * 4, 1),
                SlotAttentionLayer(base_channels * 4, num_slots) 
            ) for i in range(blocks[2])],
        )

        self.position_encoder = nn.Sequential(
            LambdaModule(lambda x: rearrange(x, 'b (n c) h w -> (b n) c h w', n=num_slots)),
            ConvNeXtBlock(base_channels * 4 // num_slots, base_channels),
            *[ConvNeXtBlock(base_channels) for _ in range(blocks[3]-1)],
            ConvNeXtBlock(base_channels, 3),
        )

        self.xy_encoder  = nn.Sequential(
            LambdaModule(lambda x: x[:,0:1]),
            PixelToPosition([size[0] // 16, size[1] // 16]),
            LambdaModule(lambda x: rearrange(x, '(b n) c -> b n c', n=num_slots)),
        )
        self.std_encoder = nn.Sequential(
            LambdaModule(lambda x: x[:,1:2]),
            PixelToSTD(),
            LambdaModule(lambda x: rearrange(x, '(b n) c -> b n c', n=num_slots)),
        )
        self.priority_encoder = nn.Sequential(
            LambdaModule(lambda x: x[:,2:3]),
            PixelToPriority(num_slots),
        )

        self.gestalt_encoder = nn.Sequential(
            *[nn.Sequential(
                PixelAttention(base_channels * 4, 1),
                SlotAttentionLayer(base_channels * 4, num_slots) 
            ) for i in range(blocks[3])],
            MemoryEfficientPatchDownScale(base_channels * 4, base_channels * 8, expand_ratio=16, scale_factor=2),
            *[nn.Sequential(
                PixelAttention(base_channels * 8, 1),
                SlotAttentionLayer(base_channels * 8, num_slots) 
            ) for i in range(blocks[3]*2)],
        )

        self.gestalt_pooling = MultiArgSequential(
            PositionPooling(latent_size, base_channels * 8, gestalt_size),
            LambdaModule(lambda x: rearrange(x, '(b n) c 1 1 -> b n c', n=num_slots)),
        )

    def forward(self, x):
        x0 = self.base_encoder(x)

        x1  = self.position_encoder(x0)
        xy  = self.xy_encoder(x1)
        std = self.std_encoder(x1)

        position = th.cat((xy, std), dim=2)
        priority = self.priority_encoder(x1)

        gestalt = self.gestalt_encoder(x0)
        gestalt = self.gestalt_pooling(
            repeat(gestalt, 'b c h w -> (b n) c h w', n=self.num_slots),
            rearrange(position, 'b n c -> (b n) c')
        )
        
        return position, gestalt, priority


class UncertaintyPertrainer(nn.Module):
    def __init__(self, in_channels, num_slots, size, base_channels = 16, blocks=[1,2,3,1], gestalt_size=256, reg_lambda=0):
        super(UncertaintyPertrainer, self).__init__()
        self.num_slots = num_slots
        
        latent_size = [size[1] // 16, size[0] // 16]

        self.prioritize  = Prioritize(num_slots)
        self.gaus2d      = Gaus2D(latent_size)
        self.mask_center = MaskCenter(size)
            
        self.decoder = nn.Sequential(
            nn.Conv2d(base_channels * 16, base_channels * 8, kernel_size=3, padding=1),
            nn.SiLU(),
            nn.Conv2d(base_channels * 8, base_channels * 4, kernel_size=3, padding=1),
            nn.SiLU(),
            nn.Conv2d(base_channels * 4, base_channels * 4, kernel_size=3, padding=1),
            MemoryEfficientUpscaling(base_channels * 4, base_channels, 4, expand_ratio=16),
            MemoryEfficientUpscaling(base_channels,                 1, 4, expand_ratio=16),
            LambdaModule(lambda x: th.softmax(th.cat((x, th.ones_like(x)), dim=1), dim=1)[:,0:1])
        )

        self.position_estimator = UnceretaintyEstimator(
            in_channels   = in_channels, 
            num_slots     = num_slots, 
            size          = size, 
            base_channels = base_channels, 
            blocks        = blocks, 
            gestalt_size  = gestalt_size, 
            reg_lambda    = reg_lambda
        )

    def forward(self, rgb_input, gt_masks):

        B, C, H, W = rgb_input.shape
        assert gt_masks.shape == (B, self.num_slots, H, W)
        
        with th.no_grad():
            fg_mask  = reduce(gt_masks, 'b n h w -> b 1 h w', 'max', n=self.num_slots)
            gt_masks = rearrange(gt_masks, 'b n h w -> (b n) 1 h w', n=self.num_slots)
            gt_positions = self.mask_center(gt_masks)
            gt_positions = rearrange(gt_positions, '(b n) c -> b n c', n=self.num_slots)

            gt_used = (reduce(gt_masks, '(b n) c h w -> b n 1', 'max', n=self.num_slots) > 0.5).float()


        positions, gestalts, priority = self.position_estimator(rgb_input)
        used = (positions[:, :, 2:3] > 0.001).float() # TODO use min_std

        batch_size = positions.shape[0]

        position_loss = []
        gestalt_loss  = []
        used_loss     = []
        mse_matrix = th.mean((positions[:,:, None, :] - gt_positions[:,None, :, :])**2, dim=-1)
        mse_matrix_numpy = mse_matrix.detach().cpu().numpy()
        
        for b in range(batch_size):
            
            # Hungarian algorithm to solve linear sum assignment
            pred_indices, gt_indices = linear_sum_assignment(mse_matrix_numpy[b])
            
            # Compute the matched loss for this batch 
            matched_loss = th.mean(
                th.sum((positions[b, pred_indices, 0:3] - gt_positions[b, gt_indices, 0:3])**2 * gt_used[b, gt_indices, 0:1], dim=1) / 3 + 
                th.sum((positions[b, pred_indices, 0:2] - gt_positions[b, gt_indices, 0:2])**2 * (1 - gt_used[b, gt_indices, 0:1]) * 0.1 / self.num_slots, dim=1) / 3
            )
            position_loss.append(matched_loss)
            used_loss.append(th.abs(positions[b, pred_indices, 2]) * (1 - gt_used[b, gt_indices, 0]) / self.num_slots)
        
        positions   = rearrange(positions, 'b n c -> (b n) c')
        positions2d = rearrange(self.gaus2d(positions), '(b n) 1 h w -> b n h w', n=self.num_slots)
        positions2d = self.prioritize(positions2d, priority.squeeze(2))

        positions2d_sum = th.sum(positions2d, dim=1, keepdim=True)
        positions2d     = positions2d / th.maximum(th.ones_like(positions2d_sum), positions2d_sum)

        positions2d = rearrange(positions2d, 'b n h w -> b n 1 h w', n=self.num_slots)
        gestalts    = rearrange(gestalts, 'b n c -> b n c 1 1')

        mask = self.decoder(th.sum(gestalts * positions2d, dim=1))

        mask_loss = th.mean((mask - fg_mask)**2)

        
        return {
            'position_loss': th.stack(position_loss).mean(),
            'gestalt_loss': mask_loss,
            'used_loss': th.stack(used_loss).mean(),
            'gestalt': gestalts,
            'used': used,
            'position': rearrange(positions, '(b n) c -> b n c', n=self.num_slots),
            'mask': mask,
        }   

        
