# coding=utf-8
import torch.nn as nn
from torchvision import models
from DataAug.Mixup.EFDMix import EFDMix

vgg_dict = {"vgg11": models.vgg11, "vgg13": models.vgg13, "vgg16": models.vgg16, "vgg19": models.vgg19,
            "vgg11bn": models.vgg11_bn, "vgg13bn": models.vgg13_bn, "vgg16bn": models.vgg16_bn, "vgg19bn": models.vgg19_bn}


class VGGBase(nn.Module):
    def __init__(self, args):
        super(VGGBase, self).__init__()
        model_vgg = vgg_dict[args.net](pretrained= not args.no_pretrained)
        self.features = model_vgg.features
        # self.classifier = nn.Sequential()
        # for i in range(6):     # remove the final classifier layer. now classifiers sequential in_dim is 25088, out_dim is 4096
        #     self.classifier.add_module(
        #         "classifier"+str(i), model_vgg.classifier[i])
        # self.in_features = model_vgg.classifier[6].in_features
        self.in_features = 512  # input image shape should be (3, 32, 32)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        # x = self.classifier(x)
        return x


res_dict = {"resnet18": models.resnet18, "resnet34": models.resnet34, "resnet50": models.resnet50,
            "resnet101": models.resnet101, "resnet152": models.resnet152, "resnext50": models.resnext50_32x4d, "resnext101": models.resnext101_32x8d}


class ResBase(nn.Module):
    def __init__(self, args):
        super(ResBase, self).__init__()
        model_resnet = res_dict[args.net](pretrained= not args.no_pretrained)
        self.conv1 = model_resnet.conv1
        self.bn1 = model_resnet.bn1
        self.relu = model_resnet.relu
        self.maxpool = model_resnet.maxpool
        self.layer1 = model_resnet.layer1
        self.layer2 = model_resnet.layer2
        self.layer3 = model_resnet.layer3
        self.layer4 = model_resnet.layer4
        self.avgpool = model_resnet.avgpool
        self.in_features = model_resnet.fc.in_features

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        return x

class ResMix(nn.Module):
    def __init__(self, args):
        super(ResMix, self).__init__()
        self.args = args
        model_resnet = res_dict[args.net](pretrained= not args.no_pretrained)
        self.conv1 = model_resnet.conv1
        self.bn1 = model_resnet.bn1
        self.relu = model_resnet.relu
        self.maxpool = model_resnet.maxpool
        self.layer1 = model_resnet.layer1
        self.layer2 = model_resnet.layer2
        self.layer3 = model_resnet.layer3
        self.layer4 = model_resnet.layer4
        self.avgpool = model_resnet.avgpool
        self.in_features = model_resnet.fc.in_features

        # mix style model
        self.efdmix = EFDMix(args, p=0.5, alpha=0.1)
        self.mix = False

    def forward(self, x):
        '''
        when self.mix is True
        x[:batch_size]: current domain images as content images. torch tensor [batch, C, H ,W]
        x[batch_size:]: previous domain images as style images. torch tensor [batch, C, H ,W]
        '''
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        
        if self.mix:
            x_content = x[:self.args.batch_size]
            x_style = x[self.args.batch_size:]
            x_content = self.efdmix(x_content, x_style) if 1 in self.args.mix_layers else x_content

            x_content = self.layer2(x_content)
            x_style = self.layer2(x_style)
            x_content = self.efdmix(x_content, x_style) if 2 in self.args.mix_layers else x_content

            x_content = self.layer3(x_content)
            x_style = self.layer3(x_style)
            x_content = self.efdmix(x_content, x_style) if 3 in self.args.mix_layers else x_content

            x_content = self.layer4(x_content)
            x_style = self.layer4(x_style)
            x_content = self.efdmix(x_content, x_style) if 4 in self.args.mix_layers else x_content
            x = x_content
        
        else:
            x = self.layer2(x)
            x = self.layer3(x)
            x = self.layer4(x)
        
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        return x

    def activate_mix(self, mix=True):
        self.mix = mix

class DTNBase(nn.Module):
    def __init__(self):
        super(DTNBase, self).__init__()
        self.conv_params = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(64),
            nn.Dropout2d(0.1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(128),
            nn.Dropout2d(0.3),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(256),
            nn.Dropout2d(0.5),
            nn.ReLU()
        )
        self.in_features = 256*4*4

    def forward(self, x):
        x = self.conv_params(x)
        x = x.view(x.size(0), -1)
        return x


class LeNetBase(nn.Module):
    def __init__(self):
        super(LeNetBase, self).__init__()
        self.conv_params = nn.Sequential(
            nn.Conv2d(3, 20, kernel_size=5),
            nn.MaxPool2d(2),
            nn.ReLU(),
            nn.Conv2d(20, 50, kernel_size=5),
            nn.Dropout2d(p=0.5),
            nn.MaxPool2d(2),
            nn.ReLU(),
        )
        self.in_features = 50*4*4

    def forward(self, x):
        x = self.conv_params(x)
        x = x.view(x.size(0), -1)
        return x
