import math

import matplotlib.pyplot as plt


def plot_samples_and_means(steps, samples, means, time_domain_plot_truncation, ylim, yticks, filename):

    SMALL_SIZE = 14
    MEDIUM_SIZE = 20

    plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
    plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
    plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
    plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
    plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
    plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize

    plt.figure(dpi=200)

    plt.scatter(
        steps[:time_domain_plot_truncation],
        samples[:time_domain_plot_truncation],
        c='black',
        s=5,
        label='Samples $X_k$',
    )

    plt.scatter(
        steps[:time_domain_plot_truncation],
        means[:time_domain_plot_truncation],
        c='grey',
        s=5,
        label='Means $\\bar{x}_k$',
    )

    plt.ylim(ylim)

    plt.gca().spines['top'].set_visible(False)
    plt.gca().spines['right'].set_visible(False)

    plt.legend(frameon=False, loc='upper right')

    plt.xlabel('Training Step $k$')
    plt.yticks(yticks)
    
    if filename != None:
        plt.savefig(filename, bbox_inches='tight', transparent="True", pad_inches=0)


def plot_psd(samples, means, step_size, NFFT, xlim, ylim, yticks, filename):

    SMALL_SIZE = 14
    MEDIUM_SIZE = 20

    plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
    plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
    plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
    plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
    plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
    plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize

    plt.figure(dpi=200)

    samples_per_time = 1 // math.sqrt(step_size)
    
    plt.psd(
        samples, 
        Fs=samples_per_time,
        NFFT=NFFT,
        color='black',
        label='Samples $X_k$',
    );

    plt.psd(
        means, 
        Fs=samples_per_time,
        NFFT=NFFT,
        color='grey',
        linestyle=':',
        linewidth=2,
        label='Means $\\bar{x}_k$',
    );

    plt.xlabel('Freq. (Hz)')
    plt.xlim(xlim)

    plt.ylabel('PSD (dB/Hz)')
    plt.ylim(ylim)
    plt.yticks(yticks)


    plt.legend(frameon=False, loc='upper right')
    plt.grid(visible=False)

    plt.gca().spines['top'].set_visible(False)
    plt.gca().spines['right'].set_visible(False)
    
    if filename != None:
        plt.savefig(filename, bbox_inches='tight', transparent="True", pad_inches=0)