"""
this code is borrowed from https://github.com/NVlabs/stylegan2-ada-pytorch with few modifications

Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.

NVIDIA CORPORATION and its licensors retain all intellectual property
and proprietary rights in and to this software, related documentation
and any modifications thereto.  Any use, reproduction, disclosure or
distribution of this software and related documentation without an express
license agreement from NVIDIA CORPORATION is strictly prohibited.
"""

import torch.nn as nn


import torch
import torch.nn.functional as F
import numpy as np

import utils.style_misc as misc
from utils.style_ops import conv2d_resample
from utils.style_ops import upfirdn2d
from utils.style_ops import bias_act
from utils.style_ops import fma


def normalize_2nd_moment(x, dim=1, eps=1e-8):
    return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt()


def modulated_conv2d(
        x,  # Input tensor of shape [batch_size, in_channels, in_height, in_width].
        weight,  # Weight tensor of shape [out_channels, in_channels, kernel_height, kernel_width].
        styles,  # Modulation coefficients of shape [batch_size, in_channels].
        noise=None,  # Optional noise tensor to add to the output activations.
        up=1,  # Integer upsampling factor.
        down=1,  # Integer downsampling factor.
        padding=0,  # Padding with respect to the upsampled image.
        resample_filter=None,  # Low-pass filter to apply when resampling activations. Must be prepared beforehand by calling upfirdn2d.setup_filter().
        demodulate=True,  # Apply weight demodulation?
        flip_weight=True,  # False = convolution, True = correlation (matches torch.nn.functional.conv2d).
        fused_modconv=True,  # Perform modulation, convolution, and demodulation as a single fused operation?
):
    batch_size = x.shape[0]
    out_channels, in_channels, kh, kw = weight.shape
    misc.assert_shape(weight, [out_channels, in_channels, kh, kw])  # [OIkk]
    misc.assert_shape(x, [batch_size, in_channels, None, None])  # [NIHW]
    misc.assert_shape(styles, [batch_size, in_channels])  # [NI]

    # Pre-normalize inputs to avoid FP16 overflow.
    if x.dtype == torch.float16 and demodulate:
        weight = weight * (1 / np.sqrt(in_channels * kh * kw) / weight.norm(float("inf"), dim=[1, 2, 3], keepdim=True))  # max_Ikk
        styles = styles / styles.norm(float("inf"), dim=1, keepdim=True)  # max_I

    # Calculate per-sample weights and demodulation coefficients.
    w = None
    dcoefs = None
    if demodulate or fused_modconv:
        w = weight.unsqueeze(0)  # [NOIkk]
        w = w * styles.reshape(batch_size, 1, -1, 1, 1)  # [NOIkk]
    if demodulate:
        dcoefs = (w.square().sum(dim=[2, 3, 4]) + 1e-8).rsqrt()  # [NO]
    if demodulate and fused_modconv:
        w = w * dcoefs.reshape(batch_size, -1, 1, 1, 1)  # [NOIkk]

    # Execute by scaling the activations before and after the convolution.
    if not fused_modconv:
        x = x * styles.to(x.dtype).reshape(batch_size, -1, 1, 1)
        x = conv2d_resample.conv2d_resample(x=x,
                                            w=weight.to(x.dtype),
                                            f=resample_filter,
                                            up=up,
                                            down=down,
                                            padding=padding,
                                            flip_weight=flip_weight)
        if demodulate and noise is not None:
            x = fma.fma(x, dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1), noise.to(x.dtype))
        elif demodulate:
            x = x * dcoefs.to(x.dtype).reshape(batch_size, -1, 1, 1)
        elif noise is not None:
            x = x.add_(noise.to(x.dtype))
        return x

    # Execute as one fused op using grouped convolution.
    with misc.suppress_tracer_warnings():  # this value will be treated as a constant
        batch_size = int(batch_size)
    misc.assert_shape(x, [batch_size, in_channels, None, None])
    x = x.reshape(1, -1, *x.shape[2:])
    w = w.reshape(-1, in_channels, kh, kw)
    x = conv2d_resample.conv2d_resample(x=x,
                                        w=w.to(x.dtype),
                                        f=resample_filter,
                                        up=up,
                                        down=down,
                                        padding=padding,
                                        groups=batch_size,
                                        flip_weight=flip_weight)
    x = x.reshape(batch_size, -1, *x.shape[2:])
    if noise is not None:
        x = x.add_(noise)
    return x


