#%%
import torch
import torch.nn as nn
import matplotlib 
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm, SymLogNorm
from matplotlib.ticker import MaxNLocator

matplotlib.rc('font', family='serif')
matplotlib.rc('text', usetex=True)

import os

import seaborn as sns
import pandas as pd
import numpy as np
import pickle

from scipy.stats import wasserstein_distance

from torch.utils.data import DataLoader

from heteroskedastic_bnns.load_uci.uci_loader import UCIDatasets
from datetime import datetime


#from heteroskedastic_bnns.vi.reparam_models import make_heatmap
from heteroskedastic_bnns.generate_data import prep_uci_data, prep_double_sine

#%%
## Load data for field theory
cwd = os.getcwd()
print(cwd)

#%%
base_load_ft = cwd + '/final_plots/univariate/field_theory'  

# regularizeres
with open(base_load_ft + '/alpha_vals.p', 'rb') as f:
    alpha_vals = (pickle.load(f))
with open(base_load_ft + '/beta_vals.p', 'rb') as f:
    beta_vals = (pickle.load(f))

# fitted mu/lambda pairs
with open(base_load_ft + '/models.p', 'rb') as f:
    model_list = (pickle.load(f))

# residuals (mu - sy)
with open(base_load_ft + '/resids.p', 'rb') as f:
    resids_list = (pickle.load(f))

# data that the models were fit to
with open(base_load_ft + '/data/xorg_aug.p', 'rb') as f:
    xorg_aug = (pickle.load(f))
with open(base_load_ft + '/data/sy.p', 'rb') as f:
    sy = (pickle.load(f))

# values from the entropies
with open(base_load_ft + '/ent_vals_res.p', 'rb') as f:
    ent_vals_res = (pickle.load(f))
with open(base_load_ft + '/ent_vals_sds.p', 'rb') as f:
    ent_vals_sds = (pickle.load(f))

# metrics
with open(base_load_ft + '/data/mse.p', 'rb') as f:
    mse_ft = (pickle.load(f))
with open(base_load_ft + '/data/w_mse.p', 'rb') as f:
    w_mse_ft = (pickle.load(f))
with open(base_load_ft + '/data/log_prec.p', 'rb') as f:
    log_prec_ft = (pickle.load(f))
with open(base_load_ft + '/data/log_wt_mse.p', 'rb') as f:
    log_wt_mse_ft = pickle.load(f)
with open(base_load_ft + '/data/mean_mag.p', 'rb') as f:
    mean_mag_ft = pickle.load(f)
with open(base_load_ft + '/data/lam_mag.p', 'rb') as f:
    lam_mag_ft = pickle.load(f)
with open(base_load_ft + '/data/w_dists.p', 'rb') as f:
    w_dists_ft = pickle.load(f)
with open(base_load_ft + '/data/log_w_dists.p', 'rb') as f:
    log_w_dists_ft = pickle.load(f)
with open(base_load_ft + '/data/sd_mse.p', 'rb') as f:
    sd_mse_ft = pickle.load(f)

with open(base_load_ft + '/data/mean_abs_dev.p', 'rb') as f:
    mean_ad_ft = pickle.load(f)
with open(base_load_ft + '/data/median_abs_dev.p', 'rb') as f:
    median_ad_ft = pickle.load(f)

with open(base_load_ft + '/data/sd_mean_abs_dev.p', 'rb') as f:
    sd_mean_ad_ft = pickle.load(f)
with open(base_load_ft + '/data/sd_median_abs_dev.p', 'rb') as f:
    sd_median_ad_ft = pickle.load(f)
#%%

def make_heatmap(title, pd_df, xtick, ytick, xlab, ylab, save_path, save=True, symlognorm=True, figsize=(4, 3)):
  # plot figures
  
  plt.figure(figsize = figsize)
  plt.title(title)
  if symlognorm:
    norm = norm=SymLogNorm(linthresh=0.03, linscale=0.03,vmin=pd_df.min(), vmax=pd_df.max(), base=10)
  else:
    norm = None


  sns.heatmap(pd_df, annot=False, xticklabels=xtick, yticklabels=ytick, norm=norm, cbar_kws={'ticks':MaxNLocator(3), 'format':'%.e'})

  plt.xlabel(xlab)
  plt.ylabel(ylab)
  if save:
    plt.savefig(save_path, dpi=300)
  plt.close()

# grid plots

def generic_grid(rows, cols, figsize, title, sharex=True, sharey=True):

    fig, axs = plt.subplots(rows, cols, figsize=figsize, sharex=sharex, sharey=sharey)

    for _axs in axs:
        for ax in _axs:
            ax.set_ylabel("")
            ax.set_xlabel("")

            ax.set_xticks([])
            ax.set_yticks([])

            ax.spines['top'].set_visible(False)
            ax.spines['right'].set_visible(False)

    fig.suptitle(title)

    fig.subplots_adjust(
        left   = 0.0,   # the left side of the subplots of the figure
        right  = 1.0,   # the right side of the subplots of the figure
        bottom = 0.0,   # the bottom of the subplots of the figure
        top    = 0.9,   # the top of the subplots of the figure
        wspace = 0.1,   # the amount of width reserved for blank space between subplots
        hspace = 0.1,   # the amount of height reserved for white space between subplots
    )

    return fig, axs
#%%
# plotting params
mean_color = 'red'
mean_dots = 'green'
s_size_ft = 0.55
s_size_nn = 1.0
linewidth= 0.4

