import torch.nn as nn
import torch as th
import numpy as np
import torch.nn.functional as F
from nn.vit import UncertaintyMaskedPatchEmbedding, AlphaAttention, AlphaAttentionSum, AttentionResidualLayer, PatchUpscale, EpropAlphaGateL0rd
from utils.utils import LambdaModule, Binarize, MultiArgSequential
from nn.residual import ResidualBlock, LinearResidual, SkipConnection
from nn.manifold import LinearManifold, HyperSequential
from nn.convnext import ConvNeXtStem, ConvNeXtBlock, ConvNeXtPatchUp, ConvNeXtUnet, ConvNeXtDecoder
from einops import rearrange, repeat, reduce
import cv2
from utils.io import Timer

from typing import Union, Tuple


class PositionalEmbedding(nn.Module):
    def __init__(self, num_frequencies, size):
        super(PositionalEmbedding, self).__init__()

        H, W = size

        grid_y, grid_x = th.meshgrid(th.arange(H), th.arange(W))

        grid_x = (grid_x / (W-1)) * 2 - 1
        grid_y = (grid_y / (H-1)) * 2 - 1

        grid_x = grid_x.reshape(1, 1, H, W) * np.pi / 2
        grid_y = grid_y.reshape(1, 1, H, W) * np.pi / 2

        embedding = []

        for i in range(num_frequencies):
            embedding.append(th.sin(grid_x * 2**i))
            embedding.append(th.sin(grid_y * 2**i))

        self.register_buffer('embedding', th.cat(embedding, dim=1), persistent=False)
        
    def forward(self, input: th.Tensor):
        return repeat(self.embedding, '1 c h w -> b c h w', b = input.shape[0])

