from models.wrn import WideResNet
import torch
import torch.nn as  nn
from torchvision.models import densenet121
import numpy as np
import torchvision
from .resnet import ResNet18
import os
from datasets.connector import build_dataset
import torch.optim as optim
import sys
import torchvision.transforms as trn

import timm


def build_model(model_type, num_classes, device, args):
    if model_type == "WideResNet":
        net = WideResNet(args.layers, num_classes, args.widen_factor, dropRate=args.droprate)

    if model_type == "resnet18":
        net = ResNet18(num_classes)
    
    
    net.to(device)
    if args.gpu is not None and len(args.gpu) > 1:
        gpu_list = [int(s) for s in args.gpu.split(',')]
        net = torch.nn.DataParallel(net, device_ids=gpu_list)
    return net


def build_common_model_imagnet(modelname,mode="test",pre_trained=True,gpus=[0],dataParallel=False):
    if modelname == 'ResNet18':
        model = torchvision.models.resnet18(weights="IMAGENET1K_V1", progress=True)

    elif modelname == 'ResNet50':
        model = torchvision.models.resnet50(weights="IMAGENET1K_V1", progress=True)

    elif modelname == 'ResNet101':
        model = torchvision.models.resnet101(weights="IMAGENET1K_V1", progress=True)

    elif modelname == 'ResNet152':
        model = torchvision.models.resnet152(weights="IMAGENET1K_V1", progress=True)

    elif modelname == 'ResNeXt101':
        model = torchvision.models.resnext101_32x8d(weights="IMAGENET1K_V1", progress=True)

    elif modelname == 'VGG16':
        model = torchvision.models.vgg16(weights="IMAGENET1K_V1", progress=True)
    elif modelname == "VGG16_BN":
        model = torchvision.models.vgg16_bn(weights="IMAGENET1K_V1", progress=True)

    elif modelname == 'ShuffleNet':
        model = torchvision.models.shufflenet_v2_x1_0(weights="IMAGENET1K_V1", progress=True)
    elif modelname =="ShuffleNet_v2_x2_0":
        model = torchvision.models.shufflenet_v2_x2_0(weights="IMAGENET1K_V1", progress=True)

    elif modelname == 'Inception':
        model = torchvision.models.inception_v3(weights="IMAGENET1K_V1", progress=True)

    elif modelname == 'DenseNet161':
        model = torchvision.models.densenet161(weights="IMAGENET1K_V1", progress=True)
    elif modelname == "ViT":
        model = torchvision.models.vit_b_16(weights="IMAGENET1K_V1", progress=True)

    elif    modelname == "DeiT":
        # 定义DeiT模型架构名称
        # model_name = 'deit_base_patch16_224.fb_in1k'

        # 加载DeiT模型

        model = timm.create_model("hf_hub:timm/deit_base_distilled_patch16_224.fb_in1k", pretrained=True)

    else:
        raise NotImplementedError
    




    if mode == "test":
        model.eval()
    else:
        model.train()
    if dataParallel :
        model = torch.nn.DataParallel(model) .cuda()
    else:
        model.cuda()

    return model



