import logging
import math
from torch.utils.data import Dataset, DataLoader
from .datasetbase import BasicDataset
import numpy as np
from torchvision.datasets import ImageFolder
import random
import torch
from PIL import Image
from torch.utils.data import ConcatDataset
from golearn.datasets.utils import get_onehot
import cv2
import os

from glob import glob
from torchvision import transforms, datasets
import matplotlib.pyplot as plt

dataset_mean = (0.485, 0.456, 0.406)
dataset_std = (0.229, 0.224, 0.225)
num_classes = 10
from golearn.datasets.augmentation import RandAugment, RandomResizedCropAndInterpolation, str_to_interp_mode
from golearn.datasets.utils import split_ssl_data


def accimage_loader(path):
    import accimage
    try:
        return accimage.Image(path)
    except IOError:
        # Potentially a decoding problem, fall back to PIL.Image
        return pil_loader(path)


def pil_loader(path):
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')


def default_loader(path):
    from torchvision import get_image_backend
    if get_image_backend() == 'accimage':
        return accimage_loader(path)
    else:
        return pil_loader(path)

def get_clear10(args, alg, dataset, num_labels, num_classes, data_dir='./data', include_lb_to_ulb=True):
    img_size = 224
    crop_ratio =  0.875

    transform_weak = transforms.Compose([
        transforms.Resize((int(math.floor(img_size / crop_ratio)), int(math.floor(img_size / crop_ratio)))),
        transforms.RandomCrop((img_size, img_size)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(dataset_mean, dataset_std)
    ])

    transform_strong = transforms.Compose([
        transforms.Resize((int(math.floor(img_size / crop_ratio)), int(math.floor(img_size / crop_ratio)))),
        RandomResizedCropAndInterpolation((img_size, img_size)),
        transforms.RandomHorizontalFlip(),
        RandAugment(3, 10),
        transforms.ToTensor(),
        transforms.Normalize(dataset_mean, dataset_std)
    ])

    transform_val = transforms.Compose([
        transforms.Resize(math.floor(int(img_size / crop_ratio))),
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),
        transforms.Normalize(dataset_mean, dataset_std)
    ])
    train_folder = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10']
    test_folder = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10']
    train_set = []
    ulb_set = []
    val_set = []
    for i in range(10):
        dir = os.path.join(data_dir, 'CLEAR/clear10-train/train/labeled_images')
        dir = os.path.join(dir, train_folder[i])
        base_set = CLEAR(alg, dir)
        idx_list = np.random.permutation(len(base_set))

        subset = CLEAR(alg, dir, index=idx_list, transform=transform_weak, transform_strong=transform_strong)
        train_set.append(subset)
    train_labeled_dataset = ConcatDataset(train_set)

    for i in range(11):
        dir = os.path.join(data_dir, 'CLEAR/clear10-test/test/labeled_images')
        dir = os.path.join(dir, test_folder[i])
        base_set = CLEAR(alg, dir)
        idx_list = np.random.permutation(len(base_set))

        subset = CLEAR(alg, dir, index=idx_list, transform=transform_weak, transform_strong=transform_strong)
        val_set.append(subset)
    val_dataset = ConcatDataset(val_set)

    # dir = os.path.join(data_dir, 'CLEAR/clear10-train/train/labeled_images')
    # ulb_dir = os.path.join(dir, train_folder[5])
    # train_unlabeled_dataset = CLEAR(alg, root=ulb_dir, is_ulb=True, index=idx_list, transform=transform_weak, transform_strong=transform_strong)
    for i in range(3):
        dir = os.path.join(data_dir, 'CLEAR/clear10-train/train/labeled_images')
        dir = os.path.join(dir, train_folder[i])
        base_set = CLEAR(alg, dir)
        idx_list = np.random.permutation(len(base_set))
        subset = CLEAR(alg, dir, is_ulb=True, index=idx_list, transform=transform_weak, transform_strong=transform_strong)
        ulb_set.append(subset)
    train_unlabeled_dataset = ConcatDataset(ulb_set)

    # print(train_labeled_dataset.targets)
    return train_labeled_dataset, train_unlabeled_dataset, val_dataset


