import torch
import torch.nn as nn
import torch.nn.functional as F
from utils.network_utils import Flatten, Unflatten
from models.latent_resnet import resnet18
from models.basic_AE_resnet_layer4 import BasicAEResnet18


class CVAEResnet18(BasicAEResnet18):
    def __init__(self, latent_size=512, num_classes=1000, pretrained_encoder=False, use_shortcut_in_decoder=True,
                 use_upsampling=True):
        super(CVAEResnet18, self).__init__(latent_size=latent_size, num_classes=num_classes,
                                           use_shortcut_in_decoder=use_shortcut_in_decoder,
                                           use_upsampling=use_upsampling)
        if pretrained_encoder:
            print("pretrained encoder")
            resnet_model = resnet18(pretrained=False, num_classes=40)
            resnet_model.load_state_dict(
                torch.load("../data/first_batch_models/seed_1993_imagenet_40_batch0_45ep_noBias.pth")
            )
            layer4_enc = resnet_model.layer4
        else:
            layer4_enc = CVAEResnet18.make_encoder_layer(inplanes=256, outplanes=512, nr_blocks=2, stride=2)

        self.encoder = nn.Sequential(
            layer4_enc,
            nn.AdaptiveAvgPool2d((1, 1)),
            Flatten()
        )
        self.latent_size = latent_size
        self.ext_enc = nn.Linear(512, 256)
        self.mu = nn.Linear(256, self.latent_size)
        self.logvar = nn.Linear(256, self.latent_size)
        self.ext_dec = nn.Linear(self.latent_size + num_classes, 512)

        if use_upsampling:
            layer4_dec = self.make_decoder_layer_upsampling(512, 256, nr_blocks=2, stride=2)
        else:
            layer4_dec = self.make_decoder_layer_deconv(512, 256, nr_blocks=2, stride=2)
        self.decoder = nn.Sequential(
            Unflatten(512, 1, 1),
            nn.UpsamplingBilinear2d(scale_factor=7),
            nn.ReLU(),
            layer4_dec
        )

    def encode(self, x):
        h = self.encoder(x)
        h = F.tanh(self.ext_enc(h))
        mu, logvar = self.mu(h), self.logvar(h)
        return mu, logvar

    def decode(self, z):
        z = self.ext_dec(z)
        z = self.decoder(z)
        return z

    def load_encoder(self, weight_file):
        self.encoder.load_state_dict(torch.load(weight_file))

    def load_decoder(self, weight_file):
        self.decoder.load_state_dict(torch.load(weight_file))
