import torch
import torch.nn as nn
import torch.nn.functional as F

class Autoencoder(nn.Module):
    def __init__(self, config, device='cuda:0'):
        super(Autoencoder, self).__init__()
        encoder_dims = config['encoder_dims']
        decoder_dims = config['decoder_dims']
        self.device = device
        self.add_filter=False
        encoder_list = [nn.Linear(28*28, encoder_dims[0])]
        
        for i in range(len(encoder_dims)-1):
            try:
                if config['dropout']!=None:
                    encoder_list.append(nn.Dropout(config['dropout']))
            except:
                pass
            encoder_list.append(nn.ReLU(True))
            encoder_list.append(nn.Linear(encoder_dims[i], encoder_dims[i+1]))
        
        decoder_list = []
        for i in range(len(decoder_dims)-1):
            try:
                if config['dropout']!=None:
                    encoder_list.append(nn.Dropout(config['dropout']))
            except:
                pass
            decoder_list.append(nn.Linear(decoder_dims[i], decoder_dims[i+1]))
            decoder_list.append(nn.ReLU(True))
        decoder_list.append(nn.Linear(decoder_dims[-1], 28*28))
        decoder_list.append(nn.Tanh())

        self.encoder = nn.Sequential(*encoder_list)
        self.decoder = nn.Sequential(*decoder_list)

    def forward(self, x):
        if self.add_filter:
            x = F.conv2d(x, self.kernel, padding=1)
        return self._forward(x)[0]

    def _forward(self, x):
        code = self._encode(x)
        y = self._decode(code)
        y = y.view(x.size(0), 1, 28, 28)
        return y, code
    
    def _encode(self, x):
        x = x.view(x.size(0),-1)
        return self.encoder(x)
        
    def _decode(self, x):
        return self.decoder(x)


class ColumnAE(nn.Module):
    def __init__(self, config, pixel_threshold=0):
        super(ColumnAE, self).__init__()
        self.autoencoders = nn.ModuleList([Autoencoder(config) for i in range(10)])
        self.pixel_threshold = pixel_threshold

    def forward(self, x):
        if self.pixel_threshold == 0:
            intensity = torch.sum(self.autoencoders[0](x), [1,2,3]).view(x.size(0),1)
            for i in range(1, 10):
                intensity = torch.cat([intensity,torch.sum(self.autoencoders[i](x),[1,2,3]).view(x.size(0),1)], 1)
        else:
            y = self.autoencoders[0](x)
            zeros = torch.zeros_like(y)
            intensity = torch.sum(torch.where(y >= self.pixel_threshold, y, zeros), [1,2,3]).view(x.size(0),1)
            for i in range(1, 10):
                y = self.autoencoders[i](x)
                zeros = torch.zeros_like(y)
                intensity = torch.cat([intensity,torch.sum(torch.where(y >= self.pixel_threshold, y, zeros),[1,2,3]).view(x.size(0),1)], 1)
                
        return intensity
    
    def _forward(self, x, idx):
        return self.autoencoders[idx](x)

    def _encode(self, x, idx):
        return self.autoencoders[idx]._encode(x)

    def _decode(self, x, idx):
        return self.autoencoders[idx]._decode(x)