import torch
import torch.nn as nn



class ConditionalAutoencoder(nn.Module):
    def __init__(self, n_classes, input_dim):
        super().__init__()
        
        self.label_emb = nn.Embedding(n_classes, n_classes)
        self.linear_c = nn.Linear(n_classes, 1 * input_dim * input_dim)
        self.n_classes = n_classes
        self.input_dim = input_dim
        
        self.encoder = nn.Sequential(
            nn.Conv2d(3+1, 16, 4, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.Conv2d(16, 32, 4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(True)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            nn.ConvTranspose2d(32, 16, 4, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 3, 4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x, c):
        c = self.linear_c(self.label_emb(c)).view(-1, 1, self.input_dim, self.input_dim)
        x = torch.cat([x, c], dim=1)
        
        x = self.encoder(x)
        x = self.decoder(x)
        return x
    
    
if __name__ == '__main__':
    eps = 0.05
    trigger_generator = ConditionalAutoencoder(10, 32)
    image = torch.randn(1, 3, 32, 32)
    label = torch.tensor([1.0], dtype=torch.long)
    trigger = trigger_generator(image, label) * eps
    # trigger is within [-1, 1] * eps
    print(image.dtype)
    print(label.dtype)
    print(trigger.shape)   
    print(max(trigger.detach().numpy().flatten()))

