import numpy as np
import jax.numpy as jnp
from jax import random, vmap
import importlib
import scipy

alpha = 0.1

models = ['Election', 'GLMM', 'Earnings', 'KidScore', 'NES', 'Radon', 'Wells']
con_models = ['GLMM', 'Earnings', 'KidScore', 'NES', 'Radon']
seeds = [0, 1, 2, 3, 4]
lambs = [0.0, 0.1, 0.01, 1.0]
posterior = ['Basic', 'BasicFullRank']
regularizer = ['VIBasic', 'KLPrior']
prediction_sample = 1000

def process_quad(logits, y):
    fs = scipy.special.expit(logits)
    ave_fs = np.mean(fs, axis=0)
    return  jnp.mean(2 * (ave_fs * y - jnp.square(ave_fs) + (1 - ave_fs) * (1 - y)) - jnp.square(1 - ave_fs))


for m in models:
    module1 = importlib.import_module('jax_posteriordb.model')
    mod = getattr(module1, m)()
    for p in posterior:
        module2 = importlib.import_module('posterior')
        pos = getattr(module2, p)(mod.n)
        vires = []
        test_res = []
        test_res2 = []
        test_res3 = []
        for seed in seeds:
            vipath = f'pdb_result/{m}_0.0_100_{p}_VIBasic/VIBasic_0.0_{seed}_100'

            with open(vipath, 'r') as f:
                line = f.readline().split()[:-1]
            line = [float(l) for l in line]
            theta_sample = pos.sample(random.PRNGKey(0), jnp.array(line), prediction_sample)
            log_likelihoods = vmap(mod.valid_log_likelihoods)(theta_sample, )
            predictive_ll = np.sum(
                np.array(scipy.special.logsumexp(log_likelihoods, axis=0)) - np.log(prediction_sample))
            vires.append(predictive_ll)

            log_likelihoods = vmap(mod.test_log_likelihoods)(theta_sample, )
            predictive_ll = np.sum(
                np.array(scipy.special.logsumexp(log_likelihoods, axis=0)) - np.log(prediction_sample))
            test_res.append(predictive_ll)
            if m == 'GLMM':
                continue
            if m in con_models:
                keys1 = random.split(random.PRNGKey(1), prediction_sample)
                ys = vmap(mod.sample_test_datapoint)(keys1, theta_sample)
                if ys.shape[1] == 1:
                    ys = ys[:, 0, ...]
                y1 = jnp.swapaxes(ys[:prediction_sample//2], 0, 1)
                y2 = jnp.swapaxes(ys[prediction_sample//2:], 0, 1)
                y = mod.test_data()
                if len(y.shape) == 1:
                    y = jnp.expand_dims(y, 1)
                test_crps = -jnp.sum(-jnp.mean(jnp.abs(y - y1), axis=1) / 2 - jnp.mean(jnp.abs(y - y2), axis=1) / 2 + jnp.mean(jnp.abs(y1 - y2), axis=1) / 2)
                test_res2.append(test_crps)

                ys = vmap(mod.sample_test_datapoint)(keys1, theta_sample)
                if ys.shape[1] == 1:
                    ys = ys[:, 0, ...]
                y0 = jnp.swapaxes(ys, 0, 1)
                y = mod.test_data()
                if len(y.shape) == 1:
                    y = jnp.expand_dims(y, 1)
                lower = jnp.quantile(y0, alpha / 2, axis=1)
                u = jnp.quantile(y0, 1 - alpha / 2, axis=1)
                if len(y.shape) == 1:
                    y = jnp.expand_dims(y, 1)
                test_is = jnp.sum(
                    (u - lower) + 2 / alpha * (lower - y) * (y < lower) + 2 / alpha * (y - u) * (y > u))
                test_res3.append(test_is)
                continue
            logits = vmap(mod.test_logits)(theta_sample)
            test_quad = process_quad(logits, mod.test_data())
            test_res2.append(test_quad)
        print(m, p, 'VI', '%.2f' %np.mean(test_res), '(%.2f)' %np.std(test_res), '%.2f' %np.mean(test_res2), '(%.2f)' %np.std(test_res2), '%.2f' %(np.mean(test_res3)/1000), '(%.2f)' %(np.std(test_res3)/1000))
        best1 = -1e9
        best2 = 1e9
        best3 = 1e9
        test11 = []
        test12 = []
        test13 = []
        test21 = []
        test22 = []
        test23 = []
        test31 = []
        test32 = []
        test33 = []

        for r in regularizer:
            for l in lambs:
                pvires = []
                test_res1 = []
                test_res2 = []
                test_res3 = []
                for seed in seeds:
                    pvipath = f'pdb_result/{m}_0.0_100_{p}_PACMVIBasic/{r}_{l}_{seed}_100'
                    try:
                        with open(pvipath, 'r') as f:
                            line = f.readline().split()[:-1]
                        line = [float(l) for l in line]
                        theta_sample = pos.sample(random.PRNGKey(0), jnp.array(line), prediction_sample)
                        log_likelihoods = vmap(mod.valid_log_likelihoods)(theta_sample, )
                        predictive_ll = np.sum(
                            np.array(scipy.special.logsumexp(log_likelihoods, axis=0)) - np.log(prediction_sample))
                        pvires.append(predictive_ll)

                        log_likelihoods = vmap(mod.test_log_likelihoods)(theta_sample, )
                        test_predictive_ll = np.sum(
                            np.array(scipy.special.logsumexp(log_likelihoods, axis=0)) - np.log(prediction_sample))
                        test_res1.append(test_predictive_ll)
                        if m in con_models:
                            keys1 = random.split(random.PRNGKey(1), prediction_sample)
                            ys = vmap(mod.sample_test_datapoint)(keys1, theta_sample)
                            if ys.shape[1] == 1:
                                ys = ys[:, 0, ...]
                            y1 = jnp.swapaxes(ys[:prediction_sample // 2], 0, 1)
                            y2 = jnp.swapaxes(ys[prediction_sample // 2:], 0, 1)
                            y = mod.test_data()
                            if len(y.shape) == 1:
                                y = jnp.expand_dims(y, 1)
                            test_crps = -jnp.sum(
                                -jnp.mean(jnp.abs(y - y1), axis=1) / 2 - jnp.mean(jnp.abs(y - y2),
                                                                                  axis=1) / 2 + jnp.mean(
                                    jnp.abs(y1 - y2), axis=1) / 2)
                            test_res2.append(test_crps)

                            ys = vmap(mod.sample_test_datapoint)(keys1, theta_sample)
                            if ys.shape[1] == 1:
                                ys = ys[:, 0, ...]
                            y0 = jnp.swapaxes(ys, 0, 1)
                            y = mod.test_data()
                            if len(y.shape) == 1:
                                y = jnp.expand_dims(y, 1)
                            lower = jnp.quantile(y0, alpha / 2, axis=1)
                            u = jnp.quantile(y0, 1 - alpha / 2, axis=1)
                            if len(y.shape) == 1:
                                y = jnp.expand_dims(y, 1)
                            test_is = jnp.sum(
                                (u - lower) + 2 / alpha * (lower - y) * (y < lower) + 2 / alpha * (y - u) * (y > u))
                            test_res3.append(test_is)

                        else:
                            logits = vmap(mod.test_logits)(theta_sample)
                            test_quad = process_quad(logits, mod.test_data())
                            test_res2.append(test_quad)



                    except:
                        pass
                #print(m, p, 'PVI-log', r, l, pvires)
                if not np.isnan(np.mean(pvires)) and np.mean(pvires) > best1:
                    best1 = np.mean(pvires)
                    test11 = test_res1
                    test12 = test_res2
                    test13 = test_res3

                pvires = []
                pvires2 = []
                test_res1 = []
                test_res2 = []
                test_res3 = []
                test_res21 = []
                test_res22 = []
                test_res23 = []
                if m == 'GLMM':
                    continue
                if m in con_models:
                    for seed in seeds:
                        pvipath = f'pdb_result/{m}_0.0_100_{p}_PVICRPS/{r}_{l}_{seed}_100'
                        try:
                            with open(pvipath, 'r') as f:
                                line = f.readline().split()[:-1]
                            line = [float(l) for l in line]
                            theta_sample = pos.sample(random.PRNGKey(0), jnp.array(line), prediction_sample)

                            keys1 = random.split(random.PRNGKey(1), prediction_sample)
                            ys = vmap(mod.sample_valid_datapoint)(keys1, theta_sample)
                            if ys.shape[1] == 1:
                                ys = ys[:, 0, ...]
                            y1 = jnp.swapaxes(ys[:prediction_sample // 2], 0, 1)
                            y2 = jnp.swapaxes(ys[prediction_sample // 2:], 0, 1)
                            y = mod.valid_data()
                            if len(y.shape) == 1:
                                y = jnp.expand_dims(y, 1)
                            valid_crps = -jnp.sum(
                                -jnp.mean(jnp.abs(y - y1), axis=1) / 2 - jnp.mean(jnp.abs(y - y2),
                                                                                  axis=1) / 2 + jnp.mean(
                                    jnp.abs(y1 - y2), axis=1) / 2)
                            pvires.append(valid_crps)

                            log_likelihoods = vmap(mod.test_log_likelihoods)(theta_sample, )
                            test_predictive_ll = np.sum(
                                np.array(scipy.special.logsumexp(log_likelihoods, axis=0)) - np.log(prediction_sample))
                            test_res1.append(test_predictive_ll)

                            ys = vmap(mod.sample_test_datapoint)(keys1, theta_sample)
                            if ys.shape[1] == 1:
                                ys = ys[:, 0, ...]
                            y1 = jnp.swapaxes(ys[:prediction_sample // 2], 0, 1)
                            y2 = jnp.swapaxes(ys[prediction_sample // 2:], 0, 1)
                            y = mod.test_data()
                            if len(y.shape) == 1:
                                y = jnp.expand_dims(y, 1)
                            test_crps = -jnp.sum(
                                -jnp.mean(jnp.abs(y - y1), axis=1) / 2 - jnp.mean(jnp.abs(y - y2),
                                                                                  axis=1) / 2 + jnp.mean(
                                    jnp.abs(y1 - y2), axis=1) / 2)
                            test_res2.append(test_crps)

                            ys = vmap(mod.sample_test_datapoint)(keys1, theta_sample)
                            if ys.shape[1] == 1:
                                ys = ys[:, 0, ...]
                            y0 = jnp.swapaxes(ys, 0, 1)
                            y = mod.test_data()
                            if len(y.shape) == 1:
                                y = jnp.expand_dims(y, 1)
                            lower = jnp.quantile(y0, alpha / 2, axis=1)
                            u = jnp.quantile(y0, 1 - alpha / 2, axis=1)
                            if len(y.shape) == 1:
                                y = jnp.expand_dims(y, 1)
                            test_is = jnp.sum(
                                (u - lower) + 2 / alpha * (lower - y) * (y < lower) + 2 / alpha * (y - u) * (y > u))
                            test_res3.append(test_is)

                        except:
                            pass

                        pvipath = f'pdb_result/{m}_0.0_100_{p}_IntervalScore/{r}_{l}_{seed}_100'
                        try:
                            with open(pvipath, 'r') as f:
                                line = f.readline().split()[:-1]
                            line = [float(l) for l in line]
                            theta_sample = pos.sample(random.PRNGKey(0), jnp.array(line), prediction_sample)

                            keys1 = random.split(random.PRNGKey(1), prediction_sample)
                            ys = vmap(mod.sample_valid_datapoint)(keys1, theta_sample)
                            if ys.shape[1] == 1:
                                ys = ys[:, 0, ...]
                            y0 = jnp.swapaxes(ys, 0, 1)
                            y = mod.valid_data()
                            if len(y.shape) == 1:
                                y = jnp.expand_dims(y, 1)
                            lower = jnp.quantile(y0, alpha / 2, axis=1)
                            u = jnp.quantile(y0, 1 - alpha / 2, axis=1)
                            if len(y.shape) == 1:
                                y = jnp.expand_dims(y, 1)
                            valid_is = jnp.sum(
                                (u - lower) + 2 / alpha * (lower - y) * (y < lower) + 2 / alpha * (y - u) * (y > u))
                            pvires2.append(valid_is)

                            log_likelihoods = vmap(mod.test_log_likelihoods)(theta_sample, )
                            test_predictive_ll = np.sum(
                                np.array(scipy.special.logsumexp(log_likelihoods, axis=0)) - np.log(prediction_sample))
                            test_res21.append(test_predictive_ll)

                            ys = vmap(mod.sample_test_datapoint)(keys1, theta_sample)
                            if ys.shape[1] == 1:
                                ys = ys[:, 0, ...]
                            y1 = jnp.swapaxes(ys[:prediction_sample // 2], 0, 1)
                            y2 = jnp.swapaxes(ys[prediction_sample // 2:], 0, 1)
                            y = mod.test_data()
                            if len(y.shape) == 1:
                                y = jnp.expand_dims(y, 1)
                            test_crps = -jnp.sum(
                                -jnp.mean(jnp.abs(y - y1), axis=1) / 2 - jnp.mean(jnp.abs(y - y2),
                                                                                  axis=1) / 2 + jnp.mean(
                                    jnp.abs(y1 - y2), axis=1) / 2)
                            test_res22.append(test_crps)

                            ys = vmap(mod.sample_test_datapoint)(keys1, theta_sample)
                            if ys.shape[1] == 1:
                                ys = ys[:, 0, ...]
                            y0 = jnp.swapaxes(ys, 0, 1)
                            y = mod.test_data()
                            if len(y.shape) == 1:
                                y = jnp.expand_dims(y, 1)
                            lower = jnp.quantile(y0, alpha / 2, axis=1)
                            u = jnp.quantile(y0, 1 - alpha / 2, axis=1)
                            if len(y.shape) == 1:
                                y = jnp.expand_dims(y, 1)
                            test_is = jnp.sum(
                                (u - lower) + 2 / alpha * (lower - y) * (y < lower) + 2 / alpha * (y - u) * (y > u))
                            test_res23.append(test_is)

                        except:
                            pass
                    if not np.isnan(np.mean(pvires)) and np.mean(pvires) < best2:
                        best2 = np.mean(pvires)
                        test21 = test_res1
                        test22 = test_res2
                        test23 = test_res3

                    if not np.isnan(np.mean(pvires2)) and np.mean(pvires2) < best3:
                        best3 = np.mean(pvires2)
                        test31 = test_res21
                        test32 = test_res22
                        test33 = test_res23

                else:
                    for seed in seeds:
                        pvipath = f'pdb_result/{m}_0.0_100_{p}_QuadraticBernoulli/{r}_{l}_{seed}_100'
                        try:
                            with open(pvipath, 'r') as f:
                                line = f.readline().split()[:-1]
                            line = [float(l) for l in line]
                            theta_sample = pos.sample(random.PRNGKey(0), jnp.array(line), prediction_sample)
                            logits = vmap(mod.valid_logits)(theta_sample)
                            pvires.append(process_quad(logits, mod.valid_data()))

                            log_likelihoods = vmap(mod.test_log_likelihoods)(theta_sample, )
                            test_predictive_ll = np.sum(
                                np.array(scipy.special.logsumexp(log_likelihoods, axis=0)) - np.log(prediction_sample))
                            test_res1.append(test_predictive_ll)

                            logits = vmap(mod.test_logits)(theta_sample)
                            test_quad = process_quad(logits, mod.test_data())
                            test_res2.append(test_quad)
                        except:
                            pass
                    #print(m, p, 'PVI-quad', r, l, pvires)
                    if not np.isnan(np.mean(pvires)) and np.mean(pvires) > best2:
                        best2 = np.mean(pvires)
                        test21 = test_res1
                        test22 = test_res2
        print(m, p, 'PVI-log', '%.2f' % np.mean(test11), '(%.2f)' % np.std(test11), '%.2f' % np.mean(test12), '(%.2f)' % np.std(test12), '%.2f' %(np.mean(test13)/1000), '(%.2f)' % (np.std(test13)/1000))
        print(m, p, 'PVI-resp', '%.2f' % np.mean(test21), '(%.2f)' % np.std(test21), '%.2f' % np.mean(test22), '(%.2f)' % np.std(test22), '%.2f' % (np.mean(test23)/1000), '(%.2f)' %( np.std(test23)/1000))
        print(m, p, 'PVI-resp2', '%.2f' % np.mean(test31), '(%.2f)' % np.std(test31), '%.2f' % np.mean(test32), '(%.2f)' % np.std(test32), '%.2f' % (np.mean(test33)/1000), '(%.2f)' % (np.std(test33)/1000))


