import os
import json
import sys
sys.path.append('..')

from torch.utils.data import Dataset
import torch
from PIL import Image
from torchvision import transforms
import numpy as np

from config import opt

class ImageNet(object):
    def __init__(self, input_size = 32, transform=None, partition=None):
        train_transform = transforms.Compose([
            transforms.Resize(input_size),
            transforms.RandomCrop(input_size, padding=4),
            transforms.RandomHorizontalFlip(),
            # transforms.RandomRotation(15),
            transforms.ToTensor(),
            transforms.Normalize(
                (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
            ),
            # transforms.Normalize(
            #     (.48,.07,.02,), (.43,.77,.87,)
            # ),
        ])
        test_transform = transforms.Compose([
            transforms.Resize(input_size),
            transforms.ToTensor(),
            transforms.Normalize(
                (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
            ),
        ])
        self.train_dataset = ImageNetKaggle(
            root = opt.data_dir+'datasets/imagenet',
            split = 'train',
            transform = train_transform
        )
        self.test_dataset = ImageNetKaggle(
            root = opt.data_dir+'datasets/imagenet',
            split = 'val',
            transform = test_transform
        )
        if transform:
            self.dataset = ImageNetKaggle(
            root = opt.data_dir+'datasets/imagenet',
            split = 'train',
            transform = transform
        )


        if partition:
            if 'cls' in partition:
                pass
                # if partition == '40cls':
                #     classes_set = {'orchid', 'poppy', 'rose', 'sunflower', 'tulip',
                #                    'bottle', 'bowl', 'can', 'cup', 'plate',
                #                    'apple', 'mushroom', 'orange', 'pear', 'sweet_pepper',
                #                    'clock', 'keyboard', 'lamp', 'telephone', 'television',
                #                    'bed', 'chair', 'couch', 'table', 'wardrobe',
                #                    'maple_tree', 'oak_tree', 'palm_tree', 'pine_tree', 'willow_tree',
                #                    'bridge', 'castle', 'house', 'road', 'skyscraper',
                #                    'cloud', 'forest', 'mountain', 'plain', 'sea'}
                # elif partition == '10cls':
                #     classes_set = {'plate', 'rose', 'castle', 'keyboard', 'house', 'forest', 'road', 'television', 'bottle',
                #                    'wardrobe'}
                # elif partition == '6cls':
                #     classes_set = {'road', 'cloud', 'forest', 'mountain', 'plain', 'sea'}
                # else:
                #     raise Exception(f'Undefined class: {partition}')

                # classes_indices = []
                #
                # def filter_indices(trainset):
                #     index_list = []
                #     # print("indices = ", classes_indices)
                #     for i in range(len(trainset)):
                #         if trainset[i][1] in classes_indices:
                #             index_list.append(i)
                #     return index_list
                #
                # for k in classes_set:
                #     classes_indices.append(self.train_dataset.class_to_idx[k])
                # print(classes_indices)
                #
                # index_list = filter_indices(self.train_dataset)
                # self.train_dataset = torch.utils.data.Subset(self.train_dataset, index_list)
                # # index_list = filter_indices(self.test_dataset)
                # # self.test_dataset = torch.utils.data.Subset(self.test_dataset, index_list)
                # if transform:
                #     index_list = filter_indices(self.dataset)
                #     self.dataset = torch.utils.data.Subset(self.dataset, index_list)
                # print(len(self.train_dataset))

            elif 'pct' in partition:
                percent = int(partition.replace('pct', ''))
                seg_file = np.load(
                    os.path.join(os.path.dirname(os.path.realpath(__file__)), 'misc/imagenet_part_idx.npz'))
                index_list = seg_file[f'part{percent:02d}_idx']
                self.train_dataset = torch.utils.data.Subset(self.train_dataset, index_list)
                if transform:
                    self.dataset = torch.utils.data.Subset(self.dataset, index_list)


            else:
                raise Exception(f'Undefined partition: {partition}')

        print(len(self.train_dataset))

    def train_dataloader(self, *args, **kwargs):
        return torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size = 128, #64
            shuffle = True,
            num_workers = 4,
            drop_last = True
        )

    def test_dataloader(self, *args, **kwrk):
        return torch.utils.data.DataLoader(
            self.test_dataset,
            batch_size = 100, #16
            num_workers = 4,
            drop_last = False
        )

class ImageNetKaggle(Dataset):
    def __init__(self, root, split, transform=None):
        self.samples = []
        self.targets = []
        self.transform = transform
        self.syn_to_class = {}
        with open(os.path.join(root, "imagenet_class_index.json"), "rb") as f:
                    json_file = json.load(f)
                    for class_id, v in json_file.items():
                        self.syn_to_class[v[0]] = int(class_id)
        with open(os.path.join(root, "ILSVRC2012_val_labels.json"), "rb") as f:
                    self.val_to_syn = json.load(f)
        samples_dir = os.path.join(root, "ILSVRC/Data/CLS-LOC", split)
        for entry in os.listdir(samples_dir):
            if split == "train":
                syn_id = entry
                target = self.syn_to_class[syn_id]
                syn_folder = os.path.join(samples_dir, syn_id)
                for sample in os.listdir(syn_folder):
                    sample_path = os.path.join(syn_folder, sample)
                    self.samples.append(sample_path)
                    self.targets.append(target)
            elif split == "val":
                syn_id = self.val_to_syn[entry]
                target = self.syn_to_class[syn_id]
                sample_path = os.path.join(samples_dir, entry)
                self.samples.append(sample_path)
                self.targets.append(target)
    def __len__(self):
            return len(self.samples)
    def __getitem__(self, idx):
            x = Image.open(self.samples[idx]).convert("RGB")
            if self.transform:
                x = self.transform(x)
            return x, self.targets[idx]