import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

from nemo.models.unet import unet_res50
from nemo.models.upsampling_layer import DoubleConv
from nemo.models.upsampling_layer import Up
from nemo.models.vqvae_encoder import VQEncoder

from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models.feature_extraction import create_feature_extractor

vgg_layers = {"pool4": 24, "pool5": 31}
#* reduction proportion in image h/w to h/net_stride, etc.?
net_stride = {
    "vgg_pool4": 16,
    "vgg_pool5": 32,
    "resnet50": 32,
    "resnext50": 32,
    "resnetext": 8,
    "resnetupsample": 8,
    "resnetext2": 4,
    "resnetext3": 4,
    "unet_res50": 1,
    "frozen_resnetext": 8,
    "frozen_resnetext2": 8,
    "frozen_vgg_pool4" : 16,
    "vqvae1": 8,
    "vqvae2": 8,
}
#* number of channels in backbone feature output?
net_out_dimension = {
    "vgg_pool4": 512,
    "vgg_pool5": 512,
    "resnet50": 2048,
    "resnext50": 2048,
    "resnetext": 256,
    "resnetupsample": 2048,
    "resnetext2": 256,
    "resnetext3": 256,
    "unet_res50": 64,
    "frozen_resnetext": 256,
    "frozen_resnetext2": 256,
    "frozen_vgg_pool4" : 512,
    "vqvae1": 256,
    "vqvae2": 256,
}

class ResnetUpSample(nn.Module):
    def __init__(self, pretrained):
        super().__init__()
        net = models.resnet50(pretrained=pretrained)
        self.upsample = nn.Upsample(scale_factor=4, mode="bilinear", align_corners=True)
        self.extractor = nn.Sequential()
        self.extractor.add_module("0", net.conv1)
        self.extractor.add_module("1", net.bn1)
        self.extractor.add_module("2", net.relu)
        self.extractor.add_module("3", net.maxpool)
        self.extractor.add_module("4", net.layer1)
        self.extractor.add_module("5", net.layer2)
        self.extractor.add_module("6", net.layer3)
        self.extractor.add_module("7", net.layer4)

    def forward(self, x):
        x = self.extractor(x)
        return self.upsample(x)


class ResNetExt2(nn.Module):
    def __init__(self, pretrained):
        super().__init__()
        net = models.resnet50(pretrained=pretrained)
        self.extractor = nn.Sequential()
        self.extractor.add_module("0", net.conv1)
        self.extractor.add_module("1", net.bn1)
        self.extractor.add_module("2", net.relu)
        self.extractor.add_module("3", net.maxpool)
        self.extractor.add_module("4", net.layer1)
        self.extractor0 = net.layer2
        self.extractor1 = net.layer3
        self.extractor2 = net.layer4

        self.upsample3 = DoubleConv(2048, 1024)
        self.upsample0 = Up(2048, 1024, 512)
        self.upsample1 = Up(1024, 512, 256)
        self.upsample2 = Up(512, 512, 256)

    def forward(self, x):
        x1 = self.extractor(x)  # 256
        x2 = self.extractor0(x1)  # 512
        x3 = self.extractor1(x2)  # 1024
        x4 = self.extractor2(x3)  # 2048
        ret = self.upsample3(x4)
        ret = self.upsample0(ret, x3)
        ret = self.upsample1(ret, x2)
        ret = self.upsample2(ret, x1)
        return ret


