""" 
Creates an Xception Model as defined in:

Francois Chollet
Xception: Deep Learning with Depthwise Separable Convolutions
https://arxiv.org/pdf/1610.02357.pdf

This weights ported from the Keras implementation. Achieves the following performance on the validation set:

Loss:0.9173 Prec@1:78.892 Prec@5:94.292

REMEMBER to set your image size to 3x299x299 for both test and validation

normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                  std=[0.5, 0.5, 0.5])

The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
"""
import math
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
from torch.nn import init
import torch
from .ozan_rep_fun import ozan_rep_function,trevor_rep_function,OzanRepFunction,TrevorRepFunction,gradnorm_rep_function,GradNormRepFunction

__all__ = ['xception_taskonomy_small','xception_taskonomy_small_gradnorm','xception_taskonomy_small_ozan','xception_taskonomy_small_half','xception_taskonomy_small_quad','xception_taskonomy_small_double','xception_taskonomy_small_double_ozan','xception_taskonomy_small_half_ozan','xception_taskonomy_small_quad_ozan']

# model_urls = {
#     'xception_taskonomy':'file:///home/tstand/Dropbox/taskonomy/xception_taskonomy-a4b32ef7.pth.tar'
# }


class SeparableConv2d(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size=1,stride=1,padding=0,dilation=1,bias=False,groupsize=1):
        super(SeparableConv2d,self).__init__()

        self.conv1 = nn.Conv2d(in_channels,in_channels,kernel_size,stride,padding,dilation,groups=max(1,in_channels//groupsize),bias=bias)
        self.pointwise = nn.Conv2d(in_channels,out_channels,1,1,0,1,1,bias=bias)
    
    def forward(self,x):
        x = self.conv1(x)
        x = self.pointwise(x)
        return x


class Block(nn.Module):
    def __init__(self,in_filters,out_filters,reps,strides=1,start_with_relu=True,grow_first=True):
        super(Block, self).__init__()

        if out_filters != in_filters or strides!=1:
            self.skip = nn.Conv2d(in_filters,out_filters,1,stride=strides, bias=False)
            self.skipbn = nn.BatchNorm2d(out_filters)
        else:
            self.skip=None
        
        self.relu = nn.ReLU(inplace=True)
        rep=[]

        filters=in_filters
        if grow_first:
            rep.append(self.relu)
            rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=1,bias=False))
            rep.append(nn.BatchNorm2d(out_filters))
            filters = out_filters

        for i in range(reps-1):
            rep.append(self.relu)
            rep.append(SeparableConv2d(filters,filters,3,stride=1,padding=1,bias=False))
            rep.append(nn.BatchNorm2d(filters))
        
        if not grow_first:
            rep.append(self.relu)
            rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=1,bias=False))
            rep.append(nn.BatchNorm2d(out_filters))
            filters=out_filters

        if not start_with_relu:
            rep = rep[1:]
        else:
            rep[0] = nn.ReLU(inplace=False)

        if strides != 1:
            #rep.append(nn.AvgPool2d(3,strides,1))
            rep.append(nn.Conv2d(filters,filters,2,2))
        self.rep = nn.Sequential(*rep)

    def forward(self,inp):
        x = self.rep(inp)

        if self.skip is not None:
            skip = self.skip(inp)
            skip = self.skipbn(skip)
        else:
            skip = inp
        x+=skip
        return x

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(3, 24, 3,2, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(24)
        self.relu = nn.ReLU(inplace=True)
        self.relu2 = nn.ReLU(inplace=False)

        self.conv2 = nn.Conv2d(24,48,3,1,1,bias=False)
        self.bn2 = nn.BatchNorm2d(48)
        #do relu here

        self.block1=Block(48,96,2,2,start_with_relu=False,grow_first=True)
        self.block2=Block(96,192,2,2,start_with_relu=True,grow_first=True)
        self.block3=Block(192,512,2,2,start_with_relu=True,grow_first=True)

        #self.block4=Block(768,768,3,1,start_with_relu=True,grow_first=True)
        #self.block5=Block(768,768,3,1,start_with_relu=True,grow_first=True)
        # self.block6=Block(768,768,3,1,start_with_relu=True,grow_first=True)
        # self.block7=Block(768,768,3,1,start_with_relu=True,grow_first=True)

        self.block8=Block(512,512,2,1,start_with_relu=True,grow_first=True)
        self.block9=Block(512,512,2,1,start_with_relu=True,grow_first=True)
        self.block10=Block(512,512,2,1,start_with_relu=True,grow_first=True)
        self.block11=Block(512,512,2,1,start_with_relu=True,grow_first=True)

        #self.block12=Block(728,1024,2,2,start_with_relu=True,grow_first=False)

        self.conv3 = SeparableConv2d(512,256,3,1,1)
        self.bn3 = nn.BatchNorm2d(256)
        #self.conv3 = SeparableConv2d(1024,1536,3,1,1)
        #self.bn3 = nn.BatchNorm2d(1536)

        #do relu here
        #self.conv4 = SeparableConv2d(1536,2048,3,1,1)
        #self.bn4 = nn.BatchNorm2d(2048)
    def forward(self,input):
        
        x = self.conv1(input)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        # x = self.block4(x)
        # x = self.block5(x)
        # x = self.block6(x)
        # x = self.block7(x)
        x = self.block8(x)
        x = self.block9(x)
        x = self.block10(x)
        x = self.block11(x)
        #x = self.block12(x)
        
        x = self.conv3(x)
        x = self.bn3(x)
        #x = self.relu(x)

        
        #x = self.conv4(x)
        #x = self.bn4(x)

        representation = self.relu2(x)

        return representation

class EncoderHalf(nn.Module):
    def __init__(self):
        super(EncoderHalf, self).__init__()
        self.conv1 = nn.Conv2d(3, 24, 3,2, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(24)
        self.relu = nn.ReLU(inplace=True)
        self.relu2 = nn.ReLU(inplace=False)

        self.conv2 = nn.Conv2d(24,48,3,1,1,bias=False)
        self.bn2 = nn.BatchNorm2d(48)
        #do relu here

        self.block1=Block(48,64,2,2,start_with_relu=False,grow_first=True)
        self.block2=Block(64,128,2,2,start_with_relu=True,grow_first=True)
        self.block3=Block(128,360,2,2,start_with_relu=True,grow_first=True)

        #self.block4=Block(768,768,3,1,start_with_relu=True,grow_first=True)
        #self.block5=Block(768,768,3,1,start_with_relu=True,grow_first=True)
        # self.block6=Block(768,768,3,1,start_with_relu=True,grow_first=True)
        # self.block7=Block(768,768,3,1,start_with_relu=True,grow_first=True)

        self.block8=Block(360,360,2,1,start_with_relu=True,grow_first=True)
        self.block9=Block(360,360,2,1,start_with_relu=True,grow_first=True)
        self.block10=Block(360,360,2,1,start_with_relu=True,grow_first=True)
        self.block11=Block(360,360,2,1,start_with_relu=True,grow_first=True)

        #self.block12=Block(728,1024,2,2,start_with_relu=True,grow_first=False)

        self.conv3 = SeparableConv2d(360,256,3,1,1)
        self.bn3 = nn.BatchNorm2d(256)
        #self.conv3 = SeparableConv2d(1024,1536,3,1,1)
        #self.bn3 = nn.BatchNorm2d(1536)

        #do relu here
        #self.conv4 = SeparableConv2d(1536,2048,3,1,1)
        #self.bn4 = nn.BatchNorm2d(2048)
    def forward(self,input):
        
        x = self.conv1(input)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        # x = self.block4(x)
        # x = self.block5(x)
        # x = self.block6(x)
        # x = self.block7(x)
        x = self.block8(x)
        x = self.block9(x)
        x = self.block10(x)
        x = self.block11(x)
        #x = self.block12(x)
        
        x = self.conv3(x)
        x = self.bn3(x)
        #x = self.relu(x)

        
        #x = self.conv4(x)
        #x = self.bn4(x)

        representation = self.relu2(x)

        return representation


class EncoderQuad(nn.Module):
    def __init__(self):
        super(EncoderQuad, self).__init__()
        print('entering quad constructor')
        self.conv1 = nn.Conv2d(3, 48, 3,2, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(48)
        self.relu = nn.ReLU(inplace=True)
        self.relu2 = nn.ReLU(inplace=False)

        self.conv2 = nn.Conv2d(48,96,3,1,1,bias=False)
        self.bn2 = nn.BatchNorm2d(96)
        #do relu here

        self.block1=Block(96,192,2,2,start_with_relu=False,grow_first=True)
        self.block2=Block(192,384,2,2,start_with_relu=True,grow_first=True)
        self.block3=Block(384,1024,2,2,start_with_relu=True,grow_first=True)

        #self.block4=Block(768,768,3,1,start_with_relu=True,grow_first=True)
        #self.block5=Block(768,768,3,1,start_with_relu=True,grow_first=True)
        # self.block6=Block(768,768,3,1,start_with_relu=True,grow_first=True)
        # self.block7=Block(768,768,3,1,start_with_relu=True,grow_first=True)

        self.block8=Block(1024,1024,2,1,start_with_relu=True,grow_first=True)
        self.block9=Block(1024,1024,2,1,start_with_relu=True,grow_first=True)
        self.block10=Block(1024,1024,2,1,start_with_relu=True,grow_first=True)
        self.block11=Block(1024,1024,2,1,start_with_relu=True,grow_first=True)

        #self.block12=Block(728,1024,2,2,start_with_relu=True,grow_first=False)

        self.conv3 = SeparableConv2d(1024,256,3,1,1)
        self.bn3 = nn.BatchNorm2d(256)
        #self.conv3 = SeparableConv2d(1024,1536,3,1,1)
        #self.bn3 = nn.BatchNorm2d(1536)

        #do relu here
        #self.conv4 = SeparableConv2d(1536,2048,3,1,1)
        #self.bn4 = nn.BatchNorm2d(2048)
    def forward(self,input):
        
        x = self.conv1(input)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        # x = self.block4(x)
        # x = self.block5(x)
        # x = self.block6(x)
        # x = self.block7(x)
        x = self.block8(x)
        x = self.block9(x)
        x = self.block10(x)
        x = self.block11(x)
        #x = self.block12(x)
        
        x = self.conv3(x)
        x = self.bn3(x)
        #x = self.relu(x)

        
        #x = self.conv4(x)
        #x = self.bn4(x)

        representation = self.relu2(x)

        return representation

class EncoderDouble(nn.Module):
    def __init__(self):
        super(EncoderDouble, self).__init__()
        print('entering double constructor')
        self.conv1 = nn.Conv2d(3, 32, 3,2, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu = nn.ReLU(inplace=True)
        self.relu2 = nn.ReLU(inplace=False)

        self.conv2 = nn.Conv2d(32,64,3,1,1,bias=False)
        self.bn2 = nn.BatchNorm2d(64)
        #do relu here

        self.block1=Block(64,128,2,2,start_with_relu=False,grow_first=True)
        self.block2=Block(128,256,2,2,start_with_relu=True,grow_first=True)
        self.block3=Block(256,728,2,2,start_with_relu=True,grow_first=True)

        #self.block4=Block(768,768,3,1,start_with_relu=True,grow_first=True)
        #self.block5=Block(768,768,3,1,start_with_relu=True,grow_first=True)
        # self.block6=Block(768,768,3,1,start_with_relu=True,grow_first=True)
        # self.block7=Block(768,768,3,1,start_with_relu=True,grow_first=True)

        self.block8=Block(728,728,2,1,start_with_relu=True,grow_first=True)
        self.block9=Block(728,728,2,1,start_with_relu=True,grow_first=True)
        self.block10=Block(728,728,2,1,start_with_relu=True,grow_first=True)
        self.block11=Block(728,728,2,1,start_with_relu=True,grow_first=True)

        #self.block12=Block(728,1024,2,2,start_with_relu=True,grow_first=False)

        self.conv3 = SeparableConv2d(728,256,3,1,1)
        self.bn3 = nn.BatchNorm2d(256)
        #self.conv3 = SeparableConv2d(1024,1536,3,1,1)
        #self.bn3 = nn.BatchNorm2d(1536)

        #do relu here
        #self.conv4 = SeparableConv2d(1536,2048,3,1,1)
        #self.bn4 = nn.BatchNorm2d(2048)
    def forward(self,input):
        
        x = self.conv1(input)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        # x = self.block4(x)
        # x = self.block5(x)
        # x = self.block6(x)
        # x = self.block7(x)
        x = self.block8(x)
        x = self.block9(x)
        x = self.block10(x)
        x = self.block11(x)
        #x = self.block12(x)
        
        x = self.conv3(x)
        x = self.bn3(x)
        #x = self.relu(x)

        
        #x = self.conv4(x)
        #x = self.bn4(x)

        representation = self.relu2(x)

        return representation

def interpolate(inp,size):
    t = inp.type()
    inp = inp.float()
    out = nn.functional.interpolate(inp,size=size,mode='bilinear',align_corners=False)
    if out.type()!=t:
        out = out.half()
    return out


class Decoder(nn.Module):
    def __init__(self, output_channels=32,num_classes=None):
        super(Decoder, self).__init__()
        
        self.output_channels = output_channels
        self.num_classes = num_classes

        if num_classes is not None:
            self.fc = nn.Linear(256, num_classes)
        #else:
        #    self.fc = nn.Linear(256, 1000)
        else:
            self.relu = nn.ReLU(inplace=True)
            
            
            self.conv_decode_res = SeparableConv2d(256, 16, 3,padding=1)
            self.conv_decode_res2 = SeparableConv2d(256, 96, 3,padding=1)
            self.bn_conv_decode_res = nn.BatchNorm2d(16)
            self.bn_conv_decode_res2 = nn.BatchNorm2d(96)
            self.upconv1 = nn.ConvTranspose2d(96,96,2,2)
            self.bn_upconv1 = nn.BatchNorm2d(96)
            self.conv_decode1 = SeparableConv2d(96, 64, 3,padding=1)
            self.bn_decode1 = nn.BatchNorm2d(64)
            self.upconv2 = nn.ConvTranspose2d(64,64,2,2)
            self.bn_upconv2 = nn.BatchNorm2d(64)
            self.conv_decode2 = SeparableConv2d(64, 64, 5,padding=2)
            self.bn_decode2 = nn.BatchNorm2d(64)
            self.upconv3 = nn.ConvTranspose2d(64,32,2,2)
            self.bn_upconv3 = nn.BatchNorm2d(32)
            self.conv_decode3 = SeparableConv2d(32, 32, 5,padding=2)
            self.bn_decode3 = nn.BatchNorm2d(32)
            self.upconv4 = nn.ConvTranspose2d(32,32,2,2)
            self.bn_upconv4 = nn.BatchNorm2d(32)
            self.conv_decode4 = SeparableConv2d(48, output_channels, 5,padding=2)


    def forward(self,representation):
        #batch_size=representation.shape[0]
        if self.num_classes is None:
            x2 = self.conv_decode_res(representation)
            x2 = self.bn_conv_decode_res(x2)
            x2 = interpolate(x2,size=(256,256))
            x = self.conv_decode_res2(representation)
            x = self.bn_conv_decode_res2(x)
            x = self.upconv1(x)
            x = self.bn_upconv1(x)
            x = self.relu(x)
            x = self.conv_decode1(x)
            x = self.bn_decode1(x)
            x = self.relu(x)
            x = self.upconv2(x)
            x = self.bn_upconv2(x)
            x = self.relu(x)
            x = self.conv_decode2(x)
            
            x = self.bn_decode2(x)
            x = self.relu(x)
            x = self.upconv3(x)
            x = self.bn_upconv3(x)
            x = self.relu(x)
            x = self.conv_decode3(x)
            x = self.bn_decode3(x)
            x = self.relu(x)
            x = self.upconv4(x)
            x = self.bn_upconv4(x)
            x = torch.cat([x,x2],1)
            #print(x.shape,self.static.shape)
            #x = torch.cat([x,x2,input,self.static.expand(batch_size,-1,-1,-1)],1)
            x = self.relu(x)
            x = self.conv_decode4(x)

            #z = x[:,19:22,:,:].clone()
            #y = (z).norm(2,1,True).clamp(min=1e-12)
            #print(y.shape,x[:,21:24,:,:].shape)
            #x[:,19:22,:,:]=z/y

        else:
            #print(representation.shape)
            x = F.adaptive_avg_pool2d(representation, (1, 1))
            x = x.view(x.size(0), -1)
            #print(x.shape)
            x = self.fc(x)
            #print(x.shape)
        return x

class XceptionTaskonomySmall(nn.Module):
    """
    Xception optimized for the ImageNet dataset, as specified in
    https://arxiv.org/pdf/1610.02357.pdf
    """
    def __init__(self, tasks=None,num_classes=None, ozan=False, half=False):
        """ Constructor
        Args:
            num_classes: number of classes
        """
        super(XceptionTaskonomySmall, self).__init__()
        print('half is',half)
        if half=='Quad':
            print('running quad code')
            self.encoder=EncoderQuad()
        elif half == 'Double':
            self.encoder=EncoderDouble()
        elif half:
            self.encoder=EncoderHalf()
        else:
            self.encoder=Encoder()
        self.tasks=tasks
        self.ozan=ozan
        self.task_to_decoder = {}
        if tasks is not None:
            for task in tasks:
                if task == 'segment_semantic':
                    output_channels = 18
                if task == 'depth_zbuffer':
                    output_channels = 1
                if task == 'normal':
                    output_channels = 3
                if task == 'normal2':
                    output_channels = 3
                if task == 'edge_occlusion':
                    output_channels = 1
                if task == 'reshading':
                    output_channels = 3
                if task == 'keypoints2d':
                    output_channels = 1
                if task == 'edge_texture':
                    output_channels = 1
                decoder=Decoder(output_channels,num_classes)
                self.task_to_decoder[task]=decoder
        else:
            self.task_to_decoder['classification']=Decoder(output_channels=0,num_classes=1000)

        self.decoders = nn.ModuleList(self.task_to_decoder.values())
        
        #------- init weights --------
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
        #-----------------------------
    count=0
    def input_per_task_losses(self,losses):
        # if GradNormRepFunction.inital_task_losses is None:
        #     GradNormRepFunction.inital_task_losses=losses
        #     GradNormRepFunction.current_weights=[1 for i in losses]
        XceptionTaskonomySmall.count+=1
        if XceptionTaskonomySmall.count < 200:
            GradNormRepFunction.inital_task_losses=losses
            GradNormRepFunction.current_weights=[1 for i in losses]
        elif XceptionTaskonomySmall.count%20==0:
            with open("gradnorm_weights.txt", "a") as myfile:
                myfile.write(str(XceptionTaskonomySmall.count)+': '+str(GradNormRepFunction.current_weights)+'\n')
        GradNormRepFunction.current_task_losses=losses
        

    def forward(self, input):
        rep = self.encoder(input)


        if self.tasks is None:
            return self.decoders[0](rep)
        
        outputs={'rep':rep}
        if self.ozan=='gradnorm':
            GradNormRepFunction.n=len(self.decoders)
            rep = gradnorm_rep_function(rep)
            for i,(task,decoder) in enumerate(zip(self.task_to_decoder.keys(),self.decoders)):
                outputs[task]=decoder(rep[i])
        elif self.ozan:
            OzanRepFunction.n=len(self.decoders)
            rep = ozan_rep_function(rep)
            for i,(task,decoder) in enumerate(zip(self.task_to_decoder.keys(),self.decoders)):
                outputs[task]=decoder(rep[i])
        else:
            for i,(task,decoder) in enumerate(zip(self.task_to_decoder.keys(),self.decoders)):
                outputs[task]=decoder(rep)
        
        return outputs



def xception_taskonomy_small(pretrained=False,**kwargs):
    """
    Construct Xception.
    """
    # try:
    #     num_classes = kwargs['num_classes']
    # except:
    #     num_classes=1000
    # if pretrained:
    #     kwargs['num_classes']=1000
    model = XceptionTaskonomySmall(**kwargs)

    
    if pretrained:
        #state_dict = model_zoo.load_url(model_urls['xception_taskonomy'])
        # for name,weight in state_dict.items():
        #     if 'pointwise' in name:
        #         state_dict[name]=weight.unsqueeze(-1).unsqueeze(-1)
        #     if 'conv1' in name and len(weight.shape)!=4:
        #         state_dict[name]=weight.unsqueeze(1)
        #model.load_state_dict(state_dict)
        #model.load_state_dict(torch.load('xception_taskonomy_small_imagenet_pretrained.pth.tar'))
        model.encoder.load_state_dict(torch.load('xception_taskonomy_small2.encoder.pth.tar'))
        # if num_classes !=1000:
        #     model.fc = nn.Linear(2048, num_classes)
        # import torch
        # print("writing new state dict")
        # torch.save(model.state_dict(),"xception.pth.tar")
        # print("done")
        # import sys
        # sys.exit(1)

    return model


def xception_taskonomy_small_ozan(pretrained=False,**kwargs):
    """
    Construct Xception.
    """
    
    model = XceptionTaskonomySmall(ozan=True,**kwargs)

    
    if pretrained:
        #model.load_state_dict(torch.load('xception_taskonomy_small_imagenet_pretrained.pth.tar'))
        model.encoder.load_state_dict(torch.load('xception_taskonomy_small2.encoder.pth.tar'))

    return model


def xception_taskonomy_small_gradnorm(pretrained=False,**kwargs):
    """
    Construct Xception.
    """
    
    model = XceptionTaskonomySmall(ozan='gradnorm',**kwargs)

    
    if pretrained:
        #model.load_state_dict(torch.load('xception_taskonomy_small_imagenet_pretrained.pth.tar'))
        model.encoder.load_state_dict(torch.load('xception_taskonomy_small2.encoder.pth.tar'))

    return model


def xception_taskonomy_small_half(pretrained=False,**kwargs):
    """
    Construct Xception.
    """
    # try:
    #     num_classes = kwargs['num_classes']
    # except:
    #     num_classes=1000
    # if pretrained:
    #     kwargs['num_classes']=1000
    model = XceptionTaskonomySmall(half=True,**kwargs)

    
    if pretrained:
        #state_dict = model_zoo.load_url(model_urls['xception_taskonomy'])
        # for name,weight in state_dict.items():
        #     if 'pointwise' in name:
        #         state_dict[name]=weight.unsqueeze(-1).unsqueeze(-1)
        #     if 'conv1' in name and len(weight.shape)!=4:
        #         state_dict[name]=weight.unsqueeze(1)
        #model.load_state_dict(state_dict)
        #model.load_state_dict(torch.load('xception_taskonomy_small_imagenet_pretrained.pth.tar'))
        model.encoder.load_state_dict(torch.load('xception_taskonomy_small2.encoder.pth.tar'))
        # if num_classes !=1000:
        #     model.fc = nn.Linear(2048, num_classes)
        # import torch
        # print("writing new state dict")
        # torch.save(model.state_dict(),"xception.pth.tar")
        # print("done")
        # import sys
        # sys.exit(1)

    return model

def xception_taskonomy_small_quad(pretrained=False,**kwargs):
    """
    Construct Xception.
    """
    # try:
    #     num_classes = kwargs['num_classes']
    # except:
    #     num_classes=1000
    # if pretrained:
    #     kwargs['num_classes']=1000
    print('got quad')
    model = XceptionTaskonomySmall(half='Quad',**kwargs)

    
    if pretrained:
        #state_dict = model_zoo.load_url(model_urls['xception_taskonomy'])
        # for name,weight in state_dict.items():
        #     if 'pointwise' in name:
        #         state_dict[name]=weight.unsqueeze(-1).unsqueeze(-1)
        #     if 'conv1' in name and len(weight.shape)!=4:
        #         state_dict[name]=weight.unsqueeze(1)
        #model.load_state_dict(state_dict)
        #model.load_state_dict(torch.load('xception_taskonomy_small_imagenet_pretrained.pth.tar'))
        model.encoder.load_state_dict(torch.load('xception_taskonomy_small2.encoder.pth.tar'))
        # if num_classes !=1000:
        #     model.fc = nn.Linear(2048, num_classes)
        # import torch
        # print("writing new state dict")
        # torch.save(model.state_dict(),"xception.pth.tar")
        # print("done")
        # import sys
        # sys.exit(1)

    return model

def xception_taskonomy_small_double(pretrained=False,**kwargs):
    """
    Construct Xception.
    """
    # try:
    #     num_classes = kwargs['num_classes']
    # except:
    #     num_classes=1000
    # if pretrained:
    #     kwargs['num_classes']=1000
    print('got double')
    model = XceptionTaskonomySmall(half='Double',**kwargs)

    
    if pretrained:
        #state_dict = model_zoo.load_url(model_urls['xception_taskonomy'])
        # for name,weight in state_dict.items():
        #     if 'pointwise' in name:
        #         state_dict[name]=weight.unsqueeze(-1).unsqueeze(-1)
        #     if 'conv1' in name and len(weight.shape)!=4:
        #         state_dict[name]=weight.unsqueeze(1)
        #model.load_state_dict(state_dict)
        #model.load_state_dict(torch.load('xception_taskonomy_small_imagenet_pretrained.pth.tar'))
        model.encoder.load_state_dict(torch.load('xception_taskonomy_small2.encoder.pth.tar'))
        # if num_classes !=1000:
        #     model.fc = nn.Linear(2048, num_classes)
        # import torch
        # print("writing new state dict")
        # torch.save(model.state_dict(),"xception.pth.tar")
        # print("done")
        # import sys
        # sys.exit(1)

    return model

def xception_taskonomy_small_quad_ozan(pretrained=False,**kwargs):
    """
    Construct Xception.
    """
    # try:
    #     num_classes = kwargs['num_classes']
    # except:
    #     num_classes=1000
    # if pretrained:
    #     kwargs['num_classes']=1000
    print('got quad ozan')
    model = XceptionTaskonomySmall(ozan=True,half='Quad',**kwargs)

    
    if pretrained:
        #state_dict = model_zoo.load_url(model_urls['xception_taskonomy'])
        # for name,weight in state_dict.items():
        #     if 'pointwise' in name:
        #         state_dict[name]=weight.unsqueeze(-1).unsqueeze(-1)
        #     if 'conv1' in name and len(weight.shape)!=4:
        #         state_dict[name]=weight.unsqueeze(1)
        #model.load_state_dict(state_dict)
        #model.load_state_dict(torch.load('xception_taskonomy_small_imagenet_pretrained.pth.tar'))
        model.encoder.load_state_dict(torch.load('xception_taskonomy_small2.encoder.pth.tar'))
        # if num_classes !=1000:
        #     model.fc = nn.Linear(2048, num_classes)
        # import torch
        # print("writing new state dict")
        # torch.save(model.state_dict(),"xception.pth.tar")
        # print("done")
        # import sys
        # sys.exit(1)

    return model


def xception_taskonomy_small_double_ozan(pretrained=False,**kwargs):
    """
    Construct Xception.
    """
    # try:
    #     num_classes = kwargs['num_classes']
    # except:
    #     num_classes=1000
    # if pretrained:
    #     kwargs['num_classes']=1000
    print('got double')
    model = XceptionTaskonomySmall(ozan=True,half='Double',**kwargs)

    
    if pretrained:
        #state_dict = model_zoo.load_url(model_urls['xception_taskonomy'])
        # for name,weight in state_dict.items():
        #     if 'pointwise' in name:
        #         state_dict[name]=weight.unsqueeze(-1).unsqueeze(-1)
        #     if 'conv1' in name and len(weight.shape)!=4:
        #         state_dict[name]=weight.unsqueeze(1)
        #model.load_state_dict(state_dict)
        #model.load_state_dict(torch.load('xception_taskonomy_small_imagenet_pretrained.pth.tar'))
        model.encoder.load_state_dict(torch.load('xception_taskonomy_small2.encoder.pth.tar'))
        # if num_classes !=1000:
        #     model.fc = nn.Linear(2048, num_classes)
        # import torch
        # print("writing new state dict")
        # torch.save(model.state_dict(),"xception.pth.tar")
        # print("done")
        # import sys
        # sys.exit(1)

    return model


def xception_taskonomy_small_half_ozan(pretrained=False,**kwargs):
    """
    Construct Xception.
    """
    
    model = XceptionTaskonomySmall(ozan=True,half=True,**kwargs)

    
    if pretrained:
        #model.load_state_dict(torch.load('xception_taskonomy_small_imagenet_pretrained.pth.tar'))
        model.encoder.load_state_dict(torch.load('xception_taskonomy_small2.encoder.pth.tar'))

    return model
