

import os
import PIL

from torchvision import transforms, datasets
from timm.data import create_transform

from .dataset_folder import ImageFolder


CLIP_DEFAULT_MEAN = (0.48145466, 0.4578275, 0.40821073)
CLIP_DEFAULT_STD = (0.26862954, 0.26130258, 0.27577711)


def build_dataset(is_train, args):
    transform = build_transform(is_train, args)
    if args.data_set == 'IMNET':
        root = os.path.join(args.data_path, 'train' if is_train else 'val')
        dataset = datasets.ImageFolder(root, transform=transform)
        print(dataset)
    elif args.data_set == "image_folder":
        root = args.data_path if is_train else args.eval_data_path
        dataset = ImageFolder(root, transform=transform)
        nb_classes = args.nb_classes
        assert len(dataset.class_to_idx) == nb_classes
    return dataset


def build_transform(is_train, args):
    mean = CLIP_DEFAULT_MEAN
    std = CLIP_DEFAULT_STD
    # train transform
    if is_train:
        # this should always dispatch to transforms_imagenet_train
        transform = create_transform(
            input_size=args.input_size,
            is_training=True,
            color_jitter=args.color_jitter,
            auto_augment=args.aa,
            interpolation='bicubic',
            re_prob=args.reprob,
            re_mode=args.remode,
            re_count=args.recount,
            mean=mean,
            std=std,
        )
        return transform

    if args.input_size <= 224:
        crop_pct = 224 / 256
    else:
        crop_pct = 1.0
    size = int(args.input_size / crop_pct)
    # eval transform
    transform = transforms.Compose([
      transforms.Resize(size, interpolation=PIL.Image.BICUBIC),
      transforms.CenterCrop(args.input_size),
      transforms.ToTensor(),
      transforms.Normalize(mean, std)
    ])
    return transform
