import os
import gc
import copy
import json
import random
from torchvision.datasets import ImageFolder
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms
import math
from golearn.datasets.augmentation import RandAugment, RandomResizedCropAndInterpolation, str_to_interp_mode
from golearn.datasets.cv_datasets.datasetbase import BasicDataset
import numpy as np
from golearn.datasets.utils import split_ssl_data


mean, std = {}, {}
mean['imagenet'] = [0.485, 0.456, 0.406]
std['imagenet'] = [0.229, 0.224, 0.225]


def accimage_loader(path):
    import accimage
    try:
        return accimage.Image(path)
    except IOError:
        # Potentially a decoding problem, fall back to PIL.Image
        return pil_loader(path)


def pil_loader(path):
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')


def default_loader(path):
    from torchvision import get_image_backend
    if get_image_backend() == 'accimage':
        return accimage_loader(path)
    else:
        return pil_loader(path)


def get_imagenet(args, alg, name, num_labels, num_classes, data_dir='./data', include_lb_to_ulb=True):
    img_size = args.img_size
    crop_ratio = args.crop_ratio

    transform_weak = transforms.Compose([
        transforms.Resize((int(math.floor(img_size / crop_ratio)), int(math.floor(img_size / crop_ratio)))),
        transforms.RandomCrop((img_size, img_size)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean['imagenet'], std['imagenet'])
    ])

    transform_strong = transforms.Compose([
        transforms.Resize((int(math.floor(img_size / crop_ratio)), int(math.floor(img_size / crop_ratio)))),
        RandomResizedCropAndInterpolation((img_size, img_size)),
        transforms.RandomHorizontalFlip(),
        RandAugment(3, 10),
        transforms.ToTensor(),
        transforms.Normalize(mean['imagenet'], std['imagenet'])
    ])

    transform_val = transforms.Compose([
        transforms.Resize(math.floor(int(img_size / crop_ratio))),
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),
        transforms.Normalize(mean['imagenet'], std['imagenet'])
    ])

    data_dir = os.path.join(data_dir, name.lower())

    dataset = ImagenetDataset(root=os.path.join(data_dir, "train"), transform=transform_weak, idx=None, ulb=False, alg=alg)

    percentage = num_labels / len(dataset)

    lb_dset = ImagenetDataset(root=os.path.join(data_dir, "train"), transform=transform_weak, idx=None, strong_transform=transform_strong, ulb=False, alg=alg, percentage=percentage)
    # print(lb_dset.data, lb_dset.targets)

    lb_targets = lb_dset.targets
    lb_data = lb_dset.data
    lb_targets = np.array(lb_targets)
    lb_sorted_indices = np.argsort(lb_targets)
    num_task = args.split_num
    lb_indices_segments = np.array_split(lb_sorted_indices, num_task)
    for i in range(len(lb_indices_segments)):
        np.random.shuffle(lb_indices_segments[i])
    lb_idx = np.concatenate(lb_indices_segments)
    idx = 0
    for i in range(num_task-1):
        idx += len(lb_indices_segments[i])
        np.random.shuffle(lb_idx[idx-int(0.5*args.transition_ratio*len(lb_indices_segments[i])):idx+int(0.5*args.transition_ratio*len(lb_indices_segments[i+1]))])
    # lb_targets = lb_targets[lb_idx]
    lb_idx = [int(x) for x in lb_idx]
    # lb_dset = ImageSet(data=lb_data, targets=lb_targets)
    lb_dset = ImagenetDataset(root=os.path.join(data_dir, "train"), transform=transform_weak, idx=lb_idx,
                              strong_transform=transform_strong, ulb=False, alg=alg, percentage=percentage)
    # print(lb_dset.targets)

    ulb_dset = ImagenetDataset(root=os.path.join(data_dir, "train"), transform=transform_weak, alg=alg, idx=None, ulb=True, strong_transform=transform_strong, include_lb_to_ulb=include_lb_to_ulb)

    eval_dset = ImagenetDataset(root=os.path.join(data_dir, "val"), transform=transform_val, idx=None, strong_transform=transform_strong, alg=alg, ulb=False)

    return lb_dset, ulb_dset, eval_dset
    


class ImagenetDataset(BasicDataset, ImageFolder):
    def __init__(self, root, transform, ulb, alg, idx, strong_transform=None, percentage=-1, include_lb_to_ulb=True, lb_index=None):
        self.alg = alg
        self.is_ulb = ulb
        self.percentage = percentage
        self.transform = transform
        self.root = root
        self.include_lb_to_ulb = include_lb_to_ulb
        self.lb_index = lb_index
        self.idx = idx

        is_valid_file = None
        extensions = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
        classes, class_to_idx = self.find_classes(self.root)
        samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file)
        if len(samples) == 0:
            msg = "Found 0 files in subfolders of: {}\n".format(self.root)
            if extensions is not None:
                msg += "Supported extensions are: {}".format(",".join(extensions))
            raise RuntimeError(msg)

        self.loader = default_loader
        self.extensions = extensions

        self.classes = classes
        self.class_to_idx = class_to_idx
        self.data = [s[0] for s in samples]
        self.targets = [s[1] for s in samples]

        self.strong_transform = strong_transform
        if self.strong_transform is None:
            if self.is_ulb:
                assert self.alg not in ['fullysupervised', 'supervised', 'pseudolabel', 'vat', 'pimodel', 'meanteacher', 'mixmatch'], f"alg {self.alg} requires strong augmentation"


    def __sample__(self, index):
        path = self.data[index]
        sample = self.loader(path)
        target = self.targets[index]
        return sample, target

    def make_dataset(
            self,
            directory,
            class_to_idx,
            extensions=None,
            is_valid_file=None,
    ):
        instances = []
        directory = os.path.expanduser(directory)
        both_none = extensions is None and is_valid_file is None
        both_something = extensions is not None and is_valid_file is not None
        if both_none or both_something:
            raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
        if extensions is not None:
            def is_valid_file(x: str) -> bool:
                return x.lower().endswith(extensions)
        
        lb_idx = {}
        for target_class in sorted(class_to_idx.keys()):
            class_index = class_to_idx[target_class]
            target_dir = os.path.join(directory, target_class)
            if not os.path.isdir(target_dir):
                continue
            for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
                random.shuffle(fnames)
                if self.percentage != -1:
                    fnames = fnames[:int(len(fnames) * self.percentage)]
                if self.percentage != -1:
                    lb_idx[target_class] = fnames
                for fname in fnames:
                    if not self.include_lb_to_ulb:
                        if fname in self.lb_index[target_class]:
                            continue
                    path = os.path.join(root, fname)
                    if is_valid_file(path):
                        item = path, class_index
                        instances.append(item)
        gc.collect()
        self.lb_idx = lb_idx
        # print(self.idx)
        if self.idx is not None:
            instances = [x for _,x in sorted(zip(self.idx, instances))]
        return instances

class ImageSet(Dataset):
    def __init__(self, data, targets):
        self.data = data
        self.targets = targets

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        sample = self.data[index]
        target = self.targets[index]
        return sample, target