import torch
import torch.nn as nn
from .autoencoder import *

from utils import binarize



class DIM(nn.Module):
    def __init__(self, config1, config2, device, 
                 binary_cutoff_1 = False, 
                 binary_cutoff_2 = False, 
                 pixel_threshold=0):
        super(DIM, self).__init__()
        self.binary_cutoff_1 = binary_cutoff_1
        self.binary_cutoff_2 = binary_cutoff_2

        if config1 is None:
            self.model1 = lambda x: x

        elif config1['type'] == 'AutoEncoder':
            self.model1 = Autoencoder(config1['architecture'], device)
            self.model1.load_state_dict(torch.load(config1['model_path'], map_location=device))
            self.model1.to(device)
            self.model1.eval()

        if config2 is None:
            self.model2 = lambda x: x

        elif config2['type'] == "ColumnAutoEncoder":
            self.model2 = ColumnAE(config2['architecture'], pixel_threshold=pixel_threshold)
            self.model2.load_state_dict(torch.load(config2['model_path'],map_location=device))
            self.model2.to(device)
            self.model2.eval()

    def forward(self, x):
        if self.binary_cutoff_1: 
            x = self.binary_bilateral_function(x)
                
        x1 = self.model1(x)
        
        if self.binary_cutoff_2:
            x1 = self.binary_bilateral_function(x1)

        x2 = self.model2(x1)        
        return x2

    def binary_bilateral_function(self, x):
        return binarize(x)