# Based on the ResNet implementation in torchvision
# https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
import math
import torch
from torch import nn
from torchvision.models.resnet import conv3x3
import torch.nn.functional as F
from torchvision import models
from numbers import Number
from torch.autograd import Variable
import torch.nn.init as init
from utils.utils import *
from .models import register_model
from .utils import *

class GlobalPooling(nn.Module):
    def __init__(self):
        super(GlobalPooling, self).__init__()
        self.ada_avg_pool = nn.AdaptiveAvgPool2d([1, 1])
        
    def forward(self, x):
        return self.ada_avg_pool(x).view(x.shape[0], -1)
    
class BasicBlock(nn.Module):
    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.downsample = downsample
        self.stride = stride
        
        self.bn1 = nn.BatchNorm2d(inplanes)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv1 = conv3x3(inplanes, planes, stride)
        
        self.bn2 = nn.BatchNorm2d(planes)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)

    def forward(self, x):
        residual = x 
        residual = self.bn1(residual)
        residual = self.relu1(residual)
        residual = self.conv1(residual)

        residual = self.bn2(residual)
        residual = self.relu2(residual)
        residual = self.conv2(residual)

        if self.downsample is not None:
            x = self.downsample(x)
        return x + residual

class Downsample(nn.Module):
    def __init__(self, nIn, nOut, stride):
        super(Downsample, self).__init__()
        self.avg = nn.AvgPool2d(stride)
        assert nOut % nIn == 0
        self.expand_ratio = nOut // nIn

    def forward(self, x):
        x = self.avg(x)
        return torch.cat([x] + [x.mul(0)] * (self.expand_ratio - 1), 1)
    
resnet_dict = {
    "ResNet34": models.resnet34,
    "ResNet50": models.resnet50, 
    "ResNet101": models.resnet101, 
    "ResNet152": models.resnet152
}

@register_model('resnet34')
class ResNet_34Fc(nn.Module):
    def __init__(self, 
                 bottleneck_dim=256, 
                 class_num=1000, 
                 frozen=[]
                ):
        super(ResNet_34Fc, self).__init__()
        model_resnet = resnet_dict['ResNet34'](pretrained=True)
        self.ordered_module_names = []
        self.frozen=frozen
        print('Frozen Layer: ', self.frozen)
        self.conv1 = model_resnet.conv1
        self.bn1 = model_resnet.bn1
        self.relu = model_resnet.relu
        self.maxpool = model_resnet.maxpool
        self.layer1 = model_resnet.layer1
        self.layer2 = model_resnet.layer2
        self.layer3 = model_resnet.layer3
        self.layer4 = model_resnet.layer4
        self.avgpool = model_resnet.avgpool
        self.num_cls = class_num
        
        self.bottleneck = nn.Linear(
            model_resnet.fc.in_features, 
            bottleneck_dim
        )
        self.fc = nn.Linear(bottleneck_dim, class_num)
        self.bottleneck.apply(init_weights)

        self.fc.apply(init_weights)
        self.__in_features = bottleneck_dim

        self.flattening = Flattening()
        self.feature_layers = [self.conv1, self.bn1, self.layer1, self.layer2, self.layer3, self.layer4, self.bottleneck]
            
        if len(self.frozen) > 0:
            for name in self.frozen:
                if name in self.ordered_module_names:
                    for params in self._modules[name].parameters():
                        params.requires_grad = False

        self.ordered_module_names += [
            'conv1', 'bn1', 'relu', 'maxpool',
            'layer1', 'layer2', 'layer3', 'layer4', 'avgpool'
        ]
            

    def forward(self, x, temp=1, dropout=True, cosine=False, reverse=False):
        for name in self.ordered_module_names:
            module = self._modules[name]
            x = module(x)
            x = x.detach() if name in self.frozen else x
            
        embedding_coding = self.flattening(x)
        rev_rep = grad_reverse(embedding_coding, 1.0) if reverse else embedding_coding
        drop_x = F.dropout(rev_rep, training=self.training, p=0.5) if dropout else rev_rep
        encodes = torch.nn.functional.relu(self.bottleneck(drop_x), inplace=False)
        drop_x = F.dropout(encodes, training=self.training, p=0.5) if dropout else encodes
        if cosine:
            normed_x = F.normalize(drop_x, p=2, dim=1)
            logits = self.fc(normed_x) / temp
        else:
            logits = self.fc(drop_x) / temp
        return {
            'features':embedding_coding, 
            'adapted_layer': encodes, 
            'output_logits': logits
        }
    
    @classmethod
    def create(cls, dicts):
        class_num = dicts['num_cls'] if dicts.get('num_cls', False) else 10
        bottleneck_dim = dicts['adapted_dim'] if dicts.get('adapted_dim', False) else 256
        frozen = dicts['frozen'] if dicts.get('frozen', False) else []
        return cls(
            class_num=class_num, 
            bottleneck_dim=bottleneck_dim, 
            frozen=frozen
        )

    def output_num(self):
        return self.__in_features
    
    def get_classifer_in_features(self):
        return self.bottleneck.in_features

    def get_parameters(self):
        parameter_list = [
            {"params": self.bottleneck.parameters(), 'lr_mult': 10},
        ]
        for name in self.ordered_module_names:
            if name not in self.frozen and len(list(self._modules[name].parameters())) > 0:
                parameter_list += [{"params": self._modules[name].parameters(), 'lr_mult': 1}]
        return parameter_list, [{"params": self.fc.parameters(), 'lr_mult': 10}]
    
