
import numpy as np
import matplotlib.pyplot as plt


def continuous_to_softmax(bins, value):
    n_classes = len(bins) + 1
    bin_centers = [bins[0]] + [np.mean([bins[i:i + 2]]) for i in range(len(bins) - 1)] + [bins[-1]]
    print("n_classes", n_classes)

    # compute target class
    idx = np.searchsorted(bins, value)
    print("continuous value", np.round(value, 2), "->", idx)

    # compute target vector
    target = np.full((n_classes,), fill_value=0.0, dtype=np.float32)

    if idx == 0 or idx == (n_classes - 1):
        target[idx] = 1
    else:
        dists = np.abs(value - bin_centers)
        sorted_idxs = np.argsort(dists)
        cum_dist = dists[sorted_idxs[0]] + dists[sorted_idxs[1]]
        target[sorted_idxs[0]] = dists[sorted_idxs[1]] / cum_dist
        target[sorted_idxs[1]] = dists[sorted_idxs[0]] / cum_dist

    return target, n_classes, bin_centers


if __name__ == "__main__":
    """ main """

    # discretize action space
    bins = np.array([-22.5, -17.5, -12.5, -7.5, -2.5, 2.5, 7.5, 12.5, 17.5, 22.5], dtype=np.float32)

    while True:

        # sample random action
        value = np.random.uniform(-30, 30, 1)[0]

        # discretize action
        target, n_classes, bin_centers = continuous_to_softmax(bins, value)

        # project back to continuous space
        mode_idx = np.argmax(target)
        indices = [mode_idx]
        if mode_idx > 0:
            indices.insert(0, mode_idx - 1)
        if mode_idx < (n_classes - 2):
            indices.append(mode_idx + 1)
        print(indices, [bin_centers[i] for i in indices])
        proj_value = np.sum([bin_centers[i] * target[i] for i in indices])

        plt.figure("Softmax - Quantization")
        plt.clf()

        plt.subplot(2, 1, 1)
        for bin in bins:
            plt.plot([bin, bin], [0, 1], 'k--', alpha=0.5)
        plt.plot([value, value], [0, 1], 'm--', label="target value")
        plt.plot([proj_value, proj_value], [0, 1], 'm-', alpha=0.25, linewidth=3,
                 label="re-projected value (expected value of discrete distribution)")
        plt.grid()
        plt.xlim([-30, 30])
        plt.ylim([0, 1.1])
        plt.legend()
        plt.title("Bins and Target")

        plt.subplot(2, 1, 2)
        for bin in bin_centers:
            plt.plot([bin, bin], [0, 1], 'k--', alpha=0.5)
        plt.bar(bin_centers, target)
        plt.grid()
        plt.xlim([-0.5, n_classes])
        plt.xlim([-30, 30])
        plt.title("Soft Target Distribution")

        plt.show()
