# Copyright (c) 2023, NVIDIA Corporation & Affiliates. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, visit
# https://github.com/NVlabs/prismer/blob/main/LICENSE

import glob
from torch.utils.data import Dataset
from UniDet_eval.dataset.utils import *


class Classification(Dataset):
    def __init__(self, config, train):
        self.data_path = config['data_path']
        self.label_path = config['label_path']
        self.experts = config['experts']
        self.dataset = config['dataset']
        self.shots = config['shots']
        self.prefix = config['prefix']

        self.train = train
        self.transform = Transform(resize_resolution=config['image_resolution'], scale_size=[0.5, 1.0], train=True)

        if train:
            data_folders = glob.glob(f'{self.data_path}/imagenet_train/*/')
            self.data_list = [{'image': data} for f in data_folders for data in glob.glob(f + '*.JPEG')[:self.shots]]
            self.answer_list = json.load(open(f'{self.data_path}/imagenet/' + 'imagenet_answer.json'))
            self.class_list = json.load(open(f'{self.data_path}/imagenet/' + 'imagenet_class.json'))
        else:
            data_folders = glob.glob(f'{self.data_path}/imagenet/*/')
            self.data_list = [{'image': data} for f in data_folders for data in glob.glob(f + '*.JPEG')]
            self.answer_list = json.load(open(f'{self.data_path}/imagenet/' + 'imagenet_answer.json'))
            self.class_list = json.load(open(f'{self.data_path}/imagenet/' + 'imagenet_class.json'))

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, index):
        img_path = self.data_list[index]['image']
        if self.train:
            img_path_split = img_path.split('/')
            img_name = img_path_split[-2] + '/' + img_path_split[-1]
            class_name = img_path_split[-2]
            image, labels, labels_info = get_expert_labels(self.data_path, self.label_path, img_name, 'imagenet_train', self.experts)
        else:
            img_path_split = img_path.split('/')
            img_name = img_path_split[-2] + '/' + img_path_split[-1]
            class_name = img_path_split[-2]
            image, labels, labels_info = get_expert_labels(self.data_path, self.label_path, img_name, 'imagenet', self.experts)

        experts = self.transform(image, labels)
        experts = post_label_process(experts, labels_info)

        if self.train:
            caption = self.prefix + ' ' + self.answer_list[int(self.class_list[class_name])].lower()
            return experts, caption
        else:
            return experts, self.class_list[class_name]





# import os
# import glob
#
# data_path = '/Users/shikunliu/Documents/dataset/mscoco/mscoco'
#
# data_folders = glob.glob(f'{data_path}/*/')
# data_list = [data for f in data_folders for data in glob.glob(f + '*.jpg')]


