from __future__ import print_function

import os
import os.path
import numpy as np
import random
import pickle
import json
import math
import multiprocessing
from PIL import Image
from scipy.stats import beta

import torch
import torch.utils.data as data
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

#import h5py

from PIL import Image
from PIL import ImageEnhance

from pdb import set_trace as breakpoint


def buildLabelIndex(labels):
    label2inds = {}
    for idx, label in enumerate(labels):
        if label not in label2inds:
            label2inds[label] = []
        label2inds[label].append(idx)

    return label2inds


def load_data(file):
    try:
        with open(file, 'rb') as fo:
            data = pickle.load(fo)
        return data
    except:
        with open(file, 'rb') as f:
            u = pickle._Unpickler(f)
            u.encoding = 'latin1'
            data = u.load()
        return data


class ListDataset(object):
    """
    Args:
        elem_list (iterable/str): List of arguments which will be passed to
            `load` function. It can also be a path to file with each line
            containing the arguments to `load`
        load (function, optional): Function which loads the data.
            i-th sample is returned by `load(elem_list[i])`. By default `load`
            is identity i.e, `lambda x: x`
    """

    def __init__(self, elem_list, load=lambda x: x):
        self.list = elem_list
        self.load = load

    def __len__(self):
        return len(self.list)

    def __getitem__(self, idx):
        if idx >= len(self):
            raise IndexError("CustomRange index out of range")
        return self.load(self.list[idx])


class FewShotDataloader(object):
    def __init__(self, dataset, kway=5, kshot=1, kquery=1, batch_size=1, num_workers=2,
                 epoch_size=2000, shuffle=True, fixed=False):

        self.dataset = dataset
        self.phase = self.dataset.phase

        max_possible_cate = self.dataset.num_cats
        assert(kway >= 0 and kway <= max_possible_cate)

        self.kway = kway
        self.kshot = kshot
        self.kquery = kquery

        self.batch_size = batch_size
        self.num_workers = num_workers
        self.epoch_size = epoch_size

        if fixed:
            self.file = os.path.join(self.dataset.dataset_dir, self.dataset.taskaug + '_' +
                                     self.phase + '_' + str(kway) + 'way_' + str(kshot) + 'shot_' + str(kquery) + 'query.pkl')
            self.tasks = pickle.load(open(self.file, 'rb'))\
                if os.path.exists(self.file)else self.create_alltasks()
            self.get_iterator = self.get_fixed_iterator
            shuffle = False

        self.shuffle = shuffle

    def sampleCategories(self, sample_size=1):
        """
        Samples `sample_size` number of unique categories picked from categories.

        Args:
            sample_size: number of categories that will be sampled.

        Returns:
            cat_ids: a list of length `sample_size` with unique category ids.
        """

        assert(self.dataset.num_cats >= sample_size)
        # return sample_size unique categories chosen from labelIds set of
        # categories (that can be either self.labelIds_base or self.labelIds_novel)
        # Note: random.sample samples elements without replacement.
        return self.dataset.sampleCategories(sample_size)

    def sample_examples_for_categories(self, categories, kshot, kquery):
        """
        Samples train and test examples of the categories.

        Args:
    	    categories: a list with the ids of the categories.
    	    kshot: the number of training examples per category that
                will be sampled.
            kquery: the number of test images that will be sampled
                from per the categories.

        Returns:
            shot_e: a list of length len(categories) * kshot with 2-element tuples.
                The 1st element of each tuple is the image id that was sampled and
                the 2nd element is its category label (which is in the range
                [0, dataset.num_cats - 1]).
            query_e: a list of length len(categories) * kquery with 2-element tuples.
                The 1st element of each tuple is the image id that was sampled and
                the 2nd element is its category label (which is in the range
                [0, dataset.num_cats - 1]).
        """


        if len(categories) == 0:
            return [], []

        shot_e = []
        query_e = []
        #for cate_idx in range(len(categories)):
        for cate_idx, cate in enumerate(categories):
            imd_ids = self.dataset.sampleImageIdsFrom(
                categories[cate_idx],
                sample_size=(kquery + kshot))

            imds_shot = imd_ids[kquery:]
            imds_query = imd_ids[:kquery]

            shot_e += [(img_id, cate_idx, cate) for img_id in imds_shot]
            query_e += [(img_id, cate_idx, cate) for img_id in imds_query]

        assert (len(shot_e) == len(categories) * kshot)
        assert(len(query_e) == len(categories) * kquery)

        random.shuffle(shot_e)
        random.shuffle(query_e)

        return shot_e, query_e

    '''
    def createExamplesTensorData(self, examples):
        """
        Creates the examples image and label tensor data.

        Args:
            examples: a list of 2-element tuples, each representing a
                train or test example. The 1st element of each tuple
                is the image id of the example and 2nd element is the
                category label of the example, which is in the range
                [0, nK - 1], where nK is the total number of categories
                (both novel and base).

        Returns:
            images: a tensor of shape [nExamples, Height, Width, 3] with the
                example images, where nExamples is the number of examples
                (i.e., nExamples = len(examples)).
            labels: a tensor of shape [nExamples] with the category label
                of each example.
        """
        images = torch.stack(
            [self.dataset[img_idx][0] for img_idx, _ in examples], dim=0)
        labels = torch.LongTensor([label for _, label in examples])
        return images, labels
    '''

    def get_iterator(self, epoch=0):
        random.seed(epoch)
        np.random.seed(epoch)
        torch.manual_seed(epoch)
        torch.cuda.manual_seed_all(epoch)

        def load_function(iter_idx):
            categories = self.sampleCategories(self.kway)
            shot_e, query_e = self.sample_examples_for_categories(categories, self.kshot, self.kquery)
            Xt, Yt, Ot = self.dataset.createExamplesTensorData(query_e)
            if len(shot_e) > 0:
                Xe, Ye, Oe = self.dataset.createExamplesTensorData(shot_e)
                return Xe, Ye, Oe, Xt, Yt, Ot
            else:
                return Xt, Yt, Ot

        listdataset = ListDataset(elem_list=range(self.epoch_size), load=load_function)
        data_loader = DataLoader(listdataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=self.shuffle)

        return data_loader

    def create_alltasks(self, seed=0):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

        tasks = []
        for _ in range(self.epoch_size):
            categories = self.sampleCategories(self.kway)
            shot_e, query_e = self.sample_examples_for_categories(categories, self.kshot, self.kquery)
            tasks.append((shot_e, query_e))

        pickle.dump(tasks, open(self.file, 'wb'))
        return tasks

    def get_fixed_iterator(self, epoch=0):

        def load_function(iter_idx):
            shot_e, query_e = self.tasks[iter_idx]
            Xt, Yt, Ot = self.dataset.createExamplesTensorData(query_e)
            if len(shot_e) > 0:
                Xe, Ye, Oe = self.dataset.createExamplesTensorData(shot_e)
                return Xe, Ye, Oe, Xt, Yt, Ot
            else:
                return Xt, Yt, Ot

        listdataset = ListDataset(elem_list=range(self.epoch_size), load=load_function)
        data_loader = DataLoader(listdataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=self.shuffle)

        return data_loader

    def __call__(self, epoch=0):
        return self.get_iterator(epoch)

    def __len__(self):
        return int(self.epoch_size / self.batch_size)


