import os

import numpy as np
from torchvision import transforms
import torch
import torchvision
import sys
sys.path.append('..')
from config import opt

class CIFAR100(object):
    def __init__(self, input_size = 32, transform=None, partition=None):
        train_transform = transforms.Compose([
            transforms.RandomCrop(input_size, padding=4),
            transforms.RandomHorizontalFlip(),
            # transforms.Resize(input_size),
            # transforms.RandomRotation(15),
            transforms.ToTensor(),
            transforms.Normalize(
                (0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762)
            ),
            # transforms.Normalize(
            #     (.48,.07,.02,), (.43,.77,.87,)
            # ),
        ])
        test_transform = transforms.Compose([
            transforms.Resize(input_size),
            transforms.ToTensor(),
            transforms.Normalize(
                (0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762)
            ),
        ])
        self.train_dataset = torchvision.datasets.CIFAR100(
            root = opt.data_dir+'datasets',
            train = True,
            download = True,
            transform = train_transform
        )
        self.train_dataset_no_aug = torchvision.datasets.CIFAR100(
            root=opt.data_dir + 'datasets',
            train=True,
            download=True,
            transform=test_transform
        )
        self.test_dataset = torchvision.datasets.CIFAR100(
            root = opt.data_dir+'datasets',
            train = False,
            download = True,
            transform = test_transform
        )
        if transform:
            self.dataset = torchvision.datasets.CIFAR100(
            root = opt.data_dir+'datasets',
            train = True,
            download = True,
            transform = transform
        )


        if partition:
            if 'cls' in partition:
                if partition == '50cls':
                    classes_set = {'orchid', 'poppy', 'rose', 'sunflower', 'tulip',
                                   'bottle', 'bowl', 'can', 'cup', 'plate',
                                   'apple', 'mushroom', 'orange', 'pear', 'sweet_pepper',
                                   'clock', 'keyboard', 'lamp', 'telephone', 'television',
                                   'bed', 'chair', 'couch', 'table', 'wardrobe',
                                   'maple_tree', 'oak_tree', 'palm_tree', 'pine_tree', 'willow_tree',
                                   'bridge', 'castle', 'house', 'road', 'skyscraper',
                                   'cloud', 'forest', 'mountain', 'plain', 'sea',
                                   'beaver', 'dolphin', 'otter', 'seal', 'whale',
                                   'camel', 'cattle', 'chimpanzee', 'elephant', 'kangaroo'}
                elif partition == '40cls':
                    classes_set = {'orchid', 'poppy', 'rose', 'sunflower', 'tulip',
                                   'bottle', 'bowl', 'can', 'cup', 'plate',
                                   'apple', 'mushroom', 'orange', 'pear', 'sweet_pepper',
                                   'clock', 'keyboard', 'lamp', 'telephone', 'television',
                                   'bed', 'chair', 'couch', 'table', 'wardrobe',
                                   'maple_tree', 'oak_tree', 'palm_tree', 'pine_tree', 'willow_tree',
                                   'bridge', 'castle', 'house', 'road', 'skyscraper',
                                   'cloud', 'forest', 'mountain', 'plain', 'sea'}
                elif partition == '30cls':
                    classes_set = {'orchid', 'poppy', 'rose', 'sunflower', 'tulip',
                                   'bottle', 'bowl', 'can', 'cup', 'plate',
                                   'apple', 'mushroom', 'orange', 'pear', 'sweet_pepper',
                                   'clock', 'keyboard', 'lamp', 'telephone', 'television',
                                   'bed', 'chair', 'couch', 'table', 'wardrobe',
                                   'maple_tree', 'oak_tree', 'palm_tree', 'pine_tree', 'willow_tree'}
                elif partition == '10cls':
                    classes_set = {'plate', 'rose', 'castle', 'keyboard', 'house', 'forest', 'road', 'television', 'bottle',
                                   'wardrobe'}
                elif partition == '6cls':
                    classes_set = {'road', 'cloud', 'forest', 'mountain', 'plain', 'sea'}
                else:
                    raise Exception(f'Undefined class: {partition}')

                classes_indices = []

                def filter_indices(trainset):
                    index_list = []
                    # print("indices = ", classes_indices)
                    for i in range(len(trainset)):
                        if trainset[i][1] in classes_indices:
                            index_list.append(i)
                    return index_list

                for k in classes_set:
                    classes_indices.append(self.train_dataset.class_to_idx[k])
                print(classes_indices)

                index_list = filter_indices(self.train_dataset)
                self.train_dataset = torch.utils.data.Subset(self.train_dataset, index_list)
                # index_list = filter_indices(self.test_dataset)
                # self.test_dataset = torch.utils.data.Subset(self.test_dataset, index_list)
                if transform:
                    index_list = filter_indices(self.dataset)
                    self.dataset = torch.utils.data.Subset(self.dataset, index_list)
                print(len(self.train_dataset))

            elif 'pct' in partition:
                percent = int(partition.replace('pct', ''))
                seg_file = np.load(
                    os.path.join(os.path.dirname(os.path.realpath(__file__)), 'misc/cifar_part_idx.npz'))
                index_list = seg_file[f'part{percent:02d}_idx']
                self.train_dataset = torch.utils.data.Subset(self.train_dataset, index_list)
                if transform:
                    self.dataset = torch.utils.data.Subset(self.dataset, index_list)
                print(len(self.train_dataset))

            else:
                raise Exception(f'Undefined partition: {partition}')



    def train_dataloader(self, *args, **kwargs):
        return torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size = 128,#128, #64
            shuffle = True,
            num_workers = 4,
            drop_last = True
        )

    def test_dataloader(self, *args, **kwrk):
        return torch.utils.data.DataLoader(
            self.test_dataset,
            batch_size = 100, #16
            num_workers = 4,
            drop_last = False
        )