@register_model('resnet50')
class ResNet_50Fc(nn.Module):
    def __init__(self, 
                 bottleneck_dim=256, 
                 class_num=1000, 
                 frozen=[]
                ):
        super(ResNet_50Fc, self).__init__()
        model_resnet = resnet_dict['ResNet50'](pretrained=True)
        self.ordered_module_names = []
        self.frozen=frozen
        print('Frozen Layer: ', self.frozen)
        self.conv1 = model_resnet.conv1
        self.bn1 = model_resnet.bn1
        self.relu = model_resnet.relu
        self.maxpool = model_resnet.maxpool
        self.layer1 = model_resnet.layer1
        self.layer2 = model_resnet.layer2
        self.layer3 = model_resnet.layer3
        self.layer4 = model_resnet.layer4
        self.avgpool = model_resnet.avgpool
        self.num_cls = class_num
        
        self.bottleneck = nn.Linear(
            model_resnet.fc.in_features, 
            bottleneck_dim
        )
        self.fc = nn.Linear(bottleneck_dim, class_num)
        self.bottleneck.apply(init_weights)

        self.fc.apply(init_weights)
        self.__in_features = bottleneck_dim

        self.flattening = Flattening()
        self.feature_layers = [self.conv1, self.bn1, self.layer1, self.layer2, self.layer3, self.layer4, self.bottleneck]
            
        if len(self.frozen) > 0:
            for name in self.frozen:
                if name in self.ordered_module_names:
                    for params in self._modules[name].parameters():
                        params.requires_grad = False

        self.ordered_module_names += [
            'conv1', 'bn1', 'relu', 'maxpool',
            'layer1', 'layer2', 'layer3', 'layer4', 'avgpool'
        ]
            

    def forward(self, x, temp=1, dropout=True, cosine=False, reverse=False):
        for name in self.ordered_module_names:
            module = self._modules[name]
            x = module(x)
            x = x.detach() if name in self.frozen else x
            
        embedding_coding = self.flattening(x)
        rev_rep = grad_reverse(embedding_coding, 1.0) if reverse else embedding_coding
        drop_x = F.dropout(rev_rep, training=self.training, p=0.5) if dropout else rev_rep
        encodes = torch.nn.functional.relu(self.bottleneck(drop_x), inplace=False)
        drop_x = F.dropout(encodes, training=self.training, p=0.5) if dropout else encodes
        if cosine:
            normed_x = F.normalize(drop_x, p=2, dim=1)
            logits = self.fc(normed_x) / temp
        else:
            logits = self.fc(drop_x) / temp
        return {
            'features':embedding_coding, 
            'adapted_layer': encodes, 
            'output_logits': logits
        }
    
    @classmethod
    def create(cls, dicts):
        class_num = dicts['num_cls'] if dicts.get('num_cls', False) else 10
        bottleneck_dim = dicts['adapted_dim'] if dicts.get('adapted_dim', False) else 256
        frozen = dicts['frozen'] if dicts.get('frozen', False) else []
        return cls(
            class_num=class_num, 
            bottleneck_dim=bottleneck_dim, 
            frozen=frozen
        )

    def output_num(self):
        return self.__in_features
    
    def get_classifer_in_features(self):
        return self.bottleneck.in_features

    def get_parameters(self):
        parameter_list = [
            {"params": self.bottleneck.parameters(), 'lr_mult': 10},
        ]
        for name in self.ordered_module_names:
            if name not in self.frozen and len(list(self._modules[name].parameters())) > 0:
                parameter_list += [{"params": self._modules[name].parameters(), 'lr_mult': 1}]
        return parameter_list, [{"params": self.fc.parameters(), 'lr_mult': 10}]
    
    