def build_common_model_cifar10(modelname,mode="test",pre_trained=True,gpus=[0],dataParallel=False,model_pkl_path=""):
    num_classes  = 10
    num_epochs = 10
    transform = None
    batch_size = 1024
    mean = (0.492, 0.482, 0.446)
    std = (0.247, 0.244, 0.262)
    lr = 0.001
    if modelname == 'ResNet18':
        model = torchvision.models.resnet18(weights="IMAGENET1K_V1", progress=True)
        # 修改分类层，适应CIFAR-10的类别数 
        
        in_features = model.fc.in_features
        model.fc = nn.Linear(in_features, num_classes)

    elif modelname == 'ResNet50':
        model = torchvision.models.resnet50(weights="IMAGENET1K_V1", progress=True)
        in_features = model.fc.in_features
        model.fc = nn.Linear(in_features, num_classes)
       

    elif modelname == 'ResNet101':
        model = torchvision.models.resnet101(weights="IMAGENET1K_V1", progress=True)
        in_features = model.fc.in_features
        model.fc = nn.Linear(in_features, num_classes)
       
    elif modelname == 'ResNet152':
        model = torchvision.models.resnet152(weights="IMAGENET1K_V1", progress=True)

    elif modelname == 'ResNeXt101':
        model = torchvision.models.resnext101_32x8d(weights="IMAGENET1K_V1", progress=True)

    elif modelname == 'VGG16':
        model = torchvision.models.vgg16(weights="IMAGENET1K_V1", progress=True)
        num_features = model.classifier[6].in_features
        model.classifier[6] = nn.Linear(num_features, num_classes)
        
    elif modelname == "VGG16_BN":
        model = torchvision.models.vgg16_bn(weights="IMAGENET1K_V1", progress=True)

    elif modelname == 'ShuffleNet':
        model = torchvision.models.shufflenet_v2_x1_0(weights="IMAGENET1K_V1", progress=True)
    elif modelname =="ShuffleNet_v2_x2_0":
        model = torchvision.models.shufflenet_v2_x2_0(weights="IMAGENET1K_V1", progress=True)

    elif modelname == 'Inception':
        model = torchvision.models.inception_v3(weights="IMAGENET1K_V1", progress=True)
        in_features = model.fc.in_features
        model.fc = nn.Linear(in_features, num_classes)
        transform = trn.Compose([trn.RandomHorizontalFlip(), 
                                   trn.RandomCrop(32, padding=4),
                                    trn.Resize(299),
                                    trn.ToTensor(), 
                                    trn.Normalize(mean, std)])
        batch_size = 128
        num_epochs=2
        lr=1e-4
    elif modelname == 'DenseNet161':
        model = torchvision.models.densenet161(weights="IMAGENET1K_V1", progress=True)
        in_features = model.classifier.in_features
        
        model.classifier = nn.Linear(in_features, num_classes)
    elif modelname == "ViT":
        
        model = torchvision.models.vit_b_16(weights="IMAGENET1K_V1", progress=True)
        in_features = model.heads[0].in_features
        model.heads = nn.Linear(in_features, num_classes)
        transform = trn.Compose([trn.RandomHorizontalFlip(), 
                                   trn.RandomCrop(32, padding=4),
                                    trn.Resize(224),
                                    trn.ToTensor(), 
                                    trn.Normalize(mean, std)])
        # print(model)
        batch_size = 128
        num_epochs=4
        lr=1e-5
    else:
        raise NotImplementedError
    


    
    # if len(gpus)>0:
    #     if len(gpus) == 1:
    #         device = torch.device('cuda:{}'.format(int(gpus[0])))
    #     else:
    #         device = torch.device('cuda:')
    # else:
    #     device = torch.device('cpu')
    # model.to(device)

    # # multi-GPU loading the data
    # if args.gpu is not None and len(args.gpu) > 1:
    #     gpu_list = [int(s) for s in args.gpu.split(',')]
    #     model = torch.nn.DataParallel(model, device_ids=gpu_list)
    # else:
    #     model = torch.nn.DataParallel(model)
    
    # 对模型进行预训练
    dataset_name = 'cifar10'
    if pre_trained==True or model_pkl_path!="":
        if model_pkl_path =="":
            # 用户的路径
            # usr_dir = os.path.expanduser('~')
            # data_dir = os.path.join(usr_dir,"data")
            # model_pkl_path = os.path.join(data_dir,"{}_pretrain".format(dataset_name),'{}_finetuned.pth'.format(modelname))
            

            model_pkl_path = os.path.join("data/ckpt.pth")
        if os.path.exists(model_pkl_path):
                pretrained_dict = torch.load(model_pkl_path)
                model.load_state_dict(pretrained_dict)
        else:
            print("Dataset is {}. No pretrained model, starting training!!!".format(dataset_name))
            criterion = nn.CrossEntropyLoss()
            
            train_dataset,_ = build_dataset(dataset_name,mode="train",transform=transform)
            trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
            
            # optimizer = optim.Adam(model.parameters(), lr=0.001)
            model.cuda()
           
            optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
            

            for epoch in range(num_epochs):
                model.train()
                running_loss = 0.0
                pre_acc=0
                for inputs, labels in trainloader:
                    labels= labels.cuda()
                    inputs= inputs.cuda()
                    optimizer.zero_grad()
                    outputs = model(inputs)
                    if modelname == "Inception":
                        outputs = outputs.logits
                    pre_acc += torch.sum(torch.argmax(outputs,axis=1)== labels)
                    loss = criterion(outputs, labels)
                    loss.backward()
                    optimizer.step()
                    running_loss += loss.item()
                print(f"Epoch {epoch + 1}, Loss: {running_loss / len(trainloader)},ACC.: {pre_acc / len(train_dataset)}")

            # 保存微调后的模型
            torch.save(model.state_dict(), model_pkl_path)
        
        
        if mode == "test":
            model.eval()
        else:
            model.train()
        if dataParallel :
            model = torch.nn.DataParallel(model) .cuda()
        else:
            model.cuda()
    return model