class FullyConnectedLayer(torch.nn.Module):
    def __init__(
            self,
            in_features,  # Number of input features.
            out_features,  # Number of output features.
            bias=True,  # Apply additive bias before the activation function?
            activation="linear",  # Activation function: "relu", "lrelu", etc.
            lr_multiplier=1,  # Learning rate multiplier.
            bias_init=0,  # Initial value for the additive bias.
    ):
        super().__init__()
        self.activation = activation
        self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier)
        self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None
        self.weight_gain = lr_multiplier / np.sqrt(in_features)
        self.bias_gain = lr_multiplier

    def forward(self, x):
        w = self.weight.to(x.dtype) * self.weight_gain
        b = self.bias
        if b is not None:
            b = b.to(x.dtype)
            if self.bias_gain != 1:
                b = b * self.bias_gain

        if self.activation == "linear" and b is not None:
            x = torch.addmm(b.unsqueeze(0), x, w.t())
        else:
            x = x.matmul(w.t())
            x = bias_act.bias_act(x, b, act=self.activation)
        return x


class Conv2dLayer(torch.nn.Module):
    def __init__(
            self,
            in_channels,  # Number of input channels.
            out_channels,  # Number of output channels.
            kernel_size,  # Width and height of the convolution kernel.
            bias=True,  # Apply additive bias before the activation function?
            activation="linear",  # Activation function: "relu", "lrelu", etc.
            up=1,  # Integer upsampling factor.
            down=1,  # Integer downsampling factor.
            resample_filter=[1, 3, 3, 1],  # Low-pass filter to apply when resampling activations.
            conv_clamp=None,  # Clamp the output to +-X, None = disable clamping.
            channels_last=False,  # Expect the input to have memory_format=channels_last?
            trainable=True,  # Update the weights of this layer during training?
    ):
        super().__init__()
        self.activation = activation
        self.up = up
        self.down = down
        self.conv_clamp = conv_clamp
        self.register_buffer("resample_filter", upfirdn2d.setup_filter(resample_filter))
        self.padding = kernel_size // 2
        self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size**2))
        self.act_gain = bias_act.activation_funcs[activation].def_gain

        memory_format = torch.channels_last if channels_last else torch.contiguous_format
        weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format)
        bias = torch.zeros([out_channels]) if bias else None
        if trainable:
            self.weight = torch.nn.Parameter(weight)
            self.bias = torch.nn.Parameter(bias) if bias is not None else None
        else:
            self.register_buffer("weight", weight)
            if bias is not None:
                self.register_buffer("bias", bias)
            else:
                self.bias = None

    def forward(self, x, gain=1):
        w = self.weight * self.weight_gain
        b = self.bias.to(x.dtype) if self.bias is not None else None
        flip_weight = (self.up == 1)  # slightly faster
        x = conv2d_resample.conv2d_resample(x=x,
                                            w=w.to(x.dtype),
                                            f=self.resample_filter,
                                            up=self.up,
                                            down=self.down,
                                            padding=self.padding,
                                            flip_weight=flip_weight)

        act_gain = self.act_gain * gain
        act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
        x = bias_act.bias_act(x, b, act=self.activation, gain=act_gain, clamp=act_clamp)
        return x


class MappingNetwork(torch.nn.Module):
    def __init__(
            self,
            z_dim,  # Input latent (Z) dimensionality, 0 = no latent.
            c_dim,  # Conditioning label (C) dimensionality, 0 = no label.
            w_dim,  # Intermediate latent (W) dimensionality.
            num_ws,  # Number of intermediate latents to output, None = do not broadcast.
            num_layers=8,  # Number of mapping layers.
            embed_features=None,  # Label embedding dimensionality, None = same as w_dim.
            layer_features=None,  # Number of intermediate features in the mapping layers, None = same as w_dim.
            activation="lrelu",  # Activation function: "relu", "lrelu", etc.
            lr_multiplier=0.01,  # Learning rate multiplier for the mapping layers.
            w_avg_beta=0.998,  # Decay for tracking the moving average of W during training, None = do not track.
    ):
        super().__init__()
        self.z_dim = z_dim
        self.c_dim = c_dim
        self.w_dim = w_dim
        self.num_ws = num_ws
        self.num_layers = num_layers
        self.w_avg_beta = w_avg_beta

        if embed_features is None:
            embed_features = w_dim
        if c_dim == 0:
            embed_features = 0
        if layer_features is None:
            layer_features = w_dim
        features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim]

        if c_dim > 0:
            self.embed = FullyConnectedLayer(c_dim, embed_features)
        for idx in range(num_layers):
            in_features = features_list[idx]
            out_features = features_list[idx + 1]
            layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier)
            setattr(self, f"fc{idx}", layer)

        if num_ws is not None and w_avg_beta is not None:
            self.register_buffer("w_avg", torch.zeros([w_dim]))

    def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False):
        # Embed, normalize, and concat inputs.
        x = None
        if self.z_dim > 0:
            misc.assert_shape(z, [None, self.z_dim])
            x = normalize_2nd_moment(z.to(torch.float32))
        if self.c_dim > 0:
            misc.assert_shape(c, [None, self.c_dim])
            y = normalize_2nd_moment(self.embed(c.to(torch.float32)))
            x = torch.cat([x, y], dim=1) if x is not None else y

        # Main layers.
        for idx in range(self.num_layers):
            layer = getattr(self, f"fc{idx}")
            x = layer(x)

        # Update moving average of W.
        if update_emas and self.w_avg_beta is not None:
            self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))

        # Broadcast.
        if self.num_ws is not None:
            x = x.unsqueeze(1).repeat([1, self.num_ws, 1])

        # Apply truncation.
        if truncation_psi != 1:
            assert self.w_avg_beta is not None
            if self.num_ws is None or truncation_cutoff is None:
                x = self.w_avg.lerp(x, truncation_psi)
            else:
                x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi)
        return x


