# This implementation is based on the DenseNet implementation in torchvision
# https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py

import math
import torch
from torch import nn
from torchvision.models.resnet import conv3x3


class BasicBlockWithDeathRate(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, survival_rate=1.,
                 downsample=None, scale=1.):
        super(BasicBlockWithDeathRate, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.relu2 = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        self.survival_rate = survival_rate
        self.register_buffer("scale", torch.Tensor([scale]))
        #self.scale = scale

    def forward(self, x):
        residual = x
        if self.downsample is not None:
            x = self.downsample(x)

        # TODO: fix the bug of original Stochatic depth
        if not self.training or torch.rand(1)[0] <= self.survival_rate:
            residual = self.conv1(residual)
            residual = self.bn1(residual)
            residual = self.relu1(residual)
            residual = self.conv2(residual)
            residual = self.bn2(residual)
            # We add the survival condition for sensitivity initialization
            if not self.training or self.survival_rate == 1.:
                residual *= self.survival_rate

            x = x + residual*self.scale
            x = self.relu2(x)

        return x


class DownsampleB(nn.Module):

    def __init__(self, nIn, nOut, stride):
        super(DownsampleB, self).__init__()
        self.avg = nn.AvgPool2d(stride)
        self.expand_ratio = nOut // nIn

    def forward(self, x):
        x = self.avg(x)
        return torch.cat([x] + [x.mul(0)] * (self.expand_ratio - 1), 1)


class ResNetCifar(nn.Module):
    '''Small ResNet for CIFAR & SVHN
    death_rates: death_rates of each block except for the first and
                 the last block
    '''

    def __init__(self, depth, survival_rates=None, block=BasicBlockWithDeathRate,
                 num_classes=10, scale=1.):
        assert (depth - 2) % 6 == 0, 'depth should be one of 6N+2'
        super(ResNetCifar, self).__init__()
        n = (depth - 2) // 6
        assert survival_rates is None or len(survival_rates) == 3 * n
        if survival_rates is None:
            survival_rates = [1.] * (3 * n)
        
       
        self.survival_rates = survival_rates
        
        # todo Delete, this is for stable resnet
        if scale == 0:
            block_scale = 1./np.sqrt(3*n)
        else:
            block_scale = scale
        
        self.register_buffer("block_scale", torch.Tensor([block_scale]))
        
        print("Scaling = ", self.block_scale)
        
        self.inplanes = 16
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self._make_layer(block, 16, survival_rates[:n])
        self.layer2 = self._make_layer(block, 32, survival_rates[n:2 * n],
                                       stride=2)
        self.layer3 = self._make_layer(block, 64, survival_rates[2 * n:],
                                       stride=2)
        self.avgpool = nn.AvgPool2d(8)
        self.fc = nn.Linear(64 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, survival_rates, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = DownsampleB(self.inplanes, planes * block.expansion,
                                     stride)
        
        # todo: Delete scale (stable resnet)
        layers = [block(self.inplanes, planes, stride, downsample=downsample,
                        survival_rate=survival_rates[0], scale=self.block_scale.float())]
        self.inplanes = planes * block.expansion
        for survival_rate in survival_rates[1:]:
            layers.append(block(self.inplanes, planes, survival_rate=survival_rate, scale=self.block_scale.float()))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x

    def update_rates(self, survival_rates):
        layers = [self.layer1, self.layer2, self.layer3]
        n = len(survival_rates) // 3
        self.survival_rates = survival_rates
        for i, layer in enumerate(layers):
            for k,l in enumerate(layer):
                l.survival_rate = survival_rates[i*n+k]


import numpy as np 
def createModel(depth, data, num_classes, survival_mode='none', survival_prop=0.5, scale=1.,
                **kwargs):
    assert (depth - 2) % 6 == 0, 'depth should be one of 6N+2'
    print('Create ResNet-{:d} for {}'.format(depth, data))
    nblocks = (depth - 2) // 2
    
    survival_rates = get_survival_rates(nblocks, survival_mode, survival_prop)
    
    print("Training length = {}".format(np.sum(survival_rates)))
    print("Initial survival rates:")
    print(survival_rates)
    
    return ResNetCifar(depth, survival_rates, BasicBlockWithDeathRate,
                       num_classes, scale)

def createModelWithSensitivity(data_loader, loss_fn, depth, num_classes, data, survival_mode='none', survival_prop=0.5, scale=1.,
                               min_survival=0.1, **kwargs):
    assert (depth - 2) % 6 == 0, 'depth should be one of 6N+2'
    print('Create ResNet-{:d} for {}'.format(depth, data))
    
    if survival_prop < 0.15:
        min_survival = 0.05
        
    nblocks = (depth - 2) // 2
    
    A = get_rates_for_sensitivity(nblocks)
    
    model = ResNetCifar(depth, A, BasicBlockWithDeathRate,
                       num_classes, scale).cuda()
    
    xb, yb = list(data_loader)[0]
    xb = xb.cuda()
    yb = yb.cuda()
    
    pred = model(xb)
    loss = loss_fn(pred, yb)
    output = torch.mean(loss)

    output.backward()
    
    grads_abs = np.abs(A.grad.cpu().numpy())
    survival_rates = grads_abs/np.sum(grads_abs)*nblocks*(survival_prop-min_survival) + min_survival
    
    for _ in range(20):
        survival_rates *= (survival_prop*nblocks)/np.sum(survival_rates)
        survival_rates = np.clip(survival_rates, 0, 1)
        
    model.zero_grad()
    
    model.update_rates(survival_rates)
    
    
    
    print("Training length = {}".format(np.sum(survival_rates)))
    print(survival_rates)
    
    return model

def get_survival_rates(depth, mode='none', survival_prop=0.5):
    if mode == 'uniform':
        survival_rates = [survival_prop] * depth
    elif mode == 'linear':
        survival_rates =  [1 - 2 * float(i + 1) * (1-survival_prop) / float(depth+1)
                       for i in range(depth)]
    else:
        survival_rates = None
    return survival_rates

def get_rates_for_sensitivity(depth):
    return torch.ones(depth, requires_grad=True)
