import torch
import numpy as np

import os
import os.path
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from torchvision.utils import save_image
import config



target_class = 0

# This trigger setting is deprecated. See config.py!
triggers= {
    'badnet': 'badnet_patch_256.png',
    'blend' : 'hellokitty_224.png',
    'trojan' : 'trojan_watermark_224.png',
    'SRA': 'phoenix_corner_256.png',
    'none': 'none'
}

test_set_labels = None

transform_resize = transforms.Compose([
            transforms.Resize(size=[256, 256]),
            # transforms.Resize(size=256),
            transforms.ToTensor(),
])


transform_no_aug = transforms.Compose([
            transforms.CenterCrop(224),
        ])

transform_aug = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip()
        ])




def find_classes(directory: str) -> Tuple[List[str], Dict[str, int], Dict[int, str]]:
    """Finds the class folders in a dataset.
    See :class:`DatasetFolder` for details.
    """
    classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
    if not classes:
        raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")

    class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}

    idx_to_class = {i: cls_name for i, cls_name in enumerate(classes)}

    return classes, class_to_idx, idx_to_class



def assign_img_identifier(directory, classes):

    num_imgs = 0
    img_id_to_path = []
    img_labels = []

    for i, cls_name in enumerate(classes):
        cls_dir = os.path.join(directory, cls_name)
        img_entries = sorted(entry.name for entry in os.scandir(cls_dir))

        for img_entry in img_entries:
            entry_path = os.path.join(cls_name, img_entry)
            img_id_to_path.append(entry_path)
            img_labels.append(i)
            num_imgs += 1

    return num_imgs, img_id_to_path, img_labels



class imagenet_dataset(Dataset):

    def __init__(self, directory, shift=False, data_transform=None,
                 poison_directory=None, poison_indices=None,
                 label_file=None, target_class = None, num_classes=1000, scale_for_ct=False, poison_transform=None):

        self.num_classes = num_classes
        self.shift = shift
        self.data_transform = data_transform

        if label_file is None: # divide classes by directory
            self.classes, self.class_to_idx, self.idx_to_class = find_classes(directory)
            self.num_imgs, self.img_id_to_path, self.img_labels = assign_img_identifier(directory, self.classes)

        else: # samples from all classes are in the same directory
            entries = sorted(entry.name for entry in os.scandir(directory))
            self.num_imgs = len(entries)
            self.img_id_to_path = []
            for i, img_name in enumerate(entries):
                self.img_id_to_path.append(img_name)
            self.img_labels = torch.load(label_file)

        self.img_labels = torch.LongTensor(self.img_labels)
        self.is_poison = [False for _ in range(self.num_imgs)]


        if poison_indices is not None:
            for i in poison_indices:
                self.is_poison[i] = True

        self.poison_directory = poison_directory
        self.directory = directory
        self.target_class = target_class
        if self.target_class is not None:
            self.target_class = torch.tensor(self.target_class).long()

        for i in range(self.num_imgs):
            if self.is_poison[i]:
                self.img_id_to_path[i] = os.path.join(self.poison_directory, self.img_id_to_path[i])
                self.img_labels[i] = self.target_class
            else:
                self.img_id_to_path[i] = os.path.join(self.directory, self.img_id_to_path[i])
                if self.shift:
                    self.img_labels[i] = (self.img_labels[i] + 1) % self.num_classes

        self.scale_for_ct = scale_for_ct
        self.poison_transform = poison_transform


    def __len__(self):
        return self.num_imgs

    def __getitem__(self, idx):

        idx = int(idx)

        img_path = self.img_id_to_path[idx]
        label = self.img_labels[idx]
        img = Image.open(img_path).convert("RGB")
        img = self.data_transform(img)
        
        return img, label


def create_256_scaled_version(src_directory, dst_directory, is_train_set=True):

    import time

    st = time.time()

    if is_train_set:
        classes = sorted(entry.name for entry in os.scandir(src_directory) if entry.is_dir())
        if not classes:
            raise FileNotFoundError(f"Couldn't find any class folder in {src_directory}.")

        cnt = 0
        tot = len(classes)

        for cls_name in classes:

            print('start :', cls_name)

            cnt += 1

            dst_cls_dir_path = os.path.join(dst_directory, cls_name)
            if not os.path.exists(dst_cls_dir_path):
                os.mkdir(dst_cls_dir_path)
            src_cls_dir_path = os.path.join(src_directory, cls_name)
            img_entries = sorted(entry.name for entry in os.scandir(src_cls_dir_path))

            #with Pool(8) as p:
            #    p.map(sub_process, pars_set)


            for img_entry in img_entries:
                src_img_path = os.path.join(src_cls_dir_path, img_entry)
                dst_img_path = os.path.join(dst_cls_dir_path, img_entry)
                scaled_img = transform_resize(Image.open(src_img_path).convert("RGB"))
                save_image(scaled_img, dst_img_path)

            print('[time: %f minutes] progress by classes [%d/%d], done : %s' % ( (time.time() - st)/60, cnt, tot, cls_name) )


    else:

        img_entries = sorted(entry.name for entry in os.scandir(src_directory))
        tot = len(img_entries)
        for i, img_entry in enumerate(img_entries):
            src_img_path = os.path.join(src_directory, img_entry)
            dst_img_path = os.path.join(dst_directory, img_entry)
            scaled_img = transform_resize(Image.open(src_img_path).convert("RGB"))
            save_image(scaled_img, dst_img_path)
            print('[time: %f minutes] progress : [%d/%d]' % ((time.time() - st)/60, i+1, tot))





if __name__ == "__main__":

    root_path = config.imagenet_dir
    # label_maps = os.path.join(root_path, 'imagenet_class_index.json')
    # val_labels = os.path.join(root_path, 'ILSVRC2012_val_labels.json')
    label_maps = 'data/imagenet/imagenet_class_index.json'
    val_labels = 'data/imagenet/ILSVRC2012_val_labels.json'

    class_to_id = dict()

    import json

    with open(label_maps) as f:
        table = json.load(f)
        for i in range(1000):
            class_name = table[str(i)][0]
            class_to_id[class_name] = i

    labels = []
    with open(val_labels) as f:
        table = json.load(f)

        for i in range(1,50001):
            img_name = 'ILSVRC2012_val_%08d.JPEG' % i
            class_name = table[img_name]
            label = class_to_id[class_name]
            labels.append(label)

    torch.save(labels, os.path.join(root_path, 'val_labels') )
    print('save: ', os.path.join(root_path, 'val_labels'))










