from torchvision.models import resnet50, ResNet50_Weights,vit_b_16, ViT_B_16_Weights
from torchvision import transforms,datasets
from robustbench.utils import load_model as load_model_robustbench
from robustbench.model_zoo.architectures.utils_architectures import normalize_model
from robustbench.data import get_preprocessing
import torch
import numpy as np

PREPROCESSINGS = {
    'Res256Crop224':
    transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor()
    ]),
    'Crop288':
    transforms.Compose([transforms.CenterCrop(288),
                        transforms.ToTensor()]),
    None:
    transforms.Compose([transforms.ToTensor()]),
    'Res224':
    transforms.Compose([
        transforms.Resize([224, 224]),
        transforms.ToTensor()
    ]),
    'BicubicRes256Crop224':
    transforms.Compose([
        transforms.Resize(
            256,
            interpolation=transforms.InterpolationMode("bicubic")),
        transforms.CenterCrop(224),
        transforms.ToTensor()
    ])
}

class Imagenette_wrapper(torch.nn.Module):

    def __init__(self):
        super(Imagenette_wrapper, self).__init__()
        self.imagenette_indices = [0,217,482,491,497,566,569,571,574,701]
        self.softmax = torch.nn.LogSoftmax(1)

    def forward(self, x):
        return x[:,self.imagenette_indices]

def load_model(model_name,dataset):
    prepr = 'Res256Crop224'
    mean,std = (0.485, 0.456, 0.406),(0.229, 0.224, 0.225)

    if model_name == 'ResNet50':
        model = resnet50(weights = ResNet50_Weights.IMAGENET1K_V2)
        if dataset == 'imagenette':
            model = torch.nn.Sequential(model,Imagenette_wrapper())
        model = normalize_model(model,mean,std)
        model.eval()

    elif model_name == 'ResNet50finetuned':
        model = resnet50(weights = ResNet50_Weights.IMAGENET1K_V2)
        model.fc = torch.nn.Linear(2048,10,  bias=True)
        model = torch.nn.Sequential(model,torch.nn.LogSoftmax(1))
        model.load_state_dict((torch.load('./data/resnet50_imagenette_weights_jpeg.pt',weights_only=True)))
        model = normalize_model(model,mean,std)
        model.eval()

    elif model_name == 'Vit':
        weights = ViT_B_16_Weights.IMAGENET1K_V1
        model = vit_b_16(weights=weights)
        if dataset == 'imagenette':
            model = torch.nn.Sequential(model,Imagenette_wrapper())
        model = normalize_model(model,mean,std)
        model.eval()

    elif model_name == 'Xu2024MIMIR_Swin-L':
        model = load_model_robustbench(model_name, threat_model='Linf', dataset='imagenet')
        model.eval()
        prepr = 'BicubicRes256Crop224'

    elif model_name == 'Liu2023Comprehensive_Swin-L':
        model = load_model_robustbench(model_name, threat_model='Linf', dataset='imagenet')
        model.eval()
        prepr = 'BicubicRes256Crop224'

    elif model_name == 'Amini2024MeanSparse_Swin-L':
        model = load_model_robustbench(model_name, threat_model='Linf', dataset='imagenet')
        model.eval()
        prepr = 'BicubicRes256Crop224'

    elif model_name == 'Bai2024MixedNUTS':
        model = load_model_robustbench(model_name, threat_model='Linf', dataset='imagenet')
        model.eval()
        prepr = 'BicubicRes256Crop224'

    elif model_name == 'Liu2023Comprehensive_ConvNeXt-L':
        model = load_model_robustbench(model_name, threat_model='Linf', dataset='imagenet')
        model.eval()
        prepr = 'BicubicRes256Crop224'

    elif model_name == 'Amini2024MeanSparse_ConvNeXt-L':
        model = load_model_robustbench(model_name, threat_model='Linf', dataset='imagenet')
        model.eval()
        prepr = 'BicubicRes256Crop224'

    elif model_name == 'Singh2023Revisiting_ConvNeXt-L-ConvStem':
        model = load_model_robustbench(model_name, threat_model='Linf', dataset='imagenet')
        model.eval()
        prepr = 'BicubicRes256Crop224'

    elif model_name == 'Singh2023Revisiting_ConvNeXt-B-ConvStem':
        model = load_model_robustbench(model_name, threat_model='Linf', dataset='imagenet')
        model.eval()
        prepr = 'BicubicRes256Crop224'

    elif model_name == 'Xu2024MIMIR_Swin-B':
        model = load_model_robustbench(model_name, threat_model='Linf', dataset='imagenet')
        model.eval()
        prepr = 'BicubicRes256Crop224'

    elif model_name == 'Liu2023Comprehensive_ConvNeXt-B':
        model = load_model_robustbench(model_name, threat_model='Linf', dataset='imagenet')
        model.eval()
        prepr = 'BicubicRes256Crop224'

    elif model_name == 'Liu2023Comprehensive_Swin-B':
        model = load_model_robustbench(model_name, threat_model='Linf', dataset='imagenet')
        model.eval()
        prepr = 'BicubicRes256Crop224'

    else:
        raise ValueError(f'Unknown model {model_name}')
    if dataset.startswith('Image'):
        dataset = 'imagenet'
    preprocessing = PREPROCESSINGS[prepr]
    return model,preprocessing

