# PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN
# The MIT License (MIT)
# See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details

# models/big_resnet.py

import torch
import torch.nn as nn
import torch.nn.functional as F

import utils.ops as ops
import utils.misc as misc

from torch.autograd import Variable


class GenBlock(nn.Module):
    def __init__(self, in_channels, out_channels, g_cond_mtd, affine_input_dim, MODULES):
        super(GenBlock, self).__init__()
        self.g_cond_mtd = g_cond_mtd

        self.bn1 = MODULES.g_bn(affine_input_dim, in_channels, MODULES)
        self.bn2 = MODULES.g_bn(affine_input_dim, out_channels, MODULES)

        self.activation = MODULES.g_act_fn
        self.conv2d0 = MODULES.g_conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
        self.conv2d1 = MODULES.g_conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
        self.conv2d2 = MODULES.g_conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)

    def forward(self, x, affine):
        x0 = x
        x = self.bn1(x, affine)
        x = self.activation(x)
        x = F.interpolate(x, scale_factor=2, mode="nearest")
        x = self.conv2d1(x)

        x = self.bn2(x, affine)
        x = self.activation(x)
        x = self.conv2d2(x)

        x0 = F.interpolate(x0, scale_factor=2, mode="nearest")
        x0 = self.conv2d0(x0)
        out = x + x0
        return out


class Generator(nn.Module):
    def __init__(self, z_dim, g_shared_dim, img_size, g_conv_dim, apply_attn, attn_g_loc, g_cond_mtd, num_classes, g_init, g_depth,
                 mixed_precision, MODULES, MODEL):
        super(Generator, self).__init__()
        g_in_dims_collection = {
            "32": [g_conv_dim * 4, g_conv_dim * 4, g_conv_dim * 4],
            "64": [g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2],
            "128": [g_conv_dim * 16, g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2],
            "256": [g_conv_dim * 16, g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2],
            "512": [g_conv_dim * 16, g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim]
        }

        g_out_dims_collection = {
            "32": [g_conv_dim * 4, g_conv_dim * 4, g_conv_dim * 4],
            "64": [g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim],
            "128": [g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim],
            "256": [g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim],
            "512": [g_conv_dim * 16, g_conv_dim * 8, g_conv_dim * 8, g_conv_dim * 4, g_conv_dim * 2, g_conv_dim, g_conv_dim]
        }

        bottom_collection = {"32": 4, "64": 4, "128": 4, "256": 4, "512": 4}

        self.z_dim = z_dim
        self.g_shared_dim = g_shared_dim
        self.g_cond_mtd = g_cond_mtd
        self.num_classes = num_classes
        self.mixed_precision = mixed_precision
        self.MODEL = MODEL
        self.in_dims = g_in_dims_collection[str(img_size)]
        self.out_dims = g_out_dims_collection[str(img_size)]
        self.bottom = bottom_collection[str(img_size)]
        self.num_blocks = len(self.in_dims)
        self.chunk_size = z_dim // (self.num_blocks + 1)
        self.affine_input_dim = self.chunk_size
        assert self.z_dim % (self.num_blocks + 1) == 0, "z_dim should be divided by the number of blocks"

        info_dim = 0
        if self.MODEL.info_type in ["discrete", "both"]:
            info_dim += self.MODEL.info_num_discrete_c*self.MODEL.info_dim_discrete_c
        if self.MODEL.info_type in ["continuous", "both"]:
            info_dim += self.MODEL.info_num_conti_c

        if self.MODEL.info_type != "N/A":
            if self.MODEL.g_info_injection == "concat":
                self.info_mix_linear = MODULES.g_linear(in_features=self.z_dim + info_dim, out_features=self.z_dim, bias=True)
            elif self.MODEL.g_info_injection == "cBN":
                self.affine_input_dim += self.g_shared_dim
                self.info_proj_linear = MODULES.g_linear(in_features=info_dim, out_features=self.g_shared_dim, bias=True)

        self.linear0 = MODULES.g_linear(in_features=self.chunk_size, out_features=self.in_dims[0]*self.bottom*self.bottom, bias=True)

        if self.g_cond_mtd != "W/O":
            self.affine_input_dim += self.g_shared_dim
            self.shared = ops.embedding(num_embeddings=self.num_classes, embedding_dim=self.g_shared_dim)

        self.blocks = []
        for index in range(self.num_blocks):
            self.blocks += [[
                GenBlock(in_channels=self.in_dims[index],
                         out_channels=self.out_dims[index],
                         g_cond_mtd=self.g_cond_mtd,
                         affine_input_dim=self.affine_input_dim,
                         MODULES=MODULES)
            ]]

            if index + 1 in attn_g_loc and apply_attn:
                self.blocks += [[ops.SelfAttention(self.out_dims[index], is_generator=True, MODULES=MODULES)]]

        self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])

        self.bn4 = ops.batchnorm_2d(in_features=self.out_dims[-1])
        self.activation = MODULES.g_act_fn
        self.conv2d5 = MODULES.g_conv2d(in_channels=self.out_dims[-1], out_channels=3, kernel_size=3, stride=1, padding=1)
        self.tanh = nn.Tanh()

        ops.init_weights(self.modules, g_init)

    def forward(self, z, label, shared_label=None, eval=False):
        affine_list = []
        with torch.cuda.amp.autocast() if self.mixed_precision and not eval else misc.dummy_context_mgr() as mp:
            if self.MODEL.info_type != "N/A":
                if self.MODEL.g_info_injection == "concat":
                    z = self.info_mix_linear(z)
                elif self.MODEL.g_info_injection == "cBN":
                    z, z_info = z[:, :self.z_dim], z[:, self.z_dim:]
                    affine_list.append(self.info_proj_linear(z_info))

            zs = torch.split(z, self.chunk_size, 1)
            z = zs[0]
            if self.g_cond_mtd != "W/O":
                if shared_label is None:
                    shared_label = self.shared(label)
                affine_list.append(shared_label)
            if len(affine_list) == 0:
                affines = [item for item in zs[1:]]
            else:
                affines = [torch.cat(affine_list + [item], 1) for item in zs[1:]]

            act = self.linear0(z)
            act = act.view(-1, self.in_dims[0], self.bottom, self.bottom)
            counter = 0
            for index, blocklist in enumerate(self.blocks):
                for block in blocklist:
                    if isinstance(block, ops.SelfAttention):
                        act = block(act)
                    else:
                        act = block(act, affines[counter])
                        counter += 1

            act = self.bn4(act)
            act = self.activation(act)
            act = self.conv2d5(act)
            out = self.tanh(act)
        return out


