import torch.nn as nn
import torch as th
from torch.autograd import Function
import nn as nn_modules
import utils
from nn.residual import ResidualBlock, SkipConnection
from nn.encoder import AggressiveDownConv
from nn.encoder import AggressiveConvTo1x1
from nn.decoder import AggressiveUpConv
from utils.utils import LambdaModule, ForcedAlpha, PrintShape
from nn.predictor import EpropAlphaGateL0rd
from nn.vae import VariationalFunction
from einops import rearrange, repeat, reduce

from typing import Union, Tuple



class RunningNormalization(nn.Module):
    def __init__(self, channels):
        super(RunningNormalization, self).__init__()

        self.register_buffer('mean', th.zeros(1, channels, 1, 1))
        self.register_buffer('mean2', th.zeros(1, channels, 1, 1))
        self.register_buffer('counter', th.zeros(1, channels, 1, 1))
        self.register_buffer('memory', th.ones(1))
        self.gamma = nn.Parameter(th.ones(1, channels, 1, 1))
        self.beta  = nn.Parameter(th.zeros(1, channels, 1, 1))
        self.froozen = True

    def freeze(self):
        self.froozen = True

    def forward(self, input: th.Tensor):
        
        if not self.froozen:
            self.memory = self.memory + 0.01
            factor = th.exp(-1 / self.memory) 
            self.mean    = self.mean * factor + reduce(input, 'b c h w -> 1 c 1 1', 'sum').detach()
            self.mean2   = self.mean2 * factor + reduce(input**2, 'b c h w -> 1 c 1 1', 'sum').detach()
            self.counter = self.counter * factor + input.shape[0] * input.shape[2] * input.shape[3]

        mean  = (self.mean  / self.counter)
        mean2 = (self.mean2 / self.counter)

        std = th.sqrt(th.relu((self.counter / (self.counter - 1)) * (mean2 - mean**2))) + 1e-8

        return self.gamma * (input - mean) / std + self.beta

class SigmoidResidual(nn.Module):
    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            kernel_size: Union[int, Tuple[int, int]] = (3, 3),
            bias: bool = True,
        ):

        super(SigmoidResidual, self).__init__()
        self.in_channels    = in_channels
        self.out_channels   = out_channels

        if isinstance(kernel_size, int):
            kernel_size = [kernel_size, kernel_size]

        padding = (kernel_size[0] // 2, kernel_size[1] // 2)

        self.pre_residual = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                padding=padding,
                bias=bias
            ),
            RunningNormalization(out_channels),
            nn.Sigmoid(),
            nn.Conv2d(
                in_channels=out_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                padding=padding,
                bias=bias
            )
        )

        self.post_residual = nn.Sequential(
            RunningNormalization(out_channels),
            nn.Sigmoid()
        )

        self.skip = SkipConnection(
            in_channels=in_channels,
            out_channels=out_channels,
        )

    def forward(self, input: th.Tensor) -> th.Tensor:
        return self.post_residual(self.pre_residual(input) + self.skip(input))

class TimeDistanceLoss(nn.Module):
    def __init__(
        self, 
        input_size: Tuple[int, int], 
        img_channels: int, 
        level1_channels,
        latent_channels,
        deepth
    ):
        super(TimeDistanceLoss, self).__init__()

        latent_size = [input_size[0] // 16, input_size[1] // 16]
        self.input_size = input_size
        #self.l1loss = nn.L1Loss(reduction='sum')
        self.l1loss = nn.MSELoss()

        self.level = 2
        self.loss  = False
        self.level2 = nn.Sequential(
            SigmoidResidual(img_channels, level1_channels // 2),
            nn.MaxPool2d(2),
            *[SigmoidResidual(level1_channels // 2, level1_channels // 2) for i in range(deepth)],
            SigmoidResidual(level1_channels // 2, level1_channels),
            nn.MaxPool2d(2),
            *[SigmoidResidual(level1_channels, level1_channels) for i in range(deepth)],
        )

        self.level1 = nn.Sequential(
            SigmoidResidual(level1_channels, latent_channels // 2),
            nn.MaxPool2d(2),
            *[SigmoidResidual(latent_channels // 2, latent_channels // 2) for i in range(deepth)],
            SigmoidResidual(latent_channels // 2, latent_channels),
            nn.MaxPool2d(2),
            *[SigmoidResidual(latent_channels, latent_channels) for i in range(deepth)],
        )

        self.level0 = nn.Sequential(
            ResidualBlock(latent_channels * 2, latent_channels),
            *[ResidualBlock(latent_channels, latent_channels) for i in range(deepth)],
            AggressiveConvTo1x1(latent_channels, latent_size),
            ResidualBlock(latent_channels, 1, kernel_size=1),
            #LambdaModule(lambda x: rearrange(th.clip(x, 0, 1), 'b 1 1 1 -> b')),
            LambdaModule(lambda x: rearrange(x, 'b 1 1 1 -> b')),
        )

        self.decoder = nn.Sequential(
            ResidualBlock(latent_channels, latent_channels),
            AggressiveUpConv(latent_channels, level1_channels, alpha=1),
            ResidualBlock(level1_channels, level1_channels),
            AggressiveUpConv(level1_channels, img_channels, alpha=1),
        )

        self.to_channels = nn.ModuleList([
            SkipConnection(img_channels, latent_channels),
            SkipConnection(img_channels, level1_channels),
            SkipConnection(img_channels, img_channels),
        ])

    def activate_loss(self):
        self.loss = True
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, output, target):
        latent1 = self.to_channels[self.level](output)
        latent2 = self.to_channels[self.level](target)

        latent = rearrange(th.stack((latent1, latent2)), 'n b c h w -> (n b) c h w')

        if self.loss:
            loss = 0
        
            for model in self.level2:
                out    = rearrange(latent, '(n b) c h w -> n b c h w', n = 2)
                loss   = loss + self.l1loss(out[0], out[1])
                latent = model(latent)
        
            for model in self.level1:
                out    = rearrange(latent, '(n b) c h w -> n b c h w', n = 2)
                loss   = loss + self.l1loss(out[0], out[1])
                latent = model(latent)
        
            out    = rearrange(latent, '(n b) c h w -> n b c h w', n = 2)
            loss   = loss + self.l1loss(out[0], out[1])
            return loss


        if self.level >= 2:
            latent = self.level2(latent)

        if self.level >= 1:
            latent = self.level1(latent)

        if self.loss:
            latent = rearrange(latent, '(n b) c h w -> n b c h w', n = 2)
            return th.mean((latent[0] - latent[1])**2)

        decoded = self.decoder(latent)
        decoded = rearrange(decoded, '(n b) c h w -> n b c h w', n = 2)
        latent = rearrange(latent, '(n b) c h w -> b (n c) h w', n = 2)
        return self.level0(latent), decoded[0], decoded[1]

