import random
import numpy as np

def getExample(X, Y, ids, width):
    vals = []
    for i in range(width):
        vals.append(Y[ids[i]])
    nonzeros = [i for i, e in enumerate(vals) if e != 0]
    if vals[0] == 0:  # first img is a 0
        if not len(nonzeros) == 0:  # there are non zeros to swap with
            zero_id = ids[0]  # id of first img, which is a 0
            nonzero_id = ids[nonzeros[0]]  # id of first non zero img
            ids[0] = nonzero_id
            ids[nonzeros[0]] = zero_id
            vals[0] = vals[nonzeros[0]]
            vals[nonzeros[0]] = 0

    images = []
    sum = 0
    for i in range(width):
        im_id = ids[i]
        sum += Y[im_id] * pow(10, width - i - 1)
        im = X[im_id]
        images.append(im)
    return images, sum, ids


def sample_extras(X, Y, num_sample, width):
    n_dataset = []
    if num_sample > 0:
        while num_sample != 0:
            a_ids = random.sample(range(len(X)), width)
            b_ids = random.sample(range(len(X)), width)
            a_img, a_sum, ids = getExample(X, Y, a_ids, width)
            if ids != a_ids:
                a_ids = ids
            b_img, b_sum, ids = getExample(X, Y, b_ids, width)
            if ids != b_ids:
                b_ids = ids
            images = a_img + b_img
            sum = a_sum + b_sum
            n_dataset.append([a_ids + b_ids, sum, images])
            num_sample -= 1
    return n_dataset


def generateExampleTuple(X, Y, id_list, width, height):
    images = []
    sum = 0
    counter = 0
    new_ids_list = []
    for i in range(height):
        singlenumids = id_list[counter: (i + 1) * width]
        counter = (i + 1) * width
        n_img, n_sum, n_ids = getExample(X, Y, singlenumids, width)
        if n_ids != singlenumids:
            singlenumids = n_ids
        images.append(n_img)
        sum += n_sum
        new_ids_list.append(singlenumids)
    return [images, sum, new_ids_list]


def generateDataset(X, Y, num_examples, height, width):
    ids = list(range(0, len(X)))
    random.shuffle(ids)

    counter = 0
    dataset_ids = []
    while counter < len(ids):
        example = []  # list of height*width ids .
        limit = counter + (height * width - 1)
        if limit < len(ids):
            # print('limits are [', counter, ',', limit, ']')
            example = ids[counter: limit + 1]
            # print(example)
            dataset_ids.append(example)
        counter = limit + 1

    dataset = []

    for x in dataset_ids:
        dataset.append(generateExampleTuple(X, Y, x, width, height))

    # print(len(dataset))
    extra_examples = 0
    if len(dataset) > num_examples:
        dataset = dataset[:num_examples]
    else:  # generate more
        covered = len(dataset)
        extra_examples = num_examples - covered

    if extra_examples > 0:
        num_imgs_needed = extra_examples * height * width
        blob_size = int(num_imgs_needed / 10)
        remainder = num_imgs_needed - 10 * blob_size
        new_list_ids = []
        if blob_size > 0:
            groups = []
            # uniform sample
            C0 = [i for i, j in enumerate(Y) if j == 0]
            groups.append(C0)
            C1 = [i for i, j in enumerate(Y) if j == 1]
            groups.append(C1)
            C2 = [i for i, j in enumerate(Y) if j == 2]
            groups.append(C2)
            C3 = [i for i, j in enumerate(Y) if j == 3]
            groups.append(C3)
            C4 = [i for i, j in enumerate(Y) if j == 4]
            groups.append(C4)
            C5 = [i for i, j in enumerate(Y) if j == 5]
            groups.append(C5)
            C6 = [i for i, j in enumerate(Y) if j == 6]
            groups.append(C6)
            C7 = [i for i, j in enumerate(Y) if j == 7]
            groups.append(C7)
            C8 = [i for i, j in enumerate(Y) if j == 8]
            groups.append(C8)
            C9 = [i for i, j in enumerate(Y) if j == 9]
            groups.append(C9)

            for g in groups:
                choices = (np.random.choice(g, size=blob_size, replace=True)).tolist()
                new_list_ids += choices
            if remainder > 0:
                new_list_ids += random.sample(range(len(C0)), remainder)

            # print(len(new_list_ids))
            random.shuffle(new_list_ids)

            counter = 0
            n_dataset_ids = []
            while counter < len(new_list_ids):
                example = []  # list of height*width ids .
                limit = counter + (height * width - 1)
                if limit < len(new_list_ids):
                    example = new_list_ids[counter: limit + 1]
                    # print(example)
                    n_dataset_ids.append(example)
                counter = limit + 1
            for x in n_dataset_ids:
                dataset.append(generateExampleTuple(X, Y, x, width, height))
        else:
            # do random sampling, small
            dataset.append(sample_extras(X, Y, num_imgs_needed, width))

    # print(len(dataset))
    unique_dataset_ids = set()
    for ex in dataset:
        for n in range(len(ex[2])):
            for m in range(len(ex[2][n])):
                unique_dataset_ids.add(ex[2][n][m])

    return dataset, list(unique_dataset_ids)
