from typing import final
import torch
from xad.models.bases import BatchNorm2d, ConditionalDiscriminator, ConditionalGenerator, ConditionalGAN
from xad.models.resnets.resgan_blocks import ResGenEncoderBlock, ResGenDecoderBlock, ResDisBlock, ResDisOptimizedBlock, ResConv2d
from xad.models.resnets.resgan_blocks import ResLinear, ResEmbedding


class ResNetGeneratorEncoder64(ConditionalGAN.Encoder):
    def __init__(self, n_classes, ch=64, activation=torch.nn.functional.relu, in_channel=3, xavier_gain=2**0.5):
        super().__init__()
        self.activation = activation
        self.n_classes = n_classes
        self.b1 = BatchNorm2d(in_channel)
        self.c1 = ResConv2d(in_channel, ch, 3, stride=1, padding=1, xavier_gain=xavier_gain)
        self.block2 = ResGenEncoderBlock(ch, ch * 2, activation=activation, downsample=True, n_classes=n_classes, xavier_gain=xavier_gain)
        self.block3 = ResGenEncoderBlock(ch * 2, ch * 4, activation=activation, downsample=True, n_classes=n_classes, xavier_gain=xavier_gain)
        self.block4 = ResGenEncoderBlock(ch * 4, ch * 8, activation=activation, downsample=True, n_classes=n_classes, xavier_gain=xavier_gain)
        self.block5 = ResGenEncoderBlock(ch * 8, ch * 16, activation=activation, downsample=True, n_classes=n_classes, xavier_gain=xavier_gain)

    def forward(self, x: torch.Tensor, condition: torch.Tensor):
        h = x
        h = self.activation(self.b1(h))
        h = self.c1(h)
        h = self.block2(h, condition)
        h = self.block3(h, condition)
        h = self.block4(h, condition)
        h = self.block5(h, condition)
        return h


class ResNetGeneratorDecoder64(ConditionalGAN.Decoder):
    # https://github.com/t-vi/pytorch-tvmisc/blob/master/wasserstein-distance/sn_projection_cgan_64x64_143c.ipynb
    def __init__(self,  n_classes, ch=64, activation=torch.nn.functional.relu, out_channel=3, xavier_gain=2**0.5):
        super().__init__()
        self.activation = activation
        self.n_classes = n_classes
        self.block2 = ResGenDecoderBlock(ch * 16, ch * 8, activation=activation, upsample=True, n_classes=n_classes, xavier_gain=xavier_gain)
        self.block3 = ResGenDecoderBlock(ch * 8, ch * 4, activation=activation, upsample=True, n_classes=n_classes, xavier_gain=xavier_gain)
        self.block4 = ResGenDecoderBlock(ch * 4, ch * 2, activation=activation, upsample=True, n_classes=n_classes, xavier_gain=xavier_gain)
        self.block5 = ResGenDecoderBlock(ch * 2, ch, activation=activation, upsample=True, n_classes=n_classes, xavier_gain=xavier_gain)
        self.b6 = BatchNorm2d(ch)
        self.l6 = ResConv2d(ch, out_channel, 3, stride=1, padding=1, xavier_gain=xavier_gain)

    def forward(self, z: torch.Tensor, condition: torch.Tensor):
        h = z
        h = self.block2(h, condition)
        h = self.block3(h, condition)
        h = self.block4(h, condition)
        h = self.block5(h, condition)
        h = self.b6(h)
        h = self.activation(h)
        h = torch.tanh(self.l6(h))
        return h