#%%
# Heat maps
sv = True
make_heatmap("Field Theory: MSE", torch.Tensor(mse_ft), "", "", r"$\beta$", r"$\alpha$", base_load_ft + '/plots/mse.pdf', save=sv)
make_heatmap("Field Theory: Weighted MSE", torch.Tensor(w_mse_ft), "", "", r"$\beta$", r"$\alpha$", base_load_ft + '/plots/w_mse.pdf', save=sv)
make_heatmap(r"Field Theory: $\log \Lambda$", torch.Tensor(log_prec_ft), "", "", r"$\beta$", r"$\alpha$", base_load_ft + '/plots/log_lam.pdf', save=sv)
make_heatmap(r"Field Theory: $\log$ Weighted MSE", torch.Tensor(log_wt_mse_ft), "", "", r"$\beta$", r"$\alpha$", base_load_ft + '/plots/log_wt_mse.pdf', save=sv)
make_heatmap(r"Field Theory: $||\nabla \mu||$", torch.Tensor(mean_mag_ft), "", "", r"$\beta$", r"$\alpha$", base_load_ft + '/plots/mean_mag.pdf', save=sv)             
make_heatmap(r"Field Theory: $||\nabla \Lambda||$", torch.Tensor(lam_mag_ft), "", "", r"$\beta$", r"$\alpha$", base_load_ft + '/plots/lam_mag.pdf', save=sv)
make_heatmap(r"Field Theory: Wasserstein Distance: SDs to Resids", torch.Tensor(w_dists_ft), "", "", r"$\beta$", r"$\alpha$", base_load_ft + '/plots/w_dists.pdf', save=sv)
make_heatmap(r"Field Theory: MSE(sd, res)", torch.Tensor(sd_mse_ft), "", "", r"$\beta$", r"$\alpha$", base_load_ft + '/plots/sd_mse.pdf', save=sv)
make_heatmap(r"Field Theory: Wasserstein(log SDs +1, Resids+1)", torch.Tensor(log_w_dists_ft), "", "", r"$\beta$", r"$\alpha$", base_load_ft + '/plots/log_w_dists.pdf', save=sv)
make_heatmap("Field Theory: Residuals Histogram Entropy", torch.Tensor(ent_vals_res), "", "", r"$\beta$", r"$\alpha$", base_load_ft + '/plots/res_hist_ent.pdf', save=sv)
make_heatmap("Field Theory: Predicted SDs Histogram Entropy", torch.Tensor(ent_vals_sds), "", "", r"$\beta$", r"$\alpha$", base_load_ft + '/plots/sds_hist_ent.pdf', save=sv)

make_heatmap("Field Theory: Mean Abs Deviation", torch.Tensor(mean_ad_ft.detach().cpu()), "", "", r"$\beta$", r"$\alpha$", base_load_ft + '/plots/mean_ad_ft.pdf', save=sv)
make_heatmap("Field Theory: Median Abs Deviation", torch.Tensor(median_ad_ft.detach().cpu()), "", "", r"$\beta$", r"$\alpha$", base_load_ft + '/plots/median_ad_ft.pdf', save=sv)

make_heatmap("Field Theory: Mean Abs Deviation(SD, Res)", torch.Tensor(sd_mean_ad_ft.detach().cpu()), "", "", r"$\beta$", r"$\alpha$", base_load_ft + '/plots/sd_mean_ad_ft.pdf', save=sv)
make_heatmap("Field Theory: Median Abs Deviation(SD, Res)", torch.Tensor(sd_median_ad_ft.detach().cpu()), "", "", r"$\beta$", r"$\alpha$", base_load_ft + '/plots/sd_median_ad_ft.pdf', save=sv)

#%%
ft_figsize = (4, 3)
size=4
fig_res_s, axs_res_s = generic_grid(6, 6, ft_figsize, r"Field Theory: Histograms of Residuals")
fig_sds_s, axs_sds_s = generic_grid(6, 6, ft_figsize, r"Field Theory: Histograms of $(\hat \Lambda^*)^{-1/2}$")
fig_res_sds_s, axs_res_sds_s = generic_grid(6, 6, ft_figsize, r"Field Theory: SDs over Residuals")
fig_mns_s, axs_mns_s = generic_grid(6, 6, ft_figsize, r"Field Theory: $\hat\mu^*$ over Data")


plot_alpha = [round(k, 5) for k in alpha_vals]
plot_beta = [round(k, 5) for k in beta_vals]


counter = 0
# each level of penalty on mean function
for i in range(6):

    # each level of penalty on precision
    for j in range(6):

        ax_res_s = axs_res_s[i][j]
        ax_sds_s = axs_sds_s[i][j]
        ax_res_sds_s = axs_res_sds_s[i][j]
        ax_mn_s = axs_mns_s[i][j]


        # compute, plot sds, comp entropy
        sds = torch.exp(-.5 * model_list[counter][1].cpu()[1:-1]).detach()
        data, b, c =ax_sds_s.hist((sds.flatten().numpy()), bins=[i/3 for i in range(30)])
        
        # plot residuals (as points), pred sds as a function
        ax_res_sds_s.scatter(xorg_aug.cpu()[1:-1], resids_list[counter].cpu().abs(), s=s_size_ft, lw=0)
        ax_res_sds_s.plot(xorg_aug.cpu()[1:-1].flatten(), torch.exp(-.5 * model_list[counter][1].cpu()[1:-1]).detach(), c='orange', linewidth=linewidth)


        ax_mn_s.scatter(xorg_aug.cpu()[1:-1], sy.cpu()[1:-1], s=s_size_ft, lw=0, c=mean_dots)
        ax_mn_s.plot(xorg_aug.cpu()[1:-1].flatten(), model_list[counter][0][1:-1].cpu().detach(), c=mean_color, linewidth=linewidth)

        # plot residuals, comp entropy
        data, b, c =ax_res_s.hist((resids_list[counter].abs().cpu().flatten().numpy()), bins=[i/2 for i in range(20)])
            
        
        counter += 3
        print(counter)
    counter += 4 # leftovers
    counter += 42 # skip down the rows

    print(i)

#%%
fig_res_s.savefig(base_load_ft + '/plots/res_hists-sub.pdf', dpi=300)
fig_sds_s.savefig(base_load_ft + '/plots/sds_hists-sub.pdf', dpi=300)
fig_mns_s.savefig(base_load_ft + '/plots/mu_fit-sub.pdf', dpi=300)
fig_res_sds_s.savefig(base_load_ft + '/plots/res_sds-sub.pdf', dpi=300)