class SynthesisLayer(torch.nn.Module):
    def __init__(
            self,
            in_channels,  # Number of input channels.
            out_channels,  # Number of output channels.
            w_dim,  # Intermediate latent (W) dimensionality.
            resolution,  # Resolution of this layer.
            kernel_size=3,  # Convolution kernel size.
            up=1,  # Integer upsampling factor.
            use_noise=True,  # Enable noise input?
            activation="lrelu",  # Activation function: "relu", "lrelu", etc.
            resample_filter=[1, 3, 3, 1],  # Low-pass filter to apply when resampling activations.
            conv_clamp=None,  # Clamp the output of convolution layers to +-X, None = disable clamping.
            channels_last=False,  # Use channels_last format for the weights?
    ):
        super().__init__()
        self.resolution = resolution
        self.up = up
        self.use_noise = use_noise
        self.activation = activation
        self.conv_clamp = conv_clamp
        self.register_buffer("resample_filter", upfirdn2d.setup_filter(resample_filter))
        self.padding = kernel_size // 2
        self.act_gain = bias_act.activation_funcs[activation].def_gain

        self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
        memory_format = torch.channels_last if channels_last else torch.contiguous_format
        self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format))
        if use_noise:
            self.register_buffer("noise_const", torch.randn([resolution, resolution]))
            self.noise_strength = torch.nn.Parameter(torch.zeros([]))
        self.bias = torch.nn.Parameter(torch.zeros([out_channels]))

    def forward(self, x, w, noise_mode="random", fused_modconv=True, gain=1):
        assert noise_mode in ["random", "const", "none"]
        in_resolution = self.resolution // self.up
        misc.assert_shape(x, [None, self.weight.shape[1], in_resolution, in_resolution])
        styles = self.affine(w)

        noise = None
        if self.use_noise and noise_mode == "random":
            noise = torch.randn([x.shape[0], 1, self.resolution, self.resolution], device=x.device) * self.noise_strength
        if self.use_noise and noise_mode == "const":
            noise = self.noise_const * self.noise_strength

        flip_weight = (self.up == 1)  # slightly faster
        x = modulated_conv2d(x=x,
                             weight=self.weight,
                             styles=styles,
                             noise=noise,
                             up=self.up,
                             padding=self.padding,
                             resample_filter=self.resample_filter,
                             flip_weight=flip_weight,
                             fused_modconv=fused_modconv)

        act_gain = self.act_gain * gain
        act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
        x = bias_act.bias_act(x, self.bias.to(x.dtype), act=self.activation, gain=act_gain, clamp=act_clamp)
        return x


class ToRGBLayer(torch.nn.Module):
    def __init__(self, in_channels, out_channels, w_dim, kernel_size=1, conv_clamp=None, channels_last=False):
        super().__init__()
        self.conv_clamp = conv_clamp
        self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
        memory_format = torch.channels_last if channels_last else torch.contiguous_format
        self.weight = torch.nn.Parameter(torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format))
        self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
        self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size**2))

    def forward(self, x, w, fused_modconv=True):
        styles = self.affine(w) * self.weight_gain
        x = modulated_conv2d(x=x, weight=self.weight, styles=styles, demodulate=False, fused_modconv=fused_modconv)
        x = bias_act.bias_act(x, self.bias.to(x.dtype), clamp=self.conv_clamp)
        return x


