import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
import functools
from torch.optim import lr_scheduler
from torchsummary import summary
from typing import Type, Any, Callable, Union, List, Optional, cast
from torch import Tensor
from collections import OrderedDict 

###############################################################################
# Helper Functions
###############################################################################

class Identity(nn.Module):
    def forward(self, x):
        return x


def get_norm_layer(norm_type='instance'):
    """Return a normalization layer

    Parameters:
        norm_type (str) -- the name of the normalization layer: batch | instance | none

    For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
    For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
    """
    if norm_type == 'batch':
        norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
    elif norm_type == 'instance':
        norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=True)
    elif norm_type == 'instance3d': 
        norm_layer = functools.partial(nn.InstanceNorm3d, affine=False, track_running_stats=True)
    elif norm_type == 'none':
        def norm_layer(x): return Identity()
    else:
        raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
    return norm_layer

def init_weights(net, init_type='normal', init_gain=0.02):
    """Initialize network weights.

    Parameters:
        net (network)   -- network to be initialized
        init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
        init_gain (float)    -- scaling factor for normal, xavier and orthogonal.

    We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
    work better for some applications. Feel free to try yourself.
    """
    def init_func(m):  # define the initialization function
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
            if init_type == 'normal':
                init.normal_(m.weight.data, 0.0, init_gain)
            elif init_type == 'xavier':
                init.xavier_normal_(m.weight.data, gain=init_gain)
            elif init_type == 'kaiming':
                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            elif init_type == 'orthogonal':
                init.orthogonal_(m.weight.data, gain=init_gain)
            else:
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
            if hasattr(m, 'bias') and m.bias is not None:
                init.constant_(m.bias.data, 0.0)
        elif classname.find('BatchNorm2d') != -1:  # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
            init.normal_(m.weight.data, 1.0, init_gain)
            init.constant_(m.bias.data, 0.0)

    print('initialize network with %s' % init_type)
    net.apply(init_func)  # apply the initialization function <init_func>


def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
    """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
    Parameters:
        net (network)      -- the network to be initialized
        init_type (str)    -- the name of an initialization method: normal | xavier | kaiming | orthogonal
        gain (float)       -- scaling factor for normal, xavier and orthogonal.
        gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2

    Return an initialized network.
    """
    if len(gpu_ids) > 0:
        assert(torch.cuda.is_available())
        net.to(gpu_ids[0])
        net = torch.nn.DataParallel(net, gpu_ids)  # multi-GPUs
    init_weights(net, init_type, init_gain=init_gain)
    return net


class AttentionBlock(nn.Module):
    def __init__(self, skip_channels, input_channels, output_channels):
        super(AttentionBlock, self).__init__()
        
        # Process skip connection features
        self.skip_conv = nn.Conv3d(skip_channels, output_channels, kernel_size=1, stride=1, padding=0)
        self.skip_bn = nn.BatchNorm3d(output_channels)
        
        # Process input features from decoder
        self.input_conv = nn.Conv3d(input_channels, output_channels, kernel_size=1, stride=1, padding=0)
        self.input_bn = nn.BatchNorm3d(output_channels)
        
        # Attention mechanism
        self.psi_conv = nn.Conv3d(output_channels, 1, kernel_size=1, stride=1, padding=0)
        self.relu = nn.ReLU(inplace=True)
        self.sigmoid = nn.Sigmoid()

    def forward(self, skip_features, input_features):
        # Process skip features
        x_skip = self.skip_bn(self.skip_conv(skip_features))
        
        # Process input features
        x_input = self.input_bn(self.input_conv(input_features))
        
        # Ensure spatial dimensions match for addition
        if x_input.size()[2:] != x_skip.size()[2:]:
            x_input = F.interpolate(x_input, size=x_skip.size()[2:], mode='trilinear', align_corners=False)
            
        # Combine features and generate attention map
        x_combined = self.relu(x_skip + x_input)
        attention_map = self.sigmoid(self.psi_conv(x_combined))
        
        # Apply attention map to skip features
        out = skip_features * attention_map
        
        return out

class Attgenerator3D(nn.Module):
    # initializers
    def __init__(self, d=64):
        super(Attgenerator3D, self).__init__()
        # Unet encoder
        self.conv1 = nn.Conv3d(1, d, 4, 2, 1)
        self.conv2 = nn.Conv3d(d, d * 2, 4, 2, 1)
        self.conv2_bn = nn.BatchNorm3d(d * 2)
        self.conv3 = nn.Conv3d(d * 2, d * 4, 4, 2, 1)
        self.conv3_bn = nn.BatchNorm3d(d * 4)
        self.conv4 = nn.Conv3d(d * 4, d * 8, 4, 2, 1)
        self.conv4_bn = nn.BatchNorm3d(d * 8)
        self.conv5 = nn.Conv3d(d * 8, d * 8, 4, 2, 1)
        self.conv5_bn = nn.BatchNorm3d(d * 8)
        self.conv6 = nn.Conv3d(d * 8, d * 8, 4, 2, 1)
        self.conv6_bn = nn.BatchNorm3d(d * 8)
        self.conv7 = nn.Conv3d(d * 8, d * 8, 4, 2, 1)
        self.conv7_bn = nn.BatchNorm3d(d * 8)
        self.conv8 = nn.Conv3d(d * 8, d * 8, 4, 2, 1)
        # self.conv8_bn = nn.BatchNorm3d(d * 8)

        # Unet decoder
        self.deconv1 = nn.ConvTranspose3d(d * 8, d * 8, 4, 2, 1)
        self.deconv1_bn = nn.BatchNorm3d(d * 8)
        self.deconv2 = nn.ConvTranspose3d(d * 8 * 2, d * 8, 4, 2, 1)
        self.deconv2_bn = nn.BatchNorm3d(d * 8)
        self.deconv3 = nn.ConvTranspose3d(d * 8 * 2, d * 8, 4, 2, 1)
        self.deconv3_bn = nn.BatchNorm3d(d * 8)
        self.deconv4 = nn.ConvTranspose3d(d * 8 * 2, d * 8, 4, 2, 1)
        self.deconv4_bn = nn.BatchNorm3d(d * 8)
        self.deconv5 = nn.ConvTranspose3d(d * 8 * 2, d * 4, 4, 2, 1)
        self.deconv5_bn = nn.BatchNorm3d(d * 4)
        self.deconv6 = nn.ConvTranspose3d(d * 4 * 2, d * 2, 4, 2, 1)
        self.deconv6_bn = nn.BatchNorm3d(d * 2)
        self.deconv7 = nn.ConvTranspose3d(d * 2 * 2, d, 4, 2, 1)
        self.deconv7_bn = nn.BatchNorm3d(d)
        self.deconv8 = nn.ConvTranspose3d(d * 2, 1, 4, 2, 1)
        
        # Attention gates
        self.attention7 = AttentionBlock(d * 8, d * 8, d * 8)
        self.attention6 = AttentionBlock(d * 8, d * 8, d * 8)
        self.attention5 = AttentionBlock(d * 8, d * 8, d * 8)
        self.attention4 = AttentionBlock(d * 8, d * 8, d * 8)
        self.attention3 = AttentionBlock(d * 4, d * 8, d * 4)
        self.attention2 = AttentionBlock(d * 2, d * 4, d * 2)
        self.attention1 = AttentionBlock(d, d * 2, d)

    # weight_init
    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init3d(self._modules[m], mean, std)

    # forward method
    def forward(self, input):
        # 3D encoder
        e1 = self.conv1(input)
        e2 = self.conv2_bn(self.conv2(F.leaky_relu(e1, 0.2)))
        e3 = self.conv3_bn(self.conv3(F.leaky_relu(e2, 0.2)))
        e4 = self.conv4_bn(self.conv4(F.leaky_relu(e3, 0.2)))
        e5 = self.conv5_bn(self.conv5(F.leaky_relu(e4, 0.2)))
        e6 = self.conv6_bn(self.conv6(F.leaky_relu(e5, 0.2)))
        e7 = self.conv7_bn(self.conv7(F.leaky_relu(e6, 0.2)))
        e8 = self.conv8(F.leaky_relu(e7, 0.2))
        
        # 3D decoder with attention
        d1 = self.deconv1_bn(self.deconv1(F.relu(e8)))
        a7 = self.attention7(e7, d1)
        d1 = torch.cat([d1, a7], 1)
        
        d2 = self.deconv2_bn(self.deconv2(F.relu(d1)))
        a6 = self.attention6(e6, d2)
        d2 = torch.cat([d2, a6], 1)
        
        d3 = self.deconv3_bn(self.deconv3(F.relu(d2)))
        a5 = self.attention5(e5, d3)
        d3 = torch.cat([d3, a5], 1)
        
        d4 = self.deconv4_bn(self.deconv4(F.relu(d3)))
        a4 = self.attention4(e4, d4)
        d4 = torch.cat([d4, a4], 1)
        
        d5 = self.deconv5_bn(self.deconv5(F.relu(d4)))
        a3 = self.attention3(e3, d5)
        d5 = torch.cat([d5, a3], 1)
        
        d6 = self.deconv6_bn(self.deconv6(F.relu(d5)))
        a2 = self.attention2(e2, d6)
        d6 = torch.cat([d6, a2], 1)
        
        d7 = self.deconv7_bn(self.deconv7(F.relu(d6)))
        a1 = self.attention1(e1, d7)
        d7 = torch.cat([d7, a1], 1)
        
        d8 = self.deconv8(F.relu(d7))
        o = torch.tanh(d8)

        return o

