import os.path

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection


def plot_state(samples, lik, target, outcomes, child_width, plot_outcomes=False):
    # some bookkeeping for plotting
    mean = (samples * lik.reshape((-1, 1))).sum(axis=0)
    close_to_mean = torch.linalg.norm(samples - mean, axis=-1) < child_width
    mass_mean = lik[close_to_mean].sum()

    max_idx = torch.argmax(lik)
    max_coords = samples[max_idx]
    close_to_max = torch.linalg.norm(samples - max_coords, axis=-1) < child_width
    mass_max = lik[close_to_max].sum()

    # plotting
    plt.figure()
    plt.title(f"{len(outcomes)}")
    # show the distribution
    plt.scatter(samples[:, 0], samples[:, 1], c=lik)

    # show the target
    plt.scatter(target[0, 0], target[0, 1], c="black")
    plt.scatter(target[0, 0], target[0, 1], c="yellow", marker="x")

    show_mean = True

    if show_mean:
        # show the child around mean
        circle1 = plt.Circle(mean, child_width, fill=False, linewidth=3, color='r')
        plt.text(mean[0] + 0.1, mean[1], f"{mass_mean:1.3}", fontdict={"fontsize": 24}, c="r")
        plt.gca().add_patch(circle1)

    else:
        # show the child around mode
        plt.scatter(max_coords[0], max_coords[1], c="r")
        circle2 = plt.Circle(max_coords, child_width, fill=False, linewidth=3, color='r')
        plt.text(max_coords[0] + 0.1, max_coords[1], f"{mass_max:1.3}", fontdict={"fontsize": 24}, c="r")
        plt.gca().add_patch(circle2)

    if plot_outcomes:
        for A, B in outcomes:
            # show the target
            plt.scatter(A[0,0], A[0,1], c="green")
            plt.scatter(B[0,0], B[0,1], c="pink", marker="x")

    # show mean and cov
    cov = torch.cov((samples).T, aweights=lik.flatten())
    print(cov)
    eig_vals, eig_vecs = torch.linalg.eigh(cov)
    max_eigvec = eig_vecs[-1]
    max_eigval = eig_vals[-1]
    A = mean + max_eigvec * max_eigval
    B = mean - max_eigvec * max_eigval

    # eig_vals, eig_vecs = torch.linalg.eigh(cov)
    # max_eigv = eig_vecs[-1]
    # l1 = [np.array(mean + max_eigv), np.array(mean-max_eigv)]

    l1 = [np.array(A), np.array(B)]
    lc = LineCollection([l1])
    plt.gca().add_collection(lc)
    plt.scatter(mean[0], mean[1], c="r")

    plt.show()

# ## plotting


def create_plots(all_runs, config, plotdir=None):
    """
    all_runs is a list of runs.
    each run is a list of information about one stage of the algorithm (currently we store only the distance to the target).
    config contains the parameters of the search.
    we generate plots and store them in a file that contains the relevant parameters in its name
    """

    if plotdir is None:
        plotdir="../plots"

    if not os.path.exists(plotdir):
        os.mkdir(plotdir)

    keys = [
        "gamma",
        "num_dims",
        "proceed_factor",
        "threshold_backtrack",
        "threshold_proceed",
        "num_samples"]

    config = {key: config[key] for key in keys}

    minlen = min(len(r) for r in all_runs)
    vals = np.zeros((len(all_runs), minlen))

    for idx, r in enumerate(all_runs):
        vals[idx] = r[:minlen]

    vals = np.log(vals)

    mean_v = vals.mean(axis=0)
    std_v = vals.std(axis=0)

    fig = plt.figure(figsize=(5, 4), dpi=600)
    ax = plt.axes()
    ax.set_xscale("linear")
    ax.set_yscale("linear")
    ax.fill_between(np.arange(minlen), mean_v + std_v, mean_v - std_v, alpha=0.3, label="+/- std")
    # ax.errorbar(np.arange(minlen), mean_v, yerr=std_v, alpha=0.4, label="+/- std")
    ax.plot(np.arange(minlen), mean_v, label='mean distance')
    ax.set_xlabel("number of queries")
    ax.set_ylabel("log(distance from target)")
    ax.set_title("Convergence rate")
    ax.legend()
    plt.savefig(f"{plotdir}/{str(config)}_fillbetween.jpg")
    plt.show()

    fig = plt.figure(figsize=(5, 4), dpi=600)
    ax = plt.axes()
    ax.set_xscale("linear")
    ax.set_yscale("linear")
    # ax.fill_between(np.arange(minlen), mean_v+std_v, mean_v-std_v, alpha=0.3, label="+/- std")
    # ax.errorbar(np.arange(minlen), mean_v, yerr=std_v, alpha=0.4, label="+/- std")
    ax.plot(np.arange(minlen), mean_v)

    num_individual = min(len(vals), 10)
    for n in range(num_individual):
        ax.plot(vals[n, :minlen], color="blue", alpha=0.1)

    ax.set_xlabel("number of queries")
    ax.set_ylabel("log(distance from target)")
    ax.set_title("Convergence rate")
    ax.legend(["average", 'individual runs'])
    plt.tight_layout()
    plt.savefig(f"{plotdir}/{str(config)}_individualruns.jpg")
    plt.show()