""" base: https://github.com/chendaichao/VAE-pytorch/blob/master/Models/VAE/model.py
"""
import torch
import torch.nn as nn
from collections import OrderedDict


class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        batch_size = x.shape[0]
        # return x.view(batch_size, -1)
        return x.reshape(batch_size, -1)


class MLP(nn.Module):
    def __init__(self, hidden_size, last_activation=True):
        super(MLP, self).__init__()
        q = []
        for i in range(len(hidden_size) - 1):
            in_dim = hidden_size[i]
            out_dim = hidden_size[i + 1]
            q.append(("Linear_%d" % i, nn.Linear(in_dim, out_dim)))
            if (i < len(hidden_size) - 2) or ((i == len(hidden_size) - 2) and (last_activation)):
                q.append(("BatchNorm_%d" % i, nn.BatchNorm1d(out_dim)))
                q.append(("ReLU_%d" % i, nn.ReLU(inplace=True)))
        self.mlp = nn.Sequential(OrderedDict(q))

    def forward(self, x):
        return self.mlp(x)


class Encoder(nn.Module):
    def __init__(self, shape, dim_out_latent=16, dim_condition=0):
        super(Encoder, self).__init__()
        c, h, w = shape
        ww = ((w - 8) // 2 - 4) // 2
        hh = ((h - 8) // 2 - 4) // 2
        self.encode = nn.Sequential(nn.Conv2d(c, 16, 5, padding=0), nn.BatchNorm2d(16), nn.ReLU(inplace=True),
                                    nn.Conv2d(16, 32, 5, padding=0), nn.BatchNorm2d(32), nn.ReLU(inplace=True),
                                    nn.MaxPool2d(2, 2),
                                    nn.Conv2d(32, 64, 3, padding=0), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
                                    nn.Conv2d(64, 64, 3, padding=0), nn.BatchNorm2d(64), nn.ReLU(inplace=True),
                                    nn.MaxPool2d(2, 2),
                                    Flatten(), MLP([ww * hh * 64, 256, 128]))
        self.calc_mean = MLP([128 + dim_condition, 64, dim_out_latent], last_activation=False)
        self.calc_logvar = MLP([128 + dim_condition, 64, dim_out_latent], last_activation=False)

    def forward(self, x, action=None):
        x = self.encode(x)
        if action is None:
            return self.calc_mean(x), self.calc_logvar(x)
        else:
            return self.calc_mean(torch.cat((x, action), dim=1)), self.calc_logvar(torch.cat((x, action), dim=1))


class Decoder(nn.Module):
    def __init__(self, shape, dim_out_latent=16, dim_condition=0):
        super(Decoder, self).__init__()
        c, w, h = shape
        self.shape = shape
        self.decode = nn.Sequential(
            MLP([dim_out_latent + dim_condition, 64, 128, 256, c * w * h], last_activation=False), nn.Sigmoid()
        )

    def forward(self, z, action=None):
        c, w, h = self.shape
        if action is None:
            return self.decode(z).view(-1, c, w, h)
        else:
            return self.decode(torch.cat((z, action), dim=1)).view(-1, c, w, h)


class cVAE(nn.Module):
    def __init__(self, shape, dim_action, dim_out_latent=16, dim_condition=16, device="cpu"):
        super(cVAE, self).__init__()
        self.encoder = Encoder(shape, dim_out_latent, dim_condition=dim_condition)
        self.decoder = Decoder(shape, dim_out_latent, dim_condition=dim_condition)
        # self.label_embedding = nn.Embedding(dim_action, dim_condition)
        self.label_embedding = nn.Sequential(nn.Linear(dim_action, dim_condition))
        # self.label_embedding = MLP([dim_action, 64, dim_condition], last_activation=False)
        self._device = device

        from gen_rl.policy.paintGym_critics import TReLU
        import torch.nn.utils.weight_norm as weightNorm

        # ==== Action to Image component; from paintGym_critics
        self.a2img_encoder = nn.Sequential(
            weightNorm(nn.Conv2d(dim_action + 2, 64, 1, 1, 0)),
            TReLU(),
            weightNorm(nn.Conv2d(64, 64, 1, 1, 0)),
            TReLU(),
            weightNorm(nn.Conv2d(64, 32, 1, 1, 0)),
            TReLU(),
            Flatten(),
            nn.Linear(131072, 128), TReLU(), nn.Linear(128, dim_condition)
        )

        self.coord = torch.zeros([1, 2, 64, 64])
        for i in range(64):
            for j in range(64):
                self.coord[0, 0, i, j] = i / 63.
                self.coord[0, 1, i, j] = j / 63.
                self.coord = self.coord.to(device)
        # ==== Action to Image component

        self.optim = torch.optim.Adam(
            list(self.encoder.parameters()) + list(self.decoder.parameters()) + list(self.label_embedding.parameters()),
            lr=0.0003, weight_decay=0.0001)
        self._criterion = nn.BCELoss(reduction="sum")

    def sampling(self, mean, logvar):
        eps = torch.randn(mean.shape, device=self._device)
        sigma = 0.5 * torch.exp(logvar)
        return mean + eps * sigma

    def forward(self, x, action, if_return_latent=False):
        tmp = self.coord.expand(action.shape[0], 2, 64, 64)
        action = action.repeat(64, 64, 1, 1).permute(2, 3, 0, 1)
        action = self.a2img_encoder(torch.cat([action, tmp], 1))
        # action = self.label_embedding(action)
        mean, logvar = self.encoder(x, action)
        z = self.sampling(mean, logvar)
        if if_return_latent:
            return self.decoder(z, action), z
        else:
            return self.decoder(z, action), mean, logvar

    def criterion(self, X, X_hat, mean, logvar):
        # reconst_loss = torch.nn.functional.binary_cross_entropy(X_hat, X, size_average=False)
        reconst_loss = self._criterion(X_hat, X)
        KL = 0.5 * torch.sum(-1 - logvar + torch.exp(logvar) + mean ** 2)
        return reconst_loss + KL


if __name__ == '__main__':
    encoder_layer_sizes = [784, 256]
    decoder_layer_sizes = [256, 784]
    latent_size = 2
    conditional = False
    device = "cpu"
    num_samples = 2
    dim_action = 10

    model = cVAE(shape=(1, 28, 28), dim_action=10, dim_out_latent=2, dim_condition=16, device=device)
    # x = torch.randn(num_samples, 28, 28, 1)
    x = torch.randn(num_samples, 1, 28, 28)
    # x = x.transpose(0, 3, 1, 2)
    a = torch.randn(num_samples, dim_action)

    X_hat, mean, logvar = model(x, a)
    l = model.criterion(x, X_hat, mean, logvar).to(device)
    print(X_hat.shape, mean.shape, logvar.shape, l.item())