# from. import models_cifar100
def build_common_model_cifar100(net,mode,pre_trained,gpus,dataParallel,model_pkl_path=""):
    """ return given network
    """

    if net == 'VGG16':
        from .models_cifar100.vgg import vgg16_bn
        net = vgg16_bn()
        if model_pkl_path =="":
            model_pkl_path = "/home//MyFiles/pytorch-cifar100/checkpoint/vgg16/Monday_31_July_2023_14h_56m_26s/vgg16-200-regular.pth"
    elif net == 'vgg13':
        from models_cifar100.vgg import vgg13_bn
        net = vgg13_bn()
    elif net == 'vgg11':
        from models_cifar100.vgg import vgg11_bn
        net = vgg11_bn()
    elif net == 'vgg19':
        from models_cifar100.vgg import vgg19_bn
        net = vgg19_bn()
    elif net == 'densenet121':
        from models_cifar100.densenet import densenet121
        net = densenet121()
    elif net == 'DenseNet161':
        from .models_cifar100.densenet import densenet161
        if model_pkl_path =="":
            model_pkl_path = "/home//MyFiles/pytorch-cifar100/checkpoint/DenseNet161/Monday_31_July_2023_15h_24m_16s/DenseNet161-200-regular.pth"
        net = densenet161()
    elif net == 'densenet169':
        from models_cifar100.densenet import densenet169
        net = densenet169()
    elif net == 'densenet201':
        from models_cifar100.densenet import densenet201
        net = densenet201()
    elif net == 'googlenet':
        from models_cifar100.googlenet import googlenet
        net = googlenet()
    elif net == 'Inception':
        from .models_cifar100.inceptionv3 import inceptionv3
        if model_pkl_path =="":
            model_pkl_path = "/home//MyFiles/pytorch-cifar100/checkpoint/inceptionv3/Tuesday_22_August_2023_12h_22m_25s/inceptionv3-200-regular.pth"
        net = inceptionv3()
    elif net == 'inceptionv4':
        from models_cifar100.inceptionv4 import inceptionv4
        net = inceptionv4()
    elif net == 'inceptionresnetv2':
        from models_cifar100.inceptionv4 import inception_resnet_v2
        net = inception_resnet_v2()
    elif net == 'xception':
        from models_cifar100.xception import xception
        net = xception()
    elif net == 'ResNet18':
        from .models_cifar100.resnet import resnet18
        if model_pkl_path =="":
            model_pkl_path = "/home//MyFiles/pytorch-cifar100/checkpoint/resnet18/Monday_31_July_2023_15h_17m_56s/resnet18-200-regular.pth"
        net = resnet18()
    elif net == 'resnet34':
        from models_cifar100.resnet import resnet34
        net = resnet34()
    elif net == 'ResNet50':
        from .models_cifar100.resnet import resnet50
        model_pkl_path = "/home//MyFiles/pytorch-cifar100/checkpoint/resnet50/Monday_31_July_2023_15h_20m_53s/resnet50-200-regular.pth"
        net = resnet50()
    elif net == 'ResNet101':
        from .models_cifar100.resnet import resnet101
        if model_pkl_path =="":
            model_pkl_path = "/home//MyFiles/pytorch-cifar100/checkpoint/ResNet101/Monday_31_July_2023_15h_26m_23s/ResNet101-200-regular.pth"
        net = resnet101()
    elif net == 'resnet152':
        from models_cifar100.resnet import resnet152
        net = resnet152()
    elif net == 'preactresnet18':
        from models_cifar100.preactresnet import preactresnet18
        net = preactresnet18()
    elif net == 'preactresnet34':
        from models_cifar100.preactresnet import preactresnet34
        net = preactresnet34()
    elif net == 'preactresnet50':
        from models_cifar100.preactresnet import preactresnet50
        net = preactresnet50()
    elif net == 'preactresnet101':
        from models_cifar100.preactresnet import preactresnet101
        net = preactresnet101()
    elif net == 'preactresnet152':
        from models_cifar100.preactresnet import preactresnet152
        net = preactresnet152()
    elif net == 'resnext50':
        from models_cifar100.resnext import resnext50
        net = resnext50()
    elif net == 'resnext101':
        from models_cifar100.resnext import resnext101
        net = resnext101()
    elif net == 'resnext152':
        from models_cifar100.resnext import resnext152
        net = resnext152()
    elif net == 'shufflenet':
        from models_cifar100.shufflenet import shufflenet
        net = shufflenet()
    elif net == 'shufflenetv2':
        from models_cifar100.shufflenetv2 import shufflenetv2
        net = shufflenetv2()
    elif net == 'squeezenet':
        from models_cifar100.squeezenet import squeezenet
        net = squeezenet()
    elif net == 'mobilenet':
        from models_cifar100.mobilenet import mobilenet
        net = mobilenet()
    elif net == 'mobilenetv2':
        from models_cifar100.mobilenetv2 import mobilenetv2
        net = mobilenetv2()
    elif net == 'nasnet':
        from models_cifar100.nasnet import nasnet
        net = nasnet()
    elif net == 'attention56':
        from models_cifar100.attention import attention56
        net = attention56()
    elif net == 'attention92':
        from models_cifar100.attention import attention92
        net = attention92()
    elif net == 'seresnet18':
        from models_cifar100.senet import seresnet18
        net = seresnet18()
    elif net == 'seresnet34':
        from models_cifar100.senet import seresnet34
        net = seresnet34()
    elif net == 'seresnet50':
        from models_cifar100.senet import seresnet50
        net = seresnet50()
    elif net == 'seresnet101':
        from models_cifar100.senet import seresnet101
        net = seresnet101()
    elif net == 'seresnet152':
        from models_cifar100.senet import seresnet152
        net = seresnet152()
    elif net == 'wideresnet':
        from models_cifar100.wideresidual import wideresnet
        net = wideresnet()
    elif net == 'stochasticdepth18':
        from models_cifar100.stochasticdepth import stochastic_depth_resnet18
        net = stochastic_depth_resnet18()
    elif net == 'stochasticdepth34':
        from models_cifar100.stochasticdepth import stochastic_depth_resnet34
        net = stochastic_depth_resnet34()
    elif net == 'stochasticdepth50':
        from models_cifar100.stochasticdepth import stochastic_depth_resnet50
        net = stochastic_depth_resnet50()
    elif net == 'stochasticdepth101':
        from models_cifar100.stochasticdepth import stochastic_depth_resnet101
        net = stochastic_depth_resnet101()

    elif net == "ViT":
        net = torchvision.models.vit_b_16(weights="IMAGENET1K_V1", progress=True)
        in_features = net.heads[0].in_features
        net.heads = nn.Linear(in_features, 100)
        if model_pkl_path =="":
            model_pkl_path = ""
        
    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    if pre_trained:
        
        # # 用户的路径
        # usr_dir = os.path.expanduser('~')
        # data_dir = os.path.join(usr_dir,"data")
        # model_pkl_path = os.path.join(data_dir,)
        pretrained_dict = torch.load(model_pkl_path)
        net.load_state_dict(pretrained_dict)
    
    if mode == "test":
        net.eval()
    else:
        net.train()
    if dataParallel :
        net = torch.nn.DataParallel(net) .cuda()
    else:
        net.cuda()

    return net