#%%
'''
size=4
fig_res, axs_res = generic_grid(len(alpha_vals), len(beta_vals),  (len(alpha_vals)*size, len(beta_vals)*size), "Field Theory: Histograms of Residuals")
fig_sds, axs_sds = generic_grid(len(alpha_vals), len(beta_vals), (len(alpha_vals)*size, len(beta_vals)*size), "Field Theory: Histograms of SDs")
fig_res_sds, axs_res_sds = generic_grid(len(alpha_vals), len(beta_vals), (len(alpha_vals)*size, len(beta_vals)*size), "Field Theory: SDs over Residuals")
fig_mns, axs_mns = generic_grid(len(alpha_vals), len(beta_vals), (len(alpha_vals)*size, len(beta_vals)*size), r"Field Theory: $\mu$ over Data")


plot_alpha = [round(k, 5) for k in alpha_vals]
plot_beta = [round(k, 5) for k in beta_vals]


counter = 0
# each level of penalty on mean function
for i in range(len(alpha_vals)):

    # each level of penalty on precision
    for j in range(len(beta_vals)):

        ax_res = axs_res[i][j]
        ax_sds = axs_sds[i][j]
        ax_res_sds = axs_res_sds[i][j]
        ax_mn = axs_mns[i][j]

        #if j == 0:

            #ax_res.set_ylabel(r"$\alpha$: {}".format(plot_alpha[i]))
            #ax_sds.set_ylabel(r"$\alpha$: {}".format(plot_alpha[i]))
            #ax_res_sds.set_ylabel(r"$\alpha$: {}".format(plot_alpha[i]))


        #if i == len(plot_alpha)-1:
            #ax_res.set_xlabel(r"$\beta$: {}".format(plot_beta[j]))
            #ax_sds.set_xlabel(r"$\beta$: {}".format(plot_beta[j]))
            #ax_res_sds.set_xlabel(r"$\beta$: {}".format(plot_beta[j]))

        # compute, plot sds, comp entropy
        sds = torch.exp(-.5 * model_list[counter][1].cpu()[1:-1]).detach()
        data, b, c =ax_sds.hist((sds.flatten().numpy()), bins=[i/3 for i in range(30)])
        
        # plot residuals (as points), pred sds as a function
        ax_res_sds.scatter(xorg_aug.cpu()[1:-1], resids_list[counter].cpu().abs(), marker=",")
        ax_res_sds.plot(xorg_aug.cpu()[1:-1].flatten(), torch.exp(-.5 * model_list[counter][1].cpu()[1:-1]).detach(), c='orange')


        ax_mn.scatter(xorg_aug.cpu()[1:-1], sy.cpu()[1:-1], marker=",")
        ax_mn.plot(xorg_aug.cpu()[1:-1].flatten(), model_list[counter][0][1:-1].cpu().detach(), c='green')

        # plot residuals, comp entropy
        data, b, c =ax_res.hist((resids_list[counter].abs().cpu().flatten().numpy()), bins=[i/2 for i in range(20)])
            
            
        counter += 1


    print(i)

#%%
fig_res.savefig(base_load_ft + '/plots/res_hists-o.pdf', dpi=300)
fig_sds.savefig(base_load_ft + '/plots/sds_hists-o.pdf', dpi=300)
fig_mns.savefig(base_load_ft + '/plots/mu_fit-o.pdf', dpi=300)
fig_res_sds.savefig(base_load_ft + '/plots/res_sds-o.pdf', dpi=300)
'''

# %%

#-------------------------------------------------------------
# load data for 1d simulated data: nn
base_load_uni = cwd + '/final_plots/univariate/thin-sine-models/run-0'

with open(base_load_uni + '/mn_prior_stds.p', 'rb') as f:
    mn_priors = (pickle.load(f))

with open(base_load_uni + '/va_prior_stds.p', 'rb') as f:
    va_priors = (pickle.load(f))

with open(base_load_uni + '/resids.p', 'rb') as f:
    resids_list_nn = (pickle.load(f))
    
with open(base_load_uni + '/fitted_mn_sd.p', 'rb') as f:
    fitted_mn_sd_nn = (pickle.load(f))
    
# [img_vals, ivg_vals, res_tr_corr_vals, res_te_corr_vals, mse_tr_vals, mse_te_vals, appr_img_vals, appr_ivg_vals]
with open(base_load_uni + '/data.p', 'rb') as f:
    unpickled_data_nn = (pickle.load(f))


# values from the entropies
with open(base_load_uni + '/data/ent_vals_res.p', 'rb') as f:
    ent_vals_res = (pickle.load(f))
with open(base_load_uni + '/data/ent_vals_sds.p', 'rb') as f:
    ent_vals_sds = (pickle.load(f))

# metrics
with open(base_load_uni + '/data/flip_mse.p', 'rb') as f:
    mse_uni = (pickle.load(f))
with open(base_load_uni + '/data/flip_w_mse.p', 'rb') as f:
    w_mse_uni = (pickle.load(f))
with open(base_load_uni + '/data/flip_log_sig.p', 'rb') as f:
    log_sig_uni = (pickle.load(f))
with open(base_load_uni + '/data/flip_log_wt_mse.p', 'rb') as f:
    log_wt_mse_uni = pickle.load(f)

with open(base_load_uni + '/data/flip_w_dists.p', 'rb') as f:
    w_dists_uni = pickle.load(f)
with open(base_load_uni + '/data/flip_log_w_dists.p', 'rb') as f:
    log_w_dists_uni = pickle.load(f)
with open(base_load_uni + '/data/flip_sd_mse.p', 'rb') as f:
    sd_mse_uni = pickle.load(f)

with open(base_load_uni + '/data/flip_mean_abs_dev.p', 'rb') as f:
    mean_ad_uni = pickle.load(f)