import torch
import torch.nn as nn
import torch.nn.functional as F


class ResBlock3D(nn.Module):
    """3D残差块"""
    def __init__(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1):
        super(ResBlock3D, self).__init__()
        
        # 主路径
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding)
        self.bn = nn.BatchNorm3d(out_channels, track_running_stats=True, momentum=0.1)
        
        # 快捷连接
        self.shortcut = nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=stride)
        
    def forward(self, x):
        # 先进行激活，然后再分流（pre-activation残差设计）
        x_activated = F.leaky_relu(x, 0.2)
        
        # 主路径
        out = self.conv(x_activated)
        out = self.bn(out)
        
        # 快捷连接（直接使用1x1卷积，不经过BN，避免在小批次/小特征图时出现问题）
        shortcut = self.shortcut(x)
        
        # 合并
        out = out + shortcut
        
        return out


class ResUpBlock3D(nn.Module):
    """3D残差上采样块，用于解码器部分"""
    def __init__(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1):
        super(ResUpBlock3D, self).__init__()
        
        # 转置卷积
        self.deconv = nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride, padding)
        self.bn = nn.BatchNorm3d(out_channels, track_running_stats=True, momentum=0.1)
        
        # 快捷连接（转置卷积）
        self.shortcut = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=1, stride=stride, output_padding=1)
        
    def forward(self, x):
        # 注意：这里对输入不进行激活，因为在调用者中已经应用了激活函数
        
        # 主路径
        out = self.deconv(x)
        
        # 安全处理：如果特征图很小且批次为1，暂时切换为eval模式避免BN错误
        bn_training = self.bn.training
        if x.size()[0] == 1 and x.size()[2:] == torch.Size([1, 1, 1]):
            self.bn.eval()
            
        out = self.bn(out)
        
        # 恢复原始训练状态
        if bn_training:
            self.bn.train()
            
        # 快捷连接（无BN）
        shortcut = self.shortcut(x)
        
        # 合并
        out = out + shortcut
        
        return out


