
import numpy as np
import matplotlib.pyplot as plt


def quantize_continuous(bins, value):
    n_classes = len(bins) + 1
    print("n_classes", n_classes)

    # compute target class
    idx = np.searchsorted(bins, value)
    print("continuous value", np.round(value, 2), "->", idx)

    # compute target vector
    target = np.full((n_classes,), fill_value=0.0, dtype=np.float32)
    target[idx] = 1.0

    return target, n_classes


if __name__ == "__main__":
    """ main """

    # discretize action space
    bins = np.array([-20, -10, -5, -1, 1, 5, 10, 20], dtype=np.float32)

    # sample random action
    value = np.random.uniform(-50, 50, 1)[0]

    # discretize action
    target, n_classes = quantize_continuous(bins, value)

    # project back to continuous space
    bin_centers = np.asarray([bins[0]] + [np.mean([bins[i:i+2]]) for i in range(len(bins) - 1)] + [bins[-1]])
    proj_value = np.sum(bin_centers * target)
    # proj_value = bin_centers[np.argmax(target)]

    plt.figure("Target Vector")
    plt.clf()

    plt.subplot(2, 1, 1)
    for bin in bins:
        plt.plot([bin, bin], [0, 1], 'k--', alpha=0.5)
    plt.plot([value, value], [0, 1], 'm--')
    plt.plot([proj_value, proj_value], [0, 1], 'm-')
    plt.grid()
    plt.xlim([-50, 50])
    plt.ylim([0, 1.1])

    plt.subplot(2, 1, 2)
    plt.bar(range(n_classes), target)
    plt.grid()
    plt.xlim([-0.5, n_classes])
    plt.ylim([0, 1.1])

    plt.show()