class ResNetExt3(nn.Module):
    def __init__(self, pretrained):
        super().__init__()
        net = models.resnet50(pretrained=pretrained)
        self.extractor = nn.Sequential()
        self.extractor.add_module("0", net.conv1)
        self.extractor.add_module("1", net.bn1)
        self.extractor.add_module("2", net.relu)
        self.extractor.add_module("3", net.maxpool)
        self.extractor.add_module("4", net.layer1)
        self.extractor0 = net.layer2
        self.extractor1 = net.layer3
        self.extractor2 = net.layer4

        self.upsample0 = Up(3072, 1024, 512)
        self.upsample1 = Up(1024, 512, 256)
        self.upsample2 = Up(512, 512, 256)

    def forward(self, x):
        x1 = self.extractor(x)
        x2 = self.extractor0(x1)
        x3 = self.extractor1(x2)
        x4 = self.extractor2(x3)
        ret = self.upsample0(x4, x3)
        ret = self.upsample1(ret, x2)
        ret = self.upsample2(ret, x1)
        return ret


class ResNetExt(nn.Module):
    def __init__(self, pretrained):
        super().__init__()
        net = models.resnet50(pretrained=pretrained)
        self.extractor = nn.Sequential()
        self.extractor.add_module("0", net.conv1)
        self.extractor.add_module("1", net.bn1)
        self.extractor.add_module("2", net.relu)
        self.extractor.add_module("3", net.maxpool)
        self.extractor.add_module("4", net.layer1)
        self.extractor.add_module("5", net.layer2)
        self.extractor1 = net.layer3
        self.extractor2 = net.layer4

        self.upsample0 = DoubleConv(2048, 1024)
        self.upsample1 = Up(2048, 1024, 512)
        self.upsample2 = Up(1024, 512, 256)

    def forward(self, x):
        x1 = self.extractor(x)
        x2 = self.extractor1(x1)
        x3 = self.extractor2(x2)
        return self.upsample2(self.upsample1(self.upsample0(x3), x2), x1)

class Mod_ResNetExt(nn.Module):
    def __init__(self, pretrained):
        super().__init__()
        # self.extractor = ResNetExt(pretrained=pretrained) #models.resnet50(pretrained=pretrained)
        net = models.resnet50(pretrained=pretrained)
        self.extractor = nn.Sequential()
        self.extractor.add_module("0", net.conv1)
        self.extractor.add_module("1", net.bn1)
        self.extractor.add_module("2", net.relu)
        self.extractor.add_module("3", net.maxpool)
        self.extractor.add_module("4", net.layer1)
        self.extractor.add_module("5", net.layer2)
        self.extractor1 = net.layer3
        self.extractor2 = net.layer4
        for param in self.extractor.parameters():
            param.requires_grad = False
        for param in self.extractor1.parameters():
            param.requires_grad = False
        for param in self.extractor2.parameters():
            param.requires_grad = False

        self.upsample0 = DoubleConv(2048, 1024)
        self.upsample1 = Up(2048, 1024, 512)
        self.upsample2 = Up(1024, 512, 256)

    def forward(self, x):
        x1 = self.extractor(x)
        x2 = self.extractor1(x1)
        x3 = self.extractor2(x2)
        # for param in self.extractor.parameters():
        #     if param.requires_grad == True:
        #         raise(RuntimeError)
        #         exit()
        #     else:
        #         print(param.requires_grad)
        # exit()
        return self.upsample2(self.upsample1(self.upsample0(x3), x2), x1)
    
class Mod_ResNetExt2(nn.Module):
    #* Add MLP
    def __init__(self, pretrained):
        super().__init__()
        # self.extractor = ResNetExt(pretrained=pretrained) #models.resnet50(pretrained=pretrained)
        net = models.resnet50(pretrained=pretrained)
        self.extractor = nn.Sequential()
        self.extractor.add_module("0", net.conv1)
        self.extractor.add_module("1", net.bn1)
        self.extractor.add_module("2", net.relu)
        self.extractor.add_module("3", net.maxpool)
        self.extractor.add_module("4", net.layer1)
        self.extractor.add_module("5", net.layer2)
        self.extractor1 = net.layer3
        self.extractor2 = net.layer4
        for param in self.extractor.parameters():
            param.requires_grad = False
        for param in self.extractor1.parameters():
            param.requires_grad = False
        for param in self.extractor2.parameters():
            param.requires_grad = False

        self.upsample0 = DoubleConv(2048, 1024)
        self.upsample1 = Up(2048, 1024, 512)
        self.upsample2 = Up(1024, 512, 256)
        
        self.doubleconv1 = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=1, padding=0),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 256, kernel_size=1, padding=0),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=1, padding=0),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        x1 = self.extractor(x)
        x2 = self.extractor1(x1)
        x3 = self.extractor2(x2)
        # for param in self.extractor.parameters():
        #     if param.requires_grad == True:
        #         raise(RuntimeError)
        #         exit()
        #     else:
        #         print(param.requires_grad)
        # exit()
        # return self.upsample2(self.upsample1(self.upsample0(x3), x2), x1)
        x4 = self.upsample2(self.upsample1(self.upsample0(x3), x2), x1)

        return self.doubleconv1(x4)


