"""Copyright (c) Facebook, Inc. and its affiliates.
All rights reserved.

This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.

Portions of the source code are from the OLTR project which
notice below and in LICENSE in the root directory of
this source tree.

Copyright (c) 2019, Zhongqi Miao
All rights reserved.
"""


from utils import load_state_dict
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from utils import autocast

def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)

# This class is from LDAM: https://github.com/kaidic/LDAM-DRW.
class NormedLinear(nn.Module):

    def __init__(self, in_features, out_features):
        super(NormedLinear, self).__init__()
        self.weight = nn.Parameter(torch.Tensor(in_features, out_features))
        self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5)

    def forward(self, x):
        out = F.normalize(x, dim=1).mm(F.normalize(self.weight, dim=0))
        return out

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out
    
class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

class ResNet(nn.Module):

    def __init__(self, block, layers, dropout=None, num_classes=1000, use_norm=False,
                 reduce_dimension=False, layer3_output_dim=None, layer4_output_dim=None, load_pretrained_weights=False,
                 returns_feat=False, s=30, reduce_first_kernel=False, zero_init_residual=False):
        self.inplanes = 64
        super(ResNet, self).__init__()
        if not reduce_first_kernel:
            self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        else:
            self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)

        if layer3_output_dim is None:
            if reduce_dimension:
                layer3_output_dim = 192
            else:
                layer3_output_dim = 256

        if layer4_output_dim is None:
            if reduce_dimension:
                layer4_output_dim = 384
            else:
                layer4_output_dim = 512

        self.layer3 = self._make_layer(block, layer3_output_dim, layers[2], stride=2)
        self.layer4 = self._make_layer(block, layer4_output_dim, layers[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        
        self.use_dropout = True if dropout else False

        if self.use_dropout:
            print('Using dropout.')
            self.dropout = nn.Dropout(p=dropout)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck) and m.bn3.weight is not None:
                    nn.init.constant_(m.bn3.weight, 0)  # type: ignore[arg-type]
                elif isinstance(m, BasicBlock) and m.bn2.weight is not None:
                    nn.init.constant_(m.bn2.weight, 0)  # type: ignore[arg-type]

        if use_norm:
            self.linear = NormedLinear(layer4_output_dim * block.expansion, num_classes)
        else:
            s = 1
            self.linear = nn.Linear(layer4_output_dim * block.expansion, num_classes)

        self.returns_feat = returns_feat
        self.s = s

        if load_pretrained_weights:
            caffe_model = True
            if caffe_model:
                print('Loading Caffe Pretrained ResNet 152 Weights.')
                pretrained_weights_state_dict = torch.load('./data/caffe_resnet152.pth')
            else:
                print('Loading Places-LT Pretrained ResNet 152 Weights.')
                pretrained_weights_state_dict = torch.load('./data/places_lt_pretrained.pth')['state_dict_best']['feat_model']
                pretrained_weights_state_dict = {k[7:]: v for k, v in pretrained_weights_state_dict.items()} # remove "module."

            should_ignore = lambda param_name: param_name.startswith('fc') # It's called fc in caffe model.
            
            for k in list(pretrained_weights_state_dict.keys()):
                if should_ignore(k):
                    pretrained_weights_state_dict.pop(k)
                    print("Ignored when loading the model:", k)

            # The number of parameters may mismatch since we don't have num_batches_tracked in the caffe model.
            load_state_dict(self, pretrained_weights_state_dict, no_ignore=True)
            
            print("Warning: We allow training on layer 3 and layer 4.")
            # should_train = lambda param_name: param_name.startswith('layer3') or param_name.startswith('layer4') or param_name.startswith('linear')
            should_train = lambda param_name: param_name.startswith('layer4') or param_name.startswith('linear')
            for name, param in self.named_parameters():
                if not should_train(name):
                    param.requires_grad_(False)
                else:
                    print("Allow gradient on:", name)

    def _hook_before_iter(self):
        assert self.training, "_hook_before_iter should be called at training time only, after train() is called"
        count = 0
        for module in self.modules():
            if isinstance(module, nn.BatchNorm2d):
                if module.weight.requires_grad == False:
                    module.eval()
                    count += 1

        if count > 0:
            print("Warning: detected at least one frozen BN, set them to eval state. Count:", count)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        with autocast():
            x = self.conv1(x)
            x = self.bn1(x)
            x = self.relu(x)
            x = self.maxpool(x)

            x = self.layer1(x)
            x = self.layer2(x)
            x = self.layer3(x)
            x = self.layer4(x)

            x = self.avgpool(x)
            
            x = x.view(x.size(0), -1)
            self.feat = x

            if self.use_dropout:
                x = self.dropout(x)

            x = self.linear(x)

            x = x * self.s # This hyperparam s is originally in the loss function, but we moved it here to prevent using s multiple times in distillation.
            
        if self.returns_feat:
            return {
                "output": x, 
                "feat": self.feat
            }
        else:
            return x