class ResUnetgenerator3D(nn.Module):
    """带残差连接的3D UNet生成器"""
    def __init__(self, d=64):
        super(ResUnetgenerator3D, self).__init__()
        
        # 编码器部分 - 使用残差块
        self.conv1 = nn.Conv3d(1, d, 4, 2, 1)  # 第一层不使用残差
        self.res_block2 = ResBlock3D(d, d * 2, 4, 2, 1)
        self.res_block3 = ResBlock3D(d * 2, d * 4, 4, 2, 1)
        self.res_block4 = ResBlock3D(d * 4, d * 8, 4, 2, 1)
        self.res_block5 = ResBlock3D(d * 8, d * 8, 4, 2, 1)
        self.res_block6 = ResBlock3D(d * 8, d * 8, 4, 2, 1)
        self.res_block7 = ResBlock3D(d * 8, d * 8, 4, 2, 1)
        self.conv8 = nn.Conv3d(d * 8, d * 8, 4, 2, 1)  # 瓶颈层
        
        # 保持原始层引用以便兼容性
        self.conv2 = self.res_block2.conv
        self.conv2_bn = self.res_block2.bn
        self.conv3 = self.res_block3.conv
        self.conv3_bn = self.res_block3.bn
        self.conv4 = self.res_block4.conv
        self.conv4_bn = self.res_block4.bn
        self.conv5 = self.res_block5.conv
        self.conv5_bn = self.res_block5.bn
        self.conv6 = self.res_block6.conv
        self.conv6_bn = self.res_block6.bn
        self.conv7 = self.res_block7.conv
        self.conv7_bn = self.res_block7.bn
        
        # 解码器部分 - 使用原始结构但配合安全的批归一化
        # 注意：此处不使用ResUpBlock3D以避免在第一个上采样块处的批归一化问题
        self.deconv1 = nn.ConvTranspose3d(d * 8, d * 8, 4, 2, 1)
        self.deconv1_bn = nn.BatchNorm3d(d * 8, track_running_stats=True, momentum=0.1)
        self.deconv2 = nn.ConvTranspose3d(d * 8 * 2, d * 8, 4, 2, 1)
        self.deconv2_bn = nn.BatchNorm3d(d * 8, track_running_stats=True, momentum=0.1)
        self.deconv3 = nn.ConvTranspose3d(d * 8 * 2, d * 8, 4, 2, 1)
        self.deconv3_bn = nn.BatchNorm3d(d * 8, track_running_stats=True, momentum=0.1)
        self.deconv4 = nn.ConvTranspose3d(d * 8 * 2, d * 8, 4, 2, 1)
        self.deconv4_bn = nn.BatchNorm3d(d * 8, track_running_stats=True, momentum=0.1)
        self.deconv5 = nn.ConvTranspose3d(d * 8 * 2, d * 4, 4, 2, 1)
        self.deconv5_bn = nn.BatchNorm3d(d * 4, track_running_stats=True, momentum=0.1)
        self.deconv6 = nn.ConvTranspose3d(d * 4 * 2, d * 2, 4, 2, 1)
        self.deconv6_bn = nn.BatchNorm3d(d * 2, track_running_stats=True, momentum=0.1)
        self.deconv7 = nn.ConvTranspose3d(d * 2 * 2, d, 4, 2, 1)
        self.deconv7_bn = nn.BatchNorm3d(d, track_running_stats=True, momentum=0.1)
        self.deconv8 = nn.ConvTranspose3d(d * 2, 1, 4, 2, 1)

    def weight_init(self, mean, std):
        """权重初始化方法"""
        for m in self._modules:
            if isinstance(self._modules[m], (nn.Conv3d, nn.ConvTranspose3d)):
                normal_init3d(self._modules[m], mean, std)
            elif isinstance(self._modules[m], (ResBlock3D, ResUpBlock3D)):
                if hasattr(self._modules[m], 'conv'):
                    normal_init3d(self._modules[m].conv, mean, std)
                if hasattr(self._modules[m], 'shortcut') and isinstance(self._modules[m].shortcut, nn.Conv3d):
                    normal_init3d(self._modules[m].shortcut, mean, std)
                if hasattr(self._modules[m], 'deconv'):
                    normal_init3d(self._modules[m].deconv, mean, std)
                if hasattr(self._modules[m], 'shortcut') and isinstance(self._modules[m].shortcut, nn.ConvTranspose3d):
                    normal_init3d(self._modules[m].shortcut, mean, std)

    def forward(self, input):
        """前向传播方法"""
        # 编码器部分
        e1 = self.conv1(input)  # 第一层不使用BN和残差
        e2 = self.res_block2(e1)
        e3 = self.res_block3(e2)
        e4 = self.res_block4(e3)
        e5 = self.res_block5(e4)
        e6 = self.res_block6(e5)
        e7 = self.res_block7(e6)
        e8 = self.conv8(F.leaky_relu(e7, 0.2))
        
        # 解码器部分 - 使用安全的批归一化
        # 对小特征图的处理
        training_mode = self.training
        if training_mode and e8.size()[0] == 1 and e8.size()[2:] == torch.Size([1, 1, 1]):
            self.deconv1_bn.eval()  # 临时切换为eval模式
            
        d1 = self.deconv1_bn(self.deconv1(F.relu(e8)))
        
        # 恢复训练模式
        if training_mode:
            self.deconv1_bn.train()
            
        d1 = torch.cat([d1, e7], 1)
        d2 = self.deconv2_bn(self.deconv2(F.relu(d1)))
        d2 = torch.cat([d2, e6], 1)
        d3 = self.deconv3_bn(self.deconv3(F.relu(d2)))
        d3 = torch.cat([d3, e5], 1)
        d4 = self.deconv4_bn(self.deconv4(F.relu(d3)))
        d4 = torch.cat([d4, e4], 1)
        d5 = self.deconv5_bn(self.deconv5(F.relu(d4)))
        d5 = torch.cat([d5, e3], 1)
        d6 = self.deconv6_bn(self.deconv6(F.relu(d5)))
        d6 = torch.cat([d6, e2], 1)
        d7 = self.deconv7_bn(self.deconv7(F.relu(d6)))
        d7 = torch.cat([d7, e1], 1)
        d8 = self.deconv8(F.relu(d7))
        o = torch.tanh(d8)

        return o
    
    def extract_encoder_feature(self, input):
        """提取编码器特征"""
        # 编码器部分
        e1 = self.conv1(input)
        e2 = self.res_block2(e1)
        e3 = self.res_block3(e2)
        e4 = self.res_block4(e3)
        e5 = self.res_block5(e4)
        e6 = self.res_block6(e5)
        e7 = self.res_block7(e6)
        e8 = self.conv8(F.leaky_relu(e7, 0.2))
        
        # 解码器部分
        training_mode = self.training
        if training_mode and e8.size()[0] == 1 and e8.size()[2:] == torch.Size([1, 1, 1]):
            self.deconv1_bn.eval()
            
        d1 = self.deconv1_bn(self.deconv1(F.relu(e8)))
        
        if training_mode:
            self.deconv1_bn.train()
            
        d1 = torch.cat([d1, e7], 1)
        d2 = self.deconv2_bn(self.deconv2(F.relu(d1)))
        d2 = torch.cat([d2, e6], 1)
        d3 = self.deconv3_bn(self.deconv3(F.relu(d2)))
        d3 = torch.cat([d3, e5], 1)
        d4 = self.deconv4_bn(self.deconv4(F.relu(d3)))
        d4 = torch.cat([d4, e4], 1)
        d5 = self.deconv5_bn(self.deconv5(F.relu(d4)))
        d5 = torch.cat([d5, e3], 1)
        d6 = self.deconv6_bn(self.deconv6(F.relu(d5)))
        d6 = torch.cat([d6, e2], 1)
        d7 = self.deconv7_bn(self.deconv7(F.relu(d6)))
        d7 = torch.cat([d7, e1], 1)
        d8 = self.deconv8(F.relu(d7))
        o = torch.tanh(d8)
        
        return [e2, e3, e4, e5, e6, e7], o
    
    def extract_decoder_feature(self, input):
        """提取解码器特征"""
        # 编码器部分
        e1 = self.conv1(input)
        e2 = self.res_block2(e1)
        e3 = self.res_block3(e2)
        e4 = self.res_block4(e3)
        e5 = self.res_block5(e4)
        e6 = self.res_block6(e5)
        e7 = self.res_block7(e6)
        e8 = self.conv8(F.leaky_relu(e7, 0.2))
        
        # 解码器部分
        training_mode = self.training
        if training_mode and e8.size()[0] == 1 and e8.size()[2:] == torch.Size([1, 1, 1]):
            self.deconv1_bn.eval()
            
        d1 = self.deconv1_bn(self.deconv1(F.relu(e8)))
        
        if training_mode:
            self.deconv1_bn.train()
            
        d1 = torch.cat([d1, e7], 1)
        d2_ = self.deconv2_bn(self.deconv2(F.relu(d1)))
        d2 = torch.cat([d2_, e6], 1)
        d3_ = self.deconv3_bn(self.deconv3(F.relu(d2)))
        d3 = torch.cat([d3_, e5], 1)
        d4_ = self.deconv4_bn(self.deconv4(F.relu(d3)))
        d4 = torch.cat([d4_, e4], 1)
        d5_ = self.deconv5_bn(self.deconv5(F.relu(d4)))
        d5 = torch.cat([d5_, e3], 1)
        d6_ = self.deconv6_bn(self.deconv6(F.relu(d5)))
        d6 = torch.cat([d6_, e2], 1)
        d7_ = self.deconv7_bn(self.deconv7(F.relu(d6)))
        d7 = torch.cat([d7_, e1], 1)
        d8 = self.deconv8(F.relu(d7))
        o = torch.tanh(d8)
        
        return [d2_, d3_, d4_, d5_, d6_, d7_], o