with open(base_load_uni + '/data/flip_median_abs_dev.p', 'rb') as f:
    median_ad_uni = pickle.load(f)

with open(base_load_uni + '/data/flip_test_mean_abs_dev.p', 'rb') as f:
    test_mean_ad_uni = pickle.load(f)
with open(base_load_uni + '/data/flip_test_median_abs_dev.p', 'rb') as f:
    test_median_ad_uni = pickle.load(f)

with open(base_load_uni + '/data/flip_sd_mean_abs_dev.p', 'rb') as f:
    sd_mean_ad_uni = pickle.load(f)
with open(base_load_uni + '/data/flip_sd_median_abs_dev.p', 'rb') as f:
    sd_median_ad_uni = pickle.load(f)

with open(base_load_uni + '/data/flip_sd_test_mean_abs_dev.p', 'rb') as f:
    test_sd_mean_ad_uni = pickle.load(f)
with open(base_load_uni + '/data/flip_sd_test_median_abs_dev.p', 'rb') as f:
    test_sd_median_ad_uni = pickle.load(f)





#%%
nn_xlabel = r"$\beta$"
nn_ylabel = r"$\alpha$"
sv = True
make_heatmap("NN: MSE", torch.Tensor(mse_uni), "", "", nn_xlabel, nn_ylabel, base_load_uni + '/plots/mse.pdf', save=sv)
make_heatmap("NN: Weighted MSE", torch.Tensor(w_mse_uni), "", "", nn_xlabel, nn_ylabel, base_load_uni+ '/plots/w_mse.pdf', save=sv)
make_heatmap(r"NN: $\log \hat\sigma_\phi$", torch.Tensor(log_sig_uni), "", "", nn_xlabel, nn_ylabel, base_load_uni+ '/plots/log_lam.pdf', save=sv)
make_heatmap(r"NN: $\log$ Weighted MSE", torch.Tensor(log_wt_mse_uni), "", "", nn_xlabel, nn_ylabel, base_load_uni + '/plots/log_wt_mse.pdf', save=sv)

make_heatmap(r"NN: Wasserstein Distance: SDs to Resids", torch.Tensor(w_dists_uni), "", "", nn_xlabel, nn_ylabel, base_load_uni + '/plots/w_dists.pdf', save=sv)
make_heatmap(r"NN: MSE(sd, res)", torch.Tensor(sd_mse_uni), "", "", nn_xlabel, nn_ylabel, base_load_uni + '/plots/sd_mse.pdf', save=sv)
make_heatmap(r"NN: Wasserstein(log SDs +1, Resids+1)", torch.Tensor(log_w_dists_uni), "", "", nn_xlabel, nn_ylabel, base_load_uni + '/plots/log_w_dists.pdf', save=sv)
make_heatmap("NN: Residuals Histogram Entropy", torch.Tensor(ent_vals_res), "", "", nn_xlabel, nn_ylabel, base_load_uni + '/plots/res_hist_ent.pdf', save=sv)
make_heatmap("NN: Predicted SDs Histogram Entropy", torch.Tensor(ent_vals_sds).clip(torch.tensor(ent_vals_sds).min(), torch.quantile(torch.tensor(ent_vals_sds), .9)), "", "", nn_xlabel, nn_ylabel, base_load_uni + '/plots/sds_hist_ent.pdf', save=sv)

make_heatmap("NN: Mean Abs Deviation", torch.Tensor(mean_ad_uni.detach().cpu()), "", "", r"$\beta$", r"$\alpha$", base_load_uni + '/plots/mean_ad_ft.pdf', save=sv)
make_heatmap("NN: Median Abs Deviation", torch.Tensor(median_ad_uni.detach().cpu()), "", "", r"$\beta$", r"$\alpha$", base_load_uni + '/plots/median_ad_ft.pdf', save=sv)

make_heatmap("NN: Mean Abs Deviation(SD, Res)", torch.Tensor(sd_mean_ad_uni.detach().cpu()).clip(0, torch.quantile(sd_mean_ad_uni, .9)), "", "", r"$\beta$", r"$\alpha$", base_load_uni + '/plots/sd_mean_ad_ft.pdf', save=sv)
make_heatmap("NN: Median Abs Deviation(SD, Res)", torch.Tensor(sd_median_ad_uni.detach().cpu()), "", "", r"$\beta$", r"$\alpha$", base_load_uni + '/plots/sd_median_ad_ft.pdf', save=sv)

make_heatmap("NN: Test Mean Abs Deviation", torch.Tensor(test_mean_ad_uni.detach().cpu().clip(0, test_mean_ad_uni.quantile(.9))), "", "", r"$\beta$", r"$\alpha$", base_load_uni + '/plots/test_mean_ad_ft.pdf', save=sv)
make_heatmap("NN: Test Median Abs Deviation", torch.Tensor(test_median_ad_uni.detach().cpu()), "", "", r"$\beta$", r"$\alpha$", base_load_uni + '/plots/test_median_ad_ft.pdf', save=sv)

make_heatmap("NN: Test Mean Abs Deviation(SD, Res)", torch.Tensor(test_sd_mean_ad_uni.detach().cpu().clip(0, test_sd_mean_ad_uni.quantile(0.9))), "", "", r"$\beta$", r"$\alpha$", base_load_uni + '/plots/test_sd_mean_ad_ft.pdf', save=sv)
make_heatmap("NN: Test Median Abs Deviation(SD, Res)", torch.Tensor(test_sd_median_ad_uni.detach().cpu()), "", "", r"$\beta$", r"$\alpha$", base_load_uni + '/plots/test_sd_median_ad_ft.pdf', save=sv)


sv = True
flipL = torch.eye(unpickled_data_nn[0].size()[1]).flip(0)
flipR = torch.eye(unpickled_data_nn[0].size()[2]).flip(0)