class SynthesisBlock(torch.nn.Module):
    def __init__(
            self,
            in_channels,  # Number of input channels, 0 = first block.
            out_channels,  # Number of output channels.
            w_dim,  # Intermediate latent (W) dimensionality.
            resolution,  # Resolution of this block.
            img_channels,  # Number of output color channels.
            is_last,  # Is this the last block?
            architecture="skip",  # Architecture: "orig", "skip", "resnet".
            resample_filter=[1, 3, 3, 1],  # Low-pass filter to apply when resampling activations.
            conv_clamp=None,  # Clamp the output of convolution layers to +-X, None = disable clamping.
            use_fp16=False,  # Use FP16 for this block?
            fp16_channels_last=False,  # Use channels-last memory format with FP16?
            **layer_kwargs,  # Arguments for SynthesisLayer.
    ):
        assert architecture in ["orig", "skip", "resnet"]
        super().__init__()
        self.in_channels = in_channels
        self.w_dim = w_dim
        self.resolution = resolution
        self.img_channels = img_channels
        self.is_last = is_last
        self.architecture = architecture
        self.use_fp16 = use_fp16
        self.channels_last = (use_fp16 and fp16_channels_last)
        self.register_buffer("resample_filter", upfirdn2d.setup_filter(resample_filter))
        self.num_conv = 0
        self.num_torgb = 0

        if in_channels == 0:
            self.const = torch.nn.Parameter(torch.randn([out_channels, resolution, resolution]))

        if in_channels != 0:
            self.conv0 = SynthesisLayer(in_channels,
                                        out_channels,
                                        w_dim=w_dim,
                                        resolution=resolution,
                                        up=2,
                                        resample_filter=resample_filter,
                                        conv_clamp=conv_clamp,
                                        channels_last=self.channels_last,
                                        **layer_kwargs)
            self.num_conv += 1

        self.conv1 = SynthesisLayer(out_channels,
                                    out_channels,
                                    w_dim=w_dim,
                                    resolution=resolution,
                                    conv_clamp=conv_clamp,
                                    channels_last=self.channels_last,
                                    **layer_kwargs)
        self.num_conv += 1

        if is_last or architecture == "skip":
            self.torgb = ToRGBLayer(out_channels, img_channels, w_dim=w_dim, conv_clamp=conv_clamp, channels_last=self.channels_last)
            self.num_torgb += 1

        if in_channels != 0 and architecture == "resnet":
            self.skip = Conv2dLayer(in_channels,
                                    out_channels,
                                    kernel_size=1,
                                    bias=False,
                                    up=2,
                                    resample_filter=resample_filter,
                                    channels_last=self.channels_last)

    def forward(self, x, img, ws, force_fp32=False, fused_modconv=None, update_emas=False, **layer_kwargs):
        _ = update_emas # unused
        misc.assert_shape(ws, [None, self.num_conv + self.num_torgb, self.w_dim])
        w_iter = iter(ws.unbind(dim=1))
        dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
        memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
        if fused_modconv is None:
            with misc.suppress_tracer_warnings():  # this value will be treated as a constant
                fused_modconv = (not self.training) and (dtype == torch.float32 or int(x.shape[0]) == 1)

        # Input.
        if self.in_channels == 0:
            x = self.const.to(dtype=dtype, memory_format=memory_format)
            x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1])
        else:
            misc.assert_shape(x, [None, self.in_channels, self.resolution // 2, self.resolution // 2])
            x = x.to(dtype=dtype, memory_format=memory_format)

        # Main layers.
        if self.in_channels == 0:
            x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
        elif self.architecture == "resnet":
            y = self.skip(x, gain=np.sqrt(0.5))
            x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
            x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), **layer_kwargs)
            x = y.add_(x)
        else:
            x = self.conv0(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)
            x = self.conv1(x, next(w_iter), fused_modconv=fused_modconv, **layer_kwargs)

        # ToRGB.
        if img is not None:
            misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2])
            img = upfirdn2d.upsample2d(img, self.resample_filter)
        if self.is_last or self.architecture == "skip":
            y = self.torgb(x, next(w_iter), fused_modconv=fused_modconv)
            y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format)
            img = img.add_(y) if img is not None else y

        assert x.dtype == dtype
        assert img is None or img.dtype == torch.float32
        return x, img


class SynthesisNetwork(torch.nn.Module):
    def __init__(
            self,
            w_dim,  # Intermediate latent (W) dimensionality.
            img_resolution,  # Output image resolution.
            img_channels,  # Number of color channels.
            channel_base=32768,  # Overall multiplier for the number of channels.
            channel_max=512,  # Maximum number of channels in any layer.
            num_fp16_res=0,  # Use FP16 for the N highest resolutions.
            **block_kwargs,  # Arguments for SynthesisBlock.
    ):
        assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0
        super().__init__()
        self.w_dim = w_dim
        self.img_resolution = img_resolution
        self.img_resolution_log2 = int(np.log2(img_resolution))
        self.img_channels = img_channels
        self.block_resolutions = [2**i for i in range(2, self.img_resolution_log2 + 1)]
        channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions}
        fp16_resolution = max(2**(self.img_resolution_log2 + 1 - num_fp16_res), 8)

        self.num_ws = 0
        for res in self.block_resolutions:
            in_channels = channels_dict[res // 2] if res > 4 else 0
            out_channels = channels_dict[res]
            use_fp16 = (res >= fp16_resolution)
            is_last = (res == self.img_resolution)
            block = SynthesisBlock(in_channels,
                                   out_channels,
                                   w_dim=w_dim,
                                   resolution=res,
                                   img_channels=img_channels,
                                   is_last=is_last,
                                   use_fp16=use_fp16,
                                   **block_kwargs)
            self.num_ws += block.num_conv
            if is_last:
                self.num_ws += block.num_torgb
            setattr(self, f"b{res}", block)

    def forward(self, ws, **block_kwargs):
        block_ws = []
        misc.assert_shape(ws, [None, self.num_ws, self.w_dim])
        ws = ws.to(torch.float32)
        w_idx = 0
        for res in self.block_resolutions:
            block = getattr(self, f"b{res}")
            block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb))
            w_idx += block.num_conv

        x = img = None
        for res, cur_ws in zip(self.block_resolutions, block_ws):
            block = getattr(self, f"b{res}")
            x, img = block(x, img, cur_ws, **block_kwargs)
        return img


