from svbmc import SVBMC # https://github.com/acerbilab/S-VBMC/tree/main
from svbmc.utils import overlay_corner_plot
from pyvbmc.variational_posterior import VariationalPosterior
import numpy as np
from pathlib import Path
import cloudpickle as pickle


def save_posterior(obj, filepath):
    with open(f"{filepath}.pkl", "wb") as f:
        pickle.dump(obj, f)

def load_posterior(filepath):
    with open(f"{filepath}.pkl", "rb") as f:
        return pickle.load(f)

def get_vps(sbj, RHO_A, split):
    if RHO_A == 1:
        folder = "rho_1.0"
    else:
        folder = "rho_1.3333333333"
    if sbj < 10:
        sbj = "0"+str(int(sbj))
    else:
        sbj = str(int(sbj))
    
    vp_list = []

    for run in range(200):
        if run < 10:
            run = "0"+str(int(run))
        else:
            run = str(int(run))
        try:
            vp_list.append(VariationalPosterior.load(f"results_truncated_400_split_{split}/{folder}/sbj_{sbj}/run_{run}/vp.npz"))
        except:
            continue

    return vp_list


def run_svbmc(s, rho, split):
    vp_list = get_vps(s, rho, split)
    # Filtering out bad runs, if present
    elbos = []
    before_filter = len(vp_list) 
    vp_list_filtered = [vp for vp in vp_list if (vp.stats['stable'] and np.max(vp.stats['J_sjk']) < 5)]
    for vp in vp_list_filtered:
        elbos.append(vp.stats['elbo'])
    elbo_median = np.median(elbos)
    vp_list_filtered = [vp for vp in vp_list_filtered if (np.abs(vp.stats['elbo'] - elbo_median) < 5)]
    after_filter = len(vp_list_filtered) 

    if after_filter >= 20:
        
        final_list = vp_list_filtered[:20]
        stacked_vp = SVBMC(final_list)
        stacked_vp.optimize()

        if s < 10:
            sbj = "0" + str(int(s))
        else:
            sbj = str(int(s))
        if rho == 1:
            r = "1.0"
        else:
            r = "1.3333333333"
        filepath = f"results_truncated_400_split_{split}/rho_{r}/sbj_{sbj}/stacked_vp"
        save_posterior(stacked_vp, filepath)

    else:
        print(f"NOT ENOUGH GOOD RUNS FOR SUBJECT {s} WITH RHO={rho} on split {split}")