class ResNetGenerator64(ConditionalGenerator):
    def __init__(self, latent_dim: int, condition_shape: torch.Size, img_channels: int = 3, xavier_gain=2**0.5):
        super().__init__(latent_dim, condition_shape)
        assert latent_dim % 16 == 0, f'LatentDim for ResNets need to be a factor of 16, but {latent_dim} % 16 = {latent_dim % 16}'
        self.encoder = ResNetGeneratorEncoder64(condition_shape.numel(), latent_dim // 16, in_channel=img_channels, xavier_gain=xavier_gain)
        self.decoder = ResNetGeneratorDecoder64(condition_shape.numel(), latent_dim // 16, out_channel=img_channels, xavier_gain=xavier_gain)


class SNResNetProjectionDiscriminator64(ConditionalDiscriminator):
    # https://github.com/t-vi/pytorch-tvmisc/blob/master/wasserstein-distance/sn_projection_cgan_64x64_143c.ipynb
    def __init__(self, latent_dim: int, condition_shape: torch.Size, activation=torch.nn.functional.relu, ordinal=True):
        super().__init__(latent_dim, condition_shape)
        assert latent_dim % 16 == 0, f'LatentDim for ResNets need to be a factor of 16, but {latent_dim} % 16 = {latent_dim % 16}'
        ch = latent_dim // 16
        self.ordinal = ordinal
        self.n_classes = condition_shape.numel()
        self.activation = activation
        self.block1 = ResDisOptimizedBlock(3, ch)
        self.block2 = ResDisBlock(ch, ch * 2, activation=activation, downsample=True)
        self.block3 = ResDisBlock(ch * 2, ch * 4, activation=activation, downsample=True)
        self.block4 = ResDisBlock(ch * 4, ch * 8, activation=activation, downsample=True)
        self.block5 = ResDisBlock(ch * 8, ch * 16, activation=activation, downsample=True)
        self.l6 = ResLinear(ch * 16, 1)

        if self.n_classes > 0:
            if self.ordinal:
                self.l_y = ResEmbedding(2, ch * 16)
            else:
                self.l_y = ResEmbedding(self.n_classes, ch * 16)

    def parameterize(self):
        self.block1.parameterize()
        self.block2.parameterize()
        self.block3.parameterize()
        self.block4.parameterize()
        self.block5.parameterize()
        torch.nn.utils.spectral_norm(self.l6)
        if self.n_classes > 0:
            torch.nn.utils.spectral_norm(self.l_y)

    def forward(self, x: torch.Tensor, condition: torch.Tensor) -> torch.Tensor:
        h = x
        h = self.block1(h)
        h = self.block2(h)
        h = self.block3(h)
        h = self.block4(h)
        h = self.block5(h)
        h = self.activation(h)
        h = h.sum([2, 3])
        output = self.l6(h)
        if condition is not None:
            if self.ordinal:
                for i in range(self.n_classes - 1):
                    output = output + (self.l_y(condition[:, i]) * h).sum(dim=1, keepdims=True)
            else:
                w_y = self.l_y(condition)
                output = output + (w_y * h).sum(dim=1, keepdim=True)
        return output


class WideResNetGeneratorEncoder64(ConditionalGAN.Encoder):
    def __init__(self, n_classes, ch=64, activation=torch.nn.functional.relu, in_channel=3, xavier_gain=2**0.5):
        super().__init__()
        self.activation = activation
        self.n_classes = n_classes
        self.b1 = BatchNorm2d(in_channel)
        self.c1 = ResConv2d(in_channel, ch, 3, stride=1, padding=1, xavier_gain=xavier_gain)
        self.block2 = ResGenEncoderBlock(ch, ch * 4, activation=activation, downsample=True, n_classes=n_classes, xavier_gain=xavier_gain)
        self.block3 = ResGenEncoderBlock(ch * 4, ch * 8, activation=activation, downsample=True, n_classes=n_classes, xavier_gain=xavier_gain)
        self.block4 = ResGenEncoderBlock(ch * 8, ch * 16, activation=activation, downsample=True, n_classes=n_classes, xavier_gain=xavier_gain)

    def forward(self, x: torch.Tensor, condition: torch.Tensor):
        h = x
        h = self.activation(self.b1(h))
        h = self.c1(h)
        h = self.block2(h, condition)
        h = self.block3(h, condition)
        h = self.block4(h, condition)
        return h


class WideResNetGeneratorDecoder64(ConditionalGAN.Decoder):
    # https://github.com/t-vi/pytorch-tvmisc/blob/master/wasserstein-distance/sn_projection_cgan_64x64_143c.ipynb
    def __init__(self,  n_classes, ch=64, activation=torch.nn.functional.relu, out_channel=3, final_act=torch.tanh, xavier_gain=2**0.5):
        super().__init__()
        self.activation = activation
        self.n_classes = n_classes
        self.final_act = final_act
        self.block2 = ResGenDecoderBlock(ch * 16, ch * 16, activation=activation, upsample=True, n_classes=n_classes, xavier_gain=xavier_gain)
        self.block3 = ResGenDecoderBlock(ch * 16, ch * 8, activation=activation, upsample=True, n_classes=n_classes, xavier_gain=xavier_gain)
        self.block4 = ResGenDecoderBlock(ch * 8, ch * 4, activation=activation, upsample=True, n_classes=n_classes, xavier_gain=xavier_gain)
        self.b6 = BatchNorm2d(ch * 4)
        self.l6 = ResConv2d(ch * 4, out_channel, 3, stride=1, padding=1)

    def forward(self, z: torch.Tensor, condition: torch.Tensor):
        h = z
        h = self.block2(h, condition)
        h = self.block3(h, condition)
        h = self.block4(h, condition)
        h = self.b6(h)
        h = self.activation(h)
        h = self.final_act(self.l6(h))
        return h


class WideResNetGenerator64(ConditionalGenerator):
    def __init__(self, latent_dim: int, condition_shape: torch.Size, img_channels: int = 3, final_act=torch.tanh, xavier_gain=2**0.5):
        super().__init__(latent_dim, condition_shape)
        assert latent_dim % 16 == 0, f'LatentDim for ResNets need to be a factor of 16, but {latent_dim} % 16 = {latent_dim % 16}'
        self.encoder = WideResNetGeneratorEncoder64(condition_shape.numel(), latent_dim // 16, in_channel=img_channels, xavier_gain=xavier_gain)
        self.decoder = WideResNetGeneratorDecoder64(condition_shape.numel(), latent_dim // 16, out_channel=img_channels, final_act=final_act, xavier_gain=xavier_gain)


class WideSNResNetProjectionDiscriminator64(ConditionalDiscriminator):
    # https://github.com/t-vi/pytorch-tvmisc/blob/master/wasserstein-distance/sn_projection_cgan_64x64_143c.ipynb
    def __init__(self, latent_dim: int, condition_shape: torch.Size, activation=torch.nn.functional.relu,
                 ordinal=True, grayscale=False):
        super().__init__(latent_dim, condition_shape)
        assert latent_dim % 16 == 0, f'LatentDim for ResNets need to be a factor of 16, but {latent_dim} % 16 = {latent_dim % 16}'
        ch = latent_dim // 16
        self.ordinal = ordinal
        self.n_classes = condition_shape.numel()
        self.activation = activation
        self.block1 = ResDisOptimizedBlock(3 if not grayscale else 1, ch)
        self.block2 = ResDisBlock(ch, ch * 2, activation=activation, downsample=True)
        self.block3 = ResDisBlock(ch * 2, ch * 4, activation=activation, downsample=True)
        self.block4 = ResDisBlock(ch * 4, ch * 8, activation=activation, downsample=True)
        self.l6 = ResLinear(ch * 8, 1)

        if self.n_classes > 0:
            if self.ordinal:
                self.l_y = ResEmbedding(2, ch * 8)
            else:
                self.l_y = ResEmbedding(self.n_classes, ch * 8)

    def parameterize(self):
        self.block1.parameterize()
        self.block2.parameterize()
        self.block3.parameterize()
        self.block4.parameterize()
        torch.nn.utils.spectral_norm(self.l6)
        if self.n_classes > 0:
            torch.nn.utils.spectral_norm(self.l_y)

    def forward(self, x: torch.Tensor, condition: torch.Tensor) -> torch.Tensor:
        h = x
        h = self.block1(h)
        h = self.block2(h)
        h = self.block3(h)
        h = self.block4(h)
        h = self.activation(h)
        h = h.sum([2, 3])
        output = self.l6(h)
        if condition is not None:
            if self.ordinal:
                for i in range(self.n_classes - 1):
                    output = output + (self.l_y(condition[:, i]) * h).sum(dim=1, keepdims=True)
            else:
                w_y = self.l_y(condition)
                output = output + (w_y * h).sum(dim=1, keepdim=True)
        return output


class WideResNetGeneratorEncoder32(ConditionalGAN.Encoder):
    def __init__(self, n_classes, ch=64, activation=torch.nn.functional.relu, grayscale=False):
        super().__init__()
        self.activation = activation
        self.n_classes = n_classes
        self.b1 = BatchNorm2d(3 if not grayscale else 1)
        self.c1 = ResConv2d(3 if not grayscale else 1, ch, 3, stride=1, padding=1, xavier_gain=2**0.5)
        self.block2 = ResGenEncoderBlock(ch, ch * 4, activation=activation, downsample=True, n_classes=n_classes)
        self.block3 = ResGenEncoderBlock(ch * 4, ch * 8, activation=activation, downsample=True, n_classes=n_classes)
        self.block4 = ResGenEncoderBlock(ch * 8, ch * 16, activation=activation, downsample=False, n_classes=n_classes)

    def forward(self, x: torch.Tensor, condition: torch.Tensor):
        h = x
        h = self.activation(self.b1(h))
        h = self.c1(h)
        h = self.block2(h, condition)
        h = self.block3(h, condition)
        h = self.block4(h, condition)
        return h


class WideResNetGeneratorDecoder32(ConditionalGAN.Decoder):
    # https://github.com/t-vi/pytorch-tvmisc/blob/master/wasserstein-distance/sn_projection_cgan_64x64_143c.ipynb
    def __init__(self,  n_classes, ch=64, activation=torch.nn.functional.relu, grayscale=False):
        super().__init__()
        self.activation = activation
        self.n_classes = n_classes
        self.block2 = ResGenDecoderBlock(ch * 16, ch * 16, activation=activation, upsample=False, n_classes=n_classes)
        self.block3 = ResGenDecoderBlock(ch * 16, ch * 8, activation=activation, upsample=True, n_classes=n_classes)
        self.block4 = ResGenDecoderBlock(ch * 8, ch * 4, activation=activation, upsample=True, n_classes=n_classes)
        self.b6 = BatchNorm2d(ch * 4)
        self.l6 = ResConv2d(ch * 4, 3 if not grayscale else 1, 3, stride=1, padding=1)

    def forward(self, z: torch.Tensor, condition: torch.Tensor):
        h = z
        h = self.block2(h, condition)
        h = self.block3(h, condition)
        h = self.block4(h, condition)
        h = self.b6(h)
        h = self.activation(h)
        h = torch.tanh(self.l6(h))
        return h


class WideResNetGenerator32(ConditionalGenerator):
    def __init__(self, latent_dim: int, condition_shape: torch.Size, grayscale=False):
        super().__init__(latent_dim, condition_shape)
        assert latent_dim % 16 == 0, f'LatentDim for ResNets need to be a factor of 16, but {latent_dim} % 16 = {latent_dim % 16}'
        self.encoder = WideResNetGeneratorEncoder32(condition_shape.numel(), latent_dim // 16, grayscale=grayscale)
        self.decoder = WideResNetGeneratorDecoder32(condition_shape.numel(), latent_dim // 16, grayscale=grayscale)


class WideSNResNetProjectionDiscriminator32(ConditionalDiscriminator):
    # https://github.com/t-vi/pytorch-tvmisc/blob/master/wasserstein-distance/sn_projection_cgan_64x64_143c.ipynb
    def __init__(self, latent_dim: int, condition_shape: torch.Size, activation=torch.nn.functional.relu,
                 ordinal=True, grayscale=False):
        super().__init__(latent_dim, condition_shape)
        assert latent_dim % 16 == 0, f'LatentDim for ResNets need to be a factor of 16, but {latent_dim} % 16 = {latent_dim % 16}'
        ch = latent_dim // 16
        self.ordinal = ordinal
        self.n_classes = condition_shape.numel()
        self.activation = activation
        self.block1 = ResDisOptimizedBlock(3 if not grayscale else 1, ch)
        self.block2 = ResDisBlock(ch, ch * 2, activation=activation, downsample=True)
        self.block3 = ResDisBlock(ch * 2, ch * 4, activation=activation, downsample=True)
        self.block4 = ResDisBlock(ch * 4, ch * 8, activation=activation, downsample=False)
        self.l6 = ResLinear(ch * 8, 1)

        if self.n_classes > 0:
            if self.ordinal:
                self.l_y = ResEmbedding(2, ch * 8)
            else:
                self.l_y = ResEmbedding(self.n_classes, ch * 8)

    def parameterize(self):
        self.block1.parameterize()
        self.block2.parameterize()
        self.block3.parameterize()
        self.block4.parameterize()
        torch.nn.utils.spectral_norm(self.l6)
        if self.n_classes > 0:
            torch.nn.utils.spectral_norm(self.l_y)

    def forward(self, x: torch.Tensor, condition: torch.Tensor) -> torch.Tensor:
        h = x
        h = self.block1(h)
        h = self.block2(h)
        h = self.block3(h)
        h = self.block4(h)
        h = self.activation(h)
        h = h.sum([2, 3])
        output = self.l6(h)
        if condition is not None:
            if self.ordinal:
                for i in range(self.n_classes - 1):
                    output = output + (self.l_y(condition[:, i]) * h).sum(dim=1, keepdims=True)
            else:
                w_y = self.l_y(condition)
                output = output + (w_y * h).sum(dim=1, keepdim=True)
        return output


class ResNetGeneratorEncoder224(ConditionalGAN.Encoder):
    def __init__(self, n_classes, ch=64, activation=torch.nn.functional.relu):
        super().__init__()
        self.activation = activation
        self.n_classes = n_classes
        self.b1 = BatchNorm2d(3)
        self.c1 = ResConv2d(3, ch, 3, stride=1, padding=1, xavier_gain=2**0.5)
        self.block2 = ResGenEncoderBlock(ch, ch * 2, activation=activation, downsample=True, n_classes=n_classes)
        self.block3 = ResGenEncoderBlock(ch * 2, ch * 4, activation=activation, downsample=True, n_classes=n_classes)
        self.block4 = ResGenEncoderBlock(ch * 4, ch * 8, activation=activation, downsample=True, n_classes=n_classes)
        self.block5 = ResGenEncoderBlock(ch * 8, ch * 16, activation=activation, downsample=True, n_classes=n_classes)
        # self.block6 = ResGenEncoderBlock(ch * 16, ch * 16, activation=activation, downsample=True, n_classes=n_classes)

    def forward(self, x: torch.Tensor, condition: torch.Tensor):
        h = x
        h = self.activation(self.b1(h))
        h = self.c1(h)
        h = self.block2(h, condition)
        h = self.block3(h, condition)
        h = self.block4(h, condition)
        h = self.block5(h, condition)
        # h = self.block6(h, condition)
        return h


class ResNetGeneratorDecoder224(ConditionalGAN.Decoder):
    # https://github.com/t-vi/pytorch-tvmisc/blob/master/wasserstein-distance/sn_projection_cgan_64x64_143c.ipynb
    def __init__(self,  n_classes, ch=64, activation=torch.nn.functional.relu,):
        super().__init__()
        self.activation = activation
        self.n_classes = n_classes
        # self.block2 = ResGenDecoderBlock(ch * 16, ch * 16, activation=activation, upsample=True, n_classes=n_classes)
        self.block3 = ResGenDecoderBlock(ch * 16, ch * 8, activation=activation, upsample=True, n_classes=n_classes)
        self.block4 = ResGenDecoderBlock(ch * 8, ch * 4, activation=activation, upsample=True, n_classes=n_classes)
        self.block5 = ResGenDecoderBlock(ch * 4, ch * 2, activation=activation, upsample=True, n_classes=n_classes)
        self.block6 = ResGenDecoderBlock(ch * 2, ch, activation=activation, upsample=True, n_classes=n_classes)
        self.b7 = BatchNorm2d(ch)
        self.l7 = ResConv2d(ch, 3, 3, stride=1, padding=1)

    def forward(self, z: torch.Tensor, condition: torch.Tensor):
        h = z
        # h = self.block2(h, condition)
        h = self.block3(h, condition)
        h = self.block4(h, condition)
        h = self.block5(h, condition)
        h = self.block6(h, condition)
        h = self.b7(h)
        h = self.activation(h)
        h = torch.tanh(self.l7(h))
        return h


class ResNetGenerator224(ConditionalGenerator):
    def __init__(self, latent_dim: int, condition_shape: torch.Size):
        super().__init__(latent_dim, condition_shape)
        assert latent_dim % 16 == 0, f'LatentDim for ResNets need to be a factor of 16, but {latent_dim} % 16 = {latent_dim % 16}'
        self.encoder = ResNetGeneratorEncoder224(condition_shape.numel(), latent_dim // 16, )
        self.decoder = ResNetGeneratorDecoder224(condition_shape.numel(), latent_dim // 16, )


class SNResNetProjectionDiscriminator224(ConditionalDiscriminator):
    # https://github.com/t-vi/pytorch-tvmisc/blob/master/wasserstein-distance/sn_projection_cgan_64x64_143c.ipynb
    def __init__(self, latent_dim: int, condition_shape: torch.Size, activation=torch.nn.functional.relu, ordinal=True):
        super().__init__(latent_dim, condition_shape)
        assert latent_dim % 16 == 0, f'LatentDim for ResNets need to be a factor of 16, but {latent_dim} % 16 = {latent_dim % 16}'
        ch = latent_dim // 16
        self.ordinal = ordinal
        self.n_classes = condition_shape.numel()
        self.activation = activation
        self.block1 = ResDisOptimizedBlock(3, ch)
        self.block2 = ResDisBlock(ch, ch * 2, activation=activation, downsample=True)
        self.block3 = ResDisBlock(ch * 2, ch * 4, activation=activation, downsample=True)
        self.block4 = ResDisBlock(ch * 4, ch * 8, activation=activation, downsample=True)
        self.block5 = ResDisBlock(ch * 8, ch * 16, activation=activation, downsample=True)
        self.block6 = ResDisBlock(ch * 16, ch * 16, activation=activation, downsample=False)
        self.l7 = ResLinear(ch * 16, 1)

        if self.n_classes > 0:
            if self.ordinal:
                self.l_y = ResEmbedding(2, ch * 16)
            else:
                self.l_y = ResEmbedding(self.n_classes, ch * 16)

    def parameterize(self):
        self.block1.parameterize()
        self.block2.parameterize()
        self.block3.parameterize()
        self.block4.parameterize()
        self.block5.parameterize()
        self.block6.parameterize()
        torch.nn.utils.spectral_norm(self.l7)
        if self.n_classes > 0:
            torch.nn.utils.spectral_norm(self.l_y)

    def forward(self, x: torch.Tensor, condition: torch.Tensor) -> torch.Tensor:
        h = x
        h = self.block1(h)
        h = self.block2(h)
        h = self.block3(h)
        h = self.block4(h)
        h = self.block5(h)
        h = self.block6(h)
        h = self.activation(h)
        h = h.sum([2, 3])
        output = self.l7(h)
        if condition is not None:
            if self.ordinal:
                for i in range(self.n_classes - 1):
                    output = output + (self.l_y(condition[:, i]) * h).sum(dim=1, keepdims=True)
            else:
                w_y = self.l_y(condition)
                output = output + (w_y * h).sum(dim=1, keepdim=True)
        return output



class WideResNetGeneratorEncoder64S(ConditionalGAN.Encoder):
    def __init__(self, n_classes, ch=64, activation=torch.nn.functional.relu, in_channel=3, xavier_gain=2**0.5):
        super().__init__()
        self.activation = activation
        self.n_classes = n_classes
        self.b1 = BatchNorm2d(in_channel)
        self.c1 = ResConv2d(in_channel, ch, 3, stride=1, padding=1, xavier_gain=xavier_gain)
        self.block2 = ResGenEncoderBlock(ch, ch * 4, activation=activation, downsample=True, n_classes=n_classes, xavier_gain=xavier_gain)
        self.block3 = ResGenEncoderBlock(ch * 4, ch * 8, activation=activation, downsample=True, n_classes=n_classes, xavier_gain=xavier_gain)

    def forward(self, x: torch.Tensor, condition: torch.Tensor):
        h = x
        h = self.activation(self.b1(h))
        h = self.c1(h)
        h = self.block2(h, condition)
        h = self.block3(h, condition)
        return h


class WideResNetGeneratorDecoder64S(ConditionalGAN.Decoder):
    # https://github.com/t-vi/pytorch-tvmisc/blob/master/wasserstein-distance/sn_projection_cgan_64x64_143c.ipynb
    def __init__(self,  n_classes, ch=64, activation=torch.nn.functional.relu, out_channel=3, final_act=torch.tanh, xavier_gain=2**0.5):
        super().__init__()
        self.activation = activation
        self.n_classes = n_classes
        self.final_act = final_act
        self.block3 = ResGenDecoderBlock(ch * 8, ch * 4, activation=activation, upsample=True, n_classes=n_classes, xavier_gain=xavier_gain)
        self.block4 = ResGenDecoderBlock(ch *4, ch, activation=activation, upsample=True, n_classes=n_classes, xavier_gain=xavier_gain)
        self.b6 = BatchNorm2d(ch)
        self.l6 = ResConv2d(ch, out_channel, 3, stride=1, padding=1, xavier_gain=xavier_gain)

    def forward(self, z: torch.Tensor, condition: torch.Tensor):
        h = z
        h = self.block3(h, condition)
        h = self.block4(h, condition)
        h = self.b6(h)
        h = self.activation(h)
        h = self.final_act(self.l6(h))
        return h


class WideResNetGenerator64S(ConditionalGenerator):
    def __init__(self, latent_dim: int, condition_shape: torch.Size, img_channels: int = 3, final_act=torch.tanh, xavier_gain=2**0.5):
        super().__init__(latent_dim, condition_shape)
        assert latent_dim % 16 == 0, f'LatentDim for ResNets need to be a factor of 16, but {latent_dim} % 16 = {latent_dim % 16}'
        self.encoder = WideResNetGeneratorEncoder64S(condition_shape.numel(), latent_dim // 16, in_channel=img_channels, xavier_gain=xavier_gain)
        self.decoder = WideResNetGeneratorDecoder64S(condition_shape.numel(), latent_dim // 16, out_channel=img_channels, final_act=final_act, xavier_gain=xavier_gain)