def resnetupsample(pretrain):
    net = ResnetUpSample(pretrained=pretrain)
    return net


def resnetext(pretrain):
    net = ResNetExt(pretrained=pretrain)
    return net


def resnetext2(pretrain):
    net = ResNetExt2(pretrained=pretrain)
    return net


def resnetext3(pretrain):
    net = ResNetExt3(pretrained=pretrain)
    return net


def vgg16(layer="pool4"):
    net = models.vgg16(pretrained=True)
    model = nn.Sequential()
    features = nn.Sequential()
    for i in range(0, vgg_layers[layer]):
        features.add_module("{}".format(i), net.features[i])
    model.add_module("features", features)
    return model

def frozen_resnetext(pretrain):
    net = Mod_ResNetExt(pretrained=pretrain)
    return net

def frozen_resnetext2(pretrain):
    net = Mod_ResNetExt2(pretrained=pretrain)
    return net

def mod_vgg16(layer="pool4"):
    #! not done
    net = models.vgg16(pretrained=True)
    model = nn.Sequential()
    features = nn.Sequential()
    for i in range(0, vgg_layers[layer]):
        features.add_module("{}".format(i), net.features[i])
    model.add_module("features", features)
    for param in model.parameters():
        param.requires_grad = False
    if layer == "pool4":
        doubleconv1 = nn.Sequential(
                nn.Conv2d(512, 256, kernel_size=1, padding=0),
                nn.BatchNorm2d(256),
                nn.ReLU(inplace=True),
                nn.Conv2d(256, 256, kernel_size=1, padding=0),
                nn.BatchNorm2d(256),
                nn.ReLU(inplace=True),
                nn.Conv2d(256, 512, kernel_size=1, padding=0),
                nn.BatchNorm2d(512),
                nn.ReLU(inplace=True),
            )
    elif layer == "pool5":
        doubleconv1 = nn.Sequential(
                nn.Conv2d(256, 128, kernel_size=1, padding=0),
                nn.BatchNorm2d(128),
                nn.ReLU(inplace=True),
                nn.Conv2d(128, 256, kernel_size=1, padding=0),
                nn.BatchNorm2d(256),
                nn.ReLU(inplace=True),
            )
    else:
        raise(RuntimeError)
    model.add_module("trainable", doubleconv1)
    return model

class Vqvae_Plus(nn.Module):
    def __init__(self, type=1):
        super().__init__()
        # self.extractor = ResNetExt(pretrained=pretrained) #models.resnet50(pretrained=pretrained)
        self.net = VQEncoder()
        for param in self.net.parameters():
            if param.requires_grad == True:
                raise RuntimeError
        self.extractor = nn.Sequential()
        if type == 1:
            #* 1x1 conv
            doubleconv1 = nn.Sequential(
                    nn.Conv2d(256, 512, kernel_size=1, padding=0),
                    nn.BatchNorm2d(512),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(512, 512, kernel_size=1, padding=0),
                    nn.BatchNorm2d(512),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(512, 256, kernel_size=1, padding=0),
                    nn.BatchNorm2d(256),
                    nn.ReLU(inplace=True),
                )
        elif type == 2:
            #* Simple Cifar Autoencoder
            encoder = nn.Sequential(
            nn.Conv2d(256, 512, 4, stride=2, padding=1),            # [batch, 12, 16, 16]
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, 4, stride=2, padding=1),           # [batch, 24, 8, 8]
            nn.ReLU(inplace=True),
			nn.Conv2d(512,1024, 4, stride=2, padding=1),           # [batch, 48, 4, 4]
            nn.ReLU(inplace=True),