def build_common_model(modelname,dataset_name="imagnet",mode="test",pre_trained=True,gpus=[0],dataParallel=False,path="",category=""):
    if category == "hua":
        from .models_cifar100.resnet_hua import ResNet50
        model_pkl_path = "/home//MyFiles/CP_test/data/ckpt.pth"
        net = ResNet50()
        net = torch.nn.DataParallel(net)
        torch.backends.cudnn.benchmark = True

        pretrained_dict = torch.load(model_pkl_path)
        net.load_state_dict(pretrained_dict['net'])
        net.cuda()
        
        return net
    if modelname == "CLIP":
        # Load the model
        from  .clip import load as clipload
        device = "cuda" if torch.cuda.is_available() else "cpu"
        model, preprocess = clipload('ViT-B/32', device)
        return [model.eval().to(device),preprocess]
        
    if "imagenet" in dataset_name:
        return build_common_model_imagnet(modelname,mode,pre_trained,gpus,dataParallel)
    
    elif "cifar10" == dataset_name:
        return build_common_model_cifar10(modelname,mode,pre_trained,gpus,dataParallel,model_pkl_path=path)
    
    elif "cifar100" == dataset_name:
        return build_common_model_cifar100(modelname,mode,pre_trained,gpus,dataParallel,model_pkl_path=path)
    else:
        raise NotImplementedError
    
    
