# Licensed under the MIT License.
# Copyright (c) Microsoft Corporation.
import matplotlib.pyplot as plt
import matplotlib.ticker as tck
import numpy as np

SMALL_SIZE = 8
MEDIUM_SIZE = 10
BIGGER_SIZE = 12

plt.rc('text', usetex = False)
plt.rc('font', size=SMALL_SIZE,family='serif')          # controls default text sizes
plt.rc('axes', titlesize=MEDIUM_SIZE,linewidth=1)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title


def plot_rul(ref, pred, std=None, name='res', save=False):
    fig, ax = plt.subplots(figsize=(1.5, 1.5))
    ax.grid()
    c = ax.scatter(ref, pred, s=5, c=std, cmap='cool', zorder=3)
    ax.plot(ref, ref, c='tab:pink', lw=0.5, alpha=1, zorder=2)
    ax.set_xlabel('Reference')
    ax.set_ylabel('Prediction')
    ax.set_xlim(ref.min()//500 * 500, ref.max()//500 * 500 + 500)
    ax.set_ylim(ref.min()//500 * 500, ref.max()//500 * 500 + 500)
    ax.tick_params(direction="in",top=True,right=True,which='both')
    # ax.yaxis.set_minor_locator(tck.AutoMinorLocator())
    # ax.xaxis.set_minor_locator(tck.AutoMinorLocator())
    ax.set_aspect('equal', adjustable='box')    
    plt.locator_params(axis='y', nbins=4)
    plt.locator_params(axis='x', nbins=4)
    cb = fig.colorbar(c, ax=ax, format='%d', shrink=0.8)
    cb.ax.yaxis.set_major_locator(plt.LinearLocator(numticks=5))
    cb.ax.set_title('$\sigma_{RUL}$')
    if save:
        plt.savefig(f'../figures/rul_{name}.png', dpi=300, bbox_inches = 'tight', transparent=True)
    

def plot_samples(samples, name='samples'):
    if len(samples.shape) == 4:
        samples = np.squeeze(samples, axis=1)
    samples = np.flip(np.transpose(samples, (0, 2, 1)), axis=1)
    n = int(np.ceil(np.sqrt(samples.shape[0])))
    v = np.linspace(2.0, 3.6, num=samples.shape[1], endpoint=True)
    cycle = np.linspace(1.0, 100.0, num=samples.shape[2], endpoint=True)
    x, y = np.meshgrid(cycle, v)
    plt.set_cmap('jet')        
    fig, ax = plt.subplots(n, n, figsize=(n/2, n/2), sharey=True)
    for i in range(samples.shape[0]):
        axx = ax.flatten()[i]

        axx.pcolormesh(x, y, samples[i])
        axx.set_aspect('auto')
        axx.set_yticks([])
        axx.set_xticks([])
        
    for i in range(n**2 - samples.shape[0]):
        axx = ax.flatten()[-i-1]
        axx.set_axis_off()
    
    # plt.savefig(f'../figures/heatmaps_{name}.png', dpi=300, bbox_inches = 'tight', transparent=True)

