import torch.nn as nn
import torch
import torch.nn.functional as F
import numpy as np
from torch.nn.modules.module import Module
from torch.nn.parameter import Parameter
import math
import torch.utils.model_zoo as model_zoo
from utils.common import interpolateFeatures,_get_num_features
from utils.route_norm import RouteNorm
from utils.trans_norm import TransNorm2d


__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
           'resnet152']

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual

        out = self.relu(out)

        return out



class Identity(nn.Module):
    def __init__(self,tg_layer_id,src_layer_id, planes):
        super(Identity, self).__init__()
        self.tg_layer_id = tg_layer_id
        self.src_layer_id = src_layer_id
        # self.bn = nn.BatchNorm2d(planes)

    def forward(self, input1, input2):
        return input1


class SimpleAdd(nn.Module):
    def __init__(self,tg_layer_id,src_layer_id, planes):
        super(SimpleAdd, self).__init__()
        self.tg_layer_id = tg_layer_id
        self.src_layer_id = src_layer_id
        # self.bn = nn.BatchNorm2d(planes)

    def forward(self, input1, input2):
        y = input1 + input2
        # y=self.bn(y)

        return y


class WeightedAdd(nn.Module):
    def __init__(self,tg_layer_id,src_layer_id, planes):
        super(WeightedAdd, self).__init__()
        self.tg_layer_id = tg_layer_id
        self.src_layer_id = src_layer_id
        self.wt = nn.Parameter(torch.ones(2) * 0.5)
        # self.bn = nn.BatchNorm2d(planes)

    def forward(self, input1, input2):
        y = self.wt[0] * input1 + self.wt[1] * input2
        # y=self.bn(y)
        return y


class LinearCombine(nn.Module):
    def __init__(self,tg_layer_id,src_layer_id, planes):
        super(LinearCombine, self).__init__()
        self.tg_layer_id = tg_layer_id
        self.src_layer_id = src_layer_id
        self.su = nn.Linear(planes, 1)
        self.tu = nn.Linear(planes, 1)
        nn.init.constant_(self.su.bias, 1.0)
        nn.init.constant_(self.tu.bias, 1.0)

    def forward(self, input1, input2):
        f1 = F.avg_pool2d(input1, input1.size(2)).view(-1, input1.size(1))
        f2 = F.avg_pool2d(input2, input2.size(2)).view(-1, input2.size(1))
        y = self.tu(f1).reshape(input1.size(0), 1, 1, 1) * input1 + self.su(f2).reshape(input2.size(0), 1, 1, 1) * input2
        return y


class FactorizedReduce(nn.Module):
  def __init__(self, tg_layer_id,src_layer_id, planes):
    super(FactorizedReduce, self).__init__()
    assert planes % 2 == 0
    self.tg_layer_id = tg_layer_id
    self.src_layer_id = src_layer_id
    self.relu = nn.ReLU(inplace=False)
    self.conv_1 = nn.Conv2d(planes, (np.ceil(planes/2).astype(int)), 1,  bias=False)
    self.conv_2 = nn.Conv2d(planes, (np.floor(planes/2).astype(int)), 1, bias=False)
    self.bn = nn.BatchNorm2d(planes)

  def forward(self, input1, input2):
    input1 = self.relu(input1)
    input2 = self.relu(input2)
    out = torch.cat([self.conv_1(input1), self.conv_2(input2)], dim=1)
    out = self.bn(out)
    return out

OPS = {
    'iden' : lambda tg_id,src_id, planes: Identity(tg_id,src_id, planes),
    'simpleadd': lambda tg_id, src_id, planes: SimpleAdd(tg_id, src_id, planes),
    'wtadd': lambda tg_id, src_id, planes: WeightedAdd(tg_id, src_id, planes),
    'lincombine': lambda tg_id, src_id, planes: LinearCombine(tg_id, src_id, planes),
    'factred': lambda tg_id, src_id, planes: FactorizedReduce(tg_id, src_id, planes),
}

class CrossStitch(nn.Module):
    def __init__(self,tg_layer_id,src_layer_id, planes):
        super(CrossStitch, self).__init__()
        self.tg_layer_id = tg_layer_id
        self.src_layer_id = src_layer_id
        self.wt = nn.Parameter(torch.ones(2) * 0.5)
        # self.bn = nn.BatchNorm2d(planes)

    def forward(self, input1, input2):
        y = self.wt[0] * input1 + self.wt[1] * input2
        # y=self.bn(y)

        return y


