import math
import os
import torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo

bn_mom = 0.0003

class SeparableConv2d(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size=1,stride=1,padding=0,dilation=1,bias=False,activate_first=True,inplace=True):
        super(SeparableConv2d,self).__init__()
        self.relu0 = nn.ReLU(inplace=inplace)
        self.depthwise = nn.Conv2d(in_channels,in_channels,kernel_size,stride,padding,dilation,groups=in_channels,bias=bias)
        self.bn1 = nn.BatchNorm2d(in_channels, momentum=bn_mom)
        self.relu1 = nn.ReLU(inplace=True)
        self.pointwise = nn.Conv2d(in_channels,out_channels,1,1,0,1,1,bias=bias)
        self.bn2 = nn.BatchNorm2d(out_channels, momentum=bn_mom)
        self.relu2 = nn.ReLU(inplace=True)
        self.activate_first = activate_first
    def forward(self,x):
        if self.activate_first:
            x = self.relu0(x)
        x = self.depthwise(x)
        x = self.bn1(x)
        if not self.activate_first:
            x = self.relu1(x)
        x = self.pointwise(x)
        x = self.bn2(x)
        if not self.activate_first:
            x = self.relu2(x)
        return x

class Block(nn.Module):
    def __init__(self,in_filters,out_filters,strides=1,atrous=None,grow_first=True,activate_first=True,inplace=True):
        super(Block, self).__init__()
        if atrous == None:
            atrous = [1]*3
        elif isinstance(atrous, int):
            atrous_list = [atrous]*3
            atrous = atrous_list
        idx = 0
        self.head_relu = True
        if out_filters != in_filters or strides!=1:
            self.skip = nn.Conv2d(in_filters,out_filters,1,stride=strides, bias=False)
            self.skipbn = nn.BatchNorm2d(out_filters, momentum=bn_mom)
            self.head_relu = False
        else:
            self.skip=None
        
        self.hook_layer = None
        if grow_first:
            filters = out_filters
        else:
            filters = in_filters
        self.sepconv1 = SeparableConv2d(in_filters,filters,3,stride=1,padding=1*atrous[0],dilation=atrous[0],bias=False,activate_first=activate_first,inplace=self.head_relu)
        self.sepconv2 = SeparableConv2d(filters,out_filters,3,stride=1,padding=1*atrous[1],dilation=atrous[1],bias=False,activate_first=activate_first)
        self.sepconv3 = SeparableConv2d(out_filters,out_filters,3,stride=strides,padding=1*atrous[2],dilation=atrous[2],bias=False,activate_first=activate_first,inplace=inplace)

    def forward(self,inp):
        
        if self.skip is not None:
            skip = self.skip(inp)
            skip = self.skipbn(skip)
        else:
            skip = inp

        x = self.sepconv1(inp)
        x = self.sepconv2(x)
        self.hook_layer = x
        x = self.sepconv3(x)

        x+=skip
        return x