class Generator(torch.nn.Module):
    def __init__(
            self,
            z_dim,  # Input latent (Z) dimensionality.
            c_dim,  # Conditioning label (C) dimensionality.
            w_dim,  # Intermediate latent (W) dimensionality.
            img_resolution,  # Output resolution.
            img_channels,  # Number of output color channels.
            MODEL,  # MODEL config required for applying infoGAN
            mapping_kwargs={},  # Arguments for MappingNetwork.
            synthesis_kwargs={},  # Arguments for SynthesisNetwork.
    ):
        super().__init__()
        self.z_dim = z_dim
        self.c_dim = c_dim
        self.w_dim = w_dim
        self.MODEL = MODEL
        self.img_resolution = img_resolution
        self.img_channels = img_channels

        z_extra_dim = 0
        if self.MODEL.info_type in ["discrete", "both"]:
            z_extra_dim += self.MODEL.info_num_discrete_c*self.MODEL.info_dim_discrete_c
        if self.MODEL.info_type in ["continuous", "both"]:
            z_extra_dim += self.MODEL.info_num_conti_c

        if self.MODEL.info_type != "N/A":
            self.z_dim += z_extra_dim

        self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, **synthesis_kwargs)
        self.num_ws = self.synthesis.num_ws
        self.mapping = MappingNetwork(z_dim=self.z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs)

    def forward(self, z, c, eval=False, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs):
        ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
        img = self.synthesis(ws, update_emas=update_emas, **synthesis_kwargs)
        return img


class DiscriminatorBlock(torch.nn.Module):
    def __init__(
            self,
            in_channels,  # Number of input channels, 0 = first block.
            tmp_channels,  # Number of intermediate channels.
            out_channels,  # Number of output channels.
            resolution,  # Resolution of this block.
            img_channels,  # Number of input color channels.
            first_layer_idx,  # Index of the first layer.
            architecture="resnet",  # Architecture: "orig", "skip", "resnet".
            activation="lrelu",  # Activation function: "relu", "lrelu", etc.
            resample_filter=[1, 3, 3, 1],  # Low-pass filter to apply when resampling activations.
            conv_clamp=None,  # Clamp the output of convolution layers to +-X, None = disable clamping.
            use_fp16=False,  # Use FP16 for this block?
            fp16_channels_last=False,  # Use channels-last memory format with FP16?
            freeze_layers=0,  # Freeze-D: Number of layers to freeze.
    ):
        assert in_channels in [0, tmp_channels]
        assert architecture in ["orig", "skip", "resnet"]
        super().__init__()
        self.in_channels = in_channels
        self.resolution = resolution
        self.img_channels = img_channels
        self.first_layer_idx = first_layer_idx
        self.architecture = architecture
        self.use_fp16 = use_fp16
        self.channels_last = (use_fp16 and fp16_channels_last)
        self.register_buffer("resample_filter", upfirdn2d.setup_filter(resample_filter))

        self.num_layers = 0

        def trainable_gen():
            while True:
                layer_idx = self.first_layer_idx + self.num_layers
                trainable = (layer_idx >= freeze_layers)
                self.num_layers += 1
                yield trainable

        trainable_iter = trainable_gen()

        if in_channels == 0 or architecture == "skip":
            self.fromrgb = Conv2dLayer(img_channels,
                                       tmp_channels,
                                       kernel_size=1,
                                       activation=activation,
                                       trainable=next(trainable_iter),
                                       conv_clamp=conv_clamp,
                                       channels_last=self.channels_last)

        self.conv0 = Conv2dLayer(tmp_channels,
                                 tmp_channels,
                                 kernel_size=3,
                                 activation=activation,
                                 trainable=next(trainable_iter),
                                 conv_clamp=conv_clamp,
                                 channels_last=self.channels_last)

        self.conv1 = Conv2dLayer(tmp_channels,
                                 out_channels,
                                 kernel_size=3,
                                 activation=activation,
                                 down=2,
                                 trainable=next(trainable_iter),
                                 resample_filter=resample_filter,
                                 conv_clamp=conv_clamp,
                                 channels_last=self.channels_last)

        if architecture == "resnet":
            self.skip = Conv2dLayer(tmp_channels,
                                    out_channels,
                                    kernel_size=1,
                                    bias=False,
                                    down=2,
                                    trainable=next(trainable_iter),
                                    resample_filter=resample_filter,
                                    channels_last=self.channels_last)

    def forward(self, x, img, force_fp32=False):
        dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
        memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format

        # Input.
        if x is not None:
            misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution])
            x = x.to(dtype=dtype, memory_format=memory_format)

        # FromRGB.
        if self.in_channels == 0 or self.architecture == "skip":
            misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution])
            img = img.to(dtype=dtype, memory_format=memory_format)
            y = self.fromrgb(img)
            x = x + y if x is not None else y
            img = upfirdn2d.downsample2d(img, self.resample_filter) if self.architecture == "skip" else None

        # Main layers.
        if self.architecture == "resnet":
            y = self.skip(x, gain=np.sqrt(0.5))
            x = self.conv0(x)
            x = self.conv1(x, gain=np.sqrt(0.5))
            x = y.add_(x)
        else:
            x = self.conv0(x)
            x = self.conv1(x)

        assert x.dtype == dtype
        return x, img


