from cgitb import reset
from turtle import forward
import torch
from torch import nn
from torch.utils import model_zoo
import torchvision
from torchvision.models.resnet import BasicBlock, model_urls, Bottleneck
import os
from torch.nn import functional as F

class GroupNorm(torch.nn.GroupNorm):
    def __init__(self, num_channels, num_groups=32, **kwargs):
        super().__init__(num_groups, num_channels, **kwargs)



class MLP(nn.Module):   ####hparams是个config，直接写入宽高之类就好 256 0.2
    """Just  an MLP"""
    def __init__(self, n_inputs, n_outputs):
        super(MLP, self).__init__()
        self.input = nn.Linear(n_inputs,256)
        #self.dropout = nn.Dropout(0.2)
        self.hiddens = nn.ModuleList([
            nn.Linear(256, 256)
            for _ in range(256-2)])
        self.output = nn.Linear(256, n_outputs)
        self.n_outputs = n_outputs

    def forward(self, x):
        x = self.input(x)
        #x = self.dropout(x)
        x = F.relu(x)
        for hidden in self.hiddens:
            x = hidden(x)
            #x = self.dropout(x)
            x = F.relu(x)
        x = self.output(x)
        return x

# bypass layer
class Identity(nn.Module):
    def __init__(self,n_inputs):
        super(Identity, self).__init__()
        self.in_features=n_inputs
        
    def forward(self, x):
        return x

class Model_for_DANNIN(nn.Module):
    def __init__(self,feat,fc,disc,embed):
        super(Model_for_DANNIN, self).__init__()
        self.featurizer =feat
        self.classifier = fc
        self.discriminator = disc
        self.class_embeddings = embed
    def forward(self,input_data):
        feature = self.featurizer(input_data)
        class_output = self.classifier(feature)
        return class_output

def get_model_for_DANNIN(model_name,classes,num_ch,pre_trained,os_env):
    resnet_enc=get_resnet(model_name, 61,0, num_ch, pre_trained, os_env)###5--domain个数
    feat=resnet_enc
    resnet_fc=get_resnet(model_name, 61,1, num_ch, pre_trained, os_env)###5--domain个数
    fc=resnet_fc.fc
    disc=MLP(512,10)##label
    embed=nn.Embedding(61,512)
    model=Model_for_DANNIN(feat,fc,disc,embed)
    return model 

def get_resnet(model_name, classes, fc_layer,num_ch, pre_trained, os_env):    
    if model_name == 'resnet18':
        if os_env:        
            model=  torchvision.models.resnet18()
            if pre_trained:
                model.load_state_dict(torch.load( os.getenv('PT_DATA_DIR') + '/checkpoints/resnet18-5c106cde.pth' ))
        else:
            model=  torchvision.models.resnet18(pre_trained)
            
        n_inputs = model.fc.in_features
        n_outputs= classes
        
        
    elif model_name == 'resnet50':
        if os_env:        
            model=  torchvision.models.resnet50()
            if pre_trained:
                model.load_state_dict(torch.load( os.getenv('PT_DATA_DIR') + '/checkpoints/resnet50-19c8e357.pth' ))
        else:
            model=  torchvision.models.resnet50(pre_trained)
            
        n_inputs = model.fc.in_features
        n_outputs= classes
        
        model.fc = nn.Linear(n_inputs, n_outputs)
    if fc_layer:
        model.fc = nn.Linear(n_inputs, n_outputs)
    else:
        print('Here')
        model.fc = Identity(n_inputs)   
    if num_ch==1:
        model.conv1 = nn.Conv2d(1, 64, 
                                kernel_size=(7, 7), 
                                stride=(2, 2), 
                                padding=(3, 3), 
                        bias=False)
    return model