class Xception(nn.Module):
    """
    Xception optimized for the ImageNet dataset, as specified in
    https://arxiv.org/pdf/1610.02357.pdf
    """
    def __init__(self, downsample_factor,image_C=3,AdaFactor=1):
        """ Constructor
        Args:
            num_classes: number of classes
        """
        super(Xception, self).__init__()

        stride_list = None
        if downsample_factor == 8:
            stride_list = [2,1,1]
        elif downsample_factor == 16:
            stride_list = [2,2,1]
        else:
            raise ValueError('xception.py: output stride=%d is not supported.'%os) 
        self.conv1 = nn.Conv2d(image_C, int(32*AdaFactor), 3, 2, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(int(32*AdaFactor), momentum=bn_mom)
        self.relu = nn.ReLU(inplace=True)
        
        self.conv2 = nn.Conv2d(int(32*AdaFactor),int(64*AdaFactor),3,1,1,bias=False)
        self.bn2 = nn.BatchNorm2d(int(64*AdaFactor), momentum=bn_mom)
        #do relu here

        self.block1=Block(int(64*AdaFactor),int(128*AdaFactor),2)
        self.block2=Block(int(128*AdaFactor),int(128*AdaFactor)*2,stride_list[0],inplace=False)
        self.block3=Block(int(128*AdaFactor)*2,int(728*AdaFactor),stride_list[1])

        rate = 16//downsample_factor
        self.block4=Block(int(728*AdaFactor),int(728*AdaFactor),1,atrous=rate)
        self.block5=Block(int(728*AdaFactor),int(728*AdaFactor),1,atrous=rate)
        self.block6=Block(int(728*AdaFactor),int(728*AdaFactor),1,atrous=rate)
        self.block7=Block(int(728*AdaFactor),int(728*AdaFactor),1,atrous=rate)

        self.block8=Block(int(728*AdaFactor),int(728*AdaFactor),1,atrous=rate)
        self.block9=Block(int(728*AdaFactor),int(728*AdaFactor),1,atrous=rate)
        self.block10=Block(int(728*AdaFactor),int(728*AdaFactor),1,atrous=rate)
        self.block11=Block(int(728*AdaFactor),int(728*AdaFactor),1,atrous=rate)

        self.block12=Block(int(728*AdaFactor),int(728*AdaFactor),1,atrous=rate)
        self.block13=Block(int(728*AdaFactor),int(728*AdaFactor),1,atrous=rate)
        self.block14=Block(int(728*AdaFactor),int(728*AdaFactor),1,atrous=rate)
        self.block15=Block(int(728*AdaFactor),int(728*AdaFactor),1,atrous=rate)

        self.block16=Block(int(728*AdaFactor),int(728*AdaFactor),1,atrous=[1*rate,1*rate,1*rate])
        self.block17=Block(int(728*AdaFactor),int(728*AdaFactor),1,atrous=[1*rate,1*rate,1*rate])
        self.block18=Block(int(728*AdaFactor),int(728*AdaFactor),1,atrous=[1*rate,1*rate,1*rate])
        self.block19=Block(int(728*AdaFactor),int(728*AdaFactor),1,atrous=[1*rate,1*rate,1*rate])
        
        self.block20=Block(int(728*AdaFactor),int(1024*AdaFactor),stride_list[2],atrous=rate,grow_first=False)
        self.conv3 = SeparableConv2d(int(1024*AdaFactor),int(1536*AdaFactor),3,1,1*rate,dilation=rate,activate_first=False)

        self.conv4 = SeparableConv2d(int(1536*AdaFactor),int(1536*AdaFactor),3,1,1*rate,dilation=rate,activate_first=False)

        self.conv5 = SeparableConv2d(int(1536*AdaFactor),int(2048*AdaFactor),3,1,1*rate,dilation=rate,activate_first=False)
        self.layers = []
        self.in_channels = int(2048*AdaFactor)
        self.low_level_channels = int(128*AdaFactor)*2
        #------- init weights --------
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
        #-----------------------------

    def forward(self, input):
        self.layers = []
        x = self.conv1(input)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        
        x = self.block1(x)
        x = self.block2(x)
        low_featrue_layer = self.block2.hook_layer
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        x = self.block6(x)
        x = self.block7(x)
        x = self.block8(x)
        x = self.block9(x)
        x = self.block10(x)
        x = self.block11(x)
        x = self.block12(x)
        x = self.block13(x)
        x = self.block14(x)
        x = self.block15(x)
        x = self.block16(x)
        x = self.block17(x)
        x = self.block18(x)
        x = self.block19(x)
        x = self.block20(x)       

        x = self.conv3(x)

        x = self.conv4(x)
        
        x = self.conv5(x)
        return low_featrue_layer,x

def load_url(url, model_dir='./model_data', map_location=None):
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    filename = url.split('/')[-1]
    cached_file = os.path.join(model_dir, filename)
    if os.path.exists(cached_file):
        return torch.load(cached_file, map_location=map_location)
    else:
        return model_zoo.load_url(url,model_dir=model_dir)

def xception(pretrained=True, downsample_factor=16,image_C=3,AdaFactor=1):
    model = Xception(AdaFactor=AdaFactor,downsample_factor=downsample_factor,image_C=image_C)
    if pretrained:
        model.load_state_dict(load_url('https://github.com/bubbliiiing/deeplabv3-plus-pytorch/releases/download/v1.0/xception_pytorch_imagenet.pth'), strict=False)
    return model