class DiscOptBlock(nn.Module):
    def __init__(self, in_channels, out_channels, apply_d_sn, MODULES):
        super(DiscOptBlock, self).__init__()
        self.apply_d_sn = apply_d_sn

        self.conv2d0 = MODULES.d_conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
        self.conv2d1 = MODULES.d_conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
        self.conv2d2 = MODULES.d_conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)

        if not apply_d_sn:
            self.bn0 = MODULES.d_bn(in_features=in_channels)
            self.bn1 = MODULES.d_bn(in_features=out_channels)

        self.activation = MODULES.d_act_fn
        self.average_pooling = nn.AvgPool2d(2)

    def forward(self, x):
        x0 = x
        x = self.conv2d1(x)
        if not self.apply_d_sn:
            x = self.bn1(x)
        x = self.activation(x)

        x = self.conv2d2(x)
        x = self.average_pooling(x)

        x0 = self.average_pooling(x0)
        if not self.apply_d_sn:
            x0 = self.bn0(x0)
        x0 = self.conv2d0(x0)
        out = x + x0
        return out


class DiscBlock(nn.Module):
    def __init__(self, in_channels, out_channels, apply_d_sn, MODULES, downsample=True):
        super(DiscBlock, self).__init__()
        self.apply_d_sn = apply_d_sn
        self.downsample = downsample

        self.activation = MODULES.d_act_fn

        self.ch_mismatch = False
        if in_channels != out_channels:
            self.ch_mismatch = True

        if self.ch_mismatch or downsample:
            self.conv2d0 = MODULES.d_conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0)
            if not apply_d_sn:
                self.bn0 = MODULES.d_bn(in_features=in_channels)

        self.conv2d1 = MODULES.d_conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
        self.conv2d2 = MODULES.d_conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)

        if not apply_d_sn:
            self.bn1 = MODULES.d_bn(in_features=in_channels)
            self.bn2 = MODULES.d_bn(in_features=out_channels)

        self.average_pooling = nn.AvgPool2d(2)

    def forward(self, x):
        x0 = x
        if not self.apply_d_sn:
            x = self.bn1(x)
        x = self.activation(x)
        x = self.conv2d1(x)

        if not self.apply_d_sn:
            x = self.bn2(x)
        x = self.activation(x)
        x = self.conv2d2(x)
        if self.downsample:
            x = self.average_pooling(x)

        if self.downsample or self.ch_mismatch:
            if not self.apply_d_sn:
                x0 = self.bn0(x0)
            x0 = self.conv2d0(x0)
            if self.downsample:
                x0 = self.average_pooling(x0)
        out = x + x0
        return out