def getFinalLayer(model,modelname,dataset_name):
    if modelname in ["ResNeXt101","ResNet152","ResNet101","ResNet50","ResNet18"]:
        layer = model.fc
    elif modelname =="DenseNet161":
        layer = model.classifier
    elif modelname == "VGG16":
        layer = model.classifier[6]
        
        

    elif modelname =="Inception":
      
        layer = model.fc
        
    elif modelname =="ShuffleNet":
        
        layer = model.fc
    elif modelname =="ViT":
        
        layer = model.heads
    elif modelname == "DeiT":
        layer1 = model.head
        layer2 = model.head_dist
        class tmpLayer(nn.Module):
            def __init__(self,layer1,layer2) -> None:
                super().__init__()
                self.layer1 = layer1
                self.layer2 = layer2
                
            def forward(self,x):
                x_t=x[:,:768]
                x_dist=x[:,768:]
                return (self.layer1(x_t)+self.layer2(x_dist))/2
                
        layer = tmpLayer(layer1,layer2)
    else:
        raise NotImplementedError
    return layer
    
def get_prediction_head(modelname,dataset_name="imagnet",mode="test",pre_trained=True,gpus=[0],dataParallel=False,path=""):

    if "imagenet" in dataset_name:
        model =  build_common_model_imagnet(modelname,mode,pre_trained,gpus,dataParallel)
        return getFinalLayer(model,modelname,dataset_name)
    
    elif "cifar10" == dataset_name:
        model=  build_common_model_cifar10(modelname,mode,pre_trained,gpus,dataParallel,model_pkl_path=path)
        return getFinalLayer(model,modelname,dataset_name)
    
    elif "cifar100" == dataset_name:
        return build_common_model_cifar100(modelname,mode,pre_trained,gpus,dataParallel,model_pkl_path=path)
    else:
        raise NotImplementedError