class MinibatchStdLayer(torch.nn.Module):
    def __init__(self, group_size, num_channels=1):
        super().__init__()
        self.group_size = group_size
        self.num_channels = num_channels

    def forward(self, x):
        N, C, H, W = x.shape
        with misc.suppress_tracer_warnings():  # as_tensor results are registered as constants
            G = torch.min(torch.as_tensor(self.group_size), torch.as_tensor(N)) if self.group_size is not None else N
        F = self.num_channels
        c = C // F

        y = x.reshape(G, -1, F, c, H, W)  # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c.
        y = y - y.mean(dim=0)  # [GnFcHW] Subtract mean over group.
        y = y.square().mean(dim=0)  # [nFcHW]  Calc variance over group.
        y = (y + 1e-8).sqrt()  # [nFcHW]  Calc stddev over group.
        y = y.mean(dim=[2, 3, 4])  # [nF]     Take average over channels and pixels.
        y = y.reshape(-1, F, 1, 1)  # [nF11]   Add missing dimensions.
        y = y.repeat(G, 1, H, W)  # [NFHW]   Replicate over group and pixels.
        x = torch.cat([x, y], dim=1)  # [NCHW]   Append to input as new channels.
        return x


class DiscriminatorEpilogue(torch.nn.Module):
    def __init__(
            self,
            in_channels,  # Number of input channels.
            cmap_dim,  # Dimensionality of mapped conditioning label, 0 = no label.
            resolution,  # Resolution of this block.
            img_channels,  # Number of input color channels.
            architecture="resnet",  # Architecture: "orig", "skip", "resnet".
            mbstd_group_size=4,  # Group size for the minibatch standard deviation layer, None = entire minibatch.
            mbstd_num_channels=1,  # Number of features for the minibatch standard deviation layer, 0 = disable.
            activation="lrelu",  # Activation function: "relu", "lrelu", etc.
            conv_clamp=None,  # Clamp the output of convolution layers to +-X, None = disable clamping.
    ):
        assert architecture in ["orig", "skip", "resnet"]
        super().__init__()
        self.in_channels = in_channels
        self.cmap_dim = cmap_dim
        self.resolution = resolution
        self.img_channels = img_channels
        self.architecture = architecture

        if architecture == "skip":
            self.fromrgb = Conv2dLayer(img_channels, in_channels, kernel_size=1, activation=activation)
        self.mbstd = MinibatchStdLayer(group_size=mbstd_group_size, num_channels=mbstd_num_channels) if mbstd_num_channels > 0 else None
        self.conv = Conv2dLayer(in_channels + mbstd_num_channels, in_channels, kernel_size=3, activation=activation, conv_clamp=conv_clamp)
        self.fc = FullyConnectedLayer(in_channels * (resolution**2), in_channels, activation=activation)
        # self.out = FullyConnectedLayer(in_channels, 1 if cmap_dim == 0 else cmap_dim)

    def forward(self, x, img, force_fp32=False):
        misc.assert_shape(x, [None, self.in_channels, self.resolution, self.resolution])  # [NCHW]
        _ = force_fp32  # unused
        dtype = torch.float32
        memory_format = torch.contiguous_format

        # FromRGB.
        x = x.to(dtype=dtype, memory_format=memory_format)
        if self.architecture == "skip":
            misc.assert_shape(img, [None, self.img_channels, self.resolution, self.resolution])
            img = img.to(dtype=dtype, memory_format=memory_format)
            x = x + self.fromrgb(img)

        # Main layers.
        if self.mbstd is not None:
            x = self.mbstd(x)
        x = self.conv(x)
        x = self.fc(x.flatten(1))
        # x = self.out(x)

        return x


