import os
from copy import deepcopy

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

class Identity(nn.Module):
    def __init__(self,n_inputs):
        super(Identity, self).__init__()
        self.in_features=n_inputs
        
    def forward(self, x):
        return x

def get_resnet( classes, fc_layer, num_ch, pre_trained, os_env):    
    
    if os_env:        
        model=  torchvision.models.resnet18()
        if pre_trained:
            model.load_state_dict(torch.load( os.getenv('PT_DATA_DIR') + '/checkpoints/resnet18-5c106cde.pth' ))
    else:
        model=  torchvision.models.resnet18(pre_trained)
        
    n_inputs = model.fc.in_features
    n_outputs= classes

    if fc_layer:
        model.fc = nn.Linear(n_inputs, n_outputs)
    else:
        print('Here')
        model.fc = Identity(n_inputs)
        
    if num_ch==1:
        model.conv1 = nn.Conv2d(1, 64, 
                                kernel_size=(7, 7), 
                                stride=(2, 2), 
                                padding=(3, 3), 
                        bias=False)
    return model


def get_model_for_fish(classes,num_ch,pre_trained,os_env):
    resnet_enc=get_resnet( classes,0, num_ch, pre_trained, os_env)
    resnet_fc=get_resnet( classes,1, num_ch, pre_trained, os_env)
    fc=resnet_fc.fc
    model=Model_for_fish(resnet_enc,fc,weights=None)
    return model 


class Model_for_fish(nn.Module):
    def __init__(self,enc,fc,weights):
        super(Model_for_fish, self).__init__()
        #self.num_classes = NUM_CLASSES
        self.enc=enc
        self.fc=fc
        ####resnet18:512 resnet50:2048
        if weights is not None:
            self.load_state_dict(deepcopy(weights))

    def reset_weights(self, weights):###加载权重
        self.load_state_dict(deepcopy(weights))

    def forward(self, x):
        # x = x.expand(-1, 3, -1, -1)  # reshape MNIST from 1x32x32 => 3x32x32
        if len(x.shape) == 3:
            x.unsqueeze_(0)
        e = self.enc(x)
        return self.fc(e.squeeze(-1).squeeze(-1))