import torch
import torch.nn as nn
from torchvision.models import densenet121
import numpy as np
import torchvision
import os
import torch.optim as optim
import sys
from dataset.utils import build_dataset

import torchvision.transforms as trn


def build_common_model(modelname, dataset_name="imagnet", mode="test", trained=True, gpus=[0], dataParallel=False):
    if "imagenet" in dataset_name:
        return build_common_model_imagnet(modelname, mode, trained, gpus, dataParallel)
    elif "cifar10" == dataset_name:
        return build_common_model_cifar10(modelname, mode, trained, gpus, dataParallel)

    elif "cifar100" == dataset_name:
        return build_common_model_cifar100(modelname, mode, trained, gpus, dataParallel)
    else:
        raise NotImplementedError


def build_common_model_imagnet(modelname, mode="test", trained=True, gpus=[0], dataParallel=False):
    if modelname == 'ResNet18':
        model = torchvision.models.resnet18(weights="IMAGENET1K_V1", progress=True)

    elif modelname == 'ResNet50':
        model = torchvision.models.resnet50(weights="IMAGENET1K_V1", progress=True)

    elif modelname == 'ResNet101':
        model = torchvision.models.resnet101(weights="IMAGENET1K_V1", progress=True)

    elif modelname == 'ResNet152':
        model = torchvision.models.resnet152(weights="IMAGENET1K_V1", progress=True)

    elif modelname == 'ResNeXt101':
        model = torchvision.models.resnext101_32x8d(weights="IMAGENET1K_V1", progress=True)

    elif modelname == 'VGG16':
        model = torchvision.models.vgg16(weights="IMAGENET1K_V1", progress=True)
    elif modelname == "VGG16_BN":
        model = torchvision.models.vgg16_bn(weights="IMAGENET1K_V1", progress=True)

    elif modelname == 'ShuffleNet':
        model = torchvision.models.shufflenet_v2_x1_0(weights="IMAGENET1K_V1", progress=True)
    elif modelname == "ShuffleNet_v2_x2_0":
        model = torchvision.models.shufflenet_v2_x2_0(weights="IMAGENET1K_V1", progress=True)

    elif modelname == 'Inception':
        model = torchvision.models.inception_v3(weights="IMAGENET1K_V1", progress=True)

    elif modelname == 'DenseNet161':
        model = torchvision.models.densenet161(weights="IMAGENET1K_V1", progress=True)
    elif modelname == "ViT":
        model = torchvision.models.vit_b_16(weights="IMAGENET1K_V1", progress=True)

    elif modelname == "DeiT":

        model = timm.create_model("hf_hub:timm/deit_base_distilled_patch16_224.fb_in1k", pretrained=True)

    else:
        raise NotImplementedError

    if mode == "test":
        model.eval()
    else:
        model.train()
    if dataParallel:
        model = torch.nn.DataParallel(model).cuda()
    else:
        model.cuda()
    return model