class Discriminator(torch.nn.Module):
    def __init__(
            self,
            c_dim,  # Conditioning label (C) dimensionality.
            img_resolution,  # Input resolution.
            img_channels,  # Number of input color channels.
            architecture="resnet",  # Architecture: "orig", "skip", "resnet".
            channel_base=32768,  # Overall multiplier for the number of channels.
            channel_max=512,  # Maximum number of channels in any layer.
            num_fp16_res=0,  # Use FP16 for the N highest resolutions.
            conv_clamp=None,  # Clamp the output of convolution layers to +-X, None = disable clamping.
            cmap_dim=None,  # Dimensionality of mapped conditioning label, None = default.
            d_cond_mtd=None,  # conditioning method of the discriminator
            aux_cls_type=None,  # type of auxiliary classifier
            d_embed_dim=None,  # dimension of feature maps after convolution operations
            num_classes=None,  # number of classes
            normalize_d_embed=None,  # whether to normalize the feature maps or not
            block_kwargs={},  # Arguments for DiscriminatorBlock.
            mapping_kwargs={},  # Arguments for MappingNetwork.
            epilogue_kwargs={},  # Arguments for DiscriminatorEpilogue.
            MODEL=None, # needed to check options for infoGAN
    ):
        super().__init__()
        self.c_dim = c_dim
        self.img_resolution = img_resolution
        self.img_channels = img_channels
        self.cmap_dim = cmap_dim
        self.d_cond_mtd = d_cond_mtd
        self.aux_cls_type = aux_cls_type
        self.num_classes = num_classes
        self.normalize_d_embed = normalize_d_embed
        self.img_resolution_log2 = int(np.log2(img_resolution))
        self.block_resolutions = [2**i for i in range(self.img_resolution_log2, 2, -1)]
        self.MODEL = MODEL
        channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]}
        fp16_resolution = max(2**(self.img_resolution_log2 + 1 - num_fp16_res), 8)

        if self.cmap_dim is None:
            self.cmap_dim = channels_dict[4]
        if c_dim == 0:
            self.cmap_dim = 0

        common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp)
        cur_layer_idx = 0
        for res in self.block_resolutions:
            in_channels = channels_dict[res] if res < img_resolution else 0
            tmp_channels = channels_dict[res]
            out_channels = channels_dict[res // 2]
            use_fp16 = (res >= fp16_resolution)
            block = DiscriminatorBlock(in_channels,
                                       tmp_channels,
                                       out_channels,
                                       resolution=res,
                                       first_layer_idx=cur_layer_idx,
                                       use_fp16=use_fp16,
                                       **block_kwargs,
                                       **common_kwargs)
            setattr(self, f"b{res}", block)
            cur_layer_idx += block.num_layers

        self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=self.cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs)

       # linear layer for adversarial training
        if self.d_cond_mtd == "MH":
            self.linear1 = FullyConnectedLayer(channels_dict[4], 1 + self.num_classes, bias=True)
        elif self.d_cond_mtd == "MD":
            self.linear1 = FullyConnectedLayer(channels_dict[4], self.num_classes, bias=True)
        elif self.d_cond_mtd == "SPD":
            self.linear1 = FullyConnectedLayer(channels_dict[4], 1 if self.cmap_dim == 0 else self.cmap_dim, bias=True)
        else:
            self.linear1 = FullyConnectedLayer(channels_dict[4], 1, bias=True)

        # double num_classes for Auxiliary Discriminative Classifier
        if self.aux_cls_type == "ADC" or self.aux_cls_type == "AOVA" or self.aux_cls_type == "TADC":
            num_classes, c_dim = num_classes * 2, c_dim * 2

# ======================================================
        if self.aux_cls_type == "IMA":
            num_classes, c_dim = num_classes + 1, c_dim + 1
            
