import torch.nn as nn
import torch as th
from torch.autograd import Function
import nn as nn_modules
import utils
from nn.residual import ResNeXtBlock, ResidualBlock
from nn.encoder import PatchDownConv
from nn.encoder import AggressiveConvToGestalt
from nn.decoder import PatchUpscale
from utils.utils import LambdaModule, ForcedAlpha, PrintShape, Binarize
from nn.predictor import EpropAlphaGateL0rd
from nn.vae import VariationalFunction
from einops import rearrange, repeat, reduce

from typing import Union, Tuple



class BackgroundEnhancer(nn.Module):
    def __init__(
        self, 
        input_size: Tuple[int, int], 
        img_channels: int, 
        level1_channels,
        latent_channels,
        gestalt_size,
        batch_size,
        reg_lambda,
        vae_factor,
        depth
    ):
        super(BackgroundEnhancer, self).__init__()

        self.input_size = input_size

        self.encoder = nn.Sequential(
            ResidualBlock(img_channels, latent_channels),
            *[ResidualBlock(latent_channels, latent_channels) for _ in range(depth)],
            ResidualBlock(latent_channels, 1),
            Binarize(1e-8),
        )

        self.decoder = nn.Sequential(
            ResidualBlock(1, latent_channels),
            *[ResidualBlock(latent_channels, latent_channels) for _ in range(depth)],
            ResidualBlock(latent_channels, img_channels),
        )
        
        self.binarized = 0

    def get_openings(self):
        return self.binarized

    def get_init(self):
        return 0

    def step_init(self):
        return 0

    def set_level(self, level):
        return level

    def get_last_latent_gird(self):
        return None

    def forward(self, input: th.Tensor, error: th.Tensor = None, mask: th.Tensor = None):

        latent = self.encoder(input)
        output = self.decoder(latent)

        self.binarized = th.mean(th.min(th.abs(latent), th.abs(1-latent))).item()

        return latent, output

        
