import random
import heapq

import numpy

import scipy.stats
import sklearn.mixture
from matplotlib import pyplot

import tqdm


def main(seed, components, peaks, samples, acquisition_size, shuffle_repeat, min_peak, max_peak, bins, screen_offset):

    numpy.random.seed(seed)
    random.seed(seed)

    data = numpy.random.randn(samples * peaks, 1)
    for i in range(peaks):
        data[i * samples: (i+1) * samples] += numpy.random.rand() * (max_peak-min_peak) + min_peak

    gmm = sklearn.mixture.GaussianMixture(n_components=components)
    gmm.fit(data)
    print(f"Covariance shape: {gmm.covariances_.shape}")

    xs = numpy.linspace(data.min()-screen_offset, data.max()+screen_offset, 1000)
    max_y = -float("inf")
    component_ys = []
    for i in range(components):
        ys = scipy.stats.norm.pdf(xs, gmm.means_[i, 0], gmm.covariances_[i, 0]) * gmm.weights_[i]
        component_ys.append(ys)
        max_y = max(max_y, ys.max())

    # Plot 1:
    #   - samples
    #   - histogram
    #   - gmm overfit

    def plot_data(plot):
        plot.scatter(data, numpy.zeros(data.shape), alpha=0.2, label="Samples")

    def plot_hist(plot):
        plot_data(plot)
        plot.hist(data, bins=bins, alpha=0.3, density=True, label="Sample distribution")

    def plot_gmm(plot):
        plot_hist(plot)
        for i, ys in enumerate(component_ys):
            plot.plot(xs, ys)

    _, axes = pyplot.subplots(ncols=3, nrows=1, sharex=True, sharey=True, figsize=(15, 5))

    plot_data(axes[0])
    plot_hist(axes[1])
    plot_gmm(axes[2])

    axes[0].set_ylabel("Probability density")
    axes[1].set_xlabel("Sample value")

    axes[0].set_title(f"Collected {samples} samples per {peaks} random means")
    axes[1].set_title(f"Histogram of samples ({bins} bins)")
    axes[2].set_title(f"GMM with {components} components")

    pyplot.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
    pyplot.savefig("gmm.png", bbox_inches="tight")

    # Plot 2:
    # [ A | B | C ]
    # [ D | E | F ]
    # A, B, C are gmm plots with select points chosen.
    # D, E, F are their probability vectors, represented as heat maps.

    pyplot.clf()

    _, axes = pyplot.subplots(nrows=2, ncols=3, sharex="row", sharey="row", figsize=(15, 10))

    weights = gmm.weights_.reshape((1, -1))
    mean_i = numpy.argsort(gmm.means_.reshape(-1))

    def plot_sample_scores(plot, acquisition_fn):
        plot_gmm(plot[0])
        acquired, probs = acquisition_fn()
        mbbald, _ = score_modified_batchbald(acquired)

        acquired = acquired.reshape(-1)
        sorted_i = numpy.argsort(acquired)
        acquired = acquired[sorted_i]
        probs = probs[sorted_i]

        plot[0].vlines(acquired, 0, max_y, linestyle="dotted", label="Acquired samples")
        plot[1].imshow(probs.T[mean_i], cmap="hot", interpolation="bicubic", aspect="auto")

        plot[1].set_title(f"{mbbald:.2f} M-bBALD")

    def acquire_randomly():
        batch = numpy.concatenate([random.choice(data) for _ in range(acquisition_size)]).reshape((-1, 1))
        return batch, gmm.predict_proba(batch)

    def iter_shuffled_batches(desc):
        batches = len(data)//acquisition_size
        shuffle = data.copy()
        for _ in tqdm.tqdm(range(shuffle_repeat), desc=desc, ncols=80):
            random.shuffle(shuffle)
            for i in range(batches):
                yield shuffle[i*acquisition_size: (i+1)*acquisition_size].copy()

    def acquire_best_score(score_fn, desc):
        max_score = -float("inf")
        best_batch = None
        best_probs = None
        for batch in iter_shuffled_batches(desc):
            score, probs = score_fn(batch)
            if score > max_score:
                max_score = score
                best_batch = batch
                best_probs = probs

        assert best_batch is not None
        return best_batch, best_probs

    def greedy_acquire_best_score(score_fn, desc):
        shuffle = data.copy()
        best_batch = best_probs = None

        while len(shuffle) >= 2*acquisition_size:
            heap = []
            for j in tqdm.tqdm(range(shuffle_repeat), desc=desc, ncols=80):
                batches = len(shuffle) // acquisition_size
                random.shuffle(shuffle)
                for i in range(batches):
                    batch = shuffle[i*acquisition_size: (i+1)*acquisition_size].copy()

                    score, probs = score_fn(batch)
                    heapq.heappush(heap, (score, j, i, batch, probs))  # i is tie breaker

            selected = heapq.nsmallest(batches//2, heap)
            _, _, _, best_batch, best_probs = selected[0]
            shuffle = numpy.concatenate([b for _, _, _, b, _ in selected], axis=0)

        assert best_batch is not None
        return best_batch, best_probs

    eps = 1e-32

    def score_most_probable(batch):
        probs = gmm.predict_proba(batch)
        assert probs.shape == (len(batch), components)
        log_p = numpy.log((probs * weights).sum(axis=1) + eps).sum()
        return log_p, probs

    def acquire_most_probable():
        return acquire_best_score(score_most_probable, "Find most probable batch")

    def score_modified_batchbald(batch):
        probs = gmm.predict_proba(batch)
        score = numpy.log(probs.max(axis=0) + eps).sum()
        return score, probs

    def acquire_with_modified_batchbald_score():
        return acquire_best_score(score_modified_batchbald, "Find highest modified batch-BALD score")

    plot_sample_scores(axes[:, 0], acquire_most_probable)
    plot_sample_scores(axes[:, 1], acquire_randomly)
    plot_sample_scores(axes[:, 2], acquire_with_modified_batchbald_score)

    axes[0, 0].set_title("Most probable")
    axes[0, 1].set_title("Random")
    axes[0, 2].set_title("Modified batch-BALD")

    axes[1, 0].set_ylabel("Component (probability)")
    axes[1, 1].set_xlabel("Batch sample")

    pyplot.savefig("heatmap.png", bbox_inches="tight")

    # Plot 3:
    # Prob vs Modified Batch-BALD

    pyplot.clf()

    modified_batchbald_scores = []
    for batch in iter_shuffled_batches("Collecting modified batch-BALD samples"):
        score, _ = score_modified_batchbald(batch)
        modified_batchbald_scores.append(score)

    pyplot.hist(modified_batchbald_scores, bins=bins, alpha=0.3, density=True)
    pyplot.title("Distribution of modified batch-BALD scores")
    pyplot.ylabel("Probability density")
    pyplot.xlabel("Modified batch-BALD score (higher is more entropy)")

    pyplot.savefig("distrib-batchbald.png", bbox_inches="tight")



if __name__ == "__main__":

    import argparse

    parser = argparse.ArgumentParser()

    parser.add_argument("--seed", type=int, default=1337)
    parser.add_argument("--components", type=int, default=32)
    parser.add_argument("--acquisition_size", type=int, default=8)
    parser.add_argument("--shuffle_repeat", type=int, default=5)
    parser.add_argument("--peaks", type=int, default=16)
    parser.add_argument("--samples", type=int, default=100)
    parser.add_argument("--min_peak", type=float, default=-20)
    parser.add_argument("--max_peak", type=float, default=20)
    parser.add_argument("--screen_offset", type=float, default=2)
    parser.add_argument("--bins", type=int, default=50)

    args = parser.parse_args()

    main(**vars(args))