# 			nn.Conv2d(48, 96, 4, stride=2, padding=1),           # [batch, 96, 2, 2]
#             nn.ReLU(),
            )
            decoder = nn.Sequential(
    #             nn.ConvTranspose2d(96, 48, 4, stride=2, padding=1),  # [batch, 48, 4, 4]
    #             nn.ReLU(),
                nn.ConvTranspose2d(1024, 512, 4, stride=2, padding=1),  # [batch, 24, 8, 8]
                nn.ReLU(inplace=True),
                nn.ConvTranspose2d(512, 512, 4, stride=2, padding=1, output_padding=(0,1)),  # [batch, 12, 16, 16]
                nn.ReLU(inplace=True),
                nn.ConvTranspose2d(512, 256, 4, stride=2, padding=(1,0)),   # [batch, 3, 32, 32]
                # nn.Sigmoid(),
                nn.ReLU(inplace=True),
            )
            doubleconv1 = nn.Sequential()
            doubleconv1.add_module("simple_encoder", encoder)
            doubleconv1.add_module("simple_decoder", decoder)
        else:
            raise(RuntimeError)

        self.extractor.add_module("trainable", doubleconv1)

    def forward(self, x):
        x1 = self.net(x) # 3x320x672 -> 256x40x84
        x2 = self.extractor(x1)
        # print(x1.shape, x2.shape)
        # for param in self.extractor.parameters():
        #     if param.requires_grad == True:
        #         raise(RuntimeError)
        #         exit()
        #     else:
        #         print(param.requires_grad)
        # exit()
        return x2

def resnet50(pretrain):
    net = models.resnet50(pretrained=pretrain)
    extractor = nn.Sequential()
    extractor.add_module("0", net.conv1)
    extractor.add_module("1", net.bn1)
    extractor.add_module("2", net.relu)
    extractor.add_module("3", net.maxpool)
    extractor.add_module("4", net.layer1)
    extractor.add_module("5", net.layer2)
    extractor.add_module("6", net.layer3)
    extractor.add_module("7", net.layer4)
    return extractor


# original_img_size = torch.Size([224, 300])
# calculate which patch contains kp. if (1, 1) and line size = 9, return 1*9+1 = 10
def keypoints_to_pixel_index(keypoints, downsample_rate, original_img_size=(480, 640)):
    # line_size = 9
    line_size = original_img_size[1] // downsample_rate
    # round down, new coordinate (keypoints[:,:,0]//downsample_rate, keypoints[:, :, 1] // downsample_rate)
    return (
        keypoints[:, :, 0] // downsample_rate * line_size
        + keypoints[:, :, 1] // downsample_rate
    )


def get_noise_pixel_index(keypoints, max_size, n_samples, obj_mask=None):
    n = keypoints.shape[0]

    # remove the point in keypoints by set probability to 0 otherwise 1 -> mask [n, size] with 0 or 1
    mask = torch.ones((n, max_size), dtype=torch.float32).to(keypoints.device)
    mask = mask.scatter(1, keypoints.type(torch.long), 0.0)
    if obj_mask is not None:
        mask *= obj_mask

    # generate the sample by the probabilities
    try:
        return torch.multinomial(mask, n_samples)
    except:
        return None
    """
    return torch.multinomial(mask, n_samples)
    """


