import functools
import torch.nn as nn
from taming.modules.util import ActNorm
import torch

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('GroupNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

def Normalize(in_channels, num_groups=32):
    return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)

class NLayerDiscriminator(nn.Module):
    """Defines a 3D PatchGAN discriminator"""
    def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
        """Construct a PatchGAN discriminator for 3D data
        Parameters:
            input_nc (int)  -- the number of channels in input images
            ndf (int)       -- the number of filters in the last conv layer
            n_layers (int)  -- the number of conv layers in the discriminator
            norm_layer      -- normalization layer
        """
        super(NLayerDiscriminator, self).__init__()
        
        if not use_actnorm:
            norm_layer = Normalize
        else:
            norm_layer = ActNorm

        if type(norm_layer) == functools.partial:  # no need to use bias as GroupNorm has affine parameters
            use_bias = norm_layer.func != Normalize
        else:
            use_bias = norm_layer != Normalize

        kw = 4
        padw = 1
        
        sequence = [nn.Conv3d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
        nf_mult = 1
        nf_mult_prev = 1
        
        for n in range(1, n_layers):  # gradually increase the number of filters
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            sequence += [
                nn.Conv3d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)
        sequence += [
            nn.Conv3d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]

        sequence += [
            nn.Conv3d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]  # output 1 channel prediction map
        self.main = nn.Sequential(*sequence)

    def forward(self, input):
        """Standard forward for 3D discriminator."""
        return self.main(input)
    

class MultiClass_NLayerDiscriminator(nn.Module):
    """Defines a 3D PatchGAN discriminator"""
    def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False, num_classes=None):
        """Construct a PatchGAN discriminator for 3D data
        Parameters:
            input_nc (int)  -- the number of channels in input images
            ndf (int)       -- the number of filters in the last conv layer
            n_layers (int)  -- the number of conv layers in the discriminator
            norm_layer      -- normalization layer
            num_classes (int) -- number of classes for conditional discrimination
        """
        super(MultiClass_NLayerDiscriminator, self).__init__()
        
        self.num_classes = num_classes
        self.discriminators = nn.ModuleList([NLayerDiscriminator(input_nc, ndf, n_layers, use_actnorm) for _ in range(num_classes)])

    def forward(self, input, y=None):
        """
        Args:
            input: Input tensor (b,c,d,h,w)
            y: Class labels (optional) (b,)
        """
        if y is None:
            # 클래스 레이블이 없는 경우, 모든 판별자의 출력 평균을 반환
            outputs = [disc(input) for disc in self.discriminators]
            return torch.mean(torch.stack(outputs), dim=0)
        else:
            # 배치 내 각 샘플에 대해 해당 클래스의 판별자 사용
            batch_size = input.size(0)
            outputs = []
            
            for i in range(batch_size):
                class_idx = y[i].item()
                # 해당 클래스의 판별자를 사용하여 개별 샘플 처리
                output = self.discriminators[class_idx](input[i:i+1])
                outputs.append(output)
                
            # 배치의 모든 출력을 결합
            return torch.cat(outputs, dim=0)
