# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import os
import json
import torchvision
import numpy as np
import math
from PIL import Image

from torchvision import transforms, datasets
from .datasetbase import BasicDataset
from semilearn.datasets.augmentation import RandAugment, RandomResizedCropAndInterpolation
from semilearn.datasets.utils import split_ssl_data


def get_dtd(args, alg, name, num_labels, num_classes, data_dir='./data', include_lb_to_ulb=True):
    
    data_dir = os.path.join(data_dir, name.lower())
    
    dset = datasets.DTD(data_dir, split='train', download=True)
    
    train_data = [np.array(item[0].resize((256,256))) for item in dset]
    train_targets = [item[1] for item in dset]
    train_data = np.array(train_data)
    
    dset = datasets.DTD(data_dir, split='val', download=True)
    valid_data = [np.array(item[0].resize((256,256))) for item in dset]
    valid_targets = [item[1] for item in dset]
    valid_data = np.array(valid_data)
    
    data = np.concatenate((train_data, valid_data), axis=0)
    targets = np.concatenate((train_targets, valid_targets), axis=0)


    # data = [item[0] for item in dset]
    # targets = [item[1] for item in dset]

    imgnet_mean = (0.485, 0.456, 0.406)
    imgnet_std = (0.229, 0.224, 0.225)
    img_size = args.img_size
    crop_ratio = args.crop_ratio

    transform_weak = transforms.Compose([
        transforms.Resize((int(math.floor(img_size / crop_ratio)), int(math.floor(img_size / crop_ratio)))),
        transforms.RandomCrop((img_size, img_size)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(imgnet_mean, imgnet_std)
    ])

    transform_strong = transforms.Compose([
        transforms.Resize((int(math.floor(img_size / crop_ratio)), int(math.floor(img_size / crop_ratio)))),
        RandomResizedCropAndInterpolation((img_size, img_size)),
        transforms.RandomHorizontalFlip(),
        RandAugment(1, 3),
        transforms.ToTensor(),
        transforms.Normalize(imgnet_mean, imgnet_std)
    ])

    transform_val = transforms.Compose([
        transforms.Resize(math.floor(int(img_size / crop_ratio))),
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),
        transforms.Normalize(imgnet_mean, imgnet_std)
    ])
    
## add1
    clip_tranform = transforms.Compose([
        transforms.RandomResizedCrop(size=224, scale=(0.5, 1), interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))])
    
    clip_tranform_val = transforms.Compose([
        transforms.Resize(size=224, interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))])

    lb_data, lb_targets, ulb_data, ulb_targets = split_ssl_data(args, data, targets, num_classes, 
                                                                lb_num_labels=num_labels,
                                                                ulb_num_labels=args.ulb_num_labels,
                                                                lb_imbalance_ratio=args.lb_imb_ratio,
                                                                ulb_imbalance_ratio=args.ulb_imb_ratio,
                                                                include_lb_to_ulb=include_lb_to_ulb)
    
    lb_count = [0 for _ in range(num_classes)]
    ulb_count = [0 for _ in range(num_classes)]
    for c in lb_targets:
        lb_count[c] += 1
    for c in ulb_targets:
        ulb_count[c] += 1
    print("lb count: {}".format(lb_count))
    print("ulb count: {}".format(ulb_count))

    if alg == 'fullysupervised':
        lb_data = data
        lb_targets = targets

    lb_dset = BasicDataset(alg, lb_data, lb_targets, num_classes, transform_weak, False, None, None, clip_tranform, False)

    ulb_dset = BasicDataset(alg, ulb_data, ulb_targets, num_classes, transform_weak, True, None, transform_strong, clip_tranform, False)

    dset = datasets.DTD(data_dir, split='test', download=True)
    test_data = [np.array(item[0].resize((256,256))) for item in dset]
    test_targets = [item[1] for item in dset]
    test_data = np.array(test_data)
    eval_dset = BasicDataset(alg, test_data, test_targets, num_classes, transform_val, False, None, None, clip_tranform_val, False)
    if args.tzsl:
        tzsl_dset = BasicDataset(alg, ulb_data, ulb_targets, num_classes, transform_val, False, None, None, clip_tranform_val, False)
        tzsl_dict = {'tzsl_dset': tzsl_dset, 'raw_data': ulb_data, 'raw_targets': ulb_targets, 'tfm_wk': transform_weak, 'tfm_st': transform_strong, 'clip_tfm': clip_tranform}
    else:
        tzsl_dict = {}
    return lb_dset, ulb_dset, eval_dset, tzsl_dict


# class DTDDataset(BasicDataset):
#     def __sample__(self, idx):
#         path = self.data[idx]
#         img = Image.open(path).convert("RGB")
#         target = self.targets[idx]
#         return img, target 

