import torch
import torch.nn as nn

from xad.models.resnets.resgan_blocks import ResDisOptimizedBlock, ResDisBlock
from xad.models.bases import ConceptNN


class ConceptResNet64(ConceptNN):
    def __init__(self, concepts: int, latent_dim=128, activation=torch.nn.functional.relu, grayscale=False):
        super().__init__(concepts)
        self.latent_dim = latent_dim
        self.activation = activation
        self.block1 = ResDisOptimizedBlock(6 if not grayscale else 2, 64, activation=activation)
        self.block2 = ResDisBlock(64, latent_dim, activation=activation, downsample=True)
        self.l3 = torch.nn.Linear(latent_dim, self.n_concepts)

    def parameterize(self):
        self.block1.parameterize()
        self.block2.parameterize()
        torch.nn.utils.spectral_norm(self.l3)

    def forward(self, x, x_hat):
        h = self.block1(torch.cat([x, x_hat], dim=1))
        h = self.block2(h)
        h = self.activation(h)
        h = h.sum([2, 3])
        logits = self.l3(h)
        return logits
