import torch
import numpy as np

import matplotlib.pyplot as plt

from matplotlib import animation
from matplotlib import rc

from training import train


def mean_steady_state_loss(losses, steady_state_fraction=0.5):
    # Assumes that we are well into convergence by the last half of the experiment.
    # returns the mean of the last half of the tensor of losses.
    return torch.mean(losses[-int(len(losses) * steady_state_fraction):])


def histogram(sample, xlim, bins):
    freqs, bin_edges = np.histogram(sample, bins=bins, range=xlim, density=True)

    bin_starts = bin_edges[:-1]
    bin_ends = bin_edges[1:]
    plot_freqs = []
    plot_bin_edges = []

    for freq, bin_start, bin_end in zip(freqs, bin_starts, bin_ends):
        plot_freqs.append(freq)
        plot_freqs.append(freq)
        plot_bin_edges.append(bin_start)
        plot_bin_edges.append(bin_end)

    return plot_freqs, plot_bin_edges


def animate_training(model, optimizer, target_function, distributions, title, ylim, xlim, dist_title, dist_ylim,
                     dist_bins, steps_per_frame):
    rc('animation', embed_limit=1e9)

    x = torch.linspace(-1, 1, 100)
    fig = plt.figure(constrained_layout=True, figsize=(8, 8))
    grid = fig.add_gridspec(5, 3, figure=fig)
    main_ax = fig.add_subplot(grid[:3, :])
    dist_ax = fig.add_subplot(grid[3:, :])

    main_ax.plot(x, target_function(x), linestyle='--', color='black')
    main_ax.set_ylim(ylim)
    main_ax.set_title(title)

    freqs, bin_edges = histogram(distributions[0](steps_per_frame), xlim, dist_bins)
    dist_ax.set_title(dist_title)
    dist_ax.set_ylim(dist_ylim)
    dist_ax.set_xlim(xlim)

    recent_history_length = 7
    recent_learned_functions = []
    recent_dist = []

    def init():
        for t_elapsed in range(recent_history_length):
            with torch.no_grad():
                x = torch.linspace(xlim[0], xlim[1], 100).reshape((100, 1))
                alpha = 2 ** -t_elapsed
                learned_function, = main_ax.plot(
                    x.reshape(100),
                    model(x).reshape(100),
                    alpha=alpha,
                    color='#1f77b4',
                )
                domain_dist, = dist_ax.plot(bin_edges, freqs, color='#d62728', alpha=alpha, linewidth=3)
                dist_ax.set_xlim(xlim)

                recent_learned_functions.append(learned_function)
                recent_dist.append(domain_dist)

    def animate(i):
        losses, samples = train(model, optimizer, distributions[(i * steps_per_frame):((i + 1) * steps_per_frame)])

        with torch.no_grad():
            x = torch.linspace(xlim[0], xlim[1], 100).reshape((100, 1))

            for t_elapsed in range(recent_history_length - 1, 0, -1):
                recent_learned_functions[t_elapsed].set_data(recent_learned_functions[t_elapsed - 1].get_data())
                recent_dist[t_elapsed].set_data(recent_dist[t_elapsed - 1].get_data())

            recent_learned_functions[0].set_data(x.reshape(100), model(x).reshape(100))

        freqs, bin_edges = histogram(samples, xlim, dist_bins)
        recent_dist[0].set_data(bin_edges, freqs)

    anim = animation.FuncAnimation(
        fig,
        animate,
        init_func=init,
        frames=int(len(distributions) / steps_per_frame),
        blit=False,
        repeat=True,
        repeat_delay=100,
    )

    return anim