class Discriminator(nn.Module):
    def __init__(self, img_size, d_conv_dim, apply_d_sn, apply_attn, attn_d_loc, d_cond_mtd, aux_cls_type, d_adv_mtd, normalize_adv_embed, class_adv_model, class_center, ETF_fc, d_embed_dim, normalize_d_embed,
                 num_classes, d_init, d_depth, mixed_precision, MODULES, MODEL):
        super(Discriminator, self).__init__()
        d_in_dims_collection = {
            "32": [3] + [d_conv_dim * 2, d_conv_dim * 2, d_conv_dim * 2],
            "64": [3] + [d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8],
            "128": [3] + [d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 16],
            "256": [3] + [d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 8, d_conv_dim * 16],
            "512": [3] + [d_conv_dim, d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 8, d_conv_dim * 16]
        }

        d_out_dims_collection = {
            "32": [d_conv_dim * 2, d_conv_dim * 2, d_conv_dim * 2, d_conv_dim * 2],
            "64": [d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 16],
            "128": [d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 16, d_conv_dim * 16],
            "256": [d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 8, d_conv_dim * 16, d_conv_dim * 16],
            "512":
            [d_conv_dim, d_conv_dim, d_conv_dim * 2, d_conv_dim * 4, d_conv_dim * 8, d_conv_dim * 8, d_conv_dim * 16, d_conv_dim * 16]
        }

        d_down = {
            "32": [True, True, False, False],
            "64": [True, True, True, True, False],
            "128": [True, True, True, True, True, False],
            "256": [True, True, True, True, True, True, False],
            "512": [True, True, True, True, True, True, True, False]
        }

        self.d_cond_mtd = d_cond_mtd
        self.aux_cls_type = aux_cls_type

# -------------------------------------------------------------------
        self.d_adv_mtd = d_adv_mtd
        self.class_adv_model = class_adv_model
        self.class_center = class_center
        self.ETF_fc = ETF_fc
        self.normalize_adv_embed = normalize_adv_embed

        self.normalize_d_embed = normalize_d_embed
        self.num_classes = num_classes
        self.mixed_precision = mixed_precision
        self.in_dims = d_in_dims_collection[str(img_size)]
        self.out_dims = d_out_dims_collection[str(img_size)]
        self.MODEL = MODEL
        down = d_down[str(img_size)]

        self.blocks = []
        for index in range(len(self.in_dims)):
            if index == 0:
                self.blocks += [[
                    DiscOptBlock(in_channels=self.in_dims[index], out_channels=self.out_dims[index], apply_d_sn=apply_d_sn, MODULES=MODULES)
                ]]
            else:
                self.blocks += [[
                    DiscBlock(in_channels=self.in_dims[index],
                              out_channels=self.out_dims[index],
                              apply_d_sn=apply_d_sn,
                              MODULES=MODULES,
                              downsample=down[index])
                ]]

            if index + 1 in attn_d_loc and apply_attn:
                self.blocks += [[ops.SelfAttention(self.out_dims[index], is_generator=False, MODULES=MODULES)]]

        self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])

        self.activation = MODULES.d_act_fn

        # linear layer for adversarial training
        if self.d_cond_mtd == "MH":
            self.linear1 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=1 + num_classes, bias=True)
        elif self.d_cond_mtd == "MD":
            self.linear1 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=num_classes, bias=True)

        # elif self.d_adv_mtd == "SOFTMAX":

        #         self.linear1 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=2, bias=False)

        #         # ---------------- hinge + softmax ------------------------------
        #         # self.linear_s = MODULES.d_linear(in_features=self.out_dims[-1], out_features=1, bias=False)
        #         # self.linear1 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=2, bias=False)
            
        #         # if self.ETF_fc:
        #         #     for m in self.linear1.parameters():

        #         #         weight = torch.sqrt(torch.tensor(2/(2-1)))*(torch.eye(2)-(1/2)*torch.ones((2, 2)))
        #         #         weight /= torch.sqrt((1/2*torch.norm(weight, 'fro')**2)) #seems no use
        #         #         m.weight = nn.Parameter(torch.mm(weight, torch.eye(2, self.out_dims[-1])))
        #         #         m.weight.requires_grad_(False)
            
        else:
            self.linear1 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=1, bias=True)

        # double num_classes for Auxiliary Discriminative Classifier
        if self.aux_cls_type == "ADC":
            num_classes = num_classes * 2