flip_mn = flipL @ unpickled_data_nn[0].mean(dim=0) @ flipR
flip_va = flipL @ unpickled_data_nn[1].mean(dim=0) @ flipR

make_heatmap(r"NN: $\int ||\nabla \hat\mu_\theta(x)||\, dx$", torch.Tensor(flip_mn), "", "", nn_xlabel, nn_ylabel, base_load_uni + '/plots/mean_heat_mag.pdf', save=sv)             
make_heatmap(r"NN: $\int ||\nabla \hat\sigma_\phi(x)||\, dx$", torch.Tensor(flip_va), "", "", nn_xlabel, nn_ylabel, base_load_uni + '/plots/sd_heat_mag.pdf', save=sv)


#%%
# data that was used to train these models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
extra_ends = 1 
xorg_init, xorg_aug, ind, true_mu, mal, sy = prep_double_sine(1024, n_mean_cycles=2, n_noise_cycles=3, extra_ends=extra_ends)

gw = (xorg_init[1]-xorg_init[0])

N = 1024
torch.manual_seed(4)
rand_inds = torch.randperm(N)

cut = 25

# holds 1/3 thru 3/3 of the data
x = torch.tensor(xorg_init[rand_inds[:cut]]).float()
y = torch.tensor(sy[rand_inds[:cut]]).float()

x_t = torch.tensor(xorg_init[rand_inds[cut:2*cut]]).float()
y_t = torch.tensor(sy[rand_inds[cut:2*cut]]).float()

x = x.to(device)
x_t = x_t.to(device)
y = y.to(device)
y_t = y_t.to(device)


curr_ind = 0
train_inds = []
x_search = x.flatten().sort()[0]
for iii in range(len(xorg_aug.cpu())):
    if (x_search[curr_ind].cpu() - xorg_aug[iii].cpu()).abs() < 5e-7:
        train_inds.append(iii)
        curr_ind = curr_ind + 1
    if curr_ind == cut:
        break
        

curr_ind = 0
test_inds = []
x_search = x.flatten().sort()[0]
for iii in range(len(xorg_aug.cpu())):
    if (x_search[curr_ind].cpu() - xorg_aug[iii].cpu()).abs() < 5e-7:
        test_inds.append(iii)
        curr_ind = curr_ind + 1
    if curr_ind == cut:
        break

'''
size=4
fig_res, axs_res = generic_grid(len(mn_priors), len(va_priors), figsize=(len(mn_priors)*size, len(va_priors)*size), title='Synthetic: Residuals')
fig_sds, axs_sds = generic_grid(len(mn_priors), len(va_priors), figsize=(len(mn_priors)*size, len(va_priors)*size), title='Synthetic: Pred SDs over Residuals')
fig_res_sds, axs_res_sds = generic_grid(len(mn_priors), len(va_priors), figsize=(len(mn_priors)*size, len(va_priors)*size), title='Synthetic: SDs')
fig_mn, axs_mn = generic_grid(len(mn_priors), len(va_priors), figsize=(len(mn_priors)*size, len(va_priors)*size), title='Synthetic: Means')


ent_vals_res = np.zeros((len(mn_priors), len(va_priors)))
ent_vals_sds = np.zeros((len(mn_priors), len(va_priors)))



plot_mn = [round(k, 5) for k in mn_priors]
plot_va = [round(k, 5) for k in va_priors]

rev_plot_mn = [i for i in reversed(plot_mn)]
rev_plot_va = [i for i in reversed(plot_va)]
sdm = 100

counter = 0
# each value of prior for the mean network

for i in range(len(mn_priors)-1, -1, -1):
    # each value of prior for the var network

    for j in range(len(va_priors)-1, -1, -1):
        ax_res = axs_res[i][j]
        ax_sds = axs_sds[i][j]
        ax_res_sds = axs_res_sds[i][j]
        ax_mn = axs_mn[i][j]


        
        mns = fitted_mn_sd_nn[counter][0].cpu().detach()
        sds = fitted_mn_sd_nn[counter][1].cpu().detach()    
  
        
        if sdm > sds.min():
            sdm = sds.min()
        
        data, b, c =ax_sds.hist((sds.flatten().numpy()), bins=[i/3 for i in range(20)])

            
        ax_res_sds.scatter(x.cpu(), resids_list_nn[counter].cpu().abs(), s=s_size, lw=0)
        ax_res_sds.plot(xorg_init.sort(0)[0].cpu().flatten(), sds[1:-1].flatten(), c='orange', linewidth=linewidth)
        ax_res_sds.set_ylim(-0.5, 10)
        
        ax_mn.scatter(x.cpu(), y.cpu(), s=s_size, lw=0)
        #ax_mn.plot(xorg_init.sort(0)[0].cpu().flatten(), true_mu[1:-1].flatten(), c='black', linestyle='dashed')
        ax_mn.plot(xorg_init.sort(0)[0].cpu().flatten(), mns[1:-1].flatten(), c=mean_color, linewidth=linewidth)
        ax_mn.set_ylim(-7.5, 7.5)
        

        data, b, c =ax_res.hist((resids_list[counter].cpu().flatten().numpy()), bins=[i/3 for i in range(20)])

            
        counter += 1
        

    print(i, sdm)
    



plt.show()






fig_res.savefig(base_load_uni + '/plots/res_hists-o.pdf', dpi=300)
fig_sds.savefig(base_load_uni + '/plots/sds_hists-o.pdf', dpi=300)
fig_mn.savefig(base_load_uni + '/plots/mu_fit-o.pdf', dpi=300)
fig_res_sds.savefig(base_load_uni + '/plots/res_sds-o.pdf', dpi=300)
'''


# subplot, every other 14x14 --> 7 x7

