import argparse
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.models import resnet50, ResNet50_Weights, mobilenet_v2, MobileNet_V2_Weights, vit_b_16, ViT_B_16_Weights, swin_b, Swin_B_Weights, \
    inception_v3, Inception_V3_Weights, densenet161, DenseNet161_Weights, resnet152, ResNet152_Weights, efficientnet_b7, EfficientNet_B7_Weights
from torchvision.transforms.functional import to_tensor, resize, center_crop
from transformers import MobileViTForImageClassification, PvtV2ForImageClassification
from tqdm import tqdm
import warnings
from typing import List
import shutil
warnings.filterwarnings("ignore")

class TestDataset(Dataset):
    def __init__(self, data_dir, label_path, data_type='pth'):
        self.data_dir = data_dir
        self.type = data_type
        self.labels = open(label_path).readlines()
        self.data = os.listdir(data_dir)

    def __len__(self):
        return len(self.data)


    def __getitem__(self, idx):
        if self.type == 'png':
            data = to_tensor(Image.open(os.path.join(self.data_dir, self.data[idx])).convert('RGB').resize((299, 299)))
        elif self.type == 'pth':
            data = torch.load(os.path.join(self.data_dir, self.data[idx])).squeeze(0).cpu()

        label_idx = int(self.data[idx].split('.')[0]) - 1
        label = int(self.labels[label_idx].strip()) - 1
        return data, label, self.data[idx]

class Classifier:
    def __init__(self, name) -> None:
        self.model_name = name
        if self.model_name == 'res50':
            self.model = resnet50(weights=ResNet50_Weights.DEFAULT).eval().cuda()
            self.preprocess = ResNet50_Weights.DEFAULT.transforms(antialias=True)
        elif self.model_name == 'mnv2':
            self.model = mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT).cuda().eval()
            self.preprocess = MobileNet_V2_Weights.DEFAULT.transforms(antialias=True)
        elif self.model_name == 'incv3':
            self.model = inception_v3(weights=Inception_V3_Weights.DEFAULT).cuda().eval()
            self.preprocess = Inception_V3_Weights.DEFAULT.transforms(antialias=True)
        elif self.model_name == 'dense161':
            self.model = densenet161(weights=DenseNet161_Weights.DEFAULT).cuda().eval()
            self.preprocess = DenseNet161_Weights.DEFAULT.transforms(antialias=True)
        elif self.model_name == 'res152':
            self.model = resnet152(weights=ResNet152_Weights.DEFAULT).cuda().eval()
            self.preprocess = ResNet152_Weights.DEFAULT.transforms(antialias=True)
        elif self.model_name == 'effb7':
            self.model = efficientnet_b7(weights=EfficientNet_B7_Weights.DEFAULT).cuda().eval()
            self.preprocess = EfficientNet_B7_Weights.DEFAULT.transforms(antialias=True)
        elif self.model_name == 'mobvit':
            self.model = MobileViTForImageClassification.from_pretrained("/home/checkpoints/apple-mobilevit-small").cuda().eval()
            self.preprocess = self.mobvit_preprocess
        elif self.model_name == 'pvt':
            self.model = PvtV2ForImageClassification.from_pretrained("/home/checkpoints/opengvlab-pvtv2b5").cuda().eval()
            self.preprocess = self.pvt_preprocess
        elif self.model_name == 'vitb':
            self.model = vit_b_16(weights=ViT_B_16_Weights.DEFAULT).cuda().eval()
            self.preprocess = ViT_B_16_Weights.DEFAULT.transforms(antialias=True)
        elif self.model_name == 'swinb':
            self.model = swin_b(weights=Swin_B_Weights.DEFAULT).cuda().eval()
            self.preprocess = Swin_B_Weights.DEFAULT.transforms(antialias=True)
        else:
            raise ValueError(f"Model {self.model_name} not supported")
        
    def __call__(self, in_data: torch.Tensor):
        data = self.preprocess(in_data)
        # data = in_data
        if self.model_name in ['mobvit', 'pvt']:
            with torch.no_grad():
                logits = self.model(**data).logits
        else:
            logits = self.model(data.cuda())
        
        return logits
    
    def mobvit_preprocess(self, in_data: torch.Tensor):
        data_tensor = center_crop(resize(in_data, (256, 256)), (256, 256)).flip(1)
        data_dict = dict(pixel_values=data_tensor.cuda())
        return data_dict
    
    def pvt_preprocess(self, in_data: torch.Tensor):
        data_tensor = (center_crop(resize(in_data, (224, 224)), (224, 224)) - torch.Tensor([0.485, 0.456, 0.406])[None, :, None, None]) / torch.Tensor([0.229, 0.224, 0.225])[None, :, None, None]
        data_dict = dict(pixel_values=data_tensor.cuda())
        return data_dict



def main(args):
    models_to_transfer = ['res50', 'mnv2', 'incv3', 'dense161', 'res152', 'effb7', 'mobvit', 'pvt', 'vitb', 'swinb']

    dataset = TestDataset(args.data_dir, args.label_path, args.data_type)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

    black_asr = []
    for model in models_to_transfer:
        classifier = Classifier(model)
        wrong_num = 0

        for i, item in tqdm(enumerate(dataloader)):
            data, label, name = item
            pred = classifier(data)
            idx = pred.argmax(-1).item()

            if idx != label:
                wrong_num += 1
        
        asr = wrong_num/len(dataset)*100.0
        print(f"Classifier {model} Attack Success Rate: {asr} %.")
        if model not in args.data_dir:
            black_asr.append(asr)
    print("*"*60)
    print("Average Black-box Attack Success Rate: {:.2f} %".format(sum(black_asr)/len(black_asr)))
    print("*"*60)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir', type=str, default='temp/1000/inversion', help='Directory of the saved adversairal examples.')
    parser.add_argument('--data_type', type=str, default='png', choices=['pth', 'png'], help='Data type of the input data. Default to be pth file.')
    parser.add_argument('--label_path', type=str, default='third_party/Natural-Color-Fool/dataset/label.txt')
    args = parser.parse_args()

    main(args)