import torch.nn as nn
import torch as th
from torch.autograd import Function
import nn as nn_modules
import utils
from nn.residual import ResidualBlock, SkipConnection
from nn.encoder import AggressiveDownConv
from nn.encoder import AggressiveConvTo1x1
from nn.decoder import AggressiveUpConv
from utils.utils import LambdaModule, ForcedAlpha, PrintShape
from nn.predictor import EpropAlphaGateL0rd
from nn.vae import VariationalFunction
from einops import rearrange, repeat, reduce

from typing import Union, Tuple



class TimeDistanceLoss(nn.Module):
    def __init__(
        self, 
        input_size: Tuple[int, int], 
        img_channels: int, 
        level1_channels,
        latent_channels,
        deepth
    ):
        super(TimeDistanceLoss, self).__init__()

        latent_size = [input_size[0] // 16, input_size[1] // 16]
        self.input_size = input_size
        self.l1loss = nn.L1Loss(reduction='sum')

        self.level = 2
        self.loss  = False
        self.level2 = nn.Sequential(
            AggressiveDownConv(img_channels, level1_channels, alpha = 1),
            #nn.Conv2d(img_channels, level1_channels, kernel_size=7,padding='same'),
            #nn.MaxPool2d(4),
            *[ResidualBlock(level1_channels, level1_channels, alpha_residual = False) for i in range(deepth)]
        )

        self.level1 = nn.Sequential(
            AggressiveDownConv(level1_channels, latent_channels, alpha = 1),
            #ResidualBlock(level1_channels, latent_channels, alpha_residual = False),
            #nn.Conv2d(level1_channels, latent_channels, kernel_size=7,padding='same'),
            #nn.MaxPool2d(4),
            *[ResidualBlock(latent_channels, latent_channels, alpha_residual = False) for i in range(deepth)]
        )

        self.level0 = nn.Sequential(
            ResidualBlock(latent_channels * 2, latent_channels),
            *[ResidualBlock(latent_channels, latent_channels) for i in range(deepth)],
            AggressiveConvTo1x1(latent_channels, latent_size),
            ResidualBlock(latent_channels, 1, kernel_size=1),
            #LambdaModule(lambda x: rearrange(th.clip(x, 0, 1), 'b 1 1 1 -> b')),
            LambdaModule(lambda x: rearrange(x, 'b 1 1 1 -> b')),
        )

        self.decoder = nn.Sequential(
            ResidualBlock(latent_channels, latent_channels),
            AggressiveUpConv(latent_channels, level1_channels, alpha=1),
            ResidualBlock(level1_channels, level1_channels),
            AggressiveUpConv(level1_channels, img_channels, alpha=1),
        )

        self.to_channels = nn.ModuleList([
            SkipConnection(img_channels, latent_channels),
            SkipConnection(img_channels, level1_channels),
            SkipConnection(img_channels, img_channels),
        ])

    def activate_loss(self):
        self.loss = True
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, output, target):
        latent1 = self.to_channels[self.level](output)
        latent2 = self.to_channels[self.level](target)

        latent = rearrange(th.stack((latent1, latent2)), 'n b c h w -> (n b) c h w')

        if self.loss:
            loss = 0
        
            for model in self.level2:
                out    = rearrange(latent, '(n b) c h w -> n b c h w', n = 2)
                loss   = loss + self.l1loss(out[0], out[1])
                latent = model(latent)
        
            for model in self.level1:
                out    = rearrange(latent, '(n b) c h w -> n b c h w', n = 2)
                loss   = loss + self.l1loss(out[0], out[1])
                latent = model(latent)
        
            out    = rearrange(latent, '(n b) c h w -> n b c h w', n = 2)
            loss   = loss + self.l1loss(out[0], out[1])
            return loss


        if self.level >= 2:
            latent = self.level2(latent)

        if self.level >= 1:
            latent = self.level1(latent)

        if self.loss:
            latent = rearrange(latent, '(n b) c h w -> n b c h w', n = 2)
            return th.mean((latent[0] - latent[1])**2)

        decoded = self.decoder(latent)
        decoded = rearrange(decoded, '(n b) c h w -> n b c h w', n = 2)
        latent = rearrange(latent, '(n b) c h w -> b (n c) h w', n = 2)
        return self.level0(latent), decoded[0], decoded[1]