def load_testset(name,preprocessing = None):
    """Load the test dataset
    
    Args:
        dataset (str): The name of the dataset to load
        trans (callable): The transform to apply to the dataset
    
    Returns:
        testset: The loaded test dataset
    """
    if preprocessing is None:
        preprocessing = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor()
    ])
    if name == 'imagenette':
        testset = datasets.Imagenette(root='./data', split='val', download=False, transform=preprocessing)
    elif name ==  'cifar10':
        testset = datasets.CIFAR10(root='./data/cifar_10', train=False, download=False, transform=preprocessing)
    elif name == 'imagenet':
        testset = datasets.ImageNet(root='./data', split='val', transform=preprocessing)
    elif name == 'imagenet_1000':
        testset = datasets.ImageNet(root='./data', split='val', transform=preprocessing)
        try:
            indices = np.load('data/imagenet_1000.npy')
        except OSError:
            indices = np.random.choice(len(testset), 1000, replace=False)
            np.save('data/imagenet_1000.npy',indices)
        testset = torch.utils.data.Subset(testset, indices)
    elif name == 'imagenet_5000':
        testset = datasets.ImageNet(root='./data', split='val', transform=preprocessing)
        try:
            indices = np.load('data/imagenet_5000.npy')
        except OSError:
            indices = np.random.choice(len(testset), 5000, replace=False)
            np.save('data/imagenet_5000.npy',indices)
        testset = torch.utils.data.Subset(testset, indices)
    elif name == 'imagenet_10000':
        testset = datasets.ImageNet(root='./data', split='val', transform=preprocessing)
        try:
            indices = np.load('data/imagenet_10000.npy')
        except OSError:
            indices = np.random.choice(len(testset), 10000, replace=False)
            np.save('data/imagenet_10000.npy',indices)
        testset = torch.utils.data.Subset(testset, indices)
    elif name == 'imagenet_100':
        testset = datasets.ImageNet(root='./data', split='val', transform=preprocessing)
        try:
            indices = np.load('data/imagenet_100.npy')
        except OSError:
            indices = np.random.choice(len(testset), 100, replace=False)
            np.save('data/imagenet_100.npy',indices)
        testset = torch.utils.data.Subset(testset, indices)
    elif name == 'imagenet_500':
        testset = datasets.ImageNet(root='./data', split='val', transform=preprocessing)
        try:
            indices = np.load('data/imagenet_500.npy')
        except OSError:
            indices = np.random.choice(len(testset), 500, replace=False)
            np.save('data/imagenet_500.npy',indices)
        testset = torch.utils.data.Subset(testset, indices)
    elif name == 'imagenet_16':
        testset = datasets.ImageNet(root='./data', split='val', transform=preprocessing)
        try:
            indices = np.load('data/imagenet_16.npy')
        except OSError:
            indices = np.random.choice(len(testset), 16, replace=False)
            np.save('data/imagenet_16.npy',indices)
        testset = torch.utils.data.Subset(testset, indices)
    else:
        print('Dataset not found') 
    return testset