class Unetgenerator3D(nn.Module):
    # initializers
    def __init__(self, d=64):
        super(Unetgenerator3D, self).__init__()
        # Unet encoder
        self.conv1 = nn.Conv3d(1, d, 4, 2, 1)
        self.conv2 = nn.Conv3d(d, d * 2, 4, 2, 1)
        self.conv2_bn = nn.BatchNorm3d(d * 2)
        self.conv3 = nn.Conv3d(d * 2, d * 4, 4, 2, 1)
        self.conv3_bn = nn.BatchNorm3d(d * 4)
        self.conv4 = nn.Conv3d(d * 4, d * 8, 4, 2, 1)
        self.conv4_bn = nn.BatchNorm3d(d * 8)
        self.conv5 = nn.Conv3d(d * 8, d * 8, 4, 2, 1)
        self.conv5_bn = nn.BatchNorm3d(d * 8)
        self.conv6 = nn.Conv3d(d * 8, d * 8, 4, 2, 1)
        self.conv6_bn = nn.BatchNorm3d(d * 8)
        self.conv7 = nn.Conv3d(d * 8, d * 8, 4, 2, 1)
        self.conv7_bn = nn.BatchNorm3d(d * 8)
        self.conv8 = nn.Conv3d(d * 8, d * 8, 4, 2, 1)
        # self.conv8_bn = nn.BatchNorm2d(d * 8)

        # Unet decoder
        self.deconv1 = nn.ConvTranspose3d(d * 8, d * 8, 4, 2, 1)
        self.deconv1_bn = nn.BatchNorm3d(d * 8)
        self.deconv2 = nn.ConvTranspose3d(d * 8 * 2, d * 8, 4, 2, 1)
        self.deconv2_bn = nn.BatchNorm3d(d * 8)
        self.deconv3 = nn.ConvTranspose3d(d * 8 * 2, d * 8, 4, 2, 1)
        self.deconv3_bn = nn.BatchNorm3d(d * 8)
        self.deconv4 = nn.ConvTranspose3d(d * 8 * 2, d * 8, 4, 2, 1)
        self.deconv4_bn = nn.BatchNorm3d(d * 8)
        self.deconv5 = nn.ConvTranspose3d(d * 8 * 2, d * 4, 4, 2, 1)
        self.deconv5_bn = nn.BatchNorm3d(d * 4)
        self.deconv6 = nn.ConvTranspose3d(d * 4 * 2, d * 2, 4, 2, 1)
        self.deconv6_bn = nn.BatchNorm3d(d * 2)
        self.deconv7 = nn.ConvTranspose3d(d * 2 * 2, d, 4, 2, 1)
        self.deconv7_bn = nn.BatchNorm3d(d)
        self.deconv8 = nn.ConvTranspose3d(d * 2, 1, 4, 2, 1)

    # weight_init
    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init3d(self._modules[m], mean, std)

    # forward method
    def forward(self, input):
        # 3D encoder
        e1 = self.conv1(input)
        e2 = self.conv2_bn(self.conv2(F.leaky_relu(e1, 0.2)))
        e3 = self.conv3_bn(self.conv3(F.leaky_relu(e2, 0.2)))
        e4 = self.conv4_bn(self.conv4(F.leaky_relu(e3, 0.2)))
        e5 = self.conv5_bn(self.conv5(F.leaky_relu(e4, 0.2)))
        e6 = self.conv6_bn(self.conv6(F.leaky_relu(e5, 0.2)))
        e7 = self.conv7_bn(self.conv7(F.leaky_relu(e6, 0.2)))
        e8 = self.conv8(F.leaky_relu(e7, 0.2))
        
        # 3D decoder
        d1 = self.deconv1_bn(self.deconv1(F.relu(e8)))
        d1 = torch.cat([d1, e7], 1)
        
        d2 = self.deconv2_bn(self.deconv2(F.relu(d1)))
        d2 = torch.cat([d2, e6], 1)
        
        d3 = self.deconv3_bn(self.deconv3(F.relu(d2)))
        d3 = torch.cat([d3, e5], 1)
        d4 = self.deconv4_bn(self.deconv4(F.relu(d3)))
        
        d4 = torch.cat([d4, e4], 1)
        d5 = self.deconv5_bn(self.deconv5(F.relu(d4)))
        d5 = torch.cat([d5, e3], 1)
        d6 = self.deconv6_bn(self.deconv6(F.relu(d5)))
        d6 = torch.cat([d6, e2], 1)
        d7 = self.deconv7_bn(self.deconv7(F.relu(d6)))
        d7 = torch.cat([d7, e1], 1)
        d8 = self.deconv8(F.relu(d7))
        o = torch.tanh(d8)

        return o


    def extract_encoder_feature(self, input):
        e1 = self.conv1(input)
        e2 = self.conv2_bn(self.conv2(F.leaky_relu(e1, 0.2)))
        e3 = self.conv3_bn(self.conv3(F.leaky_relu(e2, 0.2)))
        e4 = self.conv4_bn(self.conv4(F.leaky_relu(e3, 0.2)))
        e5 = self.conv5_bn(self.conv5(F.leaky_relu(e4, 0.2)))
        e6 = self.conv6_bn(self.conv6(F.leaky_relu(e5, 0.2)))
        e7 = self.conv7_bn(self.conv7(F.leaky_relu(e6, 0.2)))
        e8 = self.conv8(F.leaky_relu(e7, 0.2))
        d1 = self.deconv1_bn(self.deconv1(F.relu(e8)))
        d1 = torch.cat([d1, e7], 1)
        #d2 = F.dropout(self.deconv2_bn(self.deconv2(F.relu(d1))), 0.5, training=True)
        d2 = self.deconv2_bn(self.deconv2(F.relu(d1)))
        d2 = torch.cat([d2, e6], 1)
        #d3 = F.dropout(self.deconv3_bn(self.deconv3(F.relu(d2))), 0.5, training=True)
        d3 = self.deconv3_bn(self.deconv3(F.relu(d2)))
        d3 = torch.cat([d3, e5], 1)
        d4 = self.deconv4_bn(self.deconv4(F.relu(d3)))
        # d4 = F.dropout(self.deconv4_bn(self.deconv4(F.relu(d3))), 0.5)
        d4 = torch.cat([d4, e4], 1)
        d5 = self.deconv5_bn(self.deconv5(F.relu(d4)))
        d5 = torch.cat([d5, e3], 1)
        d6 = self.deconv6_bn(self.deconv6(F.relu(d5)))
        d6 = torch.cat([d6, e2], 1)
        d7 = self.deconv7_bn(self.deconv7(F.relu(d6)))
        d7 = torch.cat([d7, e1], 1)
        d8 = self.deconv8(F.relu(d7))
        o = torch.tanh(d8)
        return [e2,e3,e4,e5,e6,e7], o

    def extract_decoder_feature(self, input):
        e1 = self.conv1(input)
        e2 = self.conv2_bn(self.conv2(F.leaky_relu(e1, 0.2)))
        e3 = self.conv3_bn(self.conv3(F.leaky_relu(e2, 0.2)))
        e4 = self.conv4_bn(self.conv4(F.leaky_relu(e3, 0.2)))
        e5 = self.conv5_bn(self.conv5(F.leaky_relu(e4, 0.2)))
        e6 = self.conv6_bn(self.conv6(F.leaky_relu(e5, 0.2)))
        e7 = self.conv7_bn(self.conv7(F.leaky_relu(e6, 0.2)))
        e8 = self.conv8(F.leaky_relu(e7, 0.2))
        d1 = self.deconv1_bn(self.deconv1(F.relu(e8)))
        d1 = torch.cat([d1, e7], 1)
        #d2 = F.dropout(self.deconv2_bn(self.deconv2(F.relu(d1))), 0.5, training=True)
        d2_ = self.deconv2_bn(self.deconv2(F.relu(d1)))
        d2 = torch.cat([d2_, e6], 1)
        #d3 = F.dropout(self.deconv3_bn(self.deconv3(F.relu(d2))), 0.5, training=True)
        d3_ = self.deconv3_bn(self.deconv3(F.relu(d2)))
        d3 = torch.cat([d3_, e5], 1)
        d4_ = self.deconv4_bn(self.deconv4(F.relu(d3)))
        # d4 = F.dropout(self.deconv4_bn(self.deconv4(F.relu(d3))), 0.5)
        d4 = torch.cat([d4_, e4], 1)
        d5_ = self.deconv5_bn(self.deconv5(F.relu(d4)))
        d5 = torch.cat([d5_, e3], 1)
        d6_ = self.deconv6_bn(self.deconv6(F.relu(d5)))
        d6 = torch.cat([d6_, e2], 1)
        d7_ = self.deconv7_bn(self.deconv7(F.relu(d6)))
        d7 = torch.cat([d7_, e1], 1)
        d8 = self.deconv8(F.relu(d7))
        o = torch.tanh(d8)
        return [d2_,d3_,d4_,d5_,d6_,d7_], o