#%%
#%%
# data that was used to train these models
punchin_fs = (4, 3) # (len(mn_priors)*size, len(va_priors)*size)
size=4
fig_res_s, axs_res_s = generic_grid(6, 6, figsize=punchin_fs, title=r'NN: Hists of Resids')
fig_sds_s, axs_sds_s = generic_grid(6, 6, figsize=punchin_fs, title=r'NN: Hists of $\hat \sigma_\phi$')
fig_res_sds_s, axs_res_sds_s = generic_grid(6, 6, figsize=punchin_fs, title=r'NN: $\hat \sigma_\phi$ over Abs Resids')
fig_mn_s, axs_mn_s = generic_grid(6, 6, figsize=punchin_fs, title=r'NN: $\hat \mu_\theta$')

plot_mn = [round(k, 5) for k in mn_priors]
plot_va = [round(k, 5) for k in va_priors]

rev_plot_mn = [i for i in reversed(plot_mn)]
rev_plot_va = [i for i in reversed(plot_va)]
sdm = 100

# NOTE: start counter at 1 since we do negative indexing
counter = 1
# each value of prior for the mean network

for i in range(6):
    # each value of prior for the var network

    for j in range(6):
        ax_res_s = axs_res_s[i][j]
        ax_sds_s = axs_sds_s[i][j]
        ax_res_sds_s = axs_res_sds_s[i][j]
        ax_mn_s = axs_mn_s[i][j]


        
        mns = fitted_mn_sd_nn[-counter][0].cpu().detach()
        sds = fitted_mn_sd_nn[-counter][1].cpu().detach()    
  
        
        if sdm > sds.min():
            sdm = sds.min()
        
        data, b, c =ax_sds_s.hist((sds.flatten().numpy()), bins=[i/3 for i in range(20)])

            
        ax_res_sds_s.scatter(x.cpu(), resids_list_nn[-counter].cpu().abs(), s=s_size_nn, lw=0,)
        ax_res_sds_s.plot(xorg_init.sort(0)[0].cpu().flatten(), sds[1:-1].flatten(), c='orange', lw=linewidth)
        ax_res_sds_s.set_ylim(-0.5, 10)
        
        ax_mn_s.scatter(x.cpu(), y.cpu(), s=s_size_nn, lw=0, c=mean_dots)
        #ax_mn_s.plot(xorg_init.sort(0)[0].cpu().flatten(), true_mu[1:-1].flatten(), c='black', linestyle='dashed')
        ax_mn_s.plot(xorg_init.sort(0)[0].cpu().flatten(), mns[1:-1].flatten(), c=mean_color, lw=linewidth)
        ax_mn_s.set_ylim(-7.5, 7.5)
        

        data, b, c =ax_res_s.hist((resids_list_nn[-counter].cpu().flatten().numpy()), bins=[i/3 for i in range(20)])

            
        counter += 2
    counter += 12
        

    print(i, sdm)
    



plt.show()


fig_res_s.savefig(base_load_uni + '/plots/res_hists-sub.pdf', dpi=300)
fig_sds_s.savefig(base_load_uni + '/plots/sds_hists-sub.pdf', dpi=300)
fig_mn_s.savefig(base_load_uni + '/plots/mu_fit-sub.pdf', dpi=300)
fig_res_sds_s.savefig(base_load_uni + '/plots/res_sds-sub.pdf', dpi=300)





#%%



#--------------------------------------------------------------
# load data for uci regression
base_load_ucif = cwd + '/final_plots/UCI/'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

uci_long = {"concrete": base_load_ucif  + "concrete/run-0", 
            "power":    base_load_ucif  + "power/run-0",
            "yacht":    base_load_ucif  + "yacht/run-0",
            "housing":  base_load_ucif  + "housing/run-0"}

#dataset = "housing"
plot_list = ["concrete", "power", "housing"]

