# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
import os
import json

from torchvision import datasets, transforms
from torchvision.datasets.folder import ImageFolder, default_loader

from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.data import create_transform


import imagenet_h
import imagenet_h_seeds
import imagenet_h_seeds_cap


def build_dataset(is_train, args):
    transform = build_transform(is_train, args)

  
    if args.data_set == 'IMNET-H':
        nb_classes=[505, 127, 20]
        dataset = imagenet_h.ImageNetHier(
                 args.data_path, 
                 is_train,
                 transform=transform,
                 texts=args.texts,
        )
    elif args.data_set == 'IMNET-H-SUPERPIXEL':
        nb_classes=[505, 127, 20]
        dataset = imagenet_h_seeds.ImageNetHier( #_cap
                 args.data_path, 
                 is_train,
                 transform=transform,
                 mean=IMAGENET_DEFAULT_MEAN,
                std=IMAGENET_DEFAULT_STD,
                n_segments=args.num_superpixels,
                compactness=10.0,
                blur_ops=None,
                scale_factor=1.0,
        )
    elif args.data_set == 'IMNET-H-SUPERPIXEL-CAP':
        nb_classes=[505, 127, 20]
        dataset = imagenet_h_seeds_cap.ImageNetHier( #_cap
                 args.data_path, 
                 is_train,
                 transform=transform,
                 mean=IMAGENET_DEFAULT_MEAN,
                std=IMAGENET_DEFAULT_STD,
                n_segments=args.num_superpixels,
                compactness=10.0,
                blur_ops=None,
                scale_factor=1.0,
                texts=args.texts
        )
        

    return dataset, nb_classes

def build_transform(is_train, args):
    resize_im = args.input_size > 32
    if is_train:
        # this should always dispatch to transforms_imagenet_train
        transform = create_transform(
            input_size=args.input_size,
            is_training=True,
            color_jitter=args.color_jitter,
            auto_augment=args.aa,
            interpolation=args.train_interpolation,
            re_prob=args.reprob,
            re_mode=args.remode,
            re_count=args.recount,
        )
        if not resize_im:
            # replace RandomResizedCropAndInterpolation with
            # RandomCrop
            transform.transforms[0] = transforms.RandomCrop(
                args.input_size, padding=4)
        return transform

    t = []
    if resize_im:
        size = int(args.input_size / args.eval_crop_ratio)
        t.append(
            transforms.Resize(size, interpolation=3),  # to maintain same ratio w.r.t. 224 images
        )
        t.append(transforms.CenterCrop(args.input_size))

    t.append(transforms.ToTensor())
    if 'INAT' in args.data_set:
        t.append(transforms.Normalize([0.466, 0.471, 0.380], [0.195, 0.194, 0.192]))
    else:
        t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
    return transforms.Compose(t)
