import numpy as np
import math
import random
import os

from PIL import Image

subsample_pixel_count = 1000


def num_class_estimator(dict_val):
    num_classes = 0
    for key in dict_val.keys():
        if not dict_val[key]["done"]:
            num_classes = num_classes + 1
    # print(num_classes)
    return num_classes


def sample_pixel(image_pix):
    dict_val = {}
    image_pix = np.asarray(image_pix)
    image_pix = np.squeeze(image_pix)
    pixel_locations = []

    h, w = image_pix.shape[0], image_pix.shape[1]
    # print(image_pix.shape)
    # print(type(image_pix))

    # Making the pixel val dictionary entries
    for i in range(0, h):
        for j in range(0, w):
            if image_pix[i][j] not in dict_val.keys():
                dict_val[image_pix[i][j]] = {
                    'pixel_loc': [], 'count': 0, 'done': False}
            dict_val[image_pix[i][j]]['pixel_loc'].append((i, j))
            dict_val[image_pix[i][j]]['count'] += 1

    num_of_classes = num_class_estimator(dict_val)
    threshold_limit = math.ceil(subsample_pixel_count / num_of_classes)

    # for key, val in dict_val.items():
        # print(key, val["count"])

    while threshold_limit > 0:
        excess_count = 0
        for i in dict_val.keys():
            pixels = []
            if dict_val[i]["done"] == False:
                if dict_val[i]["count"] <= threshold_limit:

                    if dict_val[i]["count"] == 0:
                        continue

                    random.seed(30)
                    pixels = random.sample(
                        dict_val[i]['pixel_loc'], dict_val[i]["count"])

                    excess_count += threshold_limit - dict_val[i]["count"]

                    for p in pixels:
                        dict_val[i]['pixel_loc'].remove(p)
                        dict_val[i]['count'] = dict_val[i]['count'] - 1

                    if dict_val[i]["count"] == 0:
                        dict_val[i]["done"] = True

                    

                else:
                    random.seed(30)
                    pixels = random.sample(
                        dict_val[i]['pixel_loc'], threshold_limit)

                    for p in pixels:
                        dict_val[i]['pixel_loc'].remove(p)
                        dict_val[i]['count'] = dict_val[i]['count'] - 1

            for p in pixels:
                pixel_locations.append(p)

        num_of_classes = num_class_estimator(dict_val)

        # print(len(pixel_locations))

        if len(pixel_locations) > subsample_pixel_count:
            pixel_locations = pixel_locations[:-(len(pixel_locations) - subsample_pixel_count)]
            break

        if num_of_classes == 0 or excess_count == 0:
            threshold_limit = 0

        else:
            threshold_limit = math.ceil(excess_count / num_of_classes)

    # print(len(pixel_locations))
    return pixel_locations