#### For task level augment
class ProtoData(data.Dataset):
    taskaug = ''

    def sampleCategories(self, sample_size):
        return random.sample(self.labelIds, sample_size)

    def sampleImageIdsFrom(self, cat_id, sample_size=1):
        """
        Samples `sample_size` number of unique image ids picked from the
        category `cat_id` (i.e., self.dataset.label2ind[cat_id]).

        Args:
            cat_id: a scalar with the id of the category from which images will
                be sampled.
            sample_size: number of images that will be sampled.

        Returns:
            image_ids: a list of length `sample_size` with unique image ids.
        """
        assert(cat_id in self.label2ind)
        assert(len(self.label2ind[cat_id]) >= sample_size)
        # Note: random.sample samples elements without replacement.
        return random.sample(self.label2ind[cat_id], sample_size)

    def createExamplesTensorData(self, examples):
        """
        Creates the examples image and label tensor data.

        Args:
            examples: a list of 2-element tuples, each representing a
                train or test example. The 1st element of each tuple
                is the image id of the example and 2nd element is the
                category label of the example, which is in the range
                [0, nK - 1], where nK is the total number of categories
                (both novel and base).

        Returns:
            images: a tensor of shape [nExamples, Height, Width, 3] with the
                example images, where nExamples is the number of examples
                (i.e., nExamples = len(examples)).
            labels: a tensor of shape [nExamples] with the category label
                of each example.
        """
        images = torch.stack(
            [self[img_idx][0] for img_idx, _, _ in examples], dim=0)
        labels = torch.LongTensor([label for _, label, _ in examples])
        dc_labels = torch.LongTensor([label for _, _, label in examples])
        return images, labels, dc_labels


class Rotate90(object):
    def __init__(self, p, img_num_down=8e4):
        self.img_num = multiprocessing.Value("d", -1.)
        self.img_num_down = img_num_down
        self.p = 3. / 4. if p == -1 else p

    def __call__(self, img):
        self.img_num.value += 1.
        p = self.p * min(1., self.img_num.value / self.img_num_down)
        if random.random() < p:
            i = random.randint(0, 2)
            if i == 0:
                return img.transpose(Image.ROTATE_90)
            elif i == 1:
                return img.transpose(Image.ROTATE_180)
            elif i == 2:
                return img.transpose(Image.ROTATE_270)
        else:
            return img

    def __repr__(self):
        return self.__class__.__name__ + '(' \
               + 'p=' + str(self.p) + ', ' \
               + 'img_num_down=' + str(self.img_num_down) + ')'


