import torch.nn as nn
import torch as th
import numpy as np
import torch.nn.functional as F
from nn.vit_v2 import TopKUncertaintyMaskedPatchEmbedding, AttentionLayer, AttentionSum, MemoryEfficientBottleneck, CrossAttentionLayer
from utils.utils import LambdaModule, Binarize, MultiArgSequential
from nn.residual import ResidualBlock, LinearResidual, SkipConnection
from nn.manifold import LinearManifold, HyperSequential
from nn.convnext_v2 import ConvNeXtBlock
from nn.manifold import HyperConvNextBlock, HyperSequential, NonHyperWrapper, LinearManifold
from nn.upscale import MemoryEfficientUpscaling
from nn.downscale import MemoryEfficientPatchDownScale
from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange
import cv2
from utils.io import Timer

from typing import Union, Tuple


class ContextEncoding(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ContextEncoding, self).__init__()

        self.skip      = LambdaModule(lambda x: repeat(x, 'b n c -> b n (c c2)', c2 = out_channels // in_channels))
        self.expand    = MemoryEfficientBottleneck(in_channels, out_channels)
        self.aggregate = AttentionSum(out_channels)
        self.process   = nn.Sequential(nn.Linear(out_channels, out_channels*4), nn.SiLU(), nn.Linear(out_channels*4, out_channels))
        self.binarize  = Binarize()

    def forward(self, x):
        x = self.skip(x) + self.expand(x)
        x = self.aggregate(x)
        x = self.binarize(x + self.process(x))
        return x

class AffineGridGenerator(nn.Module):
    def __init__(self):
        super(AffineGridGenerator, self).__init__()

    def forward(self, params, size):
        # Construct the identity transformation matrix
        theta = th.zeros(params.shape[0], 2, 3, device=params.device)
        theta[:, 0, 0] = 1
        theta[:, 1, 1] = 1

        # Add the residuals from the input tensor to the identity matrix
        theta += params

        # Generate and return the affine grid
        return rearrange(F.affine_grid(theta, size), 'b h w c -> b c h w')

class MotionEncoder(nn.Module):
    def __init__(self, latent_channels, batch_size, reg_lambda = 0):
        super(MotionEncoder, self).__init__()

        embedd_hidden = 2 * latent_channels
        self.motion_embedding = nn.Sequential(
            Rearrange('b c h w -> (b h w) c'),
            nn.Linear(2, embedd_hidden),
            nn.SiLU(),
            nn.Linear(embedd_hidden, embedd_hidden),
            nn.SiLU(),
            nn.Linear(embedd_hidden, embedd_hidden),
            nn.SiLU(),
            nn.Linear(embedd_hidden, latent_channels)
        )

        self.norm1 = nn.LayerNorm(num_hidden)
        self.alpha = nn.Parameter(th.zeros(1))

        self.norm2 = nn.LayerNorm(num_hidden)
        self.mlp   = nn.Sequential(nn.Linear(out_channels, out_channels*4), nn.SiLU(), nn.Linear(out_channels*4, out_channels))

        self.attention = nn.MultiheadAttention(
            num_hidden, 
            min(1, num_hidden // head_size), 
            dropout = dropout, 
            batch_first = True
        )

        self.l0rd = EpropGateL0rd(num_hidden, num_hidden, num_hidden, reg_lambda)

        self.register_buffer('hidden', th.zeros(batch_size, 1, num_hidden), persistent = False)
        self.t0_hidden = nn.Parameter(th.randn(1, 1, num_hidden))
            
        self.compute_theta = nn.Sequential(
            nn.Linear(num_hidden, num_hidden*4),
            nn.SiLU(),
            nn.Linear(num_hidden*4, 6)
        )

        self.grid_generator = AffineGridGenerator()

    def reset_state(self):
        self.hidden = repeat(self.t0_hidden, '1 1 c -> b 1 c', b = self.hidden.shape[0])

    def detach(self):
        self.hidden = self.hidden.detach()

    def forward(self, x, delta_t):
        B, C, H, W = x.shape
            
        norm_x   = self.norm1(x)
        skip     = reduce(x, 'b s c -> b c', 'mean')
        residual = rearrange(self.attention(self.hidden, norm_x, norm_x, need_weights=False)[0], 'b 1 c -> b c')
        alpha    = th.sigmoid(self.alpha)

        x = (skip * alpha + residual * (1 - alpha)).squeeze(1)
        x = self.mlp(self.norm2(x))

        x, self.hidden = self.l0rd(x, self.hidden)

        theta = self.compute_theta(x) * (th.abs(delta_t) > 0.5).float()

        grid = self.grid_generator(theta, x.shape[-2:])
        motion_embedding = rearrange(self.motion_embedding(grid), '(b h w) c -> b c h w', h = H, w = W)

        return motion_embedding, theta, self.l0rd.openings.item()

class DepthContextDecoder(nn.Module):
    def __init__(self, depth_context_size, motion_context_size, channels, num_layers):
        super(DepthContextDecoder, self).__init__()

        self.layers     = nn.Sequential(
            ConvNeXtBlock(depth_context_size, channels),
            *[ConvNeXtBlock(channels) for _ in range(num_layers-1)]
        )
        self.to_patches = MemoryEfficientUpscaling(channels, 1, scale_factor = 16)

    def forward(self, depth_context, motion_embedding):
        latent = depth_context.unsqueeze(-1).unsqueeze(-1) + motion_embedding
        return self.to_patches(self.layers(latent))

class RGBContextDecoder(nn.Module):
    def __init__(self, context_size, channels, cross_attention_layers, convnext_layers):
        super(RGBContextDecoder, self).__init__()

        self.cross_attention_layers = nn.Sequential(*[CrossAttentionLayer(channels) for _ in range(cross_attention_layers)])
        self.convnext_layers        = nn.Sequential(*[ConvNeXtBlock(channels) for _ in range(convnext_layers)])
        self.to_patches             = nn.Sequential(MemoryEfficientUpscaling(channels, 3, scale_factor = 16), nn.Sigmoid())
        self.depth_embedding        = MemoryEfficientPatchDownScale(1, context_size, scale_factor = 16)
        self.preprocess             = nn.Sequential(
            ConvNeXtBlock(context_size, channels),
            Rearrange('b c h w -> b (h w) c')
        )

    def forward(self, rgb_patches, depth, motion_embedding):
        
        x = self.preprocess(self.depth_embedding(depth) + motion_embedding)
            
        for layer in self.cross_attention_layers:
            x = layer(x, rgb_patches)

        x = rearrange(x, 'b (h w) c -> b c h w', h = motion_embedding.shape[-2], w = motion_embedding.shape[-1])

        return self.to_patches(self.convnext_layers(x))

class UncertantyBackground(nn.Module):
    def __init__(
        self, 
        masking_ratio,
        uncertainty_noise_ratio,
        uncertainty_threshold,
        motion_context_size,
        depth_context_size,
        latent_channels,
        num_layers,
        hyper_channels,
        hyper_layers,
        depth_input = False,
        loci = False,
    ):
        super(UncertantyBackground, self).__init__()
        self.priority = nn.Parameter(th.ones(1, 1, 1, 1))
        self.loci = loci

        self.base_encoder = MultiArgSequential(
            TopKUncertaintyMaskedPatchEmbedding(
                input_channels          = 4 if depth_input else 3,
                latent_channels         = latent_channels,
                masking_ratio           = masking_ratio,
                uncertainty_noise_ratio = uncertainty_noise_ratio,
            ),
            *[AttentionLayer(latent_channels) for _ in range(num_layers)]
        )

        self.rgb_encoder   = nn.Sequential(
            *[AttentionLayer(latent_channels) for _ in range(num_layers//2)],
        )
        self.depth_encoder = nn.Sequential(
            *[AttentionLayer(latent_channels) for _ in range(num_layers//2)],
            ContextEncoding(latent_channels, depth_context_size),
        )
        self.motion_encoder = nn.Sequential(
            *[AttentionLayer(latent_channels) for _ in range(num_layers//2)],
            MotionEncoder(latent_channels, batch_size, reg_lambda),
        )

        self.motion_encoder = MotionContextEncoder(
            context_size            = motion_context_size,
            in_channels             = 4 if depth_input else 3,
            latent_channels         = latent_channels,
            num_layers              = num_layers + num_layers // 2,
            hyper_channels          = hyper_channels,
            hyper_layers            = hyper_layers,
            uncertainty_threshold   = uncertainty_threshold,
        )

        self.depth_decoder = DepthContextDecoder(
            depth_context_size  = depth_context_size,
            motion_context_size = motion_context_size,
            channels            = latent_channels,
            num_layers          = num_layers,
        )

        self.rgb_decoder = RGBContextDecoder(
            context_size           = motion_context_size,
            channels               = latent_channels,
            cross_attention_layers = num_layers,
            convnext_layers        = num_layers // 2,
        )

        self.uncertainty_estimation = nn.Sequential(
            MemoryEfficientPatchDownScale(4 if depth_input else 3, latent_channels, scale_factor = 16),
            *[ConvNeXtBlock(latent_channels) for _ in range((num_layers + num_layers // 2) * 2)],
            MemoryEfficientUpscaling(latent_channels, 1, scale_factor = 16),
            nn.Sigmoid(),
            LambdaModule(lambda x: (x, x + x * (1 - x) * th.randn_like(x))),
        )

    def get_last_layer(self):
        return None

    def forward(self, source, target, source_uncertainty, target_uncertainty, delta_t):

        latent = self.base_encoder(source, source_uncertainty)
        rgb_patches, depth_context = self.rgb_encoder(latent), self.depth_encoder(latent)

        motion_embedding, motion_context, gate_l0rd_openings = self.motion_encoder(source, target, source_uncertainty, target_uncertainty, delta_t)

        depth = self.depth_decoder(depth_context, motion_embedding)
        rgb   = self.rgb_decoder(rgb_patches, depth, motion_embedding)

        if not self.loci:
            return rgb, depth, motion_context, depth_context, gate_l0rd_openings

        priority = self.priority.expand(*uncertainty.shape)

        return priority, rgb, depth

