import torch
import random
import json
import os.path as osp
import numpy as np
from PIL import Image
from collections import defaultdict

from .Dataset import Dataset


class MetaTrainDataLoader(Dataset):
    def __init__(self,
                 phase,
                 order_file,
                 im_dir=None,
                 im_names=None,
                 im_ids=None,
                 ids2labels=None,
                 ids_per_batch=None,
                 ims_per_id=None,
                 init_ids_per_batch=None,
                 init_ims_per_id=None,
                 init_class_number=10,
                 **kwargs):
        # save super kwargs
        self.dataset_kwargs = kwargs
        # train process set
        self.order_file = order_file
        self.init_class_number = init_class_number
        self.phase = phase
        self.ids_per_batch = ids_per_batch
        self.ims_per_id = ims_per_id
        self.init_ids_per_batch = init_ids_per_batch
        self.init_ims_per_id = init_ims_per_id
        self.use_ids_per_batch = self.init_ids_per_batch
        self.use_ims_per_id = self.init_ims_per_id
        self.batch_size = self.ids_per_batch * self.ims_per_id
        # dataset parameters
        self.im_dir = im_dir
        self.im_ids = im_ids
        self.im_names = im_names
        self.ids2labels = ids2labels

        self.label_im_names = defaultdict(list)
        self.im_name_info = defaultdict(list)

        for _indx, _id in enumerate(self.im_ids):
            _label = self.ids2labels[_id]
            _name = self.im_names[_indx]
            self.label_im_names[_label].append(_name)
            self.im_name_info[_name] = [_indx, _id]

        # dataset symbol
        self.part = 'expert'
        self.class_number = len(self.ids2labels)
        self.trained_classes = []
        self.current_task_label = []
        self.current_task = defaultdict(list)
        self.classes = list(range(self.class_number))
        assert (self.class_number-self.init_class_number)%self.phase==0, 'Wrong phase, or init class number!'
        self.task_number = (self.class_number-self.init_class_number)//self.phase+1
        self.task_ptr = 0
        self.class_ptr = 0
        print(self.dataset_kwargs)

    def print_train_kwargs(self):
        print('-' * 40)
        print('Train Set')
        print('Train dataset: '+self.im_dir)
        print('Task number: '+str(self.task_ptr))
        print('-' * 40)
        print('NO. Curent Task IDs: ')
        print(self.current_task_label)
        print('NO. Current Train Track Steps(Pseudo Id): {}'.format(len(self.current_task)))
        print('NO. Current Images Per Step: {}'.format(self.ims_per_id))
        batch_size = self.use_ids_per_batch*self.use_ims_per_id
        print('Batch size prior:' + str(batch_size))
        print('-' * 40)


    def dump_order_file(self):
        print('-'*10+'Start dumpping order file'+10*'-')
        order_dict = dict(
            phase=self.phase,
            part=self.part,
            trained_classes=self.trained_classes,
            classes=self.classes
        )
        order_json_str = json.dumps(order_dict, indent=4)
        with open(self.order_file, 'w') as json_file:
            json_file.write(order_json_str)
        print('-'*10+'Dumped the order file'+10*'-')


    def load_order_file(self):
        with open(self.order_file) as json_file:
            order_dict = json.loads(json_file.read())
        self.phase = order_dict['phase']
        self.part = order_dict['part']
        self.trained_classes = list(set(order_dict['trained_classes']))
        self.classes = order_dict['classes']
        self.class_ptr = len(self.trained_classes)
        # self.task_ptr = (len(self.trained_classes)-self.init_class_number)/self.phase


    def change_task(self, resume=False):
        if resume:
            assert osp.exists(self.order_file), "The order file isn't exists! Please check to resume!"
            self.load_order_file()
            start_ptr = self.class_ptr-self.phase if self.class_ptr != self.init_class_number else 0
            end_ptr = self.class_ptr
        else:
            increment_step = self.init_class_number if self.class_ptr == 0 else self.phase
            start_ptr = self.class_ptr
            end_ptr = self.class_ptr+increment_step
        self.current_task_label = self.classes[start_ptr:end_ptr]
        self.task_ptr = (end_ptr-self.init_class_number)//self.phase+1
        if self.task_ptr == 1:
            self.use_ids_per_batch = self.init_ids_per_batch
            self.use_ims_per_id = self.init_ims_per_id
        else:
            self.use_ids_per_batch = self.ids_per_batch
            self.use_ims_per_id = self.ims_per_id
        self._make_task()
        self.print_train_kwargs()
        self.class_ptr = end_ptr
        self.trained_classes+=self.current_task_label
        self.trained_classes = list(set(self.trained_classes))


    def _make_task(self):
        del(self.current_task)
        train_track = defaultdict(list)
        labels = self.current_task_label
        np.random.shuffle(labels)
        last_step_num = last_label = -1
        img_ptr = step_indx = 0
        while True:
            for _label in labels:
                if len(self.label_im_names[_label]) >= img_ptr+self.use_ims_per_id and last_label != _label:
                    train_track[step_indx] += self.label_im_names[_label][img_ptr:img_ptr+self.use_ims_per_id]
                    train_track[step_indx].append(_label)
                    last_label = _label
                    step_indx += 1
            if last_step_num == len(train_track.keys()):
                break
            last_step_num = len(train_track.keys())
            img_ptr += self.use_ims_per_id
        self.current_task = train_track
        # for get feature
        self.dataset_kwargs['batch_size'] = self.use_ids_per_batch
        super(MetaTrainDataLoader, self).__init__(dataset_size=len(self.current_task), **self.dataset_kwargs)


    def get_sample(self, ptr):
        """
        Here one sample means several images (and labels etc) of one id.
        Returns:
        ims: a list of images
        """
        # im_name_info -> name: [indx, id]; train_track -> [names..., label]
        names = self.current_task[ptr][:self.use_ims_per_id]
        ids = [self.im_name_info[name][1] for name in names]
        labels = [self.current_task[ptr][-1]]*len(names)
        im_names = names
        ims = [np.asarray(Image.open(osp.join(self.im_dir, name)).convert('RGB'))
               for name in im_names]
        ims, mirrored = zip(*[self.pre_process_im(im) for im in ims])
        return ims, im_names, labels, mirrored, ids


    def next_batch(self):
        """
        Next batch of images and labels.
        Returns:
        ims: numpy array with shape [N, H, W, C] or [N, C, H, W], N >= 1
        img_names: a numpy array of image names, len(img_names) >= 1
        labels: a numpy array of image labels, len(labels) >= 1
        mirrored: a numpy array of booleans, whether the images are mirrored
        self.epoch_done: whether the epoch is over
        """
        # Start enqueuing and other preparation at the beginning of an epoch.
        if self.epoch_done and self.shuffle:
            np.random.shuffle(self.current_task)

        get_next_batch = self.prefetcher.get_next_train_batch
        samples, self.epoch_done = get_next_batch(used_label_index=2)
        im_list, im_names, im_labels, im_mirrored, im_ids = zip(*samples)

        ims = np.concatenate(im_list)
        names = np.concatenate(im_names)
        labels = np.concatenate(im_labels)
        mirrored = np.concatenate(im_mirrored)
        ids = np.concatenate(im_ids)

        samples_info = [ims, ids, mirrored, names, labels, self.epoch_done]
        return samples_info
