"""
This code based on codes from https://github.com/tristandeleu/ntm-one-shot \
                              and https://github.com/kjunelee/MetaOptNet
"""
import numpy as np
import random
import pickle as pkl
from functions import *
import torch
import time
import torch.nn.functional as F

class miniImageNetGenerator(object):


    def __init__(self, data_file, nb_classes, num_user, n_spt, n_qry, max_iter=None):
        super(miniImageNetGenerator, self).__init__()
        self.data_file = data_file
        self.nb_classes = nb_classes
        self.max_iter = max_iter
        self.num_iter = 0
        self.data_dict = self._load_data(self.data_file)
        self.num_user = num_user
        self.n_spt = n_spt
        self.n_qry = n_qry * num_user


    def _load_data(self, data_file):
        dataset = self.load_data(data_file)
        data = dataset['data']
        labels = dataset['labels']
        label2ind = self.buildLabelIndex(labels)

        return {key: torch.tensor(data[val]).permute([0,3,1,2]) for (key, val) in label2ind.items()}

    def load_data(self, data_file):
        try:
            with open(data_file, 'rb') as fo:
                data = pkl.load(fo)
            return data
        except:
            with open(data_file, 'rb') as f:
                u = pkl._Unpickler(f)
                u.encoding = 'latin1'
                data = u.load()
            return data

    def buildLabelIndex(self, labels):
        label2inds = {}
        for idx, label in enumerate(labels):
            if label not in label2inds:
                label2inds[label] = []
            label2inds[label].append(idx)

        return label2inds


    def __iter__(self):
        return self

    def __next__(self):
        return self.next()

    def next(self):
        if (self.max_iter is None) or (self.num_iter < self.max_iter):
            self.num_iter += 1
            x_spt, y_spt, x_qry, y_qry = self.sample(self.nb_classes)

            return (self.num_iter - 1), x_spt, y_spt, x_qry, y_qry
        else:
            raise StopIteration()

    def sample(self, nb_classes):

        t_start = time.time()
        key_list = self.data_dict.keys()
        x_spt = torch.zeros(self.num_user, nb_classes*self.n_spt, 3, 84, 84)
        y_spt = torch.zeros(self.num_user, nb_classes*self.n_spt, dtype=int)
        x_qry = torch.zeros(self.n_qry * nb_classes, 3, 84, 84)
        y_qry = torch.zeros(self.n_qry * nb_classes, dtype=int)

        sampled_class = random.sample(key_list, nb_classes)  # randomly choose k class for an episode
        for clsidx, _class in enumerate(sampled_class):
            _imgs = self.data_dict[_class]  # [600,84,84,3]
            all_idx = set([i for i in range(len(_imgs))])
            for user in range(self.num_user):
                _ind1 = random.sample(all_idx, self.n_spt)
                all_idx = all_idx - set(_ind1)
                support_set = _imgs[_ind1]
                x_spt[user, self.n_spt*clsidx:self.n_spt*(clsidx+1)] = support_set
                y_spt[user, self.n_spt*clsidx:self.n_spt*(clsidx+1)] = clsidx

            _ind2 = random.sample(all_idx, self.n_qry)
            query_set = _imgs[_ind2]
            x_qry[self.n_qry*clsidx:self.n_qry*(clsidx+1)] = query_set
            y_qry[self.n_qry*clsidx:self.n_qry*(clsidx+1)] = clsidx


        t_end = time.time()
        elapsed_time = t_end-t_start

        return x_spt, y_spt, x_qry, y_qry



