import os

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib as mpl

path = './graphs/' 
path = './1d_data/' 
save_path = './graphs_new/' 

#color_hexes = ['#006BA4', '#FF800E', '#ABABAB', '#595959', '#5F9ED1',  
#        '#C85200', '#898989', '#A2C8EC', '#FFBC79', '#CFCFCF']
color_hexes = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd',
    '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']

mpl.rc('font',family='Times New Roman')
#env = 'hetero'
env = 'bimodal'
#num_seeds = 2
#nomap = False
nomap = True
num_seeds = 10
just_legend = False
#just_legend = True
model_type = 'nflows_base'
#model_type = 'pens'
numb_uncertainty_estimates = 100



train_x = np.load(os.path.join(path, f'{env}_train_x.npy'))
train_y = np.load(os.path.join(path, f'{env}_train_y.npy'))
samps_model = np.load(os.path.join(path, f'{env}_samps_{model_type}.npy'))

uncertainty_mib = np.load(os.path.join(path, f'{env}_mib_{model_type}.npy'))

uncertainty_bhatt = np.load(os.path.join(path, f'{env}_bhatt_{model_type}.npy'))

uncertainty_kl = np.load(os.path.join(path, f'{env}_kl_{model_type}.npy'))


uncertainty_mib_hetero = np.load(os.path.join(path, f'hetero_mib_{model_type}.npy'))

uncertainty_bhatt_hetero = np.load(os.path.join(path, f'hetero_bhatt_{model_type}.npy'))

uncertainty_kl_hetero = np.load(os.path.join(path, f'hetero_kl_{model_type}.npy'))
train_x_hetero = np.load(os.path.join(path, f'hetero_train_x.npy'))
train_y_hetero = np.load(os.path.join(path, f'hetero_train_y.npy'))
datarange_hetero = np.linspace(-4.5, 4.5,numb_uncertainty_estimates)
kp_hetero = (train_x_hetero<4.5)&(train_x_hetero>-4.5)
train_x_hetero = train_x_hetero[kp_hetero]
train_y_hetero = train_y_hetero[kp_hetero]

if env =='hetero':
    datarange = np.linspace(-4.5, 4.5,numb_uncertainty_estimates)
    kp = (train_x<4.5)&(train_x>-4.5)
    train_x = train_x[kp]
    train_y = train_y[kp]
else:
    datarange = np.linspace(0, 3.5, numb_uncertainty_estimates)
    kp = (train_x<3.5)
    train_x = train_x[kp]
    train_y = train_y[kp]

def export_legend(legend, filename="legend.png", expand=[-5,-5,5,5], nomap=True):
    fig  = legend.figure
    fig.canvas.draw()
    if not nomap:
        filename = 'legend_nomap.png'
    bbox  = legend.get_window_extent()
    bbox = bbox.from_extents(*(bbox.extents + np.array(expand)))
    bbox = bbox.transformed(fig.dpi_scale_trans.inverted())
    fig.savefig(os.path.join('graphs_new', filename), bbox_inches=bbox, dpi=600)

if just_legend:
    title_size = 4 
    line_width = 2 
    fig, ax = plt.subplots(1, 1, figsize=(8, 3.5), sharey=True, sharex=True)
    ax2 = ax.twinx()
    ax.scatter(train_x, samps_model[kp], s=2, c=color_hexes[0], alpha=0.5)
    if not nomap:
        ax2.plot(datarange, uncertainty_mib.mean(0),
            c=color_hexes[3], linewidth=line_width, label = 'MC')
        ax2.plot(datarange, uncertainty_bhatt.mean(0), '--',
            c=color_hexes[0], linewidth=line_width, label='Bhatt')
        ax2.plot(datarange, uncertainty_kl.mean(0), ':',
            c=color_hexes[1], linewidth=line_width, label='KL')
        ncol = 3
    else:
        ax2.plot(datarange, uncertainty_bhatt.mean(0), '--',
            c=color_hexes[0], linewidth=line_width, label='Bhatt')
        ax2.plot(datarange, uncertainty_kl.mean(0), ':',
            c=color_hexes[1], linewidth=line_width, label='KL')
        ax2.scatter(train_x, train_y, s=2, c=color_hexes[9], label='groundtruth', alpha=0.5)
        ax2.scatter(train_x, samps_model[kp], s=2, c=color_hexes[5], label='model', alpha=0.5)
        ncol = 4 
    '''ax2.plot(datarange, ((uncertainty_mib.mean(0)-uncertainty_mib.mean(0).min())/
        (uncertainty_mib.mean(0).max()-uncertainty_mib.mean(0).min())),
        c=color_hexes[0], linewidth=line_width, label = 'Monte Carlo')
    ax2.plot(datarange, ((uncertainty_bhatt.mean(0)-uncertainty_bhatt.mean(0).min())/
        (uncertainty_bhatt.mean(0).max()-uncertainty_bhatt.mean(0).min())), '--',
        c=color_hexes[1], linewidth=line_width, label='Bhatt')
    ax2.plot(datarange, ((uncertainty_kl.mean(0)-uncertainty_kl.mean(0).min())/
        (uncertainty_kl.mean(0).max()-uncertainty_kl.mean(0).min())), ':',
        c=color_hexes[2], linewidth=line_width, label='KL')'''
    if model_type == 'nflows_base':
        title = 'Nflows Base'
    elif model_type == 'pens':
        title = 'PNEs'
    ax.set_title(title, fontdict={'fontsize':title_size})
    ax.set_xticks([])
    ax.set_yticks([])
    ax2.set_yticks([])
    lines_labels = [ax.get_legend_handles_labels() for ax in fig.axes]
    lines, labels = [sum(lol, []) for lol in zip(*lines_labels)]
    leg = fig.legend(lines, labels, loc='lower center', ncol=ncol, 
            fontsize=title_size, bbox_to_anchor=(.5, -.5), borderpad=0.5)
    leg.get_frame().set_linewidth(0.5)
    leg.get_frame().set_edgecolor('black')
    #leg.savefig('./graphs/legend.png', dpi=300, bbox_inches="tight")
    export_legend(leg, nomap=nomap)
