import torch
import torchvision
from torch import nn
from torchvision import models, transforms
import random


def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False


def initialize_model(model_name, num_classes, feature_extract, use_pretrained):
    # Initialize these variables which will be set in this if statement. Each of these
    # variables is model specific.
    model_ft = None
    input_size = 0

    if model_name == "mobilenet":
        """ MobileNet_v2
        """
        model_ft = models.mobilenet_v2(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[1].in_features
        model_ft.classifier[1] = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "resnet":
        """ Resnet50
        """
        model_ft = models.resnet50(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "vgg":
        """ VGG11_bn
        """
        model_ft = models.vgg11_bn(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
        input_size = 224

    else:
        print("Invalid model name, exiting...")
        exit()

    return model_ft, input_size


def initialize_dataset(dataset_name, input_size, batch_size):
    print("Initializing Datasets and Dataloaders...")

    # Data augmentation and normalization for training
    data_transforms = {
        'train': transforms.Compose([
            transforms.RandomResizedCrop(input_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            transforms.Resize(input_size),
            transforms.CenterCrop(input_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }

    # CIFAR-10 and CIFAR-100
    if dataset_name == 'cifar10':
        data_path = 'your path'
        trainset = torchvision.datasets.CIFAR10(root=data_path, train=True, download=True, transform=data_transforms['train'])
        testset = torchvision.datasets.CIFAR10(root=data_path, train=False, download=True, transform=data_transforms['val'])

    elif dataset_name == 'cifar100':
        data_path = 'your path'
        trainset = torchvision.datasets.CIFAR100(root=data_path, train=True, download=True, transform=data_transforms['train'])
        testset = torchvision.datasets.CIFAR100(root=data_path, train=False, download=True, transform=data_transforms['val'])

    else:
        print("Dataset not found")
        exit()

    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=1)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True, num_workers=1)
    dataloaders_dict = {'train': trainloader, 'val': testloader}

    return dataloaders_dict


def random_sample(input_tensor, size):
    rand_num = set()
    input_tensor = input_tensor.reshape(-1)
    while 1:
        rand_num.add(random.randint(0, len(input_tensor)-1))
        if len(rand_num) >= size:
            break
    result = []
    rand_num = sorted(rand_num)
    for i in rand_num:
        result.append(input_tensor[i].item())
    result = torch.tensor(result).unsqueeze(0)
    return result
