import torch.nn as nn
import torch as th
import numpy as np
from nn.residual import ResidualBlock, LinearResidual, LinearSkip, SkipConnection
from nn.eprop_gate_l0rd import EpropGateL0rd
from utils.utils import LambdaModule
from einops import rearrange, repeat, reduce
from typing import Tuple, Union, List


class UncertaintyMaskedPatchEmbedding(nn.Module):
    def __init__(
        self, 
        in_channels, 
        out_channels, 
        num_layers,
        latent_size,
        uncertainty_threshold,
    ):
        super(UncertaintyMaskedPatchEmbedding, self).__init__()
        assert out_channels % in_channels == 0

        self.latent_size = latent_size
        self.threshold   = uncertainty_threshold

        H, W = latent_size

        grid_y, grid_x = th.meshgrid(th.arange(H), th.arange(W))

        grid_x = (grid_x / (W-1)) * 2 - 1
        grid_y = (grid_y / (H-1)) * 2 - 1

        grid_x = grid_x.reshape(1, 1, H, W) * np.pi / 2
        grid_y = grid_y.reshape(1, 1, H, W) * np.pi / 2

        embedding = []

        num_frequencies = int(np.log2(min(*latent_size)))*2
        print(f'UncertaintyMaskedPatchEmbedding{in_channels} -> {out_channels // 2} -> {out_channels} + {num_frequencies}')

        for i in range(num_frequencies):
            embedding.append(th.sin(grid_x * 2**i))
            embedding.append(th.sin(grid_y * 2**i))

        embedding = rearrange(th.cat(embedding, dim=1), '1 c h w -> (h w) c')
        self.register_buffer('embedding', embedding, persistent=False)

        self.down_conv = nn.Sequential(
            LambdaModule(lambda x: rearrange(x, 'b c (h h2) (w w2) -> (b h w) c h2 w2', h2 = 16, w2 = 16)),
            nn.Conv2d(in_channels, out_channels // 4, kernel_size = 5, stride = 2, padding = 0),
            nn.SiLU(),
            nn.Conv2d(out_channels // 4, out_channels // 2, kernel_size = 3, stride = 1, padding = 0),
            nn.SiLU(),
            LambdaModule(lambda x: rearrange(x, 'b c (h h2) (w w2) -> (b h w) (c h2 w2)', h2 = 4, w2 = 4))
        )
        
        self.layers = nn.Linear((out_channels // 2) * 4**2 + num_frequencies*2, out_channels)


    def forward(self, input: th.Tensor, uncertainty: th.Tensor):
        B, _, H, W = input.shape
        K = 16 

        #input = input * (uncertainty < self.threshold).float().detach() 

        embedding   = repeat(self.embedding, 'hw c -> (b hw) c', b = B)
        uncertainty = reduce(uncertainty, 'b 1 (h h2) (w w2) -> (b h w) 1', 'max', h2=K, w2=K)
        uncertainty = (uncertainty < self.threshold).float().detach()

        latent = self.down_conv(input)
        latent = self.layers(th.cat((latent * uncertainty, embedding), dim=1))

        return rearrange(latent, '(b h w) c -> b (h w) c', h = H // K, w = W // K)


class AlphaAttention(nn.Module):
    def __init__(
        self,
        num_hidden,
        heads,
        dropout = 0.0
    ):
        super(AlphaAttention, self).__init__()

        self.alpha     = nn.Parameter(th.zeros(1)+1e-12)
        self.attention = nn.MultiheadAttention(
            num_hidden, 
            heads, 
            dropout = dropout, 
            batch_first = True
        )

    def forward(self, x: th.Tensor):
        return x + self.alpha * self.attention(x, x, x, need_weights=False)[0]

class AlphaAttentionSum(nn.Module):
    def __init__(
        self,
        num_hidden,
        heads,
        dropout = 0.0
    ):
        super(AlphaAttentionSum, self).__init__()

        self.query = nn.Parameter(th.randn(1, 1, num_hidden))

        self.alpha     = nn.Parameter(th.zeros(1)+1e-12)
        self.attention = nn.MultiheadAttention(
            num_hidden, 
            heads, 
            dropout = dropout, 
            batch_first = True
        )

    def forward(self, x: th.Tensor):

        query    = repeat(self.query, '1 1 c -> b 1 c', b = x.shape[0])
        skip     = reduce(x, 'b s c -> b c', 'mean')
        residual = rearrange(self.attention(query, x, x, need_weights=False)[0], 'b 1 c -> b c')

        return skip + self.alpha * residual

class AttentionResidualLayer(nn.Module):
    def __init__(
        self,
        num_hidden,
    ):
        super(AttentionResidualLayer, self).__init__()

        self.alpha = nn.Parameter(th.zeros(1)+1e-12)
        self.layer = nn.Sequential(
            nn.Linear(num_hidden, num_hidden),
            nn.SiLU(),
            nn.Linear(num_hidden, num_hidden)
        )

    def forward(self, x: th.Tensor):
        residual = self.layer(rearrange(x, 'b s c -> (b s) c'))
        residual = rearrange(residual, '(b s) c -> b s c', s = x.shape[1])
        return x + self.alpha * residual

class EpropAlphaGateL0rd(nn.Module):
    def __init__(self, num_hidden, batch_size, reg_lambda):
        super(EpropAlphaGateL0rd, self).__init__()
        
        self.alpha = nn.Parameter(th.zeros(1)+1e-12)
        self.l0rd  = EpropGateL0rd(
            num_inputs  = num_hidden, 
            num_hidden  = num_hidden, 
            num_outputs = num_hidden, 
            reg_lambda  = reg_lambda,
            batch_size = batch_size
        )

    def forward(self, input):
        residual = self.l0rd(rearrange(input, 'b s c -> (b s) c'))
        residual = rearrange(residual, '(b s) c -> b s c', s = input.shape[1])
        return input + self.alpha * residual

class PatchUpscale(nn.Module):
    def __init__(self, in_channels, out_channels, scale_factor = 4, alpha = 1e-16):
        super(PatchUpscale, self).__init__()
        assert in_channels % out_channels == 0
        
        self.skip = SkipConnection(in_channels, out_channels, scale_factor=scale_factor)

        self.residual = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(
                in_channels  = in_channels, 
                out_channels = in_channels, 
                kernel_size  = 3,
                padding      = 1
            ),
            nn.ReLU(),
            nn.ConvTranspose2d(
                in_channels  = in_channels, 
                out_channels = out_channels, 
                kernel_size  = scale_factor,
                stride       = scale_factor,
            ),
        )

        self.alpha = nn.Parameter(th.zeros(1, out_channels, 1, 1) + alpha)

    def forward(self, input):
        return self.skip(input) + self.alpha * self.residual(input)
