import numpy as np
from skimage.color import rgb2hsv, hsv2rgb
import tensorflow as tf
from tensorflow.keras.utils import to_categorical


def color_mnist_generator():
    (x_train_grey, y_train), (x_test_grey, y_test) = tf.keras.datasets.mnist.load_data()

    num_labels_train = to_categorical(y_train, num_classes=10)
    num_labels_test = to_categorical(y_test, num_classes=10)

    x_train_3 = np.repeat(x_train_grey[:,:,:,np.newaxis,], 3, axis=3)
    x_test_3 = np.repeat(x_test_grey[:,:,:,np.newaxis,], 3, axis=3)

    hsv_mnist_train = rgb2hsv(x_train_3)
    hsv_mnist_train[:,:,:,1] = 1  #set saturation to 1
    hsv_mnist_test = rgb2hsv(x_test_3)
    hsv_mnist_test[:,:,:,1] = 1  #set saturation to 1

    np.random.seed(1)
    list_color_mnist_train = []
    hue_labels_train = []
    for i in range(len(hsv_mnist_train)):
        hue = np.random.uniform(0,1)
        hsv_img = hsv_mnist_train[i]
        hsv_img[:,:,0] = hue
        rgb_img = hsv2rgb(hsv_img)
        list_color_mnist_train.append(rgb_img)
        hue_labels_train.append(hue)

    list_color_mnist_test = []
    hue_labels_test = []
    for i in range(len(hsv_mnist_test)):
        hue = np.random.uniform(0,1)
        hsv_img = hsv_mnist_test[i]
        hsv_img[:,:,0] = hue
        rgb_img = hsv2rgb(hsv_img)
        list_color_mnist_test.append(rgb_img)
        hue_labels_test.append(hue)

    hue_labels_train = np.array(hue_labels_train).reshape(-1,1)
    hue_labels_test = np.array(hue_labels_test).reshape(-1,1)

    img_train = np.array(list_color_mnist_train)
    img_test = np.array(list_color_mnist_test)

    np.random.seed(112)

    list_color_mnist_test_cf = []
    hue_labels_test_cf = []
    for i in range(len(hsv_mnist_test)):
        hue = np.random.uniform(0,1)
        hsv_img = hsv_mnist_test[i]
        hsv_img[:,:,0] = hue
        rgb_img = hsv2rgb(hsv_img)
        list_color_mnist_test_cf.append(rgb_img)
        hue_labels_test_cf.append(hue)

    hue_labels_test_cf = np.array(hue_labels_test_cf).reshape(-1,1)
    img_test_cf = np.array(list_color_mnist_test_cf)
    
    return img_train, y_test, num_labels_train, hue_labels_train, img_test, num_labels_test, hue_labels_test, img_test_cf, hue_labels_test_cf