## FMix
def fftfreqnd(h, w=None, z=None):
    """ Get bin values for discrete fourier transform of size (h, w, z)
    :param h: Required, first dimension size
    :param w: Optional, second dimension size
    :param z: Optional, third dimension size
    """
    fz = fx = 0
    fy = np.fft.fftfreq(h)

    if w is not None:
        fy = np.expand_dims(fy, -1)

        if w % 2 == 1:
            fx = np.fft.fftfreq(w)[: w // 2 + 2]
        else:
            fx = np.fft.fftfreq(w)[: w // 2 + 1]

    if z is not None:
        fy = np.expand_dims(fy, -1)
        if z % 2 == 1:
            fz = np.fft.fftfreq(z)[:, None]
        else:
            fz = np.fft.fftfreq(z)[:, None]

    return np.sqrt(fx * fx + fy * fy + fz * fz)


def get_spectrum(freqs, decay_power, ch, h, w=0, z=0):
    """ Samples a fourier image with given size and frequencies decayed by decay power
    :param freqs: Bin values for the discrete fourier transform
    :param decay_power: Decay power for frequency decay prop 1/f**d
    :param ch: Number of channels for the resulting mask
    :param h: Required, first dimension size
    :param w: Optional, second dimension size
    :param z: Optional, third dimension size
    """
    scale = np.ones(1) / (np.maximum(freqs, np.array([1. / max(w, h, z)])) ** decay_power)

    param_size = [ch] + list(freqs.shape) + [2]
    param = np.random.randn(*param_size)

    scale = np.expand_dims(scale, -1)[None, :]

    return scale * param


def make_low_freq_image(decay, shape, ch=1):
    """ Sample a low frequency image from fourier space
    :param decay_power: Decay power for frequency decay prop 1/f**d
    :param shape: Shape of desired mask, list up to 3 dims
    :param ch: Number of channels for desired mask
    """
    freqs = fftfreqnd(*shape)
    spectrum = get_spectrum(freqs, decay, ch, *shape)#.reshape((1, *shape[:-1], -1))
    spectrum = spectrum[:, 0] + 1j * spectrum[:, 1]
    mask = np.real(np.fft.irfftn(spectrum, shape))

    if len(shape) == 1:
        mask = mask[:1, :shape[0]]
    if len(shape) == 2:
        mask = mask[:1, :shape[0], :shape[1]]
    if len(shape) == 3:
        mask = mask[:1, :shape[0], :shape[1], :shape[2]]

    mask = mask
    mask = (mask - mask.min())
    mask = mask / mask.max()
    return mask


def sample_lam(alpha, reformulate=False):
    """ Sample a lambda from symmetric beta distribution with given alpha
    :param alpha: Alpha value for beta distribution
    :param reformulate: If True, uses the reformulation of [1].
    """
    if reformulate:
        lam = beta.rvs(alpha+1, alpha)
    else:
        lam = beta.rvs(alpha, alpha)

    return lam


def binarise_mask(mask, lam, in_shape, max_soft=0.0):
    """ Binarises a given low frequency image such that it has mean lambda.
    :param mask: Low frequency image, usually the result of `make_low_freq_image`
    :param lam: Mean value of final mask
    :param in_shape: Shape of inputs
    :param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask.
    :return:
    """
    idx = mask.reshape(-1).argsort()[::-1]
    mask = mask.reshape(-1)
    num = math.ceil(lam * mask.size) if random.random() > 0.5 else math.floor(lam * mask.size)

    eff_soft = max_soft
    if max_soft > lam or max_soft > (1-lam):
        eff_soft = min(lam, 1-lam)

    soft = int(mask.size * eff_soft)
    num_low = num - soft
    num_high = num + soft

    mask[idx[:num_high]] = 1
    mask[idx[num_low:]] = 0
    mask[idx[num_low:num_high]] = np.linspace(1, 0, (num_high - num_low))

    mask = mask.reshape((1, *in_shape))
    return mask


def sample_mask(lam, decay_power, shape, max_soft=0.0, reformulate=False):
    """ Samples a mean lambda from beta distribution parametrised by alpha, creates a low frequency image and binarises
    it based on this lambda
    :param alpha: Alpha value for beta distribution from which to sample mean of mask
    :param decay_power: Decay power for frequency decay prop 1/f**d
    :param shape: Shape of desired mask, list up to 3 dims
    :param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask.
    :param reformulate: If True, uses the reformulation of [1].
    """
    if isinstance(shape, int):
        shape = (shape,)

    # Choose lambda
    #lam = sample_lam(alpha, reformulate)

    # Make mask, get mean / std
    mask = make_low_freq_image(decay_power, shape)
    mask = binarise_mask(mask, lam, shape, max_soft)

    return lam, mask


def sample_and_apply(x, alpha, decay_power, shape, max_soft=0.0, reformulate=False):
    """
    :param x: Image batch on which to apply fmix of shape [b, c, shape*]
    :param alpha: Alpha value for beta distribution from which to sample mean of mask
    :param decay_power: Decay power for frequency decay prop 1/f**d
    :param shape: Shape of desired mask, list up to 3 dims
    :param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask.
    :param reformulate: If True, uses the reformulation of [1].
    :return: mixed input, permutation indices, lambda value of mix,
    """
    lam, mask = sample_mask(alpha, decay_power, shape, max_soft, reformulate)
    index = np.random.permutation(x.shape[0])

    x1, x2 = x * mask, x[index] * (1-mask)
    return x1+x2, index, lam