# ======================================================
        if self.aux_cls_type == "IMA":
            num_classes = num_classes + 1

        # if self.aux_cls_type == "LM":
        #     num_classes = num_classes * 2

# ======================================================

        # linear and embedding layers for discriminator conditioning
        if self.d_cond_mtd == "AC":


            # ========================================================

            self.linear2 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=num_classes, bias=False)
            
            # if self.aux_cls_type == "IADC":
            #     # self.linear5 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=num_classes // 2 + 1, bias=False)
            #     self.linear5 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=num_classes // 2, bias=False)
            
            # if self.aux_cls_type == "TIMA":
            #     # self.linear5 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=num_classes // 2 + 1, bias=False)
            #     self.linear5 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=num_classes, bias=False)


            # =================================================
            
            if self.class_adv_model:
                # self.linear3 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=num_classes, bias=True)
                self.linear3 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=num_classes, bias=False)
                # if self.aux_cls_type == "ADC":
                #     self.linear3 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=num_classes // 2, bias=False)

                self.linear4 = MODULES.d_linear(in_features=num_classes, out_features=1, bias=False)
        
            # if self.ETF_fc:
            # #     # https://github1s.com/tding1/Neural-Collapse/blob/HEAD/models/resnet.py
                
            #     lw = self.linear2.weight
            #     weight = torch.sqrt(torch.tensor(num_classes/(num_classes-1)))*(torch.eye(num_classes)-(1/num_classes)*torch.ones((num_classes, num_classes)))
            #     weight /= torch.sqrt((1/num_classes*torch.norm(weight, 'fro')**2)) #seems no use
            #     lw = nn.Parameter(torch.mm(weight, torch.eye(num_classes, self.out_dims[-1])))
            #     lw.requires_grad_(False)
            
 
            # self.iso_classifier = IsoMaxLossFirstPart(self.out_dims[-1], num_classes)

            # =================================================


        elif self.d_cond_mtd == "PD":
            self.embedding = MODULES.d_embedding(num_classes, self.out_dims[-1])
            
            # =================================================
            
        # elif self.d_cond_mtd == "EPD":
        #     # https://github1s.com/tding1/Neural-Collapse/blob/HEAD/models/resnet.py
            
        #     etf_weight = torch.sqrt(torch.tensor(num_classes/(num_classes-1)))*(torch.eye(num_classes)-(1/num_classes)*torch.ones((num_classes, num_classes)))
        #     etf_weight /= torch.sqrt((1/num_classes*torch.norm(etf_weight, 'fro')**2)) #seems no use
        #     etf_matrix = torch.mm(etf_weight, torch.eye(num_classes, self.out_dims[-1]))
    
        #     self.embedding = MODULES.d_embedding(num_classes, self.out_dims[-1])
            
        #     self.embedding.weight.data.copy_(etf_matrix)
        #     self.embedding.weight.requires_grad = False
                
            # =================================================
                
        elif self.d_cond_mtd in ["2C", "D2DCE"]:
            self.linear2 = MODULES.d_linear(in_features=self.out_dims[-1], out_features=d_embed_dim, bias=True)
            self.embedding = MODULES.d_embedding(num_classes, d_embed_dim)     
        else:
            pass

        # linear and embedding layers for evolved classifier-based GAN
        if self.aux_cls_type == "TAC":
            if self.d_cond_mtd == "AC":
                self.linear_mi = MODULES.d_linear(in_features=self.out_dims[-1], out_features=num_classes, bias=False)
            elif self.d_cond_mtd in ["2C", "D2DCE"]:
                self.linear_mi = MODULES.d_linear(in_features=self.out_dims[-1], out_features=d_embed_dim, bias=True)
                self.embedding_mi = MODULES.d_embedding(num_classes, d_embed_dim)
            else:
                raise NotImplementedError      

        # Q head network for infoGAN
        if self.MODEL.info_type in ["discrete", "both"]:
            out_features = self.MODEL.info_num_discrete_c*self.MODEL.info_dim_discrete_c
            self.info_discrete_linear = MODULES.d_linear(in_features=self.out_dims[-1], out_features=out_features, bias=False)
        if self.MODEL.info_type in ["continuous", "both"]:
            out_features = self.MODEL.info_num_conti_c
            self.info_conti_mu_linear = MODULES.d_linear(in_features=self.out_dims[-1], out_features=out_features, bias=False)
            self.info_conti_var_linear = MODULES.d_linear(in_features=self.out_dims[-1], out_features=out_features, bias=False)

        if d_init:
            ops.init_weights(self.modules, d_init)

    def forward(self, x, label, eval=False, adc_fake=False):
        with torch.cuda.amp.autocast() if self.mixed_precision and not eval else misc.dummy_context_mgr() as mp:
            embed, proxy, cls_output = None, None, None
            mi_embed, mi_proxy, mi_cls_output = None, None, None

            # ================================
            iadc_cls_output = None
            cls_adv_output = None
            adv_label = None
            cos_dict = None
            adv_output_s = None
            
            info_discrete_c_logits, info_conti_mu, info_conti_var = None, None, None
            h = x
            for index, blocklist in enumerate(self.blocks):
                for block in blocklist:
                    h = block(h)
            bottom_h, bottom_w = h.shape[2], h.shape[3]
            h = self.activation(h)
            h = torch.sum(h, dim=[2, 3]) 

            # =============================================================
            if self.d_adv_mtd == "SOFTMAX":

                LongTensor = torch.cuda.LongTensor

                if adc_fake == 'fake':
                    adv_label = Variable(LongTensor(x.size(0), 1).fill_(0), requires_grad=False) #[64, 1]
                    adv_label = adv_label.view(-1,)  #[64]
                # elif adc_fake == 'rf':  #mixup or other z
                #     adv_label = Variable(LongTensor(x.size(0), 1).fill_(2), requires_grad=False)
                #     adv_label = adv_label.view(-1,)
                else:
                    adv_label = Variable(LongTensor(x.size(0), 1).fill_(1), requires_grad=False)
                    adv_label = adv_label.view(-1,)

                # -----------------------------------------------------------
                # if adc_fake == 'fake':
                #     adv_label = label + self.num_classes // 2 - label
                # else:
                #     adv_label = label

            if self.aux_cls_type == "IADC":

                #------ IM ---------------------
                # if adc_fake == 'fake':
                #     adv_label = label + self.num_classes // 2 - label
                # else:
                #     adv_label = label

                #------ TAC ---------------------

                adv_label = label

                if adc_fake == 'fake':
                    label = label*2 + 1
                else:
                    label = label*2



            # make class labels odd (for fake) or even (for real) for ADC
            if self.aux_cls_type == "ADC":
                if adc_fake:
                    label = label*2 + 1
                else:
                    label = label*2

            # =================================================
            if self.aux_cls_type == "LM":
                label = label*2
            

            # OOD label
            if self.aux_cls_type == "IMA":
                if adc_fake:
                    label = label + self.num_classes - label
                else:
                    label = label

            # OOD label
            if self.aux_cls_type == "TIMA":
                if adc_fake == 'fake':
                    label = label + self.num_classes - label
                else:
                    label = label

                if adc_fake == 'real':
                    adv_label = label + self.num_classes - label
                else:
                    adv_label = label    
                    # self.linear_ood1
            
    
            # =================================================
            
            # forward pass through InfoGAN Q head
            if self.MODEL.info_type in ["discrete", "both"]:
                info_discrete_c_logits = self.info_discrete_linear(h/(bottom_h*bottom_w))
            if self.MODEL.info_type in ["continuous", "both"]:
                info_conti_mu = self.info_conti_mu_linear(h/(bottom_h*bottom_w))
                info_conti_var = torch.exp(self.info_conti_var_linear(h/(bottom_h*bottom_w)))

            if self.d_adv_mtd == "SOFTMAX":

                adv_cos = torch.cosine_similarity(self.linear1.weight[0], self.linear1.weight[1], dim=0)
                fake_cls_cos = torch.cosine_similarity(self.linear1.weight[0], self.linear2.weight[0], dim=0)
                real_cls_cos = torch.cosine_similarity(self.linear1.weight[1], self.linear2.weight[0], dim=0)
                cls_cos = torch.cosine_similarity(self.linear2.weight[1], self.linear2.weight[0], dim=0)

                adv_norm0 = self.linear1.weight.norm(2, dim=1)[0]
                adv_norm1 = self.linear1.weight.norm(2, dim=1)[1]

                cos_dict = {
                            'adv_cos': adv_cos,
                            'fake_cls_cos': fake_cls_cos,
                            'real_cls_cos': real_cls_cos,
                            'cls_cos': cls_cos,
                            'adv_norm0': adv_norm0,
                            'adv_norm1': adv_norm1
                            }

                # if self.ETF_fc:
                if self.normalize_adv_embed and self.normalize_d_embed:

                    for W1 in self.linear1.parameters():
                        W1 = F.normalize(W1, dim=1)

                    if not self.ETF_fc:
                        for W2 in self.linear2.parameters():
                            W2 = F.normalize(W2, dim=1)
                    h = F.normalize(h, dim=1)

                    adv_output = torch.squeeze(self.linear1(h))

                elif self.normalize_adv_embed and not self.normalize_d_embed:
                
                    for W1 in self.linear1.parameters():
                        W1 = F.normalize(W1, dim=1)
                    h1 = F.normalize(h, dim=1)
                    adv_output = torch.squeeze(self.linear1(h1))
                else:
                    adv_output = torch.squeeze(self.linear1(h))

                    adv_output_s = torch.squeeze(self.linear1(h))

            
            else:
            # adversarial training

                adv_output = torch.squeeze(self.linear1(h))
                # print(adv_output)
                # print('--------------------------')
                # print(self.linear1(h))

            # class conditioning
            if self.d_cond_mtd == "AC":
                 
                if self.aux_cls_type == "IADC" or self.aux_cls_type == "TIMA":

                    if self.normalize_d_embed:
                        if not self.ETF_fc:
                            for W5 in self.linear5.parameters():
                                W5 = F.normalize(W5, dim=1)

                        if not self.normalize_adv_embed:                    
                            h = F.normalize(h, dim=1)

                    iadc_cls_output = self.linear5(h)
                
                else:
                    if self.normalize_d_embed:
                        if not self.ETF_fc:
                            for W in self.linear2.parameters():
                                W = F.normalize(W, dim=1)

                        if not self.normalize_adv_embed:                    
                            h = F.normalize(h, dim=1)


                # -------------------------------------
                

                cls_output = self.linear2(h)


                if self.d_adv_mtd == "IsoMax":
                    
                    cls_output = self.iso_classifier(h, self.linear2.weight)

                if self.d_adv_mtd == "ECGAN":
                    adv_output_s = torch.logsumexp(cls_output, dim=1)

                    # label_d = torch.unsqueeze(label, 1)
                    
                    # adv_d = cls_output.gather(1, label_d)
                    # adv_output_s = torch.log(torch.sum(torch.exp(cls_output), dim=1)-torch.exp(adv_d))

                
                # if self.aux_cls_type == "ADC":
                #     # if adc_fake == 'fake':
                #     #     adv_output_s = torch.logsumexp(cls_output[1::2], dim=1)
                #     # else:
                #     adv_output_s = torch.logsumexp(cls_output[::2], dim=1)


                # if self.d_adv_mtd == "ECGAN" and self.aux_cls_type == "IMA":

                #     adv_output_s = torch.logsumexp(cls_output[:, :-1], dim=1)
                

                    # if self.LOSS.class_adv == 'double_adv':
                    #     ## raw ecgan
                    #     adv_output = adv_output_s



                # =================================================
                if self.class_adv_model:
                    # ------------  new  -------------
                    # cls_adv_output = self.linear3(h)
                    # cls_adv_output = cls_output

                    # ----------- no use --------------
                    cls_adv_output = self.linear4(cls_output)

                    # ----------- no use --------------

                    # adv_output = cls_output[:, -1]


                if self.class_center and not self.normalize_d_embed:
                    h = F.normalize(h, dim=1)       
                # =================================================
           
            elif self.d_cond_mtd == "PD":
                adv_output = adv_output + torch.sum(torch.mul(self.embedding(label), h), 1)
                
                # print('++++++++++++',adv_output.shape)
                # adv_ss = self.linear1(h)+torch.sum(torch.mul(self.embedding(label), h), 1)
                # print('=============',adv_ss.shape)
                # adv_output = torch.sum(torch.mul(self.embedding(label), h), 1)
                
            # elif self.d_cond_mtd == "MPD" or self.d_cond_mtd == "EPD":
            #     adv_output_s = torch.sum(torch.mul(self.embedding(label), h), 1)
            elif self.d_cond_mtd in ["2C", "D2DCE"]:
                embed = self.linear2(h)
                proxy = self.embedding(label)
                if self.normalize_d_embed:
                    embed = F.normalize(embed, dim=1)
                    proxy = F.normalize(proxy, dim=1)
            elif self.d_cond_mtd == "MD":
                idx = torch.LongTensor(range(label.size(0))).to(label.device)
                adv_output = adv_output[idx, label]
            elif self.d_cond_mtd in ["W/O", "MH"]:
                pass


            else:
                raise NotImplementedError

            # extra conditioning for TACGAN and ADCGAN
            if self.aux_cls_type == "TAC":
                if self.d_cond_mtd == "AC":
                    if self.normalize_d_embed:
                        for W in self.linear_mi.parameters():
                            W = F.normalize(W, dim=1)
                    mi_cls_output = self.linear_mi(h)
                elif self.d_cond_mtd in ["2C", "D2DCE"]:
                    mi_embed = self.linear_mi(h)
                    mi_proxy = self.embedding_mi(label)
                    if self.normalize_d_embed:
                        mi_embed = F.normalize(mi_embed, dim=1)
                        mi_proxy = F.normalize(mi_proxy, dim=1)
        return {
            "h": h,
            "adv_output": adv_output,

            "adv_output_s": adv_output_s,
            "adv_label": adv_label,
            "iadc_cls_output": iadc_cls_output, 
            "cls_adv_output": cls_adv_output,
            "cos_dict": cos_dict,

            "embed": embed,
            "proxy": proxy,
            "cls_output": cls_output,
            "label": label,
            "mi_embed": mi_embed,
            "mi_proxy": mi_proxy,
            "mi_cls_output": mi_cls_output,
            "info_discrete_c_logits": info_discrete_c_logits,
            "info_conti_mu": info_conti_mu,
            "info_conti_var": info_conti_var
        }

class IsoMaxLossFirstPart(nn.Module):
    """This part replaces the model classifier output layer nn.Linear()"""
    def __init__(self, num_features, num_classes, temperature=1.0):
        super(IsoMaxLossFirstPart, self).__init__()
        self.num_features = num_features
        self.num_classes = num_classes
        self.temperature = temperature    

        # self.prototypes = nn.Parameter(torch.Tensor(num_classes, num_features).cuda())
        # nn.init.constant_(self.prototypes, 0.0)

    def forward(self, features, prototypes):
        self.prototypes = prototypes
        distances = torch.cdist(features, self.prototypes, p=2.0, compute_mode="donot_use_mm_for_euclid_dist")
        logits = -distances
        # The temperature may be calibrated after training to improve uncertainty estimation.
        return logits / self.temperature