import torch
import torch.nn as nn
from models.latent_resnet import resnet18
from utils.network_utils import Flatten, Unflatten
from models.basic_AE_resnet_layer4 import BasicAEResnet18


class AEResnet18(BasicAEResnet18):
    def __init__(self, latent_size=512, num_classes=1000, pretrained_encoder=False, use_shortcut_in_decoder=True,
                 use_upsampling=True):
        super(AEResnet18, self).__init__(latent_size=latent_size, num_classes=num_classes,
                                         use_shortcut_in_decoder=use_shortcut_in_decoder, use_upsampling=use_upsampling)
        if pretrained_encoder:
            print("pretrained encoder")
            resnet_model = resnet18(pretrained=False, num_classes=40)
            resnet_model.load_state_dict(
                torch.load("../data/first_batch_models/seed_1993_imagenet_40_batch0_45ep_noBias.pth")
            )
            layer4_enc = resnet_model.layer4
        else:
            layer4_enc = AEResnet18.make_encoder_layer(inplanes=256, outplanes=512, nr_blocks=2, stride=2)

        self.encoder = nn.Sequential(
            layer4_enc,
            nn.AdaptiveAvgPool2d((1, 1)),
            Flatten()
        )

        if use_upsampling:
            layer4_dec = self.make_decoder_layer_upsampling(512, 256, nr_blocks=2, stride=2)
        else:
            layer4_dec = self.make_decoder_layer_deconv(512, 256, nr_blocks=2, stride=2)
        self.decoder = nn.Sequential(
            Unflatten(512, 1, 1),
            nn.UpsamplingBilinear2d(scale_factor=7),
            nn.ReLU(),
            layer4_dec
        )

    def encode(self, x):
        h = self.encoder(x)
        return h

    def decode(self, z):
        z = self.decoder(z)
        return z

    def forward(self, x, y_one_hot=None):
        h = self.encode(x)
        return self.decode(h)

    def save_encoder(self, path):
        torch.save(self.encoder.state_dict(), path)

    def save_decoder(self, path):
        torch.save(self.decoder.state_dict(), path)