for dataset in plot_list:
    

    uci_load = uci_long[dataset]



    with open(uci_load + '/mn_prior_stds.p', 'rb') as f:
        mn_priors = (pickle.load(f))

    with open(uci_load + '/va_prior_stds.p', 'rb') as f:
        va_priors = (pickle.load(f))

    with open(uci_load + '/models.p', 'rb') as f:
        model_list = (pickle.load(f))

    with open(uci_load + '/resids.p', 'rb') as f:
        resids_list = (pickle.load(f))
        
    # [img_vals, ivg_vals, res_tr_corr_vals, res_te_corr_vals, mse_tr_vals, mse_te_vals, appr_img_vals, appr_ivg_vals]
    with open(uci_load + '/data.p', 'rb') as f:
        unpickled_data = (pickle.load(f))

    uci_x, uci_y, uci_xt, uci_yt = prep_uci_data(dataset, device)

    # values from the entropies
    with open(uci_load + '/data/ent_vals_res.p', 'rb') as f:
        ent_vals_res = (pickle.load(f))
    with open(uci_load + '/data/ent_vals_sds.p', 'rb') as f:
        ent_vals_sds = (pickle.load(f))

    # metrics
    with open(uci_load + '/data/flip_mse.p', 'rb') as f:
        mse_uci = (pickle.load(f))
    with open(uci_load + '/data/flip_w_mse.p', 'rb') as f:
        w_mse_uci = (pickle.load(f))
    with open(uci_load + '/data/flip_log_sig.p', 'rb') as f:
        log_sig_uci = (pickle.load(f))
    with open(uci_load + '/data/flip_log_wt_mse.p', 'rb') as f:
        log_wt_mse_uci = pickle.load(f)

    with open(uci_load + '/data/flip_w_dists.p', 'rb') as f:
        w_dists_uci = pickle.load(f)
    with open(uci_load + '/data/flip_log_w_dists.p', 'rb') as f:
        log_w_dists_uci = pickle.load(f)
    with open(uci_load + '/data/flip_sd_mse.p', 'rb') as f:
        sd_mse_uci = pickle.load(f)


    with open(uci_load + '/data/flip_mean_abs_dev.p', 'rb') as f:
        mean_ad_uci = pickle.load(f)
    with open(uci_load  + '/data/flip_median_abs_dev.p', 'rb') as f:
        median_ad_uci = pickle.load(f)

    with open(uci_load + '/data/flip_test_mean_abs_dev.p', 'rb') as f:
        test_mean_ad_uci = pickle.load(f)
    with open(uci_load  + '/data/flip_test_median_abs_dev.p', 'rb') as f:
        test_median_ad_uci = pickle.load(f)

    with open(uci_load  + '/data/flip_sd_mean_abs_dev.p', 'rb') as f:
        sd_mean_ad_uci = pickle.load(f)
    with open(uci_load  + '/data/flip_sd_median_abs_dev.p', 'rb') as f:
        sd_median_ad_uci = pickle.load(f)

    with open(uci_load  + '/data/flip_test_sd_mean_abs_dev.p', 'rb') as f:
        test_sd_mean_ad_uci = pickle.load(f)
    with open(uci_load  + '/data/flip_test_sd_median_abs_dev.p', 'rb') as f:
        test_sd_median_ad_uci = pickle.load(f)


    uci_xlabel = r'$\beta$'
    uci_ylabel = r'$\alpha$'

    sv = True

    make_heatmap(dataset + ": MSE", torch.Tensor(mse_uci), "", "", uci_xlabel, uci_ylabel, uci_load + '/plots/mse.pdf', save=sv)
    make_heatmap(dataset + ": Weighted MSE", torch.Tensor(w_mse_uci), "", "", uci_xlabel, uci_ylabel, uci_load + '/plots/w_mse.pdf', save=sv)
    make_heatmap(dataset + r": avg $\log \hat \sigma_\phi$", torch.Tensor(log_sig_uci), "", "", uci_xlabel, uci_ylabel, uci_load + '/plots/log_sig.pdf', save=sv)
    make_heatmap(dataset + r": avg $\log$ Weighted MSE", torch.Tensor(log_wt_mse_uci.detach()), "", "", uci_xlabel, uci_ylabel, uci_load + '/plots/log_wt_mse.pdf', save=sv)
    make_heatmap(dataset + r": W(res, sds)", torch.Tensor(w_dists_uci.clip(0, torch.quantile(w_dists_uci, .9))), "", "", uci_xlabel, uci_ylabel, uci_load + '/plots/w_dists.pdf', save=sv)
    make_heatmap(dataset + r": MSE(res, sds)", torch.Tensor(sd_mse_uci.detach().clip(0, torch.quantile(sd_mse_uci, .9))), "", "", uci_xlabel, uci_ylabel, uci_load + '/plots/sd_mse.pdf', save=sv)
    make_heatmap(dataset + r": W(log(res+1), log(sds+1))", log_w_dists_uci, "", "", uci_xlabel, uci_ylabel, uci_load + '/plots/log_w_dists.pdf', save=sv)

    make_heatmap(dataset + ": Mean Abs Deviation", torch.Tensor(mean_ad_uci.detach().cpu()).clip(0, mean_ad_uci.quantile(0.9)), "", "", r"$\beta$", r"$\alpha$", uci_load + '/plots/mean_ad_ft.pdf', save=sv)
    make_heatmap(dataset + ": Median Abs Deviation", torch.Tensor(median_ad_uci.detach().cpu()), "", "", r"$\beta$", r"$\alpha$", uci_load + '/plots/median_ad_ft.pdf', save=sv)

    make_heatmap(dataset + ": Mean Abs Deviation(SD, Res)", torch.Tensor(sd_mean_ad_uci.detach().cpu()).clip(0, mean_ad_uci.quantile(0.9)), "", "", r"$\beta$", r"$\alpha$", uci_load + '/plots/sd_mean_ad_ft.pdf', save=sv)
    make_heatmap(dataset + ": Median Abs Deviation(SD, Res)", torch.Tensor(sd_median_ad_uci.detach().cpu()), "", "", r"$\beta$", r"$\alpha$", uci_load+ '/plots/sd_median_ad_ft.pdf', save=sv)

    make_heatmap(dataset + ": Test Mean Abs Deviation", torch.Tensor(test_mean_ad_uci.detach().cpu()), "", "", r"$\beta$", r"$\alpha$", uci_load + '/plots/test_mean_ad_ft.pdf', save=sv)
    make_heatmap(dataset + ": Test Median Abs Deviation", torch.Tensor(test_median_ad_uci.detach().cpu()), "", "", r"$\beta$", r"$\alpha$", uci_load + '/plots/test_median_ad_ft.pdf', save=sv)

    make_heatmap(dataset + ": Test Mean Abs Deviation(SD, Res)", torch.Tensor(test_sd_mean_ad_uci.detach().cpu()), "", "", r"$\beta$", r"$\alpha$", uci_load + '/plots/test_sd_mean_ad_ft.pdf', save=sv)
    make_heatmap(dataset + ": Test Median Abs Deviation(SD, Res)", torch.Tensor(test_sd_median_ad_uci.detach().cpu()), "", "", r"$\beta$", r"$\alpha$", uci_load + '/plots/test_sd_median_ad_ft.pdf', save=sv)


    sv = True

        
    flipL = torch.eye(unpickled_data[0].size()[1]).flip(0)
    flipR = torch.eye(unpickled_data[0].size()[2]).flip(0)

    flip_mn = flipL @ unpickled_data[0].mean(dim=0) @ flipR
    flip_va = flipL @ unpickled_data[1].mean(dim=0) @ flipR

    make_heatmap(dataset + r": $\int ||\nabla \hat \mu_\theta(x)||_2^2 \, dx$", torch.Tensor(flip_mn), "", "", uci_xlabel, uci_ylabel, uci_load + '/plots/flip_mean_heat_mag.pdf', save=sv)             
    make_heatmap(dataset + r": $\int ||\nabla \hat \sigma_\phi(x)||_2^2\, dx$", torch.Tensor(flip_va), "", "", uci_xlabel, uci_ylabel, uci_load + '/plots/flip_sd_heat_mag.pdf', save=sv)


    make_heatmap(dataset + ": Res Histogram Entropy", torch.Tensor(ent_vals_res), "", "", uci_xlabel, uci_ylabel, uci_load + '/plots/flip_res_hist_ent.pdf', save=sv)
    make_heatmap(dataset + ": SDs Histogram Entropy", torch.Tensor(ent_vals_sds), "", "", uci_xlabel, uci_ylabel, uci_load + '/plots/flip_sds_hist_ent.pdf', save=sv)

    #%%
    ''''''
    size=4
    fig_res, axs_res = generic_grid(len(mn_priors), len(va_priors), figsize=(len(mn_priors)*size, len(va_priors)*size), title='UCI ' + dataset + ': Residuals')
    fig_sds, axs_sds = generic_grid(len(mn_priors), len(va_priors), figsize=(len(mn_priors)*size, len(va_priors)*size), title='UCI ' + dataset + ': SDs')

    plot_mn = [round(k, 5) for k in mn_priors]
    plot_va = [round(k, 5) for k in va_priors]

    rev_plot_mn = [i for i in reversed(plot_mn)]
    rev_plot_va = [i for i in reversed(plot_va)]


    counter = 0
    # each value of prior for the mean network

    for i in range(len(mn_priors)-1, -1, -1):
        # each value of prior for the var network

        for j in range(len(va_priors)-1, -1, -1):
            ax_res = axs_res[i][j]
            ax_sds = axs_sds[i][j]
                
                
            sds = model_list[counter](uci_x)[1].cpu()[1:-1].detach()
            max_sd = 1.5
            
            divisor = 20//max_sd
            
            
            data, b, c =ax_sds.hist((sds.flatten().numpy()), bins=[i/divisor for i in range(20)])
        
        
            max_res = 2. #int(resids_list[counter].cpu().flatten().max().ceil().item())
            
            divisor_res = 20//max_res    
        

            data, b, c =ax_res.hist((resids_list[counter].cpu().flatten().numpy()), bins=[i/divisor_res for i in range(20)])

                
            counter += 1
            

        print(i)

    plt.show()


    fig_res.savefig(uci_load + '/plots/res_hists-o.pdf', dpi=300)
    fig_sds.savefig(uci_load + '/plots/sds_hists-o.pdf', dpi=300)


