import torch.nn as nn
import math
import torch
import numpy as np
import torch.nn.functional as F
from torch.autograd import Variable
import torch.utils.model_zoo as model_zoo
from torch.nn import init
from .gOctConv import gOctaveConv, gOctaveCBR
import os
affine_par = True



model_urls = {

}


def conv3x3(in_planes, out_planes, stride=1):
    "3x3 convolution with padding"
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class PallMSBlock(nn.Module):
    def __init__(self,in_channels, out_channels, alpha=[0.5,0.5], bias=False):
        super(PallMSBlock, self).__init__()
        self.std_conv = False
        self.convs = nn.ModuleList()

        for i in range(len(alpha)):
            self.convs.append(MSBlock(int(round(in_channels*alpha[i])), int(round(out_channels*alpha[i]))))
        self.outbranch = len(alpha)

    def forward(self, xset):
        if isinstance(xset,torch.Tensor):
            xset = [xset,]
        yset = []
        for i in range(self.outbranch):
            yset.append(self.convs[i](xset[i]))
        return yset


class MSBlock(nn.Module):
    def __init__(self, in_channels, out_channels, dilations = [1,2,4,8,16]):
        super(MSBlock,self).__init__()
        self.dilations = dilations
        each_out_channels = out_channels//5
        self.msconv = nn.ModuleList()
        for i in range(len(dilations)):
            if i != len(dilations)-1:
                this_outc = each_out_channels
            else:
                this_outc = out_channels - each_out_channels*(len(dilations)-1)
            self.msconv.append(nn.Conv2d(in_channels, this_outc,3, padding=dilations[i], dilation=dilations[i], bias=False))
        self.bn = nn.GroupNorm(32, out_channels)
        self.prelu = nn.PReLU(out_channels)

    def forward(self, x):
        outs = []
        for i in range(len(self.dilations)):
            outs.append(self.msconv[i](x))
        out = torch.cat(outs, dim=1)
        del outs
        out = self.prelu(self.bn(out))
        return out


class CSFNet(nn.Module):
    def __init__(self, num_classes=1):
        super(CSFNet, self).__init__()
        ratio=1.0 #remenber to change this value acccording to the kind of backbone. 
        #The code will be optimized and this setting will be moved to the config file. 
        fuse_in_channel = int((128+256+512+1024)*ratio)
        fuse_in_split = [1/15,2/15,4/15,8/15]
        fuse_out_channel = 128+256+512+512
        fuse_out_split = [1/11,2/11,4/11,4/11]

        self.fuse = gOctaveCBR(fuse_in_channel, fuse_out_channel, kernel_size=(1,1), padding=0, 
                                alpha_in = fuse_in_split, alpha_out = fuse_out_split, stride = 1)
        self.ms = PallMSBlock(fuse_out_channel, fuse_out_channel, alpha = fuse_out_split)
        self.fuse1x1 = gOctaveCBR(fuse_out_channel, fuse_out_channel, kernel_size=(1, 1), padding=0, 
                                alpha_in = fuse_out_split, alpha_out = [1,], stride = 1)
        self.cls_layer = nn.Conv2d(fuse_out_channel, num_classes, kernel_size=1)


    def forward(self, x):
        features=[]
        for value in x.values():
            features.append(value)
        fuse = self.fuse(features)
        fuse = self.ms(fuse)
        fuse = self.fuse1x1(fuse)
        output = self.cls_layer(fuse[0])
        output=F.sigmoid(output)
        return output

def build_model():
    return CSFNet()

def weights_init(m):
    if isinstance(m, nn.Conv2d):
        m.weight.data.normal_(0, 0.01)
        if m.bias is not None:
            m.bias.data.zero_()