class ResidualLayer(nn.Sequential): # Shareable Residual Layer
    def __init__(self, id, block, inplanes, planes, blocks, block_hidden_dim, stride=1, decisioner=None,
                 source_input_info=None, ru_units=True):
        self.id = id
        self.inplanes = inplanes
        self.num_blocks = blocks
        self.block_hidden_dim = block_hidden_dim
        self.decisioner = decisioner
        self.route_wts = []
        self.ru_units = ru_units

        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        super(ResidualLayer, self).__init__(*layers)

        if self.ru_units:
            self.ru = nn.ModuleList([])
        self.cs = nn.ModuleDict({})
        self.id2mod = None
        if source_input_info:
            source_feature_ids, source_input_channels = source_input_info
            self.id2mod = {ids: i for i, ids in enumerate(source_feature_ids)}  # maps arm d to the module
            self.id2mod[-1] = -1
            self.source_input_channels = source_input_channels  # [64, 64, 128, 256, 512]

            if self.ru_units: # Used for fixed configuration with same shape tensors
                self.ru.extend([nn.Sequential(
                    nn.Conv2d(src_planes, planes, 1),
                    nn.BatchNorm2d(planes),
                ) for src_planes in self.source_input_channels])

            for name in OPS:
                self.cs[name] = nn.ModuleList([OPS[name](id, i, planes)for i in range(len(self.source_input_channels))])
            self.cs['xstitch'] = nn.ModuleList([CrossStitch(id, i, inplanes) for i in range(len(self.source_input_channels))])

    def forward(self, input, src_input=None):
        ref_id, transfer_type, ref_input = -1, '', None
        if src_input is not None:
            src_id, transfer_type, ref_input = src_input
            ref_id = self.id2mod[src_id] if self.id2mod else -1
        
        for idx in range(self.num_blocks):
            input = self._modules.get(str(idx))(input)
        
        output = input
        if ref_input is not None:
            if self.ru_units:
                if type(ref_input) is not list:
                    ref_input = interpolateFeatures(ref_input, input.size(3)) # Matches the last two dimensions
                    ref_input = self.ru[ref_id](ref_input)
                else:
                    ref_input = [interpolateFeatures(rin, input.size(3)) for rin in ref_input]
                    ref_input = [rproj(rin) for (rin, rproj) in zip(ref_input, self.ru)]
                    ref_input.append(torch.zeros_like(input)) # PASS option or sentinel attention
            if ref_id == -1 and transfer_type != 'attention':
                # take expectation over the ref_inputs
                ref_input = torch.stack(ref_input,dim=-1)
                p = torch.from_numpy(self.decisioner.probabilities).float().to(input.device)
                ref_input = torch.matmul(ref_input,p)
                ref_id = self.id-1
            if transfer_type in OPS:
                # output=self.cs[transfer_type][ref_id](input, self.ru[ref_id](ref_input))
                output = self.cs[transfer_type][ref_id](input, ref_input)
            else:
                raise Exception("No valid transfer type found. Supported type: Cross Stitch")
        return output