class GlobalLocalConverter(nn.Module):
    def __init__(self, local_size):
        super().__init__()
        """ No used."""
        self.local_size = local_size
        self.padding = sum(([t - 1 - t // 2, t // 2] for t in local_size[::-1]), [])

    def forward(self, X):
        n, c, h, w = X.shape  # torch.Size([1, 2048, 8, 8])

        # N, C, H, W -> N, C, H + local_size0 - 1, W + local_size1 - 1
        X = F.pad(X, self.padding)

        # N, C, H + local_size0 - 1, W + local_size1 - 1 -> N, C * local_size0 * local_size1, H * W
        X = F.unfold(X, kernel_size=self.local_size)

        # N, C * local_size0 * local_size1, H * W -> N, C, local_size0, local_size1, H * W
        # X = X.view(n, c, *self.local_size, -1)

        # X:  N, C * local_size0 * local_size1, H * W
        return X


class MergeReduce(nn.Module):
    def __init__(self, reduce_method="mean"):
        super().__init__()
        self.reduce_method = reduce_method
        self.local_size = -1

    def register_local_size(self, local_size):
        self.local_size = local_size[0] * local_size[1]
        if self.reduce_method == "mean":
            self.foo_test = torch.nn.AvgPool2d(
                local_size,
                stride=1,
                padding=local_size[0] // 2,
            )
        elif self.reduce_method == "max":
            self.foo_test = torch.nn.MaxPool2d(
                local_size,
                stride=1,
                padding=local_size[0] // 2,
            )

    def forward(self, X):

        X = X.view(X.shape[0], -1, self.local_size, X.shape[2])
        if self.reduce_method == "mean":
            return torch.mean(X, dim=2)
        elif self.reduce_method == "max":
            return torch.max(X, dim=2)

    def forward_test(self, X):
        return self.foo_test(X)


def batched_index_select(t, dim, inds):
    dummy = inds.unsqueeze(2).expand(inds.size(0), inds.size(1), t.size(2))
    out = t.gather(dim, dummy)  # b * e * f
    return out


class NetE2E(nn.Module):
    def __init__(
        self,
        pretrain,
        net_type,
        local_size,
        output_dimension,
        reduce_function=None,
        n_noise_points=0,
        num_stacks=8,
        num_blocks=1,
        noise_on_mask=True,
        **kwargs
    ):
        # output_dimension = 128
        super().__init__()
        if net_type == "vgg_pool4":
            self.net = vgg16("pool4")
        elif net_type == "vgg_pool5":
            self.net = vgg16("pool5")
        elif net_type == "resnet50":
            self.net = resnet50(pretrain)
        elif net_type == "resnetext":
            self.net = resnetext(pretrain)
        elif net_type == "resnetext2":
            self.net = resnetext2(pretrain)
        elif net_type == "resnetext3":
            self.net = resnetext3(pretrain)
        elif net_type == "resnetupsample":
            self.net = resnetupsample(pretrain)
        elif net_type == "unet_res50":
            self.net = unet_res50(pretrain)
        elif net_type == "frozen_resnetext":
            #! Not Done
            self.net = frozen_resnetext(pretrain)
        elif net_type == "frozen_resnetext2":
            #! Not Done
            self.net = frozen_resnetext2(pretrain)
        elif net_type == "frozen_vgg_pool4":
            self.net = mod_vgg16("pool4")            
        elif net_type == "vqvae1":
            self.net = Vqvae_Plus(type=1)
        elif net_type == "vqvae2":
            self.net = Vqvae_Plus(type=2)
        else:
            raise RuntimeError            
            
        self.size_number = local_size[0] * local_size[1]
        self.output_dimension = output_dimension
        # size_number = reduce((lambda x, y: x * y), local_size)
        if reduce_function: #! redundant
            reduce_function.register_local_size(local_size)
            self.size_number = 1

        self.reduce_function = reduce_function
        self.net_type = net_type
        self.net_stride = net_stride[net_type]
        self.converter = GlobalLocalConverter(local_size) #!
        self.noise_on_mask = noise_on_mask #! mask or not

        # output_dimension == -1 for abilation study.
        if self.output_dimension == -1:
            self.out_layer = None
        else:
            self.out_layer = nn.Linear(
                net_out_dimension[net_type] * self.size_number, self.output_dimension
            )
            # output_dimension , net_out_dimension[net_type] * size_number

        self.n_noise_points = n_noise_points
        # self.norm_layer = lambda x: F.normalize(x, p=2, dim=1)

    # forward
    def forward_test(self, X):
        # Feature map n, c, w, h -- 1, 128, 128, 128
        X = self.net.forward(X)

        # Never used
        if self.reduce_function:
            X = self.reduce_function.forward_test(X)

        if self.output_dimension == -1:
            return F.normalize(X, p=2, dim=1)
        if self.size_number == 1:
            X = torch.nn.functional.conv2d(
                X, self.out_layer.weight.unsqueeze(2).unsqueeze(3)
            )
        elif self.size_number > 1:
            X = torch.nn.functional.conv2d(
                X,
                self.out_layer.weight.view(
                    self.output_dimension,
                    net_out_dimension[self.net_type],
                    self.size_number,
                )
                .permute(2, 0, 1)
                .reshape(
                    self.size_number * self.output_dimension,
                    net_out_dimension[self.net_type],
                )
                .unsqueeze(2)
                .unsqueeze(3),
            )
        # n, c, w, h
        # 1, 128, (w_original - 1) // 32 + 1, (h_original - 1) // 32 + 1
        return F.normalize(X, p=2, dim=1)

    def forward(self, X, keypoint_positions, obj_mask=None, return_map=False):
        """feature for each vertex"""
        # X=torch.ones(1, 3, 224, 300), kps = torch.tensor([[(36, 40), (90, 80)]])
        # n images, k keypoints and 2 states.
        # Keypoint input -> n * k * 2 (k keypoints for n images) (must be position on original image)

        n = X.shape[0]  # n = 1
        img_shape = X.shape[2::]

        # downsample_rate = 32
        m = self.net.forward(X)

        # N, C * local_size0 * local_size1, H * W
        X = self.converter(m)

        keypoint_idx = keypoints_to_pixel_index(
            keypoints=keypoint_positions,
            downsample_rate=self.net_stride,
            original_img_size=img_shape,
        ).type(torch.long) # mapping kps to their corresponding pixels in the feat map

        # Never use this reduce_function part.
        if self.reduce_function:
            X = self.reduce_function(X)

        if self.n_noise_points == 0:
            keypoint_all = keypoint_idx
        else:
            if obj_mask is not None:
                obj_mask = F.max_pool2d(
                    obj_mask.unsqueeze(dim=1),
                    kernel_size=self.net_stride,
                    stride=self.net_stride,
                    padding=(self.net_stride - 1) // 2,
                )
                obj_mask = obj_mask.view(obj_mask.shape[0], -1)
                assert obj_mask.shape[1] == X.shape[2], (
                    "mask_: " + str(obj_mask.shape) + " fearture_: " + str(X.shape)
                )
            if self.noise_on_mask:
                keypoint_noise = get_noise_pixel_index(
                    keypoint_idx,
                    max_size=X.shape[2],
                    n_samples=self.n_noise_points,
                    obj_mask=obj_mask,
                )
            else:
                keypoint_noise = get_noise_pixel_index(
                    keypoint_idx,
                    max_size=X.shape[2],
                    n_samples=self.n_noise_points,
                    obj_mask=None,
                )

            if keypoint_noise is None:
                return None

            keypoint_all = torch.cat((keypoint_idx, keypoint_noise), dim=1)

        # N, C * local_size0 * local_size1, H * W -> N, H * W, C * local_size0 * local_size1
        X = torch.transpose(X, 1, 2)

        # N, H * W, C * local_size0 * local_size1 -> N, keypoint_all, C * local_size0 * local_size1
        X = batched_index_select(X, dim=1, inds=keypoint_all)

        # L2norm, fc layer, -> dim along d
        if self.out_layer is None:
            X = F.normalize(X, p=2, dim=2)
            X = X.view(n, -1, net_out_dimension[self.net_type])
        else:
            X = F.normalize(self.out_layer(X), p=2, dim=2)
            X = X.view(n, -1, self.out_layer.weight.shape[0])

        # n * k * output_dimension
        if return_map:
            return X, F.normalize(
                torch.nn.functional.conv2d(
                    m, self.out_layer.weight.unsqueeze(2).unsqueeze(3)
                ),
                p=2,
                dim=1,
            )
        return X

    def cuda(self, device=None):
        self.net.cuda(device=device)
        self.out_layer.cuda(device=device)
        return self

    # def forward_feature(self, X):
    #     train_nodes, eval_nodes = get_graph_node_names(self.net())
    #     features = self.net.forward(X)
    #     return features
    
class NetE2E_mod(nn.Module):
    """For saving features"""
    def __init__(
        self,
        pretrain,
        net_type,
        local_size,
        output_dimension,
        reduce_function=None,
        n_noise_points=0,
        num_stacks=8,
        num_blocks=1,
        noise_on_mask=True,
        **kwargs
    ):
        # output_dimension = 128
        super().__init__()
        if net_type == "vgg_pool4":
            self.net = vgg16("pool4")
        elif net_type == "vgg_pool5":
            self.net = vgg16("pool5")
        elif net_type == "resnet50":
            self.net = resnet50(pretrain)
        elif net_type == "resnetext":
            self.net = resnetext(pretrain)
        elif net_type == "resnetext2":
            self.net = resnetext2(pretrain)
        elif net_type == "resnetext3":
            self.net = resnetext3(pretrain)
        elif net_type == "resnetupsample":
            self.net = resnetupsample(pretrain)
        elif net_type == "unet_res50":
            self.net = unet_res50(pretrain)
        elif net_type == "frozen_resnetext":
            #! Not Done
            self.net = frozen_resnetext(pretrain)
        elif net_type == "frozen_resnetext2":
            #! Not Done
            self.net = frozen_resnetext2(pretrain)
        elif net_type == "frozen_vgg_pool4":
            self.net = mod_vgg16("pool4")
            # self.net_og = vgg16("pool4")
        else:
            raise RuntimeError            
            
        self.size_number = local_size[0] * local_size[1]
        self.output_dimension = output_dimension
        # size_number = reduce((lambda x, y: x * y), local_size)
        if reduce_function:
            reduce_function.register_local_size(local_size)
            self.size_number = 1

        self.reduce_function = reduce_function
        self.net_type = net_type
        self.net_stride = net_stride[net_type]
        self.converter = GlobalLocalConverter(local_size)
        self.noise_on_mask = noise_on_mask

        # output_dimension == -1 for abilation study.
        if self.output_dimension == -1:
            self.out_layer = None
        else:
            self.out_layer = nn.Linear(
                net_out_dimension[net_type] * self.size_number, self.output_dimension
            )
            # output_dimension , net_out_dimension[net_type] * size_number

        self.n_noise_points = n_noise_points
        # self.norm_layer = lambda x: F.normalize(x, p=2, dim=1)

    def forward(self, X, keypoint_positions=None, obj_mask=None, return_map=False):
        m = self.net.forward(X)
        # m2 = self.net_og.forward(X)
        return m

    def cuda(self, device=None):
        self.net.cuda(device=device)
        self.out_layer.cuda(device=device)
        return self

    def forward_feature(self, X):
        # torch.save(X, 'pre_del4.pth')
        features = self.net.forward(X)
        # torch.save(features, 'del4.pth')
        # exit()
        # m2 = self.net_og.forward(X)
        return features