class generator3D(nn.Module):
    # initializers
    def __init__(self, d=64):
        super(generator3D, self).__init__()
        # Unet encoder
        self.conv1 = nn.Conv3d(1, d, 4, 2, 1)
        self.conv2 = nn.Conv3d(d, d * 2, 4, 2, 1)
        self.conv2_bn = nn.BatchNorm3d(d * 2)
        self.conv3 = nn.Conv3d(d * 2, d * 4, 4, 2, 1)
        self.conv3_bn = nn.BatchNorm3d(d * 4)
        self.conv4 = nn.Conv3d(d * 4, d * 8, 4, 2, 1)
        self.conv4_bn = nn.BatchNorm3d(d * 8)
        self.conv5 = nn.Conv3d(d * 8, d * 8, 4, 2, 1)
        self.conv5_bn = nn.BatchNorm3d(d * 8)
        self.conv6 = nn.Conv3d(d * 8, d * 8, 4, 2, 1)
        self.conv6_bn = nn.BatchNorm3d(d * 8)
        self.conv7 = nn.Conv3d(d * 8, d * 8, 4, 2, 1)
        self.conv7_bn = nn.BatchNorm3d(d * 8)
        self.conv8 = nn.Conv3d(d * 8, d * 8, 4, 2, 1)
        # self.conv8_bn = nn.BatchNorm2d(d * 8)

        # Unet decoder
        self.deconv1 = nn.ConvTranspose3d(d * 8, d * 8, 4, 2, 1)
        self.deconv1_bn = nn.BatchNorm3d(d * 8)
        self.deconv2 = nn.ConvTranspose3d(d * 8 * 2, d * 8, 4, 2, 1)
        self.deconv2_bn = nn.BatchNorm3d(d * 8)
        self.deconv3 = nn.ConvTranspose3d(d * 8 * 2, d * 8, 4, 2, 1)
        self.deconv3_bn = nn.BatchNorm3d(d * 8)
        self.deconv4 = nn.ConvTranspose3d(d * 8 * 2, d * 8, 4, 2, 1)
        self.deconv4_bn = nn.BatchNorm3d(d * 8)
        self.deconv5 = nn.ConvTranspose3d(d * 8 * 2, d * 4, 4, 2, 1)
        self.deconv5_bn = nn.BatchNorm3d(d * 4)
        self.deconv6 = nn.ConvTranspose3d(d * 4 * 2, d * 2, 4, 2, 1)
        self.deconv6_bn = nn.BatchNorm3d(d * 2)
        self.deconv7 = nn.ConvTranspose3d(d * 2 * 2, d, 4, 2, 1)
        self.deconv7_bn = nn.BatchNorm3d(d)
        self.deconv8 = nn.ConvTranspose3d(d * 2, 1, 4, 2, 1)

    # weight_init
    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init3d(self._modules[m], mean, std)

    # forward method
    def forward(self, input):
        # 3D encoder
        e1 = self.conv1(input)
        e2 = self.conv2_bn(self.conv2(F.leaky_relu(e1, 0.2)))
        e3 = self.conv3_bn(self.conv3(F.leaky_relu(e2, 0.2)))
        e4 = self.conv4_bn(self.conv4(F.leaky_relu(e3, 0.2)))
        e5 = self.conv5_bn(self.conv5(F.leaky_relu(e4, 0.2)))
        e6 = self.conv6_bn(self.conv6(F.leaky_relu(e5, 0.2)))
        e7 = self.conv7_bn(self.conv7(F.leaky_relu(e6, 0.2)))
        e8 = self.conv8(F.leaky_relu(e7, 0.2))
        
        # 3D decoder
        d1 = self.deconv1_bn(self.deconv1(F.relu(e8)))
        d1 = torch.cat([d1, e7], 1)
        
        d2 = self.deconv2_bn(self.deconv2(F.relu(d1)))
        d2 = torch.cat([d2, e6], 1)
        
        d3 = self.deconv3_bn(self.deconv3(F.relu(d2)))
        d3 = torch.cat([d3, e5], 1)
        d4 = self.deconv4_bn(self.deconv4(F.relu(d3)))
        
        d4 = torch.cat([d4, e4], 1)
        d5 = self.deconv5_bn(self.deconv5(F.relu(d4)))
        d5 = torch.cat([d5, e3], 1)
        d6 = self.deconv6_bn(self.deconv6(F.relu(d5)))
        d6 = torch.cat([d6, e2], 1)
        d7 = self.deconv7_bn(self.deconv7(F.relu(d6)))
        d7 = torch.cat([d7, e1], 1)
        d8 = self.deconv8(F.relu(d7))
        o = torch.tanh(d8)

        return o


    def extract_encoder_feature(self, input):
        e1 = self.conv1(input)
        e2 = self.conv2_bn(self.conv2(F.leaky_relu(e1, 0.2)))
        e3 = self.conv3_bn(self.conv3(F.leaky_relu(e2, 0.2)))
        e4 = self.conv4_bn(self.conv4(F.leaky_relu(e3, 0.2)))
        e5 = self.conv5_bn(self.conv5(F.leaky_relu(e4, 0.2)))
        e6 = self.conv6_bn(self.conv6(F.leaky_relu(e5, 0.2)))
        e7 = self.conv7_bn(self.conv7(F.leaky_relu(e6, 0.2)))
        e8 = self.conv8(F.leaky_relu(e7, 0.2))
        d1 = self.deconv1_bn(self.deconv1(F.relu(e8)))
        d1 = torch.cat([d1, e7], 1)
        #d2 = F.dropout(self.deconv2_bn(self.deconv2(F.relu(d1))), 0.5, training=True)
        d2 = self.deconv2_bn(self.deconv2(F.relu(d1)))
        d2 = torch.cat([d2, e6], 1)
        #d3 = F.dropout(self.deconv3_bn(self.deconv3(F.relu(d2))), 0.5, training=True)
        d3 = self.deconv3_bn(self.deconv3(F.relu(d2)))
        d3 = torch.cat([d3, e5], 1)
        d4 = self.deconv4_bn(self.deconv4(F.relu(d3)))
        # d4 = F.dropout(self.deconv4_bn(self.deconv4(F.relu(d3))), 0.5)
        d4 = torch.cat([d4, e4], 1)
        d5 = self.deconv5_bn(self.deconv5(F.relu(d4)))
        d5 = torch.cat([d5, e3], 1)
        d6 = self.deconv6_bn(self.deconv6(F.relu(d5)))
        d6 = torch.cat([d6, e2], 1)
        d7 = self.deconv7_bn(self.deconv7(F.relu(d6)))
        d7 = torch.cat([d7, e1], 1)
        d8 = self.deconv8(F.relu(d7))
        o = torch.tanh(d8)
        return [e2,e3,e4,e5,e6,e7], o

    def extract_decoder_feature(self, input):
        e1 = self.conv1(input)
        e2 = self.conv2_bn(self.conv2(F.leaky_relu(e1, 0.2)))
        e3 = self.conv3_bn(self.conv3(F.leaky_relu(e2, 0.2)))
        e4 = self.conv4_bn(self.conv4(F.leaky_relu(e3, 0.2)))
        e5 = self.conv5_bn(self.conv5(F.leaky_relu(e4, 0.2)))
        e6 = self.conv6_bn(self.conv6(F.leaky_relu(e5, 0.2)))
        e7 = self.conv7_bn(self.conv7(F.leaky_relu(e6, 0.2)))
        e8 = self.conv8(F.leaky_relu(e7, 0.2))
        d1 = self.deconv1_bn(self.deconv1(F.relu(e8)))
        d1 = torch.cat([d1, e7], 1)
        #d2 = F.dropout(self.deconv2_bn(self.deconv2(F.relu(d1))), 0.5, training=True)
        d2_ = self.deconv2_bn(self.deconv2(F.relu(d1)))
        d2 = torch.cat([d2_, e6], 1)
        #d3 = F.dropout(self.deconv3_bn(self.deconv3(F.relu(d2))), 0.5, training=True)
        d3_ = self.deconv3_bn(self.deconv3(F.relu(d2)))
        d3 = torch.cat([d3_, e5], 1)
        d4_ = self.deconv4_bn(self.deconv4(F.relu(d3)))
        # d4 = F.dropout(self.deconv4_bn(self.deconv4(F.relu(d3))), 0.5)
        d4 = torch.cat([d4_, e4], 1)
        d5_ = self.deconv5_bn(self.deconv5(F.relu(d4)))
        d5 = torch.cat([d5_, e3], 1)
        d6_ = self.deconv6_bn(self.deconv6(F.relu(d5)))
        d6 = torch.cat([d6_, e2], 1)
        d7_ = self.deconv7_bn(self.deconv7(F.relu(d6)))
        d7 = torch.cat([d7_, e1], 1)
        d8 = self.deconv8(F.relu(d7))
        o = torch.tanh(d8)
        return [d2_,d3_,d4_,d5_,d6_,d7_], o