@register_model('resnet101')
class ResNet_101Fc(nn.Module):
    def __init__(self, 
                 bottleneck_dim=256, 
                 class_num=1000, 
                 frozen=[]
                ):
        super(ResNet_101Fc, self).__init__()
        model_resnet = resnet_dict['ResNet101'](pretrained=True)
        self.ordered_module_names = []
        self.frozen=frozen
        print('Frozen Layer: ', self.frozen)
        self.conv1 = model_resnet.conv1
        self.bn1 = model_resnet.bn1
        self.relu = model_resnet.relu
        self.maxpool = model_resnet.maxpool
        self.layer1 = model_resnet.layer1
        self.layer2 = model_resnet.layer2
        self.layer3 = model_resnet.layer3
        self.layer4 = model_resnet.layer4
        self.avgpool = model_resnet.avgpool
        self.num_cls = class_num
        
        self.bottleneck = nn.Linear(
            model_resnet.fc.in_features, 
            bottleneck_dim
        )
        self.fc = nn.Linear(bottleneck_dim, class_num)
        self.bottleneck.apply(init_weights)

        self.fc.apply(init_weights)
        self.__in_features = bottleneck_dim

        self.flattening = Flattening()
        self.feature_layers = [self.conv1, self.bn1, self.layer1, self.layer2, self.layer3, self.layer4, self.bottleneck]
            
        if len(self.frozen) > 0:
            for name in self.frozen:
                if name in self.ordered_module_names:
                    for params in self._modules[name].parameters():
                        params.requires_grad = False

        self.ordered_module_names += [
            'conv1', 'bn1', 'relu', 'maxpool',
            'layer1', 'layer2', 'layer3', 'layer4', 'avgpool'
        ]
            

    def forward(self, x, temp=1, dropout=True, cosine=False, reverse=False):
        for name in self.ordered_module_names:
            module = self._modules[name]
            x = module(x)
            x = x.detach() if name in self.frozen else x
            
        embedding_coding = self.flattening(x)
        rev_rep = grad_reverse(embedding_coding, 1.0) if reverse else embedding_coding
        drop_x = F.dropout(rev_rep, training=self.training, p=0.5) if dropout else rev_rep
        encodes = torch.nn.functional.relu(self.bottleneck(drop_x), inplace=False)
        drop_x = F.dropout(encodes, training=self.training, p=0.5) if dropout else encodes
        if cosine:
            normed_x = F.normalize(drop_x, p=2, dim=1)
            logits = self.fc(normed_x) / temp
        else:
            logits = self.fc(drop_x) / temp
        return {
            'features':embedding_coding, 
            'adapted_layer': encodes, 
            'output_logits': logits
        }
    
    @classmethod
    def create(cls, dicts):
        class_num = dicts['num_cls'] if dicts.get('num_cls', False) else 10
        bottleneck_dim = dicts['adapted_dim'] if dicts.get('adapted_dim', False) else 256
        frozen = dicts['frozen'] if dicts.get('frozen', False) else []
        return cls(
            class_num=class_num, 
            bottleneck_dim=bottleneck_dim, 
            frozen=frozen
        )

    def output_num(self):
        return self.__in_features
    
    def get_classifer_in_features(self):
        return self.bottleneck.in_features

    def get_parameters(self):
        parameter_list = [
            {"params": self.bottleneck.parameters(), 'lr_mult': 10},
        ]
        for name in self.ordered_module_names:
            if name not in self.frozen and len(list(self._modules[name].parameters())) > 0:
                parameter_list += [{"params": self._modules[name].parameters(), 'lr_mult': 1}]
        return parameter_list, [{"params": self.fc.parameters(), 'lr_mult': 10}]
    
    
