import torch
import torch.nn as nn
from abc import ABC, abstractmethod
from models.autoencoders_utils import BasicEncoderBlock, DecoderBlockWithUpsampling, \
    DecoderBlockWithTransposedConvolution


class BasicAEResnet18(nn.Module, ABC):
    def __init__(self, latent_size=512, num_classes=1000, use_shortcut_in_decoder=True, use_upsampling=True):
        super(BasicAEResnet18, self).__init__()
        self.latent_size = latent_size
        self.num_classes = num_classes
        self.use_shortcut_in_decoder = use_shortcut_in_decoder
        self.use_upsampling = use_upsampling

    @abstractmethod
    def encode(self, x):
        pass

    @abstractmethod
    def decode(self, z):
        pass

    def reparameterize(self, mu, logvar):
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            res = eps.mul(std).add_(mu)
            return res
        else:
            return mu

    def forward(self, x, y_one_hot=None):
        assert y_one_hot is not None, "one-hot labels should be provided to the forward method"
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        z = torch.cat((z, y_one_hot), dim=1)
        return self.decode(z), mu, logvar

    @staticmethod
    def make_encoder_layer(inplanes, outplanes, nr_blocks, stride=1):
        layers = [BasicEncoderBlock(inplanes, outplanes, stride)]
        for _ in range(1, nr_blocks):
            layers.append(BasicEncoderBlock(outplanes, outplanes))
        return nn.Sequential(*layers)

    def make_decoder_layer_upsampling(self, inplanes, outplanes, nr_blocks, stride=1):
        layers = [DecoderBlockWithUpsampling(inplanes, inplanes, use_shortcut=self.use_shortcut_in_decoder)]
        for _ in range(1, nr_blocks - 1):
            layers.append(DecoderBlockWithUpsampling(inplanes, inplanes, use_shortcut=self.use_shortcut_in_decoder))
        layers.append(DecoderBlockWithUpsampling(inplanes, outplanes, stride,
                                                 use_shortcut=self.use_shortcut_in_decoder))
        return nn.Sequential(*layers)

    def make_decoder_layer_deconv(self, inplanes, outplanes, nr_blocks, stride=1):
        layers = [DecoderBlockWithTransposedConvolution(inplanes, inplanes, use_shortcut=self.use_shortcut_in_decoder)]
        for _ in range(1, nr_blocks - 1):
            layers.append(DecoderBlockWithTransposedConvolution(inplanes, inplanes,
                                                                use_shortcut=self.use_shortcut_in_decoder))
        layers.append(DecoderBlockWithTransposedConvolution(inplanes, outplanes, stride,
                                                            use_shortcut=self.use_shortcut_in_decoder))
        return nn.Sequential(*layers)
