import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from collections import defaultdict

__all__ = ['ResNet_IBN_local', 'resnet50_ibn_a_local']


class IBN(nn.Module):
    def __init__(self, planes):
        super(IBN, self).__init__()
        half1 = int(planes/2)
        self.half = half1
        half2 = planes - half1
        self.IN = nn.InstanceNorm2d(half1, affine=True)
        self.BN = nn.BatchNorm2d(half2)
    
    def forward(self, x):
        split = torch.split(x, self.half, 1)
        out1 = self.IN(split[0].contiguous())
        out2 = self.BN(split[1].contiguous())
        out = torch.cat((out1, out2), 1)
        return out


class SpatialTransformer(nn.Module):
    def __init__(self, downsample=(0.5,0.5)):
        super(SpatialTransformer, self).__init__()
        self.dsr = downsample[-1]
        print('Using Spatial Transformer Network => Downsample rate: (%.2f, %.2f)' % (self.dsr[0], self.dsr[1]))

        self.localization = nn.Sequential(
            nn.Conv2d(256, 64, kernel_size=7, stride=2),
            nn.ReLU(True),
            nn.Conv2d(64, 16, kernel_size=5, stride=2),
            nn.ReLU(True),
        )
        # Regressor for the 3 * 2 affine matrix
        self.fc_loc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(16 * 7 * 7, 64),
            nn.ReLU(True),
            nn.Linear(64, 3 * 2)
        )
        # Initialize the weights/bias with identity transformation
        self.fc_loc[-1].weight.data.zero_()
        self.fc_loc[-1].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))
    
    def forward(self, feat, resi):
        b, c, h, w = feat.shape

        xs = self.localization(feat)
        theta = self.fc_loc(xs)
        theta = theta.view(-1, 2, 3)

        grid_feat = F.affine_grid(theta, feat.size())
        grid_resi = F.affine_grid(theta, resi.size())
        sample_pool = nn.UpsamplingBilinear2d(size=(int(h*self.dsr[0]), int(w*self.dsr[1])))
        feat = sample_pool(F.grid_sample(feat, grid_feat))
        resi = sample_pool(F.grid_sample(resi, grid_resi))

        return feat, resi