def get_clear100(args, alg, dataset, num_labels, num_classes, data_dir='./data', include_lb_to_ulb=True):
    img_size = 224
    crop_ratio =  0.875

    transform_weak = transforms.Compose([
        transforms.Resize((int(math.floor(img_size / crop_ratio)), int(math.floor(img_size / crop_ratio)))),
        transforms.RandomCrop((img_size, img_size)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(dataset_mean, dataset_std)
    ])

    transform_strong = transforms.Compose([
        transforms.Resize((int(math.floor(img_size / crop_ratio)), int(math.floor(img_size / crop_ratio)))),
        RandomResizedCropAndInterpolation((img_size, img_size)),
        transforms.RandomHorizontalFlip(),
        RandAugment(3, 10),
        transforms.ToTensor(),
        transforms.Normalize(dataset_mean, dataset_std)
    ])

    transform_val = transforms.Compose([
        transforms.Resize(math.floor(int(img_size / crop_ratio))),
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),
        transforms.Normalize(dataset_mean, dataset_std)
    ])
    train_folder = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10']
    test_folder = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10']
    train_set = []
    ulb_set = []
    val_set = []
    for i in range(10):
        dir = os.path.join(data_dir, 'CLEAR/clear100-train/labeled_images')
        dir = os.path.join(dir, train_folder[i])
        base_set = CLEAR(alg, dir)
        idx_list = np.random.permutation(len(base_set))

        subset = CLEAR(alg, dir, index=idx_list, transform=transform_weak, transform_strong=transform_strong)
        train_set.append(subset)
    train_labeled_dataset = ConcatDataset(train_set)

    for i in range(11):
        dir = os.path.join(data_dir, 'CLEAR/clear100-test/labeled_images')
        dir = os.path.join(dir, test_folder[i])
        base_set = CLEAR(alg, dir)
        idx_list = np.random.permutation(len(base_set))

        subset = CLEAR(alg, dir, index=idx_list, transform=transform_weak, transform_strong=transform_strong)
        val_set.append(subset)
    val_dataset = ConcatDataset(val_set)

    # dir = os.path.join(data_dir, 'CLEAR/clear10-train/train/labeled_images')
    # ulb_dir = os.path.join(dir, train_folder[5])
    # train_unlabeled_dataset = CLEAR(alg, root=ulb_dir, is_ulb=True, index=idx_list, transform=transform_weak, transform_strong=transform_strong)
    for i in range(3):
        dir = os.path.join(data_dir, 'CLEAR/clear100-train/labeled_images')
        dir = os.path.join(dir, train_folder[i])
        base_set = CLEAR(alg, dir)
        idx_list = np.random.permutation(len(base_set))
        subset = CLEAR(alg, dir, is_ulb=True, index=idx_list, transform=transform_weak, transform_strong=transform_strong)
        ulb_set.append(subset)
    train_unlabeled_dataset = ConcatDataset(ulb_set)

    # print(train_labeled_dataset.targets)
    return train_labeled_dataset, train_unlabeled_dataset, val_dataset


class CLEAR(BasicDataset, ImageFolder):
    def __init__(self, alg, root, index=None, is_ulb=False, transform=None, target_transform=None, transform_strong=None):
        self.loader = default_loader
        self.alg = alg
        self.is_ulb = is_ulb
        self.transform = transform
        self.root = root
        self.idx_list = index
        self.onehot = False

        is_valid_file = None
        extensions = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
        classes, class_to_idx = self.find_classes(self.root)
        samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file)
        if len(samples) == 0:
            msg = "Found 0 files in subfolders of: {}\n".format(self.root)
            if extensions is not None:
                msg += "Supported extensions are: {}".format(",".join(extensions))
            raise RuntimeError(msg)

        self.extensions = extensions

        self.classes = classes
        self.class_to_idx = class_to_idx
        self.data = [s[0] for s in samples]
        self.targets = [s[1] for s in samples]

        self.strong_transform = transform_strong
        if self.strong_transform is None:
            if self.is_ulb:
                assert self.alg not in ['fullysupervised', 'supervised', 'pseudolabel', 'vat', 'pimodel', 'meanteacher',
                                        'mixmatch'], f"alg {self.alg} requires strong augmentation"

    def __sample__(self, idx):
        if self.idx_list is not None:
            idx = self.idx_list[idx]
        path = self.data[idx]
        img = self.loader(path)
        target = self.targets[idx]
        return img, target

    def __getitem__(self, index):
        return BasicDataset.__getitem__(self, index)


    def __len__(self):
        return len(self.data)




if __name__ == '__main__':
    root = '/media/parnec/7AFCDBA8FCDB5CC7/wxr/Generalized-Online-Continual-Learning/data/CLEAR/clear10-train/train/labeled_images/1'
    # get_clear(root)