# -*- coding: utf-8 -*-

"""
CIFAR Generator

credits:
    https://github.com/GT-RIPL/AlwaysBeDreaming-DFCIL
    https://github.com/VainF/Data-Free-Adversarial-Distillation
"""

import torch
import torch.nn as nn
from typing import List

class Generator(nn.Module):
    def __init__(self, zdim, in_channel, img_sz, protos: List[torch.Tensor]):
        super().__init__()

        self.z_dim = zdim
        self.init_size = img_sz // 4
        self.l0 = nn.Linear(512, zdim)
        self.l1 = nn.Sequential(nn.Linear(zdim, 128 * self.init_size ** 2))

        self.conv_blocks0 = nn.Sequential(
            nn.BatchNorm2d(128),
        )
        self.conv_blocks1 = nn.Sequential(
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
        )
        self.conv_blocks2 = nn.Sequential(
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, in_channel, 3, stride=1, padding=1),
            nn.Tanh(),
            nn.BatchNorm2d(in_channel, affine=False),
        )
        self.protos = protos

    def forward(self, targets: torch.Tensor):
        z = self.protos[targets]
        z = torch.randn(z.shape).to(z.device) * 0.8 + z
        z = self.l0(z)
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks0(out)
        img = nn.functional.interpolate(img, scale_factor=2)
        img = self.conv_blocks1(img)
        img = nn.functional.interpolate(img, scale_factor=2)
        img = self.conv_blocks2(img)
        return img

    def sample(self, targets: torch.Tensor):
        X = self.forward(targets)
        return X


def CIFAR_GEN(protos: List[torch.Tensor]):
    return Generator(zdim=1000, in_channel=3, img_sz=32, protos=protos)
