import logging
import random
import json
import os
import gc

from PIL import Image
from torchvision import datasets
from torchvision import transforms

from .randaugment import RandAugmentMC, RandAugmentCAM

logger = logging.getLogger(__name__)

imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]
normal_mean = (0.5, 0.5, 0.5)
normal_std = (0.25, 0.25, 0.25)


def get_imagenet(args, root):
    root_path = os.path.join(root, 'imagenet')
    label_per_class = args.num_labeled // args.num_classes

    def get_transform(type='lb'):
        """type: lb, ulb, test"""
        if type == 'lb':
            transform = transforms.Compose([
                transforms.Resize([256, 256]),
                transforms.RandomHorizontalFlip(),
                transforms.RandomCrop(224, padding=4, padding_mode='reflect'),
                transforms.ToTensor(),
                transforms.Normalize(imagenet_mean, imagenet_std)])
        elif type == 'ulb':
            if args.cam:
                transform = NewTransformFixMatch(mean=imagenet_mean, std=imagenet_std)
            else:
                transform = TransformFixMatch(mean=imagenet_mean, std=imagenet_std)
        elif type == 'test':
            transform = transforms.Compose([
                transforms.Resize([224, 224]),
                transforms.ToTensor(),
                transforms.Normalize(imagenet_mean, imagenet_std)])
        else:
            raise ValueError(f"Invalid transform type choice {type}!")
        return transform

    train_labeled_dataset = ImagenetDataset(root=os.path.join(root_path, "train"), transform=get_transform('lb'),
                                            num_labels=label_per_class, num_classes=args.num_classes)
    train_unlabeled_dataset = ImagenetDataset(root=os.path.join(root_path, "train"), transform=get_transform('ulb'),
                                              num_classes=args.num_classes)
    test_dataset = ImagenetDataset(root=os.path.join(root_path, "val"), transform=get_transform('test'),
                                   num_classes=args.num_classes)

    return train_labeled_dataset, train_unlabeled_dataset, test_dataset


class TransformFixMatch(object):
    """Return strong augmented image with cutout and normalization (torch.tensor)."""
    def __init__(self, mean, std):
        self.weak = transforms.Compose([
            transforms.Resize([256, 256]),
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(size=224,
                                  padding=4,
                                  padding_mode='reflect')])
        self.strong = transforms.Compose([
            transforms.Resize([256, 256]),
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(size=224,
                                  padding=4,
                                  padding_mode='reflect'),
            RandAugmentMC(n=2, m=10)])
        self.normalize = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std)])

    def __call__(self, x):
        weak = self.weak(x)
        strong = self.strong(x)
        return self.normalize(weak), self.normalize(strong)
    

class NewTransformFixMatch(object):
    """Return strong augmented image without cutout and normalization (PIL.Image)."""
    def __init__(self, mean, std):
        self.weak = transforms.Compose([
            transforms.Resize([256, 256]),
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(size=224,
                                  padding=4,
                                  padding_mode='reflect')])
        self.strong = transforms.Compose([
            transforms.Resize([256, 256]),
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(size=224,
                                  padding=4,
                                  padding_mode='reflect'),
            RandAugmentCAM(n=2, m=10)])
        self.astensor = transforms.ToTensor()
        self.normalize = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std)])

    def __call__(self, x):
        weak = self.weak(x)
        strong = self.strong(x)
        return self.normalize(weak), self.astensor(strong)


def pil_loader(path):
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')


class ImagenetDataset(datasets.ImageFolder):
    def __init__(self, root, transform, num_labels=-1, num_classes=10):
        super().__init__(root, transform)
        is_valid_file = None
        extensions = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
        classes, class_to_idx = self.find_classes(self.root)
        assert len(classes) == num_classes
        samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file, num_labels)
        if len(samples) == 0:
            msg = "Found 0 files in subfolders of: {}\n".format(self.root)
            if extensions is not None:
                msg += "Supported extensions are: {}".format(",".join(extensions))
            raise RuntimeError(msg)

        self.loader = pil_loader
        self.extensions = extensions

        self.classes = classes
        self.class_to_idx = class_to_idx
        self.samples = samples
        self.targets = [s[1] for s in samples]


    def __getitem__(self, index):
        path, target = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample_transformed = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return sample_transformed, target

    def make_dataset(
            self,
            directory,
            class_to_idx,
            extensions=None,
            is_valid_file=None,
            num_labels=-1,
    ):
        instances = []
        directory = os.path.expanduser(directory)
        both_none = extensions is None and is_valid_file is None
        both_something = extensions is not None and is_valid_file is not None
        if both_none or both_something:
            raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
        if extensions is not None:
            def is_valid_file(x: str) -> bool:
                return x.lower().endswith(extensions)

        lb_idx = {}

        for target_class in sorted(class_to_idx.keys()):
            class_index = class_to_idx[target_class]
            target_dir = os.path.join(directory, target_class)
            if not os.path.isdir(target_dir):
                continue
            for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
                random.shuffle(fnames)
                if num_labels != -1:
                    fnames = fnames[:num_labels]
                    lb_idx[target_class] = fnames
                for fname in fnames:
                    path = os.path.join(root, fname)
                    if is_valid_file(path):
                        item = path, class_index
                        instances.append(item)
        if num_labels != -1:
            with open('./sampled_label_idx.json', 'w') as f:
                json.dump(lb_idx, f)
        del lb_idx
        gc.collect()
        return instances
    

DATASET_GETTERS = {'imagenet': get_imagenet}


if __name__ == '__main__':
    class args:
        folds=0
        cam=False
        batch_size=4
        patch_size=8
        eval_step=128
        num_classes=10
        num_labeled=200
        expand_labels=False

    lb, ub, test = get_imagenet(args, './data')
