
import torch
import torch.nn as nn
# from  torch import
import torchvision
import timm
from Vits import define_Vit
import torchvision.models as models
import dill
import pickle


class Softmaxnet(torch.nn.Module):
    def __init__(self, in_feature, out_feature):
        super().__init__()
        
        self.linears = torch.nn.Sequential(
            torch.nn.Linear(in_features=in_feature, out_features=out_feature))
        
        self.softmax = torch.nn.Softmax(dim=1)
    def forward(self, x):
        x = self.linears(x)
        value,_ = torch.max(x,1)
        value = torch.unsqueeze(value, -1)
        x = torch.add(x, -value)
        x = self.softmax(x)
        return x
    

    
class Softmaxnet_new(torch.nn.Module):
    def __init__(self, in_feature, out_feature):
        super().__init__()
        
        self.linears = torch.nn.Sequential(
            torch.nn.Linear(in_features=in_feature, out_features=out_feature),
            torch.nn.BatchNorm1d(out_feature))
        
        self.softmax = torch.nn.Softmax(dim=1)
    def forward(self, x):
        x1 = self.linears(x)
        x2 = self.softmax(x1)
        return x2, x1


def Getmodel(task, backbone, n_classes, stragety, pretrain_param):
    if backbone == "resnet18":
        if pretrain_param == 'super':
            model = torchvision.models.resnet18(pretrained=True)
        elif pretrain_param == 'nosuper':
            model = ResNet18()
            checkpoint = torch.load('')
            model.load_state_dict(checkpoint["online_network_state_dict"])
            model.eval()
        else:
            model = torchvision.models.resnet18(pretrained=False)

    if backbone=="resnet50":
        if pretrain_param=='super':
            model = torchvision.models.resnet50(pretrained=True)
        elif pretrain_param=='nosuper':
            model = torchvision.models.resnet50(pretrained=False)
            model.fc = Identity()
            # You need pretrained model in Simsiam file
            checkpoint = torch.load('')
            new_state_dict = {}
            for k,v in checkpoint['state_dict'].items():
                new_state_dict[k[15:]] = v
            [new_state_dict.pop(i) for i in [ "r.0.weight", "r.1.weight", "r.1.bias", "r.1.running_mean", "r.1.running_var", "r.1.num_batches_tracked", "r.3.weight", "r.3.bias", "fc.0.weight", "fc.1.weight", "fc.1.bias", "fc.1.running_mean", "fc.1.running_var", "fc.1.num_batches_tracked", "fc.3.weight", "fc.4.weight", "fc.4.bias", "fc.4.running_mean", "fc.4.running_var", "fc.4.num_batches_tracked", "fc.6.weight", "fc.6.bias", "fc.7.running_mean", "fc.7.running_var", "fc.7.num_batches_tracked"]] 
            model.load_state_dict(new_state_dict)
            model.eval()

        elif pretrain_param=='kaiminginit_res50':
            model= torchvision.models.resnet50(pretrained=False)
            model.apply(weight_init)
        
        elif pretrain_param=='Orthogonalinit_res50':#It is mainly used to solve the problems of gradient disappearance and gradient explosion in neural networks, and is a commonly used initialization method in RNN
            model= torchvision.models.resnet50(pretrained=False)
            for m in modules():
                if isinstance(m, torch.nn.Conv2d):
                    nn.init.orthogonal(m.weight)
        
        else:
            model = torchvision.models.resnet50(pretrained=False)
            

    if backbone=="resnet50_2":
        if pretrain_param=='super':
            model = timm.create_model('wide_resnet50_2', pretrained=True, num_classes=0)
        else:
            model = timm.create_model('wide_resnet50_2', pretrained=False, num_classes=0)

    if backbone=="resnet101":
        if pretrain_param=='super':
            model= torchvision.models.resnet101(pretrained=True)
        else:
            model= torchvision.models.resnet101(pretrained=False)

    if backbone=="resnet101_2":
        if pretrain_param == 'super':
            model = timm.create_model('wide_resnet101_2', pretrained=True, num_classes=0)
        else:
            model = timm.create_model('wide_resnet101_2', pretrained=False, num_classes=0)

    if backbone[:3]=='vit':
        model=define_Vit(backbone,n_classes,pretrain_param)

    if backbone[:8]=="resnet18" and (task=="CIFAR10" or task=="CIFAR10_noresize" or task=="CIFAR100" or task=="ImageNet"or task=='car196'or task =='Flowers102')  and stragety != "CrossEntropy":
        if pretrain_param == 'nosuper':
            model.projetion = Softmaxnet(512, n_classes)
        else:
            model.fc = Softmaxnet(512, n_classes)
    elif backbone[:8]=="resnet50" and (task=="CIFAR10" or task=="CIFAR10_noresize" or task=="CIFAR100" or task=="ImageNet" or task=="DTD" or task=="Food101" or task=='aircraft' or task=='Pets'or task=='car196' or task=='SVHN' or task=='SUN397' or task =='Flowers102' or task =='Country211'):
        if pretrain_param == 'nosuper':
            
            model.fc = Softmaxnet(2048, n_classes)
        else:
            # Softmax cannot be used for cross-entropy
            if task=="ImageNet":
                model.fc = Softmaxnet(2048, n_classes)
            else:
                model.fc = Softmaxnet(2048, n_classes)
                
    elif backbone[:9]=="resnet101":
        model.fc = torch.nn.Linear(2048, n_classes)
    elif backbone[:8] == "resnet18":
        model.fc = torch.nn.Linear(512, n_classes)
    return model


class nosuper_resnet50(torch.nn.Module):
    def __init__(self, model, n_classes):
        super(nosuper_resnet50, self).__init__()
#         model.encoder.fc = Softmaxnet(2048, n_classes)
#         self.net = model.encoder
        self.net = model
        self.fc = Softmaxnet(2048, n_classes)
    def forward(self,x):
        return self.fc(self.net(x))
    

class Identity(torch.nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, x):
        return x
        
def weight_init(m):
    if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
    elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)