@register_model('resnet152')
class ResNet_152Fc(nn.Module):
    def __init__(self, 
                 bottleneck_dim=256, 
                 class_num=1000, 
                 frozen=[]
                ):
        super(ResNet_152Fc, self).__init__()
        model_resnet = resnet_dict['ResNet152'](pretrained=True)
        self.ordered_module_names = []
        self.frozen=frozen
        print('Frozen Layer: ', self.frozen)
        self.conv1 = model_resnet.conv1
        self.bn1 = model_resnet.bn1
        self.relu = model_resnet.relu
        self.maxpool = model_resnet.maxpool
        self.layer1 = model_resnet.layer1
        self.layer2 = model_resnet.layer2
        self.layer3 = model_resnet.layer3
        self.layer4 = model_resnet.layer4
        self.avgpool = model_resnet.avgpool
        self.num_cls = class_num
        
        self.bottleneck = nn.Linear(
            model_resnet.fc.in_features, 
            bottleneck_dim
        )
        self.fc = nn.Linear(bottleneck_dim, class_num)
        self.bottleneck.apply(init_weights)

        self.fc.apply(init_weights)
        self.__in_features = bottleneck_dim

        self.flattening = Flattening()
        self.feature_layers = [self.conv1, self.bn1, self.layer1, self.layer2, self.layer3, self.layer4, self.bottleneck]
            
        if len(self.frozen) > 0:
            for name in self.frozen:
                if name in self.ordered_module_names:
                    for params in self._modules[name].parameters():
                        params.requires_grad = False

        self.ordered_module_names += [
            'conv1', 'bn1', 'relu', 'maxpool',
            'layer1', 'layer2', 'layer3', 'layer4', 'avgpool'
        ]
            

    def forward(self, x, temp=1, dropout=True, cosine=False, reverse=False):
        for name in self.ordered_module_names:
            module = self._modules[name]
            x = module(x)
            x = x.detach() if name in self.frozen else x
            
        embedding_coding = self.flattening(x)
        rev_rep = grad_reverse(embedding_coding, 1.0) if reverse else embedding_coding
        drop_x = F.dropout(rev_rep, training=self.training, p=0.5) if dropout else rev_rep
        encodes = torch.nn.functional.relu(self.bottleneck(drop_x), inplace=False)
        drop_x = F.dropout(encodes, training=self.training, p=0.5) if dropout else encodes
        if cosine:
            normed_x = F.normalize(drop_x, p=2, dim=1)
            logits = self.fc(normed_x) / temp
        else:
            logits = self.fc(drop_x) / temp
        return {
            'features':embedding_coding, 
            'adapted_layer': encodes, 
            'output_logits': logits
        }
    
    @classmethod
    def create(cls, dicts):
        class_num = dicts['num_cls'] if dicts.get('num_cls', False) else 10
        bottleneck_dim = dicts['adapted_dim'] if dicts.get('adapted_dim', False) else 256
        frozen = dicts['frozen'] if dicts.get('frozen', False) else []
        return cls(
            class_num=class_num, 
            bottleneck_dim=bottleneck_dim, 
            frozen=frozen
        )

    def output_num(self):
        return self.__in_features
    
    def get_classifer_in_features(self):
        return self.bottleneck.in_features

    def get_parameters(self):
        parameter_list = [
            {"params": self.bottleneck.parameters(), 'lr_mult': 10},
        ]
        for name in self.ordered_module_names:
            if name not in self.frozen and len(list(self._modules[name].parameters())) > 0:
                parameter_list += [{"params": self._modules[name].parameters(), 'lr_mult': 1}]
        return parameter_list, [{"params": self.fc.parameters(), 'lr_mult': 10}]