class generator(nn.Module):
    # initializers
    def __init__(self, d=64):
        super(generator, self).__init__()
        # Unet encoder
        self.conv1 = nn.Conv2d(1, d, 4, 2, 1)
        self.conv2 = nn.Conv2d(d, d * 2, 4, 2, 1)
        self.conv2_bn = nn.BatchNorm2d(d * 2)
        self.conv3 = nn.Conv2d(d * 2, d * 4, 4, 2, 1)
        self.conv3_bn = nn.BatchNorm2d(d * 4)
        self.conv4 = nn.Conv2d(d * 4, d * 8, 4, 2, 1)
        self.conv4_bn = nn.BatchNorm2d(d * 8)
        self.conv5 = nn.Conv2d(d * 8, d * 8, 4, 2, 1)
        self.conv5_bn = nn.BatchNorm2d(d * 8)
        self.conv6 = nn.Conv2d(d * 8, d * 8, 4, 2, 1)
        self.conv6_bn = nn.BatchNorm2d(d * 8)
        self.conv7 = nn.Conv2d(d * 8, d * 8, 4, 2, 1)
        self.conv7_bn = nn.BatchNorm2d(d * 8)
        self.conv8 = nn.Conv2d(d * 8, d * 8, 4, 2, 1)
        # self.conv8_bn = nn.BatchNorm2d(d * 8)

        # Unet decoder
        self.deconv1 = nn.ConvTranspose2d(d * 8, d * 8, 4, 2, 1)
        self.deconv1_bn = nn.BatchNorm2d(d * 8)
        self.deconv2 = nn.ConvTranspose2d(d * 8 * 2, d * 8, 4, 2, 1)
        self.deconv2_bn = nn.BatchNorm2d(d * 8)
        self.deconv3 = nn.ConvTranspose2d(d * 8 * 2, d * 8, 4, 2, 1)
        self.deconv3_bn = nn.BatchNorm2d(d * 8)
        self.deconv4 = nn.ConvTranspose2d(d * 8 * 2, d * 8, 4, 2, 1)
        self.deconv4_bn = nn.BatchNorm2d(d * 8)
        self.deconv5 = nn.ConvTranspose2d(d * 8 * 2, d * 4, 4, 2, 1)
        self.deconv5_bn = nn.BatchNorm2d(d * 4)
        self.deconv6 = nn.ConvTranspose2d(d * 4 * 2, d * 2, 4, 2, 1)
        self.deconv6_bn = nn.BatchNorm2d(d * 2)
        self.deconv7 = nn.ConvTranspose2d(d * 2 * 2, d, 4, 2, 1)
        self.deconv7_bn = nn.BatchNorm2d(d)
        self.deconv8 = nn.ConvTranspose2d(d * 2, 1, 4, 2, 1)

    # weight_init
    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)

    # forward method
    def forward(self, input):
        e1 = self.conv1(input)
        e2 = self.conv2_bn(self.conv2(F.leaky_relu(e1, 0.2)))
        e3 = self.conv3_bn(self.conv3(F.leaky_relu(e2, 0.2)))
        e4 = self.conv4_bn(self.conv4(F.leaky_relu(e3, 0.2)))
        e5 = self.conv5_bn(self.conv5(F.leaky_relu(e4, 0.2)))
        e6 = self.conv6_bn(self.conv6(F.leaky_relu(e5, 0.2)))
        e7 = self.conv7_bn(self.conv7(F.leaky_relu(e6, 0.2)))
        e8 = self.conv8(F.leaky_relu(e7, 0.2))
        # e8 = self.conv8_bn(self.conv8(F.leaky_relu(e7, 0.2)))
        #d1 = F.dropout(self.deconv1_bn(self.deconv1(F.relu(e8))), 0.5, training=True)
        d1 = self.deconv1_bn(self.deconv1(F.relu(e8)))
        d1 = torch.cat([d1, e7], 1)
        #d2 = F.dropout(self.deconv2_bn(self.deconv2(F.relu(d1))), 0.5, training=True)
        d2 = self.deconv2_bn(self.deconv2(F.relu(d1)))
        d2 = torch.cat([d2, e6], 1)
        #d3 = F.dropout(self.deconv3_bn(self.deconv3(F.relu(d2))), 0.5, training=True)
        d3 = self.deconv3_bn(self.deconv3(F.relu(d2)))
        d3 = torch.cat([d3, e5], 1)
        d4 = self.deconv4_bn(self.deconv4(F.relu(d3)))
        # d4 = F.dropout(self.deconv4_bn(self.deconv4(F.relu(d3))), 0.5)
        d4 = torch.cat([d4, e4], 1)
        d5 = self.deconv5_bn(self.deconv5(F.relu(d4)))
        d5 = torch.cat([d5, e3], 1)
        d6 = self.deconv6_bn(self.deconv6(F.relu(d5)))
        d6 = torch.cat([d6, e2], 1)
        d7 = self.deconv7_bn(self.deconv7(F.relu(d6)))
        d7 = torch.cat([d7, e1], 1)
        d8 = self.deconv8(F.relu(d7))
        o = torch.tanh(d8)

        return o

    def extract_encoder_feature(self, input):
        e1 = self.conv1(input)
        e2 = self.conv2_bn(self.conv2(F.leaky_relu(e1, 0.2)))
        e3 = self.conv3_bn(self.conv3(F.leaky_relu(e2, 0.2)))
        e4 = self.conv4_bn(self.conv4(F.leaky_relu(e3, 0.2)))
        e5 = self.conv5_bn(self.conv5(F.leaky_relu(e4, 0.2)))
        e6 = self.conv6_bn(self.conv6(F.leaky_relu(e5, 0.2)))
        e7 = self.conv7_bn(self.conv7(F.leaky_relu(e6, 0.2)))
        e8 = self.conv8(F.leaky_relu(e7, 0.2))
        d1 = self.deconv1_bn(self.deconv1(F.relu(e8)))
        d1 = torch.cat([d1, e7], 1)
        #d2 = F.dropout(self.deconv2_bn(self.deconv2(F.relu(d1))), 0.5, training=True)
        d2 = self.deconv2_bn(self.deconv2(F.relu(d1)))
        d2 = torch.cat([d2, e6], 1)
        #d3 = F.dropout(self.deconv3_bn(self.deconv3(F.relu(d2))), 0.5, training=True)
        d3 = self.deconv3_bn(self.deconv3(F.relu(d2)))
        d3 = torch.cat([d3, e5], 1)
        d4 = self.deconv4_bn(self.deconv4(F.relu(d3)))
        # d4 = F.dropout(self.deconv4_bn(self.deconv4(F.relu(d3))), 0.5)
        d4 = torch.cat([d4, e4], 1)
        d5 = self.deconv5_bn(self.deconv5(F.relu(d4)))
        d5 = torch.cat([d5, e3], 1)
        d6 = self.deconv6_bn(self.deconv6(F.relu(d5)))
        d6 = torch.cat([d6, e2], 1)
        d7 = self.deconv7_bn(self.deconv7(F.relu(d6)))
        d7 = torch.cat([d7, e1], 1)
        d8 = self.deconv8(F.relu(d7))
        o = torch.tanh(d8)
        return [e2,e3,e4,e5,e6,e7], o

    def extract_decoder_feature(self, input):
        e1 = self.conv1(input)
        e2 = self.conv2_bn(self.conv2(F.leaky_relu(e1, 0.2)))
        e3 = self.conv3_bn(self.conv3(F.leaky_relu(e2, 0.2)))
        e4 = self.conv4_bn(self.conv4(F.leaky_relu(e3, 0.2)))
        e5 = self.conv5_bn(self.conv5(F.leaky_relu(e4, 0.2)))
        e6 = self.conv6_bn(self.conv6(F.leaky_relu(e5, 0.2)))
        e7 = self.conv7_bn(self.conv7(F.leaky_relu(e6, 0.2)))
        e8 = self.conv8(F.leaky_relu(e7, 0.2))
        d1 = self.deconv1_bn(self.deconv1(F.relu(e8)))
        d1 = torch.cat([d1, e7], 1)
        #d2 = F.dropout(self.deconv2_bn(self.deconv2(F.relu(d1))), 0.5, training=True)
        d2_ = self.deconv2_bn(self.deconv2(F.relu(d1)))
        d2 = torch.cat([d2_, e6], 1)
        #d3 = F.dropout(self.deconv3_bn(self.deconv3(F.relu(d2))), 0.5, training=True)
        d3_ = self.deconv3_bn(self.deconv3(F.relu(d2)))
        d3 = torch.cat([d3_, e5], 1)
        d4_ = self.deconv4_bn(self.deconv4(F.relu(d3)))
        # d4 = F.dropout(self.deconv4_bn(self.deconv4(F.relu(d3))), 0.5)
        d4 = torch.cat([d4_, e4], 1)
        d5_ = self.deconv5_bn(self.deconv5(F.relu(d4)))
        d5 = torch.cat([d5_, e3], 1)
        d6_ = self.deconv6_bn(self.deconv6(F.relu(d5)))
        d6 = torch.cat([d6_, e2], 1)
        d7_ = self.deconv7_bn(self.deconv7(F.relu(d6)))
        d7 = torch.cat([d7_, e1], 1)
        d8 = self.deconv8(F.relu(d7))
        o = torch.tanh(d8)
        return [d2_,d3_,d4_,d5_,d6_,d7_], o