#%%
    punchin_uci = (4, 3) # (len(mn_priors)*size, len(va_priors)*size)
    size=4
    fig_res_s, axs_res_s = generic_grid(6, 6, figsize=punchin_uci, title=dataset + r': Hists of Resids')
    fig_sds_s, axs_sds_s = generic_grid(6, 6, figsize=punchin_uci, title=dataset + r': Hists of $\hat \sigma_\phi$')


    sdm = 100

    # NOTE: start counter at 1 since we do negative indexing
    counter = 1
    # each value of prior for the mean network

    for i in range(6):
        # each value of prior for the var network

        for j in range(6):
            ax_res_s = axs_res_s[i][j]
            ax_sds_s = axs_sds_s[i][j]

            sds = model_list[-counter](uci_x)[1].cpu()[1:-1].detach()
            max_sd = 1.5
            
            divisor = 20//max_sd
            
            
            data, b, c =ax_sds_s.hist((sds.flatten().numpy()), bins=[i/divisor for i in range(20)])
        
        
            max_res = 2. #int(resids_list[counter].cpu().flatten().max().ceil().item())
            
            divisor_res = 20//max_res    
        

            data, b, c =ax_res_s.hist((resids_list[-counter].cpu().flatten().numpy()), bins=[i/divisor_res for i in range(20)])

                
            counter += 2
        counter += 12
            

        print(i, sdm)
        



    plt.show()


    fig_res_s.savefig(uci_load + '/plots/res_hists-sub.pdf', dpi=300)
    fig_sds_s.savefig(uci_load + '/plots/sds_hists-sub.pdf', dpi=300)
    fig_mn_s.savefig(uci_load  + '/plots/mu_fit-sub.pdf', dpi=300)
    fig_res_sds_s.savefig(uci_load  + '/plots/res_sds-sub.pdf', dpi=300)








# %%

zgrid = torch.zeros((4, 3))
zgrid_rot = torch.zeros((4, 3))
flipL = torch.eye(4).flip(0)
flipR = torch.eye(3).flip(0)

primes = torch.tensor([2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37])
inds12 = torch.tensor([i for i in range(12)])
print(zgrid)

counter = 0
for i in range(zgrid.size()[0]):
    for j in range(zgrid.size()[1]):
        zgrid[i][j] = inds12[counter]

        counter += 1
print(zgrid)


rev_inds12 = (flipL@zgrid@flipR).flatten()

counter = 0
for i in range(zgrid_rot.size()[0]):
    for j in range(zgrid_rot.size()[1]):
        zgrid_rot[i][j] = rev_inds12[counter]

        counter += 1

print(zgrid_rot)

counter = 1
for i in range(zgrid.size()[0]):
    for j in range(zgrid.size()[1]):
        zgrid[i][j] = inds12[-counter]

        counter += 1
print(zgrid)
# %%
pseudo196 = torch.tensor([i for i in range(196)])
zgrid49 = torch.zeros(7, 7)

counter = 0
for i in range(7):
    for j in range(7):
        zgrid49[i][j] = pseudo196[counter]
        counter += 2
    counter += 14
# %%
print(zgrid49)
# %%
zgrid462 = torch.zeros(7,7)
inds462 = torch.tensor([i for i in range(462)])



counter = 0
for i in range(7):
    for j in range(7):
        zgrid462[i][j] = inds462[counter]
        counter += 3

    counter += 42
# %%
mse_ft
# %%
sub_mse_ft = torch.zeros(7, 7)

mse_ft_flat = mse_ft.flatten()

counter = 0
for i in range(7):
    for j in range(7):
        sub_mse_ft[i][j] = mse_ft_flat[counter]
        counter += 3

    counter += 42

# %%