else:
    title_size = 18 
    line_width = 3
    tick_size = 8
    if nomap:
        fig, ax = plt.subplots(1, 2, figsize=(8, 3.5))
    else:
        fig, ax = plt.subplots(1, 2, figsize=(8, 3.5), sharex=True)
    if nomap:
        ax[0].plot(datarange, uncertainty_mib.mean(0),
            c=color_hexes[3], linewidth=line_width, label = 'Monte Carlo')
        ax[0].plot(datarange, uncertainty_bhatt.mean(0), '--',
            c=color_hexes[0], linewidth=line_width, label='Bhatt')
        ax[0].plot(datarange, uncertainty_kl.mean(0), ':',
            c=color_hexes[1], linewidth=line_width, label='KL')
        ax[1].plot(datarange_hetero, uncertainty_mib_hetero.mean(0),
            c=color_hexes[3], linewidth=line_width, label = 'Monte Carlo')
        ax[1].plot(datarange_hetero, uncertainty_bhatt_hetero.mean(0), '--',
            c=color_hexes[0], linewidth=line_width, label='Bhatt')
        ax[1].plot(datarange_hetero, uncertainty_kl_hetero.mean(0), ':',
            c=color_hexes[1], linewidth=line_width, label='KL')
    else: 
        #ax[0].plot(datarange, ((uncertainty_mib.mean(0)-uncertainty_mib.mean(0).min())/
        #    (uncertainty_mib.mean(0).max()-uncertainty_mib.mean(0).min())),
        #    c=color_hexes[3], linewidth=line_width, label = 'MC')
        ax[0].plot(datarange, ((uncertainty_bhatt.mean(0)-uncertainty_bhatt.mean(0).min())/
            (uncertainty_bhatt.mean(0).max()-uncertainty_bhatt.mean(0).min())), '--',
            c=color_hexes[0], linewidth=line_width, label='Bhatt')
        ax[0].plot(datarange, ((uncertainty_kl.mean(0)-uncertainty_kl.mean(0).min())/
            (uncertainty_kl.mean(0).max()-uncertainty_kl.mean(0).min())), ':',
            c=color_hexes[1], linewidth=line_width, label='KL')
    #ax2.tick_params(axis='y', colors=color_hexes[5], labelsize=tick_size)
    #ax2.yaxis.set_major_locator(plt.MaxNLocator(4))
    #ax2.legend(loc='upper center', bbox_to_anchor=(0.5, -0.1),
    #          fancybox=True, shadow=True, ncol=2)
    ax[0].set_ylabel('Uncertainty', fontsize=title_size)
    ax[0].set_xlabel('x', fontsize=title_size)
    ax[0].set_xticks([])
    ax2 = ax[0].twinx()
    sns.histplot(data = train_x, color = color_hexes[2],
                 stat = 'probability', ax=ax2, bins=100, alpha=0.5,
                 label='density')
    ax2.set_ylabel('')
    ax2.set_yticks([])
    if nomap:
        ax2 = ax[1].twinx()
        sns.histplot(data = train_x_hetero, color = color_hexes[2],
                     stat = 'probability', ax=ax2, bins=100, alpha=0.5,
                     label='density')
        ax2.set_ylabel('')
        ax2.set_yticks([])
        ax[1].set_ylabel('')
        ax[1].set_xlabel('x', fontsize=title_size)
        ax[1].set_xticks([])
    else:
        ax[0].set_yticks([])
        ax[1].scatter(train_x, train_y, s=2, c=color_hexes[9], label='groundtruth', alpha=0.5)
        ax[1].scatter(train_x, samps_model[kp], s=2, c=color_hexes[5], label='model', alpha=0.5)
        ax[1].set_ylabel('y', fontsize=title_size)
        ax[1].set_xlabel('x', fontsize=title_size)
        ax[1].set_yticks([])
        ax[1].set_xticks([])
    #ax[0].tick_params(axis='x', labelsize=tick_size)
    #ax[0].tick_params(axis='y', labelsize=tick_size)
    '''ax2.plot(datarange, ((uncertainty_alea.mean(0)-uncertainty_alea.mean(0).min())/
        (uncertainty_alea.mean(0).max()-uncertainty_alea.mean(0).min())), '-.',
        c=color_hexes[3], linewidth=line_width, label='Alea')
    '''
    #ax2.tick_params(axis='y', colors=color_hexes[1], labelsize=tick_size)
    if model_type == 'nflows_base':
        title = 'Nflows Base'
    elif model_type == 'pens':
        title = 'PNEs'
    #ax[1].tick_params(axis='x', labelsize=tick_size)
    #ax[1].tick_params(axis='y', labelsize=tick_size)
    #ax2.legend(loc='upper center', bbox_to_anchor=(0.5, -0.1),
    #          fancybox=True, shadow=True, ncol=3)
    fig.tight_layout()
    png_name = f'{env}'
    if model_type == 'pens':
        png_name += '_pnes'
    if nomap:
        png_name += '_nomap'
    png_name += '.png'
    plt.savefig(os.path.join(save_path, png_name), dpi=300)
    plt.close()