class ViTDepthUncertantyBackground(nn.Module):
    def __init__(
        self, 
        latent_size: Tuple[int, int],
        hidden_channels,
        hyper_channels,
        num_embedding_layers,
        num_attention_layers,
        num_hyper_layers,
        uncertainty_base_channels,
        uncertainty_blocks,
        uncertainty_threshold,
        batch_size,
        num_heads,
        reg_lambda,
        depth_input,
        rgbd_decoder = False,
        loci = False,
        dropout = 0.0
    ):
        super(ViTDepthUncertantyBackground, self).__init__()
        self.priority = nn.Parameter(th.ones(1, 1, 1, 1))
        self.loci = loci
        self.rgbd_decoder = rgbd_decoder

        H, W = latent_size
        batch_size = batch_size * H * W

        self.encoder = MultiArgSequential(
            UncertaintyMaskedPatchEmbedding(
                in_channels           = 4 if depth_input else 3,
                out_channels          = hidden_channels,
                num_layers            = num_embedding_layers,
                latent_size           = latent_size,
                uncertainty_threshold = uncertainty_threshold,
            ),
            *[nn.Sequential(
                AlphaAttention(hidden_channels, num_heads, dropout),
                EpropAlphaGateL0rd(hidden_channels, batch_size, reg_lambda)
            ) for _ in range(num_attention_layers)]
        )

        self.rgb_decoder0 = nn.Sequential(
            LambdaModule(lambda x: rearrange(x, 'b (h w) c -> b c h w', h = H, w = W)),
            nn.Conv2d(hidden_channels, 128, kernel_size = 1),
            ConvNeXtDecoder(
                out_channels = 4 if rgbd_decoder else 3,
                base_channels = 32,
                blocks = [1,2,3,0],
            ),
            nn.Sigmoid(),
        )

        if not rgbd_decoder:
            num_frequencies = int(np.log2(min(*latent_size)))*2
            self.embedding  = PositionalEmbedding(num_frequencies, size=latent_size)

            self.depth_decoder0 = HyperSequential(
                LinearManifold(num_frequencies*2, hidden_channels, act = F.silu),
                *[LinearManifold(hidden_channels, hidden_channels, act = F.silu) for _ in range(num_hyper_layers - 1)],
            )

            self.hyper_weights = nn.Sequential(
                AlphaAttentionSum(hidden_channels, num_heads, dropout),
                nn.Linear(hidden_channels, hyper_channels),
                nn.SiLU(),
                *[nn.Sequential(
                    nn.Linear(hyper_channels, hyper_channels),
                    nn.SiLU()
                ) for _ in range(num_hyper_layers-2)],
                nn.Linear(hyper_channels, self.depth_decoder0.num_weights())
            )

            self.depth_decoder1 = nn.Sequential(
                nn.Conv2d(hidden_channels, hidden_channels, kernel_size = 3, padding = 1),
                nn.SiLU(),
                LambdaModule(lambda x: F.interpolate(x, scale_factor=4, mode='bilinear', align_corners=False)),
                nn.Conv2d(hidden_channels, hidden_channels // 2, kernel_size = 3, padding = 1),
                nn.SiLU(),
                LambdaModule(lambda x: F.interpolate(x, scale_factor=4, mode='bilinear', align_corners=False)),
                nn.Conv2d(hidden_channels // 2, 1, kernel_size = 3, padding = 1),
                nn.Sigmoid() 
            )
            

        self.uncertainty_estimation = nn.Sequential(
            ConvNeXtUnet(
                in_channels   = 4 if depth_input else 3,
                out_channels  = 1, 
                base_channels = uncertainty_base_channels,
                blocks        = uncertainty_blocks,
            ),
            nn.Sigmoid(),
            LambdaModule(lambda x: (x, x + x * (1 - x) * th.randn_like(x))),
        )

    def get_last_layer(self):
        return self.rgb_decoder0[-2].layer3[-1].layers.weight

    def openings(self):
        o = 0
        for i in range(1, len(self.encoder)):
            o = self.encoder[i][1].l0rd.openings.item()

        return o / (len(self.encoder) - 1)

    def forward(self, input: th.Tensor, uncertainty: th.Tensor, compute_rbg = True):
        latent = self.encoder(input, uncertainty)
        
        rgb = self.rgb_decoder0(latent) if compute_rbg or self.rgbd_decoder else None

        depth = 0
        if not self.rgbd_decoder:
            hyper_weights = self.hyper_weights(latent)

            depth = self.depth_decoder0(self.embedding(latent), hyper_weights)
            depth = self.depth_decoder1(depth)
        else:
            depth = rgb[:,:1]
            rgb   = rgb[:,1:]

        if not self.loci:
            return rgb, depth

        priority = self.priority.expand(*uncertainty.shape)

        return priority, rgb, depth


    def detach(self):
        for module in self.modules():
            if module != self and callable(getattr(module, "detach", None)):
                module.detach()

    def reset_state(self):
        for module in self.modules():
            if module != self and callable(getattr(module, "reset_state", None)):
                module.reset_state()

    def pretraining(self, trainloader, device, max_steps=10000, depth_index = 3):

        params    = [self.depth_decoder0.base_weights] + list(self.depth_decoder1.parameters())
        optimizer = th.optim.AdamW(params, lr=0.001, weight_decay=0.01)

        num_updates = 0
        timer       = Timer()
        avg_loss    = 0
        avg_sum     = 0

        for batch_index, input in enumerate(trainloader):

            depth = input[depth_index] 

            sequence_len  = (depth.shape[1] - 1) 
            depth_next    = depth[:,0].to(device)
            hyper_weights = th.zeros((depth.shape[0], self.depth_decoder0.num_weights()), device=device)

            for t in range(sequence_len):
                num_updates += 1

                depth_cur  = depth_next
                depth_next = depth[:,t+1].to(device)
                
                output_depth = self.depth_decoder1(self.depth_decoder0(self.embedding(hyper_weights), hyper_weights))

                optimizer.zero_grad()

                loss = th.mean(th.abs(output_depth - depth_next))
                loss.backward()

                optimizer.step()

                avg_loss = avg_loss * 0.99 + loss.item()
                avg_sum  = avg_sum  * 0.99 + 1

                if batch_index % 100 == 99:
                    cv2.imwrite(f"/anon.png", output_depth[0,0].detach().cpu().numpy()*255)

            print(f"Pretraining[{batch_index}|{num_updates}|{(num_updates + 1)/max_steps * 100:.2f}%] {timer}, Loss: {avg_loss / avg_sum}", flush=True)
            if num_updates > max_steps:
                break