# nn.Embedding

        # linear and embedding layers for discriminator conditioning
        if self.d_cond_mtd == "AC":
            self.linear2 = FullyConnectedLayer(channels_dict[4], num_classes, bias=False)
        elif self.d_cond_mtd == "PD":
            # self.linear2 = FullyConnectedLayer(channels_dict[4], self.cmap_dim, bias=True)
            
            self.embedding = nn.Embedding(num_classes, channels_dict[4])
            
        elif self.d_cond_mtd == "SPD":
            self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=self.cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs)
        elif self.d_cond_mtd in ["2C", "D2DCE"]:
            self.linear2 = FullyConnectedLayer(channels_dict[4], d_embed_dim, bias=True)
            self.embedding = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=d_embed_dim, num_ws=None, w_avg_beta=None, num_layers=1, **mapping_kwargs)
        else:
            pass

        # linear and embedding layers for evolved classifier-based GAN
        if self.aux_cls_type == "TAC":
            if self.d_cond_mtd == "AC":
                self.linear_mi = FullyConnectedLayer(channels_dict[4], num_classes, bias=False)
            elif self.d_cond_mtd in ["2C", "D2DCE"]:
                self.linear_mi = FullyConnectedLayer(channels_dict[4], d_embed_dim, bias=True)
                self.embedding_mi = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=d_embed_dim, num_ws=None, w_avg_beta=None, num_layers=1, **mapping_kwargs)
            else:
                raise NotImplementedError

        # Q head network for infoGAN
        if self.MODEL.info_type in ["discrete", "both"]:
            out_features = self.MODEL.info_num_discrete_c*self.MODEL.info_dim_discrete_c
            self.info_discrete_linear = FullyConnectedLayer(in_features=channels_dict[4], out_features=out_features, bias=False)
        if self.MODEL.info_type in ["continuous", "both"]:
            out_features = self.MODEL.info_num_conti_c
            self.info_conti_mu_linear = FullyConnectedLayer(in_features=channels_dict[4], out_features=out_features, bias=False)
            self.info_conti_var_linear = FullyConnectedLayer(in_features=channels_dict[4], out_features=out_features, bias=False)

    def forward(self, img, label, eval=False, adc_fake=False, update_emas=False, **block_kwargs):
        _ = update_emas # unused
        x, embed, proxy, cls_output = None, None, None, None
        mi_embed, mi_proxy, mi_cls_output = None, None, None
        info_discrete_c_logits, info_conti_mu, info_conti_var = None, None, None
        for res in self.block_resolutions:
            block = getattr(self, f"b{res}")
            x, img = block(x, img, **block_kwargs)
        h = self.b4(x, img)

        # adversarial training
        if self.d_cond_mtd != "SPD":
            adv_output = torch.squeeze(self.linear1(h))

        # make class labels odd (for fake) or even (for real) for ADC
        if self.aux_cls_type == "ADC":
            if adc_fake:
                label = label*2 + 1
            else:
                label = label*2
                
        
        # OOD label
        if self.aux_cls_type == "IMA":
            if adc_fake:
                label = label + self.num_classes - label
            else:
                label = label
        
        # oh_label = F.one_hot(label, self.num_classes * 2 if self.aux_cls_type=="ADC" else self.num_classes)

        if self.aux_cls_type=="ADC" or self.aux_cls_type == "TADC":
            oh_label = F.one_hot(label, self.num_classes * 2)
        elif self.aux_cls_type=="IMA":
            oh_label = F.one_hot(label, self.num_classes + 1)
        else:
            oh_label = F.one_hot(label, self.num_classes)
            


        # forward pass through InfoGAN Q head
        if self.MODEL.info_type in ["discrete", "both"]:
            info_discrete_c_logits = self.info_discrete_linear(h)
        if self.MODEL.info_type in ["continuous", "both"]:
            info_conti_mu = self.info_conti_mu_linear(h)
            info_conti_var = torch.exp(self.info_conti_var_linear(h))

        # class conditioning
        if self.d_cond_mtd == "AC":
            if self.normalize_d_embed:
                for W in self.linear2.parameters():
                    W = F.normalize(W, dim=1)
                h = F.normalize(h, dim=1)
            cls_output = self.linear2(h)
        elif self.d_cond_mtd == "PD":
            
            # embed = self.linear2(h)
            # proxy = self.embedding(None, oh_label)
            
            # adv_output = adv_output + torch.sum(torch.mul(self.embedding(None, oh_label), h), 1)
            
            adv_output = adv_output + torch.sum(torch.mul(self.embedding(label), h), 1)
            
        elif self.d_cond_mtd == "SPD":
            embed = self.linear1(h)
            cmap = self.mapping(None, oh_label)
            adv_output = (embed * cmap).sum(dim=1, keepdim=True) * (1 / np.sqrt(self.cmap_dim))
        elif self.d_cond_mtd in ["2C", "D2DCE"]:
            embed = self.linear2(h)
            proxy = self.embedding(None, oh_label)
            if self.normalize_d_embed:
                embed = F.normalize(embed, dim=1)
                proxy = F.normalize(proxy, dim=1)
        elif self.d_cond_mtd == "MD":
            idx = torch.LongTensor(range(label.size(0))).to(label.device)
            adv_output = adv_output[idx, label]
        elif self.d_cond_mtd in ["W/O", "MH"]:
            pass
        else:
            raise NotImplementedError

        # extra conditioning for TACGAN and ADCGAN
        if self.aux_cls_type == "TAC":
            if self.d_cond_mtd == "AC":
                if self.normalize_d_embed:
                    for W in self.linear_mi.parameters():
                        W = F.normalize(W, dim=1)
                mi_cls_output = self.linear_mi(h)
            elif self.d_cond_mtd in ["2C", "D2DCE"]:
                mi_embed = self.linear_mi(h)
                mi_proxy = self.embedding_mi(None, oh_label)
                if self.normalize_d_embed:
                    mi_embed = F.normalize(mi_embed, dim=1)
                    mi_proxy = F.normalize(mi_proxy, dim=1)
        return {
            "h": h,
            "adv_output": adv_output,
            "embed": embed,
            "proxy": proxy,
            "cls_output": cls_output,
            "label": label,
            "mi_embed": mi_embed,
            "mi_proxy": mi_proxy,
            "mi_cls_output": mi_cls_output,
            "info_discrete_c_logits": info_discrete_c_logits,
            "info_conti_mu": info_conti_mu,
            "info_conti_var": info_conti_var
        }