class RegionPool2d(nn.Module):
    def __init__(self, downsample=((0.7,0.7),(0.5,0.5))):
        super(RegionPool2d, self).__init__()
        self.adaptive = 'AdaptiveRegionPool' in downsample
        self.region_coef = downsample[-2]
        self.total_dsr = downsample[-1]
        if self.adaptive:
            print('Using Adaptive Region Pooling => Region coefficient: (%.2f, %.2f); Total downsample rate: (%.2f, %.2f)' \
                    %(self.region_coef[0], self.region_coef[1], self.total_dsr[0], self.total_dsr[1]))
        else:
            print('Using Region Pooling => Region downsample: (%.2f, %.2f); Total downsample rate: (%.2f, %.2f)' \
                    %(self.region_coef[0], self.region_coef[1], self.total_dsr[0], self.total_dsr[1]))

    def forward(self, feat, resi):
        device = feat.device
        info = defaultdict(list)
        data = defaultdict(list)
        for i in range(feat.size(0)):
            x = feat[i].unsqueeze(0)
            y = resi[i].unsqueeze(0)
            b, c, h, w = x.shape

            # Keypoint recognition
            inorm = nn.InstanceNorm2d(c, affine=False)
            mean = torch.mean(inorm(x), dim=1, keepdim=True)
            kpt = F.sigmoid(10*mean)

            def distribution_std(dis):
                mean, var = 0, 0
                for i, p in enumerate(dis):
                    mean += i*p
                for i, p in enumerate(dis): 
                    var += (i - mean)**2*p
                return var**0.5

            # Evaluation of region downsample rate                
            if self.adaptive:
                h_dis = torch.sum(kpt.squeeze(), dim=1)/torch.sum(kpt); h_std = distribution_std(h_dis)
                w_dis = torch.sum(kpt.squeeze(), dim=0)/torch.sum(kpt); w_std = distribution_std(w_dis)
                coef = 12**0.5
                region_dsr = (self.region_coef[0]*coef*h_std/h, self.region_coef[1]*coef*w_std/w)
                region_dsr = (max(min(region_dsr[0],1.0),self.total_dsr[0]), \
                              max(min(region_dsr[1],1.0),self.total_dsr[1]))
            else:
                region_dsr = self.region_coef

            # Evaluation of region importance
            kernel_size = (int(h*region_dsr[0]), int(w*region_dsr[1]))
            _x,_y = np.mgrid[-kernel_size[0]//2:kernel_size[0]//2,-kernel_size[1]//2:kernel_size[1]//2]/min(kernel_size)
            gaussian_kernel = np.exp(-(_x**2+_y**2))
            conv = nn.Conv2d(1, 1, kernel_size=kernel_size, bias=False)
            conv.weight.data[:,:] = torch.from_numpy(gaussian_kernel)
            conv.weight.requires_grad = False
            conv = conv.to(device)
            out = conv(kpt)

            # Region pooling
            coord_h = int(torch.argmax(out.view(-1))//out.shape[3])
            coord_w = int(torch.argmax(out.view(-1))%out.shape[3])
            x_ds = x[:,:,coord_h:coord_h+kernel_size[0],coord_w:coord_w+kernel_size[1]]
            y_ds = y[:,:,coord_h:coord_h+kernel_size[0],coord_w:coord_w+kernel_size[1]]
            sample_pool = nn.UpsamplingBilinear2d(size=(int(h*self.total_dsr[0]), int(w*self.total_dsr[1])))
            data['feats'].append(sample_pool(x_ds))
            data['resis'].append(sample_pool(y_ds))

        return torch.cat(data['feats']), torch.cat(data['resis'])


class Bottleneck_IBN(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, ibn=False, stride=1, downsample=None, ds=((1.0,1.0),(1.0,1.0))):
        super(Bottleneck_IBN, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        if ibn:
            self.bn1 = IBN(planes)
        else:
            self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Bottleneck_IBN_local(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, ibn=False, stride=1, downsample=None, ds=((1.0,1.0),(1.0,1.0))):
        super(Bottleneck_IBN_local, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        if ibn:
            self.bn1 = IBN(planes)
        else:
            self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.region_pool = None
        if ds[0] == "SpatialTransformer":
            self.region_pool = SpatialTransformer(ds)
        elif ds[-1] != (1.0,1.0):
            self.region_pool = RegionPool2d(ds)
        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        
        if self.region_pool != None:
            out, residual = self.region_pool(out, residual)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(residual)

        out += residual
        out = self.relu(out)

        return out


class ResNet_IBN_local(nn.Module):
    def __init__(self, block, local_block, layers, downsample=(), frozen_stages=-1):
        scale = 64
        self.inplanes = scale
        self.frozen_stages = frozen_stages
        self.start_at = downsample[1]
        super(ResNet_IBN_local, self).__init__()
        self.conv1 = nn.Conv2d(3, scale, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(scale)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=None)
        
        self.layer1 = self._make_layer(block,       scale,   layers[0])
        self.layer2 = self._make_layer(local_block, scale*2, layers[1], ds=downsample) if self.start_at == 2 else \
                      self._make_layer(block,       scale*2, layers[1], stride=2)
        self.layer3 = self._make_layer(local_block, scale*4, layers[2], ds=downsample) if self.start_at == 3 else \
                      self._make_layer(block,       scale*4, layers[2], stride=2)
        self.layer4 = self._make_layer(block,       scale*8, layers[3], stride=1)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.InstanceNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1, ds=((1.0,1.0),(1.0,1.0))):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        ibn = True
        if planes == 512:
            ibn = False
        layers.append(block(self.inplanes, planes, ibn, stride, downsample, ds))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, ibn))

        return nn.Sequential(*layers)

    def _freeze_stages(self):
        if self.frozen_stages >= 0:
            self.bn1.eval()
            for m in [self.conv1, self.bn1]:
                for param in m.parameters():
                    param.requires_grad = False

        for i in range(1, self.frozen_stages + 1):
            m = getattr(self, 'layer{}'.format(i))
            print('layer{}'.format(i))
            m.eval()
            for param in m.parameters():
                param.requires_grad = False

    def forward(self, x):
        if isinstance(x, torch.Tensor):
            x = self.conv1(x)
            x = self.bn1(x)
            x = self.relu(x)
            x = self.maxpool(x)
            layer1_feat = self.layer1(x)
            layer2_feat = self.layer2(layer1_feat)
            layer3_feat = self.layer3(layer2_feat)
            layer4_feat = self.layer4(layer3_feat)
            return layer4_feat, {1:layer1_feat, 2:layer2_feat, 3:layer3_feat}

        if self.start_at == 2:
            layer2_feat = self.layer2(x[1])
            layer3_feat = self.layer3(layer2_feat)
            layer4_feat = self.layer4(layer3_feat)
        else:
            layer3_feat = self.layer3(x[2])
            layer4_feat = self.layer4(layer3_feat)
        return layer4_feat

    def load_param(self, param_dict):
        if 'state_dict' in param_dict:
            param_dict = param_dict['state_dict']
        for i in param_dict:
            if 'fc' in i:
                continue
            self.state_dict()[i.replace('module.','')].copy_(param_dict[i])

def resnet50_ibn_a_local(**kwargs):
    model = ResNet_IBN_local(Bottleneck_IBN, Bottleneck_IBN_local, [3, 4, 6, 3], **kwargs)
    return model

