import torch.utils.data as data
import os
import torch
import json
from PIL import Image
from torchvision import transforms
import math
import glob

from timm.data.transforms import RandomResizedCropAndInterpolation, ToNumpy, ToTensor
from utils import ROOT_PATH

name_to_class_ids_file = os.path.join(ROOT_PATH, 'image_name_to_class_id_and_name.json')
# params of dataset
INPUT_SIZE = (3, 224, 224)
INTERPOLATION = 'bicubic'
DEFAULT_CROP_PCT = 0.875 # 0.9, 1.0
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
# Imagenet 21k
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
#:
IMAGENET_ViT_MEAN = (0.0, 0.0, 0.0)
IMAGENET_ViT_STD = (1.0, 1.0, 1.0)


def params(model_name):
    if model_name in ['vit_deit_base_distilled_patch16_224','levit_256','pit_b_224','cait_s24_224','convit_base', 'visformer_small', 'deit_base_distilled_patch16_224']:
        params = {'mean': IMAGENET_DEFAULT_MEAN,
                  'std': IMAGENET_DEFAULT_STD,
                  'interpolation': INTERPOLATION,
                  'crop_pct':0.9}
    elif model_name in ['mobilevit_s']:
        params = {'mean': IMAGENET_ViT_MEAN,
                  'std': IMAGENET_ViT_STD,
                  'interpolation': INTERPOLATION
                  }
    else:
        params = {'mean': IMAGENET_INCEPTION_MEAN,
                  'std': IMAGENET_INCEPTION_STD,
                  'interpolation': INTERPOLATION,
                  'crop_pct': 0.9}
    return params

def transforms_imagenet_wo_resize(params):

    tfl = [
            transforms.ToTensor(),
            transforms.Normalize(
                     mean=torch.tensor(params['mean']),
                     std=torch.tensor(params['std']))
        ]
    return transforms.Compose(tfl)


class AdvDataset(data.Dataset):
    def __init__(self, model_name, adv_path):
        self.transform = transforms_imagenet_wo_resize(params(model_name))
        paths = glob.glob(os.path.join(adv_path, '*.png'))
        paths = [i.split('/')[-1] for i in paths]
        print ('Using ', len(paths))
        paths = [i.strip() for i in paths]
        self.query_paths = [i.split('.')[0]+'.JPEG' for i in paths]
        self.paths = [os.path.join(adv_path, i) for i in paths]
        self.model_name = model_name
        
        with open(name_to_class_ids_file, 'r') as ipt:
            self.json_info = json.load(ipt)

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        path = self.paths[index]
        query_path = self.query_paths[index]
        class_id = self.json_info[query_path]['class_id']
        class_name = self.json_info[query_path]['class_name']
        image_name = path.split('/')[-1]
        # deal with image
        img = Image.open(path).convert('RGB')
        if self.model_name == "tf2torch_resnet_v2_101":
            img = transforms.Resize((299,299))(img)
            img = transforms.Compose([transforms.ToTensor()])(img)
        elif self.model_name == 'mobilevit_s':
            img = transforms.Resize((320, 320))(img)
            # img = transforms.Compose([transforms.ToTensor()])(img)
            if self.transform is not None:
                img = self.transform(img)
        else:
            if self.transform is not None:
                img = self.transform(img)
        return img, class_id, class_name, image_name

class Clean_AdvDataset(data.Dataset):
    def __init__(self, model_name, clean_path, adv_path):
        print(f'clean_path:{clean_path}, adv_path:{adv_path}')
        self.transform = transforms_imagenet_wo_resize(params(model_name))
        paths = glob.glob(os.path.join(adv_path, '*.png'))
        paths = [i.split('/')[-1] for i in paths]
        print ('Using ', len(paths))
        paths = [i.strip() for i in paths]
        self.query_paths = [i.split('.')[0]+'.JPEG' for i in paths]
        self.paths = [os.path.join(adv_path, i) for i in paths]
        #: 干净样本的路径
        self.clean_paths = [os.path.join(clean_path, i) for i in paths]
        self.model_name = model_name
        
        with open(name_to_class_ids_file, 'r') as ipt:
            self.json_info = json.load(ipt)

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        path = self.paths[index]
        clean_path = self.clean_paths[index]

        query_path = self.query_paths[index]
        class_id = self.json_info[query_path]['class_id']
        class_name = self.json_info[query_path]['class_name']
        image_name = path.split('/')[-1]

        img = Image.open(path).convert('RGB')
        cimg = Image.open(clean_path).convert('RGB')

        if self.model_name == "tf2torch_resnet_v2_101":
            img = transforms.Resize((299,299))(img)
            img = transforms.Compose([transforms.ToTensor()])(img)
            cimg = transforms.Resize((299,299))(cimg)
            cimg = transforms.Compose([transforms.ToTensor()])(cimg)
        elif self.model_name == 'mobilevit_s':
            img = transforms.Resize((320, 320))(img)
            cimg = transforms.Resize((320, 320))(cimg)
            if self.transform is not None:
                img = self.transform(img)
                cimg = self.transform(cimg)
        else:
            if self.transform is not None:
                img = self.transform(img)
                cimg = self.transform(cimg)
        return cimg, img, class_id, class_name, image_name


class CNNDataset(data.Dataset):
    def __init__(self, model_name, adv_path):
        self.transform = transforms_imagenet_wo_resize(params(model_name))
        paths = glob.glob(os.path.join(adv_path, '*.png'))
        paths = [i.split('/')[-1] for i in paths]
        print ('Using ', len(paths))
        paths = [i.strip() for i in paths]
        self.query_paths = [i.split('.')[0]+'.JPEG' for i in paths]
        self.paths = [os.path.join(adv_path, i) for i in paths]
        
        with open(name_to_class_ids_file, 'r') as ipt:
            self.json_info = json.load(ipt)

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        path = self.paths[index]
        query_path = self.query_paths[index]
        class_id = self.json_info[query_path]['class_id']
        class_name = self.json_info[query_path]['class_name']
        image_name = path.split('/')[-1]
        # deal with image
        img = Image.open(path).convert('RGB')
        img = transforms.Resize((299,299))(img)
        img = transforms.Compose([transforms.ToTensor()])(img)
        #print(img.shape)
        return img, class_id, class_name, image_name
