import os
import torch
import numpy as np
from PIL import Image
from collections import defaultdict
import json

from .Dataset import Dataset

class TestDataLoader(Dataset):
    """
    Args:
      extract_feat_func: a function to extract features. It takes a batch of
        images and returns a batch of features.
      marks: a list, each element e denoting whether the image is from 
        query (e == 0), or
        gallery (e == 1), or 
        multi query (e == 2) set
    """

    def __init__(self, phase, order_file, im_dir=None, im_names=None, im_ids=None, ids2labels=None, init_class_number=10, **kwargs):
        self.dataset_kwargs = kwargs
        # The im dir of all images
        self.im_dir = im_dir
        self.im_names = im_names
        self.im_ids = im_ids
        self.ids2labels = ids2labels
        
        self.order_file = order_file
        self.init_class_number = init_class_number
        self.phase = phase
        self.class_number = len(self.ids2labels)
        self.classes = list(range(self.class_number))
        self.current_task = []
        self.trained_classes = []
        self.task = 0

        self.label_im_names = defaultdict(list)
        self.im_name_info = defaultdict(list)

        for _indx, _id in enumerate(self.im_ids):
            _label = self.ids2labels[_id]
            _name = self.im_names[_indx]
            self.label_im_names[_label].append(_name)
            self.im_name_info[_name] = [_label, _id]

    def task_init(self, task):
        self.current_task.clear()
        self.task = task
        start_ptr = 0 if task == 0 else self.init_class_number+(task-1)*self.phase
        end_ptr = self.init_class_number+task*self.phase
        self.current_task_label = self.classes[start_ptr:end_ptr]
        for label in self.current_task_label:
          self.current_task += self.label_im_names[label]
        super(TestDataLoader, self).__init__(dataset_size=len(self.current_task), **self.dataset_kwargs)
        self.print_test_kwargs()

  
    def load_order_file(self):
        with open(self.order_file, 'r') as json_file:
            order_dict = json.loads(json_file)
        self.phase = order_dict['phase']
        self.part = order_dict['part']
        self.trained_classes = order_dict['trained_classes'],
        self.classes = order_dict['classes']


    def print_test_kwargs(self):
        print('-' * 40)
        print('Test Set')
        print('Test dataset: '+self.im_dir)
        print('Test Task Number: '+str(self.task))
        print('-' * 40)
        print('NO. IDs: ')
        print(self.current_task_label)
        print('NO. Current Task Images: {}'.format(len(self.current_task)))


    def get_sample(self, ptr):
        im_name = self.current_task[ptr]
        im_path = os.path.join(self.im_dir, im_name)
        im = np.asarray(Image.open(im_path).convert('RGB'))
        im, _ = self.pre_process_im(im)
        label, id = self.im_name_info[im_name]
        return im, id, label, im_name


    def next_batch(self):
        if self.epoch_done and self.shuffle:
            self.prng.shuffle(self.im_names)
        samples, self.epoch_done = self.prefetcher.get_next_test_batch()
        im_list, im_ids, im_labels, im_names = zip(*samples)
        # Transform the list into a numpy array with shape [N, ...]
        ims = np.stack(im_list, axis=0)
        ids = np.array(im_ids)
        labels = np.array(im_labels)
        names = np.array(im_names)
        return ims, ids, labels, names, self.epoch_done