class ResNet(nn.Module):
    source_feature_ids = [0, 1, 2, 3, 4] # one for each internal representation
    target_route_ids = [0, 1, 2, 3] # one for each ResidualLayer

    def __init__(self, block, layers, num_classes=1000, name="resnet", init_weights=True,
                 source_info=None, ru_units=True):
        super(ResNet, self).__init__()

        self.num_classes = num_classes
        self.inplanes = 64
        self.nlayers = len(layers)
        self.channels = _get_num_features(name)

        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        if source_info:
            self.source_feature_ids, self.source_input_pass_id, decisioners, src_channels = source_info
            self.source_input_channels = [src_channels[ids] for ids in self.source_feature_ids]
            source_input_info = (self.source_feature_ids, self.source_input_channels)
            self.ru_units = ru_units

            self.layer1 = ResidualLayer(1, block, self.inplanes, 64, layers[0],block_hidden_dim=56,
                                                 decisioner=decisioners,
                                                 source_input_info=source_input_info,
                                                 ru_units=self.ru_units)
            self.inplanes = 64 * block.expansion
            self.layer2 = ResidualLayer(2, block, self.inplanes, 128, layers[1], stride=2,block_hidden_dim=56,
                                                 decisioner=decisioners,
                                                 source_input_info=source_input_info,
                                                 ru_units=self.ru_units)
            self.inplanes = 128 * block.expansion
            self.layer3 = ResidualLayer(3, block, self.inplanes, 256, layers[2], stride=2,block_hidden_dim=28,
                                                 decisioner=decisioners,
                                                 source_input_info=source_input_info,
                                                 ru_units=self.ru_units)
            self.inplanes = 256 * block.expansion
            self.layer4 = ResidualLayer(4, block, self.inplanes, 512, layers[3], stride=2,block_hidden_dim=14,
                                                 decisioner=decisioners,
                                                 source_input_info=source_input_info,
                                                 ru_units=self.ru_units)
            self.inplanes = 512 * block.expansion
            # else: # Spottune
            #     inplanes = self.inplanes
            #     self.layer1 = self._make_layer(block, 64, layers[0])
            #     self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
            #     self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
            #     self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
            #
            #     self.inplanes = inplanes
            #     self.parallel_layer1 = self._make_layer(block, 64, layers[0])
            #     self.parallel_layer2 = self._make_layer(block, 128, layers[1], stride=2)
            #     self.parallel_layer3 = self._make_layer(block, 256, layers[2], stride=2)
            #     self.parallel_layer4 = self._make_layer(block, 512, layers[3], stride=2)
            #     self.parallel_layers = [self.parallel_layer1, self.parallel_layer2, self.parallel_layer3, self.parallel_layer4]
        else:
            self.layer1 = self._make_layer(block, 64, layers[0])
            self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
            self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
            self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.layers = [self.layer1,self.layer2,self.layer3,self.layer4]

        self.avgpool = nn.AvgPool2d(7, stride=1)
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        if init_weights:
            self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                if m.weight is not None:
                    m.weight.data.fill_(1)
                if m.bias is not None:
                    m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )
        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))
        return nn.Sequential(*layers)

    def forward(self, x, ref_features=None, pairs=None, weights=None):
        c0 = self.conv1(x)
        b0 = self.bn1(c0)
        r0 = self.relu(b0)
        p0 = self.maxpool(r0)
        input = p0
        s_feat = []
        if ref_features is None:
            if weights is None:
                f1 = self.layer1(p0)
                f2 = self.layer2(f1)
                f3 = self.layer3(f2)
                f4 = self.layer4(f3)
            # else:  # Spottune
            #     assert weights is not None
            #     for t, (layer, parallel_layer) in enumerate(zip(self.layers, self.parallel_layers)):
            #         action = weights[:, t].contiguous()
            #         action_mask = action.float().view(-1, 1, 1, 1)
            #         output = layer(input)
            #         parallel_output = parallel_layer(input)
            #         input = output * (1 - action_mask) + parallel_output * action_mask
            #         s_feat.append(input)
            #     [f1, f2, f3, f4] = s_feat
        else: # Auto-transfer
            assert pairs is not None
            # pairs format (target_layer_id, source_feature_id, transfer_type)
            # print(pairs)
            # layer_func = [self.layer1, self.layer2, self.layer3, self.layer4]
            for l, (target_layer_id, source_feature_id, transfer_type) in enumerate(pairs):
                ref_input = None
                if source_feature_id != self.source_input_pass_id:
                    if source_feature_id == -1:
                        ref_input = [ref_features[sid] for sid in self.source_feature_ids]
                        # take expectation over the available features
                    else:
                        ref_input = ref_features[source_feature_id]
                else:
                    source_feature_id = self.source_feature_ids[l]
                    if transfer_type in ['combine', 'shared', 'block']:
                        ref_input = torch.zeros(
                            [input.size(0), self.source_input_channels[l], input.size(2), input.size(3)],
                            dtype=input.dtype, device=device)
                input = self.layers[l](input, (source_feature_id, transfer_type, ref_input))
                s_feat.append(input)
            [f1, f2, f3, f4] = s_feat

        f5 = self.avgpool(f4)
        f5 = f5.view(f5.size(0), -1)
        out = self.fc(f5)
        return out, [r0, f1, f2, f3, f4]

    def forward_with_features(self, x):
        return self.forward(x)


def resnet10(pretrained=False, **kwargs):
    """Constructs a ResNet-10 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    name = "resnet10"
    model = ResNet(BasicBlock, [1, 1, 1, 1], name=name,**kwargs)
    if pretrained:
        print('Pretraining not supported now')
    return model


def resnet18(pretrained=False, **kwargs):
    """Constructs a ResNet-18 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    name = "resnet18"
    model = ResNet(BasicBlock, [2, 2, 2, 2], name=name,**kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
    return model


def resnet34(pretrained=False, **kwargs):
    """Constructs a ResNet-34 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    name = "resnet34"
    model = ResNet(BasicBlock, [3, 4, 6, 3], name=name, **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
    return model


def resnet50(pretrained=False, **kwargs):
    """Constructs a ResNet-50 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    name = "resnet50"
    model = ResNet(Bottleneck, [3, 4, 6, 3], name=name, **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
    return model


def resnet101(pretrained=False, **kwargs):
    """Constructs a ResNet-101 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    name = "resnet101"
    model = ResNet(Bottleneck, [3, 4, 23, 3], name=name, **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
    return model


def resnet152(pretrained=False, **kwargs):
    """Constructs a ResNet-152 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    name = "resnet152"
    model = ResNet(Bottleneck, [3, 8, 36, 3], name=name, **kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
    return model
