import torch
import torch.nn as nn
import os
import torchvision
from torchvision.models import resnet50, resnet101, resnet152
import timm
from transformers import ViTForImageClassification, ViTFeatureExtractor, ViTConfig

class Normalize(nn.Module):
    def __init__(self, model, dataname, device, modelname):
        super(Normalize, self).__init__()

        if dataname.lower() == "cifar10":
            if modelname == 'vit_l_16':
                m, s = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
            else:
                m, s = [0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]
        elif dataname.lower() == "imagenet":
            m, s = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]

        self.register_buffer('mean', torch.Tensor(m))
        self.register_buffer('std', torch.Tensor(s))
        
        self.model = model
        self.device = device

    def forward(self, inputs):
        mean = self.mean.view(1, 3, 1, 1).to(self.device)
        std = self.std.view(1, 3, 1, 1).to(self.device)
        x = (inputs - mean) / std

        return self.model(x)


def create_model(modelname, dataname='CIFAR10', num_classes=10, test=True, logits_dim=10, device='cpu'):

    if dataname == 'ImageNet':
        if modelname == 'resnet50':
            model = resnet50(pretrained=True).to(device)
        elif modelname == 'resnet101':
            model = resnet101(pretrained=True).to(device)
        elif modelname == 'resnet152':
            model = resnet152(pretrained=True).to(device)
        elif modelname == 'resnet50_adv':
            model = timm.create_model('resnet50', pretrained=False)
            checkpoint = './results/ImageNet/saved_model/resnet50_adv/ARES_ResNet50_AT.pth'
            assert os.path.isfile(checkpoint), "checkpoint does not exist"
            model_weights = torch.load(checkpoint)
            model.load_state_dict(model_weights)
            model = model.to(device)
        elif modelname == 'resnet101_adv':
            model = timm.create_model('resnet101', pretrained=False)
            checkpoint = './results/ImageNet/saved_model/resnet101_adv/ARES_ResNet101_AT.pth'
            assert os.path.isfile(checkpoint), "checkpoint does not exist"
            model_weights = torch.load(checkpoint)
            model.load_state_dict(model_weights)
            model = model.to(device)
        elif modelname == 'convnext_adv':
            model = timm.create_model('convnext_small', pretrained=False)
            checkpoint = './results/ImageNet/saved_model/convnext_adv/ARES_ConvNext_Small_AT.pth'
            assert os.path.isfile(checkpoint), "checkpoint does not exist"
            model_weights = torch.load(checkpoint)
            model.load_state_dict(model_weights)
            model = model.to(device)
        elif modelname == 'vit_l_16':
            checkpoint = './results/ImageNet/saved_model/vit'
            model = ViTForImageClassification.from_pretrained(checkpoint, num_labels=1000)
            model = model.to(device)
        model = Normalize(model, dataname, device)

    else:
        if modelname == 'vit_l_16':
            checkpoint_init = './results/CIFAR10/saved_model/vit'
            model = ViTForImageClassification.from_pretrained(checkpoint_init)
            hidden_dim = model.classifier.in_features
            model.classifier = nn.Linear(hidden_dim, num_classes)
            if test:
                checkpoint = './results/CIFAR10/saved_model/vit_train/best_vit_large_cifar10.pth'
                model_weights = torch.load(checkpoint)
                model.load_state_dict(model_weights)
                model = Normalize(model, dataname, device, modelname)
            model = model.to(device)
        else:
            if modelname == 'resnet34_adv':
                modelname_trans = 'resnet34'
            elif modelname == 'vgg19':
                modelname_trans = modelname
            elif modelname == 'vgg19_adv':
                modelname_trans = 'vgg19'
            elif modelname == 'vgg16':
                modelname_trans = modelname
            elif modelname == 'vgg16_adv':
                modelname_trans = 'vgg16'
            elif modelname == 'vgg13':
                modelname_trans = modelname

            model = __import__('models').__dict__[modelname_trans](
                num_classes           = num_classes, 
                logits_dim            = logits_dim,
                device                = device
            )
            model = Normalize(model, dataname, device, modelname)
            model = model.to(device)

            if test:
                if modelname == 'resnet34_adv':
                    checkpoint = './results/CIFAR10/saved_model/resnet34_adv/resnet34_adv.pt'
                    assert os.path.isfile(checkpoint), "checkpoint does not exist"
                    model.model.load_state_dict(torch.load(checkpoint))
                elif modelname == 'vgg19':
                    checkpoint = './results/CIFAR10/saved_model/vgg19/vgg19.pt'
                    assert os.path.isfile(checkpoint), "checkpoint does not exist"
                    model.model.load_state_dict(torch.load(checkpoint))
                elif modelname == 'vgg19_adv':
                    checkpoint = './results/CIFAR10/saved_model/vgg19_adv/vgg19_adv.pt'
                    assert os.path.isfile(checkpoint), "checkpoint does not exist"
                    model.model.load_state_dict(torch.load(checkpoint))
                elif modelname == 'vgg16':
                    checkpoint = './results/CIFAR10/saved_model/vgg16/vgg16.pt'
                    assert os.path.isfile(checkpoint), "checkpoint does not exist"
                    model.model.load_state_dict(torch.load(checkpoint))
                elif modelname == 'vgg16_adv':
                    checkpoint = './results/CIFAR10/saved_model/vgg16_adv/vgg16_adv.pt'
                    assert os.path.isfile(checkpoint), "checkpoint does not exist"
                    model.model.load_state_dict(torch.load(checkpoint))
                elif modelname == 'vgg13':
                    checkpoint = './results/CIFAR10/saved_model/vgg13/vgg13.pt'
                    assert os.path.isfile(checkpoint), "checkpoint does not exist"
                    model.model.load_state_dict(torch.load(checkpoint))
                model = model.to(device)

    return model
