import math
import os

import torch

import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm
from IPython.display import HTML
from matplotlib import animation
from matplotlib import rc

from scipy import signal
from scipy.ndimage import gaussian_filter

from resonance import floquet_theory

rc('animation', embed_limit=1e9)


def distances_to_mean_per_frequency(distances, frequencies, runs_per_frequency, tail_fraction):
    tail_start = int(distances.shape[0] * tail_fraction)
    tail_distance_means = distances[tail_start:, :].mean(dim=0)
    mean_distances_per_run = tail_distance_means.reshape((len(frequencies), runs_per_frequency)).T
    return mean_distances_per_run.mean(dim=0)


def plot_frequency_sweep_learning_curves(frequencies, runs_per_frequency, iterations, distances):
    for idx, run in enumerate(distances.T):
        if idx % runs_per_frequency == 0:
            frequency = frequencies[math.floor(idx // runs_per_frequency)]

            plt.figure(figsize=(4, 2))
            plt.title(f'Target distance, domain sampling f={frequency:.5f} Hz (T={1 / frequency:.3f})')
            plt.ylim((0, 1))

        plt.plot(iterations, run, color='purple', alpha=0.1)

        if idx % runs_per_frequency == runs_per_frequency - 1:
            plt.show()


def animate_frequency_sweep_learning_curves(frequencies, runs_per_frequency, iterations, distances):
    # Plot first run to set up figure, then animate
    plt.ylim((0, 1))
    plt.xlabel('Training iteration (t)')
    learning_curves = []
    for run in distances.T[:runs_per_frequency]:
        learning_curve, = plt.plot(iterations, run, color='purple', alpha=0.1)
        learning_curves.append(learning_curve)

    def animate(i):
        plt.title(f'Target distance, domain sampling f={frequencies[i]:.5f} Hz (T={1 / frequencies[i]:.3f})')

        for run_idx, run in enumerate(distances.T[i * runs_per_frequency: (i + 1) * runs_per_frequency]):
            learning_curves[run_idx].set_data(iterations, run)

        return learning_curves

    anim = animation.FuncAnimation(
        plt.gcf(),
        animate,
        frames=range(0, len(frequencies)),
        interval=100,
        blit=True,
        repeat=True,
    )

    return HTML(anim.to_html5_video())


def plot_frequency_sweep_mean_distance(frequencies, runs_per_frequency, distances, title_content, tail_fraction):
    mean_distances_per_frequency = distances_to_mean_per_frequency(distances, frequencies, runs_per_frequency, tail_fraction)

    plt.locator_params(axis='x', nbins=8)
    plt.title(f'Mean Target Distance vs. Domain Sampling Frequency, {title_content}')
    plt.xlabel('Sampling translation frequency (f)')
    plt.plot(frequencies, mean_distances_per_frequency, color='blue')
    plt.show()


def plot_momentum_vs_frequency(
    results_dir,
    theoretical_results_dir,
    learning_rates,
    momenta,
    frequencies,
    runs_per_frequency,
    tail_fraction,
    outlier_quantile_threshold,
    amplitude,
    filename_key,
    x_label,
    yticks,
    save_svg,
):
    SMALL_SIZE = 14
    MEDIUM_SIZE = 20

    plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
    plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
    plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
    plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
    plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
    plt.rc('legend', fontsize=MEDIUM_SIZE)    # legend fontsize

    for learning_rate in learning_rates:
        min_freq = min(frequencies)
        max_freq = max(frequencies)
        min_mu = min(momenta)
        max_mu = min(momenta)

        image_rows = []
        for momentum in tqdm(momenta):
            results_path = os.path.join(results_dir, f'lr={learning_rate} mu={momentum}.pt')
            if os.path.exists(results_path):
                distances = torch.load(results_path)
                image_rows.append(distances_to_mean_per_frequency(distances, frequencies, runs_per_frequency, tail_fraction))

                max_mu = max(max_mu, momentum)
        
        if len(image_rows) > 0:
            fix, ax = plt.subplots(1)
            
            # plot theory contour map
            theory_results_path = os.path.join(theoretical_results_dir, f'spectral_norm_surface_{learning_rate}.npy')
            with open(theory_results_path, 'rb') as f:
                spectral_norm_surface = np.load(f)
            xv, yv = np.meshgrid(frequencies, momenta)
            contours = plt.contour(
                xv, 
                yv, 
                spectral_norm_surface, 
                levels=np.linspace(1.0, 1.5, 4), 
                linewidths=4,
                linestyles='--',
                cmap='Reds',
            )
            
            # plot empirical heatmap
            image = torch.stack(image_rows, dim=0)
            if outlier_quantile_threshold is not None:
                image = torch.clamp(image, 0, torch.quantile(image, outlier_quantile_threshold))

            # Prevent log producing -inf by replacing zeros with the next smallest value in the image.
            image[image == 0] = image[image > 0].min()
            image = torch.log(image)
            ax.imshow(
                image,
                aspect=(max_freq - min_freq) / (max_mu - min_mu) / 2,
                interpolation=None,
                cmap=plt.cm.gist_heat,
                extent=[
                    min_freq,
                    max_freq,
                    min_mu,
                    max_mu,
                ]
            )
            
            label_frequencies = frequencies[0::11]
            period_labels = [f'{1 / f:.0f}' for f in label_frequencies]
            ax.set_xticks(label_frequencies)
            ax.set_xticklabels(period_labels)
            ax.set_xlabel(x_label)
            ax.set_yticks(yticks)
            ax.yaxis.set_label_coords(-0.075, 0.55)
            ylabel = ax.set_ylabel('$\mu$')
            ylabel.set_rotation(0)
            if save_svg:
                plot_path = os.path.join(results_dir, f'heatmap-{filename_key}-{learning_rate}.svg')
                plt.savefig(plot_path, bbox_inches='tight', transparent="True", pad_inches=0)
            else:
                plt.show()


def plot_spectra(means, samples, peak_frequency_hz, label):
    plt.figure(figsize=(10, 6))

    plt.axvline(peak_frequency_hz, color='gray', linestyle='--')

    mean_frequencies, mean_psd = signal.periodogram(means)
    plt.semilogy(mean_frequencies, mean_psd, color='black', label=f'{label} mean', linestyle='--')

    frequencies, psd = signal.periodogram(samples)
    plt.semilogy(frequencies, psd, label='sampled', linewidth='2', alpha=0.7)

    plt.ylim((1e-5, 1e2))
    plt.xlim((0, 0.1))
    plt.legend()
    plt.show()


if __name__ == '__main__':
    results_dir = 'results/test'
    tail_fraction = 0.5

    amplitude = 0.2

    frequencies = torch.linspace(0.0, 0.03, 5)
    learning_rates = [thousandths / 1000 for thousandths in range(1, 2)]
    momenta = [thousandths / 1000 for thousandths in range(999, 994, -1)]

    runs_per_frequency = 10

    outlier_quantile_threshold = 0.98
    smoothing_sigma = 0.5

    plot_momentum_vs_frequency(
        results_dir,
        learning_rates,
        momenta,
        frequencies,
        runs_per_frequency,
        tail_fraction,
        outlier_quantile_threshold,
        amplitude,
    )