def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]):
    net = None
    norm_layer = get_norm_layer(norm_type=norm)

    if netD == 'basic':  # default PatchGAN classifier
        net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer)
    elif netD == 'n_layers':  # more options
        net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer)
    elif netD == 'pixel':     # classify if each pixel is real or fake
        net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer)
    else:
        raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD)
    return init_net(net, init_type, init_gain, gpu_ids)


class NLayerDiscriminator(nn.Module):
    """Defines a PatchGAN discriminator"""

    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
        """Construct a PatchGAN discriminator

        Parameters:
            input_nc (int)  -- the number of channels in input images
            ndf (int)       -- the number of filters in the last conv layer
            n_layers (int)  -- the number of conv layers in the discriminator
            norm_layer      -- normalization layer
        """
        super(NLayerDiscriminator, self).__init__()
        if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        kw = 4
        padw = 1
        sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):  # gradually increase the number of filters
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            sequence += [
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)
        sequence += [
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]

        sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]  # output 1 channel prediction map
        self.model = nn.Sequential(*sequence)

    def forward(self, input):
        """Standard forward."""
        return self.model(input)

def define_D_3D(input_nc, ndf, which_model_netD,
             n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[], use_sigmoid=False):
    netD = None
    use_gpu = len(gpu_ids) > 0
    norm_layer = get_norm_layer(norm_type=norm)

    if use_gpu:
        assert(torch.cuda.is_available())
    if which_model_netD == 'basic':
        netD = NLayerDiscriminator3D(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids)
    elif which_model_netD == 'n_layers':
        netD = NLayerDiscriminator3D(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids)
    else:
        raise NotImplementedError('Discriminator model name [%s] is not recognized' %
                                  which_model_netD)
    if use_gpu:
        netD.cuda(device=gpu_ids[0])
    return init_net(netD, init_type, init_gain, gpu_ids) #new

    #netD.apply(weights_init) #old
    #return netD # old
