# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import os
import json
import torchvision
import numpy as np
import math

from torchvision import transforms
from .datasetbase import BasicDataset
from semilearn.datasets.augmentation import RandAugment, RandomResizedCropAndInterpolation
from semilearn.datasets.utils import split_ssl_data


mean, std = {}, {}
mean['cifar10'] = [0.485, 0.456, 0.406]
mean['cifar100'] = [x / 255 for x in [129.3, 124.1, 112.4]]
mean['dtd'] = [0.485, 0.456, 0.406]

std['cifar10'] = [0.229, 0.224, 0.225]
std['cifar100'] = [x / 255 for x in [68.2, 65.4, 70.4]]
std['dtd'] = [0.229, 0.224, 0.225]
datas = []


def deep_read_files(root_dir):
    for cur_file in os.listdir(root_dir):
        if os.path.isdir(os.path.join(root_dir, cur_file)):
            deep_read_files(os.path.join(root_dir, cur_file))
        elif '.jpg' in cur_file:
            datas.append(os.path.join(root_dir, cur_file))


def get_dtd(args, alg, name, data_dir):
    deep_read_files(data_dir)
        
    img_size = args.img_size
    crop_ratio = args.crop_ratio

    transform_weak = transforms.Compose([
        transforms.Resize((int(math.floor(img_size / crop_ratio)), int(math.floor(img_size / crop_ratio)))),
        transforms.RandomCrop((img_size, img_size)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean['dtd'], std['dtd'])
    ])

    transform_strong = transforms.Compose([
        transforms.Resize((int(math.floor(img_size / crop_ratio)), int(math.floor(img_size / crop_ratio)))),
        RandomResizedCropAndInterpolation((img_size, img_size)),
        transforms.RandomHorizontalFlip(),
        RandAugment(3, 10),
        transforms.ToTensor(),
        transforms.Normalize(mean['dtd'], std['dtd'])
    ])
    
    print(len(datas))
    ulb_dset = BasicDataset(alg, datas, transform=transform_weak, strong_transform=transform_strong)

    return ulb_dset