def build_common_model_cifar10(modelname, mode="test", trained=True, gpus=[0], dataParallel=False):
    num_classes = 10
    transform = None
    lr = 0.001
    mean = (0.492, 0.482, 0.446)
    std = (0.247, 0.244, 0.262)
    batch_size = 1024
    num_epochs = 10
    if modelname == 'ResNet18':
        model = torchvision.models.resnet18(weights="IMAGENET1K_V1", progress=True)

        in_features = model.fc.in_features
        model.fc = nn.Linear(in_features, num_classes)

    elif modelname == 'ResNet50':
        model = torchvision.models.resnet50(weights="IMAGENET1K_V1", progress=True)
        in_features = model.fc.in_features
        model.fc = nn.Linear(in_features, num_classes)


    elif modelname == 'ResNet101':
        model = torchvision.models.resnet101(weights="IMAGENET1K_V1", progress=True)
        in_features = model.fc.in_features
        model.fc = nn.Linear(in_features, num_classes)

    elif modelname == 'ResNet152':
        model = torchvision.models.resnet152(weights="IMAGENET1K_V1", progress=True)

    elif modelname == 'ResNeXt101':
        model = torchvision.models.resnext101_32x8d(weights="IMAGENET1K_V1", progress=True)
        num_features = model.fc.in_features
        model.fc = nn.Linear(num_features, 10)

    elif modelname == 'VGG16':
        model = torchvision.models.vgg16(weights="IMAGENET1K_V1", progress=True)
        num_features = model.classifier[6].in_features
        model.classifier[6] = nn.Linear(num_features, num_classes)
    elif modelname == "VGG16_BN":
        model = torchvision.models.vgg16_bn(weights="IMAGENET1K_V1", progress=True)

    elif modelname == 'ShuffleNet':
        model = torchvision.models.shufflenet_v2_x1_0(weights="IMAGENET1K_V1", progress=True)
    elif modelname == "ShuffleNet_v2_x2_0":
        model = torchvision.models.shufflenet_v2_x2_0(weights="IMAGENET1K_V1", progress=True)

    elif modelname == 'Inception':
        model = torchvision.models.inception_v3(weights="IMAGENET1K_V1", progress=True)

    elif modelname == 'DenseNet161':
        model = torchvision.models.densenet161(weights="IMAGENET1K_V1", progress=True)
        in_features = model.classifier.in_features

        model.classifier = nn.Linear(in_features, num_classes)
    elif modelname == 'Inception':
        model = torchvision.models.inception_v3(weights="IMAGENET1K_V1", progress=True)
        in_features = model.fc.in_features
        model.fc = nn.Linear(in_features, num_classes)
        transform = trn.Compose([trn.RandomHorizontalFlip(),
                                 trn.RandomCrop(32, padding=4),
                                 trn.Resize(299),
                                 trn.ToTensor(),
                                 trn.Normalize(mean, std)])
        batch_size = 128
        num_epochs = 4
        lr = 1e-4
    elif modelname == "ViT":

        model = torchvision.models.vit_b_16(weights="IMAGENET1K_V1", progress=True)
        in_features = model.heads[0].in_features
        model.heads = nn.Linear(in_features, num_classes)

        transform = trn.Compose([trn.RandomHorizontalFlip(),
                                 trn.RandomCrop(32, padding=4),
                                 trn.Resize(224),
                                 trn.ToTensor(),
                                 trn.Normalize(mean, std)])
        num_epochs = 4
        lr = 1e-5
        batch_size = 128
    else:
        raise NotImplementedError

    # training the model
    dataset_name = 'cifar10'
    if trained == True:
        usr_dir = os.path.expanduser('~')
        data_dir = os.path.join(usr_dir, "data")
        model_pkl_path = os.path.join(data_dir, "{}_trained_models".format(dataset_name),
                                      '{}_finetuned.pth'.format(modelname))
        if os.path.exists(model_pkl_path):
            pretrained_dict = torch.load(model_pkl_path)
            model.load_state_dict(pretrained_dict)
        else:
            print("Dataset is {}. No trained model, starting training!!!".format(dataset_name))
            criterion = nn.CrossEntropyLoss()
            train_dataset, _ = build_dataset(dataset_name, "train", transform)
            trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
            model.cuda()

            optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)

            for epoch in range(num_epochs):
                model.train()
                running_loss = 0.0
                pre_acc = 0
                for inputs, labels in trainloader:
                    labels = labels.cuda()
                    inputs = inputs.cuda()
                    optimizer.zero_grad()
                    outputs = model(inputs)
                    if modelname == "Inception":
                        outputs = outputs.logits
                    pre_acc += torch.sum(torch.argmax(outputs, axis=1) == labels)
                    loss = criterion(outputs, labels)
                    loss.backward()
                    optimizer.step()
                    running_loss += loss.item()
                print(
                    f"Epoch {epoch + 1}, Loss: {running_loss / len(trainloader)},ACC.: {pre_acc / len(train_dataset)}")

            torch.save(model.state_dict(), model_pkl_path)

        if mode == "test":
            model.eval()
        else:
            model.train()
        if dataParallel:
            model = torch.nn.DataParallel(model).cuda()
        else:
            model.cuda()
    return model