'''
class NLayerDiscriminator3D(nn.Module):
    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm3d, use_sigmoid=False, gpu_ids=[]):
        super(NLayerDiscriminator3D, self).__init__()
        self.gpu_ids = gpu_ids
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm3d 
        else:
            use_bias = norm_layer == nn.InstanceNorm3d

        kw = 4
        #padw = int(np.ceil((kw-1)/2))
        padw=1
        sequence = [
            nn.Conv3d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), 
            nn.LeakyReLU(0.2, True)
        ]

        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2**n, 8)
            sequence += [
                nn.Conv3d(ndf * nf_mult_prev, ndf * nf_mult,
                          kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2**n_layers, 8)
        sequence += [
            nn.Conv3d(ndf * nf_mult_prev, ndf * nf_mult,
                      kernel_size=kw, stride=1, padding=padw, bias=use_bias),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]

        sequence += [nn.Conv3d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]

        if use_sigmoid:
            sequence += [nn.Sigmoid()]

        self.model = nn.Sequential(*sequence)

    def forward(self, input):
        return self.model(input)
'''
class NLayerDiscriminator3D(nn.Module):
    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm3d, use_sigmoid=False, gpu_ids=[]):
        super(NLayerDiscriminator3D, self).__init__()
        self.gpu_ids = gpu_ids
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm3d 
        else:
            use_bias = norm_layer == nn.InstanceNorm3d

        kw = 4
        #padw = int(np.ceil((kw-1)/2))
        padw=1
        sequence = [
            nn.Conv3d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), 
            nn.LeakyReLU(0.2, True)
        ]

        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2**n, 8)
            sequence += [
                nn.Conv3d(ndf * nf_mult_prev, ndf * nf_mult,
                          kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2**n_layers, 8)
        sequence += [
            nn.Conv3d(ndf * nf_mult_prev, ndf * nf_mult,
                      kernel_size=kw, stride=1, padding=padw, bias=use_bias),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]

        sequence += [nn.Conv3d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]

        if use_sigmoid:
            sequence += [nn.Sigmoid()]

        self.model = nn.Sequential(*sequence)

    def forward(self, input):
        if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor):
            return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
        else:
            return self.model(input)

class discriminator(nn.Module):
    # initializers
    def __init__(self, d=64):
        super(discriminator, self).__init__()
        self.conv1 = nn.Conv2d(2, d, 4, 2, 1)
        self.conv2 = nn.Conv2d(d, d * 2, 4, 2, 1)
        self.conv2_bn = nn.BatchNorm2d(d * 2)
        self.conv3 = nn.Conv2d(d * 2, d * 4, 4, 2, 1)
        self.conv3_bn = nn.BatchNorm2d(d * 4)
        self.conv4 = nn.Conv2d(d * 4, d * 8, 4, 1, 1)
        self.conv4_bn = nn.BatchNorm2d(d * 8)
        self.conv5 = nn.Conv2d(d * 8, 1, 4, 1, 1)

    # weight_init
    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)

    # forward method
    #def forward(self, input, label):
        #x = torch.cat([input, label], 1)
    def forward(self, input): 
        x = F.leaky_relu(self.conv1(input), 0.2)
        x = F.leaky_relu(self.conv2_bn(self.conv2(x)), 0.2)
        x = F.leaky_relu(self.conv3_bn(self.conv3(x)), 0.2)
        x = F.leaky_relu(self.conv4_bn(self.conv4(x)), 0.2)
        x = F.sigmoid(self.conv5(x))

        return x

class discriminator3D(nn.Module):
    # initializers
    def __init__(self, d=64):
        super(discriminator, self).__init__()
        self.conv1 = nn.Conv3d(2, d, 4, 2, 1)
        self.conv2 = nn.Conv3d(d, d * 2, 4, 2, 1)
        self.conv2_bn = nn.BatchNorm3d(d * 2)
        self.conv3 = nn.Conv3d(d * 2, d * 4, 4, 2, 1)
        self.conv3_bn = nn.BatchNorm3d(d * 4)
        self.conv4 = nn.Conv3d(d * 4, d * 8, 4, 1, 1)
        self.conv4_bn = nn.BatchNorm3d(d * 8)
        self.conv5 = nn.Conv3d(d * 8, 1, 4, 1, 1)

    # weight_init
    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init3d(self._modules[m], mean, std)

    # forward method
    #def forward(self, input, label):
        #x = torch.cat([input, label], 1)
    def forward(self, input): 
        x = F.leaky_relu(self.conv1(input), 0.2)
        x = F.leaky_relu(self.conv2_bn(self.conv2(x)), 0.2)
        x = F.leaky_relu(self.conv3_bn(self.conv3(x)), 0.2)
        x = F.leaky_relu(self.conv4_bn(self.conv4(x)), 0.2)
        x = F.sigmoid(self.conv5(x))

        return x

def normal_init(m, mean, std):
    if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
        m.weight.data.normal_(mean, std)
        m.bias.data.zero_()

def normal_init3d(m, mean, std):
    if isinstance(m, nn.ConvTranspose3d) or isinstance(m, nn.Conv3d):
        m.weight.data.normal_(mean, std)
        m.bias.data.zero_()

def normal_initX2(m, mean, std):
    if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose3d) or isinstance(m, nn.Conv3d):
        m.weight.data.normal_(mean, std)
        m.bias.data.zero_()

class GANLoss(nn.Module):
    def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
                 tensor=torch.FloatTensor):
        super(GANLoss, self).__init__()
        #self.real_label = target_real_label
        #self.fake_label = target_fake_label
        #self.real_label_var = None
        #self.fake_label_var = None
        self.register_buffer('real_label', torch.tensor(target_real_label))
        self.register_buffer('fake_label', torch.tensor(target_fake_label))
        #self.Tensor = torch.tensor
        if use_lsgan:
            self.loss = nn.MSELoss()
            print("MSELoss")
        else:
            self.loss = nn.BCELoss()

    def get_target_tensor(self, input, target_is_real):
        '''
        target_tensor = None
        if target_is_real:
            create_label = ((self.real_label_var is None) or
                            (self.real_label_var.numel() != input.numel()))
            if create_label:
                real_tensor = self.Tensor(input.size()).fill_(self.real_label)
                self.real_label_var = Variable(real_tensor, requires_grad=False)
            target_tensor = self.real_label_var
        else:
            create_label = ((self.fake_label_var is None) or
                            (self.fake_label_var.numel() != input.numel()))
            if create_label:
                fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
                self.fake_label_var = Variable(fake_tensor, requires_grad=False)
            target_tensor = self.fake_label_var
        return target_tensor
        '''
        if target_is_real:
            target_tensor = self.real_label
        else:
            target_tensor = self.fake_label
        return target_tensor.expand_as(input)

    def __call__(self, input, target_is_real):
        target_tensor = self.get_target_tensor(input, target_is_real)
        return self.loss(input, target_tensor)

def test():
    x = torch.randn(1,1,256,256).to('cuda')
    x3 = torch.randn(1,1,256,256,256).to('cuda')
    #x = {}
    #x['inputs']=torch.randn(1,1,256,256)
    #x['feats']=torch.randn(1,512,1,1)
    #model_G = define_G(1, 1, 64, 'unet_256')
    #model_G = Multiview_FusionNet(64).to('cuda')
    #preds_G, preds_G3d = model_G(x3,x)
    #print("G:",preds_G.shape)
    model_G = generator3D(64).to('cuda')
    preds_G = model_G(x3)
    #y = torch.randn(1,1,256,256)
    #model_D = define_D(1, 64, 'basic')
    #preds_D = model_D(y)
    #print("D:",preds_D.shape)
    print(model_G)
    summary(model_G, (1,256,256,256))

if __name__ == "__main__":
    test()



