import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import math
from functools import partial

__all__ = [
    'ResNet', 'resnet10', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
    'resnet152', 'resnet200'
]


def conv3x3x3(in_planes, out_planes, stride=1):
    # 3x3x3 convolution with padding
    return nn.Conv3d(
        in_planes,
        out_planes,
        kernel_size=3,
        stride=stride,
        padding=1,
        bias=False)

def conv1x1x1(in_planes, out_planes, stride=1):
    # 1x1x1 convolution with padding
    return nn.Conv3d(
        in_planes,
        out_planes,
        kernel_size=1,
        stride=stride,
        padding=0,
        bias=False)



def downsample_basic_block(x, planes, stride):
    out = F.avg_pool3d(x, kernel_size=1, stride=stride)
    zero_pads = torch.Tensor(
        out.size(0), planes - out.size(1), out.size(2), out.size(3),
        out.size(4)).zero_()
    if isinstance(out.data, torch.cuda.FloatTensor):
        zero_pads = zero_pads.cuda()

    out = Variable(torch.cat([out.data, zero_pads], dim=1))

    return out


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm3d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3x3(planes, planes)
        self.bn2 = nn.BatchNorm3d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm3d(planes)
        self.conv2 = nn.Conv3d(
            planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm3d(planes)
        self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm3d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self,
                 block,
                 layers,
                 sample_size,
                 sample_duration,
                 shortcut_type='B',
                 num_classes=400):
        self.inplanes = 64
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv3d(
            3,
            64,
            kernel_size=7,
            stride=(1, 2, 2),
            padding=(3, 3, 3),
            bias=False)
        self.bn1 = nn.BatchNorm3d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0], shortcut_type)
        self.layer2 = self._make_layer(
            block, 128, layers[1], shortcut_type, stride=2)
        self.layer3 = self._make_layer(
            block, 256, layers[2], shortcut_type, stride=1)
        self.layer4 = self._make_layer(
            block, 512, layers[3], shortcut_type, stride=1)
        #last_duration = int(math.ceil(sample_duration / 16))
        #last_size = int(math.ceil(sample_size / 32))
        #self.avgpool = nn.AvgPool3d(
        #    (last_duration, last_size, last_size), stride=1)
        #self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                m.weight = nn.init.kaiming_normal(m.weight, mode='fan_out')
            elif isinstance(m, nn.BatchNorm3d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, shortcut_type, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            if shortcut_type == 'A':
                downsample = partial(
                    downsample_basic_block,
                    planes=planes * block.expansion,
                    stride=stride)
            else:
                downsample = nn.Sequential(
                    nn.Conv3d(
                        self.inplanes,
                        planes * block.expansion,
                        kernel_size=1,
                        stride=stride,
                        bias=False), nn.BatchNorm3d(planes * block.expansion))

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        print(x.shape)
        x = self.layer2(x)
        print(x.shape)
        x = self.layer3(x)
        print(x.shape)
        x = self.layer4(x)
        print(x.shape)

        x = self.avgpool(x)

        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x

    def forward_cnn(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        #x = self.avgpool(x)

        #x = x.view(x.size(0), -1)
        return x


def get_fine_tuning_parameters(model, ft_begin_index):
    if ft_begin_index == 0:
        return model.parameters()

    ft_module_names = []
    for i in range(ft_begin_index, 5):
        ft_module_names.append('layer{}'.format(i))
    ft_module_names.append('fc')

    parameters = []
    for k, v in model.named_parameters():
        for ft_module in ft_module_names:
            if ft_module in k:
                parameters.append({'params': v})
                break
        else:
            parameters.append({'params': v, 'lr': 0.0})

    return parameters


def resnet18(**kwargs):
    """Constructs a ResNet-18 model.
    """
    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
    return model


def downsample_basic_block_sf(x, planes, stride):
    out = F.avg_pool3d(x, kernel_size=1, stride=(1, stride, stride))
    zero_pads = torch.Tensor(
        out.size(0), planes - out.size(1), out.size(2), out.size(3),
        out.size(4)).zero_()
    if isinstance(out.data, torch.cuda.FloatTensor):
        zero_pads = zero_pads.cuda()

    out = Variable(torch.cat([out.data, zero_pads], dim=1))

    return out


class BasicBlock_SF(nn.Module):
    expansion = 1

    def __init__(self, ksize, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock_SF, self).__init__()

        if ksize == 3:
            self.conv1 = conv3x3x3(inplanes, planes, (1, stride, stride))
        else:
            self.conv1 = conv1x1x1(inplanes, planes, (1, stride, stride))

        self.bn1 = nn.BatchNorm3d(planes)
        self.relu = nn.ReLU(inplace=True)

        if ksize == 3:
            self.conv2 = conv1x1x1(planes, planes)
        else:
            self.conv2 = conv3x3x3(planes, planes)

        self.bn2 = nn.BatchNorm3d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out




class ResNet_SF(nn.Module):

    def __init__(self,
                 block,
                 layers,
                 sample_size,
                 sample_duration,
                 shortcut_type='B',
                 num_classes=400,
                 sf='slow'):
        self.inplanes = 64
        super(ResNet_SF, self).__init__()

        if sf == 'slow':
            self.conv1 = nn.Conv3d(
                3,
                64,
                kernel_size=(1, 7, 7),
                stride=(1, 2, 2),
                padding=(0, 3, 3),
                bias=False)
        else:
            self.conv1 = nn.Conv3d(
                3,
                8,
                kernel_size=(5, 7, 7),
                stride=(1, 2, 2),
                padding=(2, 3, 3),
                bias=False)
            self.inplanes = 8
        beta_inv = 8

        if sf == 'slow':
            self.bn1 = nn.BatchNorm3d(64)
        else:
            self.bn1 = nn.BatchNorm3d(64 // beta_inv)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1))
        if sf == 'slow':
            #self.layer1 = self._make_layer(block, 1, 64, layers[0], shortcut_type, addc=64//4)
            #self.layer2 = self._make_layer(
            #    block, 1, 128, layers[1], shortcut_type, stride=2, addc=64//4)
            #self.layer3 = self._make_layer(
            #    block, 3, 256, layers[2], shortcut_type, stride=2, addc=128//4)
            #self.layer4 = self._make_layer(
            #    block, 3, 512, layers[3], shortcut_type, stride=2, addc=256//4)
            self.layer1 = self._make_layer(block, 1, 64, layers[0], shortcut_type)
            self.layer2 = self._make_layer(
                block, 1, 128, layers[1], shortcut_type, stride=2)
            self.layer3 = self._make_layer(
                block, 3, 256, layers[2], shortcut_type, stride=2)
            self.layer4 = self._make_layer(
                block, 3, 512, layers[3], shortcut_type, stride=1)

        else:
            self.layer1 = self._make_layer(block, 3, 64 // beta_inv, layers[0], shortcut_type)
            self.layer2 = self._make_layer(
                block, 3, 128 // beta_inv, layers[1], shortcut_type, stride=2)
            self.layer3 = self._make_layer(
                block, 3, 256 // beta_inv, layers[2], shortcut_type, stride=2)
            self.layer4 = self._make_layer(
                block, 3, 512 // beta_inv, layers[3], shortcut_type, stride=1)

        #last_duration = int(math.ceil(sample_duration / 16))
        #last_size = int(math.ceil(sample_size / 32))
        #self.avgpool = nn.AvgPool3d(
        #    (last_duration, last_size, last_size), stride=1)
        #self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                m.weight = nn.init.kaiming_normal(m.weight, mode='fan_out')
            elif isinstance(m, nn.BatchNorm3d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, ksize, planes, blocks, shortcut_type, stride=1, addc=0):
        downsample = None
        if stride != 1 or self.inplanes + addc != planes * block.expansion:
            if shortcut_type == 'A':
                downsample = partial(
                    downsample_basic_block_sf,
                    planes=planes * block.expansion,
                    stride=stride)
            else:
                downsample = nn.Sequential(
                    nn.Conv3d(
                        self.inplanes + addc,
                        planes * block.expansion,
                        kernel_size=1,
                        stride=(1, stride, stride),
                        bias=False), nn.BatchNorm3d(planes * block.expansion))

        layers = []
        layers.append(block(ksize, self.inplanes + addc, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(ksize, self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward_stem(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        return x

    def forward_stage1(self, x):
        x = self.layer1(x)
        return x

    def forward_stage2(self, x):
        x = self.layer2(x)
        return x

    def forward_stage3(self, x):
        x = self.layer3(x)
        return x

    def forward_stage4(self, x):
        x = self.layer4(x)
        return x

    def forward(self, x):
        x = self.forward_stem(x)
        x = self.forward_stage1(x)
        x = self.forward_stage2(x)
        x = self.forward_stage3(x)
        x = self.forward_stage4(x)
        return x



