#!/usr/bin/env python3
"""
Plot and Table Generation for Bayesian NN Results

This script loads the aggregated MAT files for each method (MAP, DE, HMC, SMC) and generates:
1. OOD entropy bar plots + LaTeX tables (saved as PNG)
2. All-domain group comparison plots + LaTeX tables (saved as PNG)
3. In-domain correct/incorrect breakdown plots + LaTeX tables (saved as PNG)
4. Combined OOD predictive entropy subplots (MAP/DE/HMC/SMC) (saved as PNG)
5. Rearranged correct/incorrect comparison plots (saved as PNG)
6. Group-method comparison plots (saved as PNG)
"""
import os, re
import numpy as np
import scipy.io as sio
from scipy.io.matlab import mat_struct
import matplotlib.pyplot as plt
import glob

# ——— choose which “P” file to load
P = 1 

# --------------------------
# Helpers to convert MATLAB structs
# --------------------------
def _todict(matobj):
    d = {}
    for f in matobj._fieldnames:
        d[f] = _check(getattr(matobj, f))
    return d


def _check(e):
    if isinstance(e, mat_struct):
        return _todict(e)
    if isinstance(e, np.ndarray) and e.dtype == object:
        return np.array([_check(x) for x in e])
    return e
# --------------------------
# Load metrics for each method
# --------------------------
def load_metrics(method):
    # MAP and DE: exactly as before
    if method in ('MAP','DE'):
        files = {
            'MAP': 'imdb_map_metrics.mat',
            'DE':  'imdb_de_metrics.mat'
        }
        raw = sio.loadmat(files[method], squeeze_me=True, struct_as_record=False)
        return {k: _check(v) for k,v in raw.items() if not k.startswith('__')}

    # HMC or SMC: load the P={P} file you just set
    pattern = {
        'HMC': f'phmc_aggregated_results_P{P}.mat',
        'SMC': f'psmc_aggregated_results_P{P}.mat'
    }[method]
    candidates = glob.glob(pattern)
    if not candidates:
        raise FileNotFoundError(f"No file matching {pattern!r} found.")
    raw = sio.loadmat(candidates[0], squeeze_me=True, struct_as_record=False)
    flat = {k: _check(v) for k,v in raw.items() if not k.startswith('__')}

    # rebuild the nested dict structure your downstream code expects
    d = {}
    # — in-domain totals
    d['total_entropy_inID'] = flat.get('mean_tot_ent', np.nan)
    d['epistemic_inID']     = flat.get('mean_epi',     np.nan)

    # — correct/incorrect breakdown
    d['ID'] = {
        'mean_total_entropy_correct':   flat.get('mean_tot_corr', np.nan),
        'mean_total_entropy_incorrect': flat.get('mean_tot_inc',  np.nan),
        'mean_epistemic_correct':       flat.get('mean_epi_corr', np.nan),
        'mean_epistemic_incorrect':     flat.get('mean_epi_inc',  np.nan),
    }

    # — OOD fields
    if method == 'SMC':
        ood = {}
        for k in ['meta','full_meta','reviews','full_reviews','lipsum']:
            ood[k] = {
                'mean_total_entropy': flat.get(f'ood_{k}_mean_tot', np.nan),
                'mean_epistemic':     flat.get(f'ood_{k}_mean_epi', np.nan)
            }
        d['OOD'] = ood
    else:  # HMC
        ood = {}
        for k in ['meta','full_meta','reviews','full_reviews','lipsum']:
            ood[k] = {
                'mean_total_entropy': flat.get(f'ood_{k}_mean_tot', np.nan),
                'mean_epistemic':     flat.get(f'ood_{k}_mean_epi', np.nan)
            }
        d['OOD'] = ood
    
    d['per_digit_total_entropy']   = flat.get('mean_class_tot_ent',   np.full(2, np.nan))
    d['per_digit_epistemic_entropy'] = flat.get('mean_class_epi_ent',   np.full(2, np.nan))

    return d

# --------------------------
# Extract OOD + All-ID summary
# --------------------------
def get_ood_values(m, d):
    if m == 'MAP':
        tot = [np.mean(d.get(f'replicate_total_entropy_od_{s}', np.nan))
               for s in ['meta','full_meta','reviews','full_reviews','lipsum']]
        allid = np.nanmean([np.nanmean(d.get(f'total_entropy_class{s}_mean', np.nan)) for s in ['0','1']])
        return tot + [allid]

    if m == 'DE':
        tot = [np.mean(d[f'replicate_total_entropy_od_{s}'])
               for s in ['meta','full_meta','reviews','full_reviews','lipsum']]
        epi = [np.mean(d[f'replicate_epistemic_od_{s}'])
               for s in ['meta','full_meta','reviews','full_reviews','lipsum']]
        ale = [t - e for t,e in zip(tot, epi)]
        aT = np.nanmean([np.nanmean( d.get(f'replicate_total_entropy_in_class{s}',
                                           np.array([np.nan])) ) for s in ['0','1']])
        aE = np.nanmean([np.nanmean( d.get(f'replicate_epistemic_in_class{s}',
                                           np.array([np.nan])) ) for s in ['0','1']])
        if np.isnan(aE):
            aE = 0.5 * (np.mean(d.get('replicate_correct_epistemic_entropy',np.nan)) +
                        np.mean(d.get('replicate_incorrect_epistemic_entropy',np.nan)))
        return tot, ale, epi, aT, aT - aE, aE

    if m == 'HMC':
        o = d['OOD']
        tot = [o['meta']['mean_total_entropy'],
               o['full_meta']['mean_total_entropy'],
               o['reviews']['mean_total_entropy'],
               o['full_reviews']['mean_total_entropy'],
               o['lipsum']['mean_total_entropy']]
        epi = [o['meta']['mean_epistemic'],
               o['full_meta']['mean_epistemic'],
               o['reviews']['mean_epistemic'],
               o['full_reviews']['mean_epistemic'],
               o['lipsum']['mean_epistemic']]
        ale = [t - e for t, e in zip(tot, epi)]
        aT = d.get('total_entropy_inID', np.nan)
        aE = d.get('epistemic_inID', np.nan)
        return tot, ale, epi, aT, aT - aE, aE

    if m == 'SMC':
        o = d['OOD']
        tot = [o['meta']['mean_total_entropy'],
               o['full_meta']['mean_total_entropy'],
               o['reviews']['mean_total_entropy'],
               o['full_reviews']['mean_total_entropy'],
               o['lipsum']['mean_total_entropy']]
        epi = [o['meta']['mean_epistemic'],
               o['full_meta']['mean_epistemic'],
               o['reviews']['mean_epistemic'],
               o['full_reviews']['mean_epistemic'],
               o['lipsum']['mean_epistemic']]
        ale = [t - e for t,e in zip(tot, epi)]
        aT = d.get('total_entropy_inID', np.nan)
        aE = d.get('epistemic_inID', np.nan)
        return tot, ale, epi, aT, aT - aE, aE

    return [], [], [], np.nan, np.nan, np.nan

# --------------------------
# Plot 1: OOD Entropy + save PNG
# --------------------------
def plot_ood(m, d, ylim):
    grp=['Meta','Full Meta','Reviews','Full reviews','Lipsum']
    fig, ax = plt.subplots()
    if m=='MAP':
        tot = get_ood_values(m,d)
        ax.bar(range(5), tot[:5], 0.5, label='Total')
    else:
        tot, ale, epi, *_ = get_ood_values(m,d)
        x = np.arange(5)
        ax.bar(x-0.25, tot, 0.25, label='Total')
        ax.bar(x,      ale, 0.25, label='Aleatoric')
        ax.bar(x+0.25, epi, 0.25, label='Epistemic')
    ax.set_xticks(range(5)); ax.set_xticklabels(grp)
    ax.set_yscale('log')
    ax.set_ylim(ylim)
    ax.set_ylabel('Entropy'); ax.legend()
    fig.tight_layout(); fig.savefig(f"imdb_ood_entropy_{m}.png",dpi=300,bbox_inches='tight'); plt.show()


# --------------------------
# Plot 3: Correct vs Incorrect + LaTeX Table + save PNG
# --------------------------
def plot_ci(all_data, ylim):
    methods = ['MAP','DE','HMC','SMC']
    # remapped tick labels
    labels = {
        'MAP': 'MAP',
        'DE': 'DE',
        'HMC': r'HMC$_\parallel$',
        'SMC': r'SMC$_\parallel$'
    }

    keys = {
        'tc':['replicate_correct_entropy','replicate_correct_total_entropy'],
        'ti':['replicate_incorrect_entropy','replicate_incorrect_total_entropy'],
        'ec':['replicate_correct_epistemic_entropy'],
        'ei':['replicate_incorrect_epistemic_entropy']
    }
    vals = {k:[] for k in keys}
    for m in methods:
        d = all_data[m]
        if m in ('HMC','SMC') and 'ID' in d:
            id_struct = d['ID']
            vals['tc'].append(id_struct.get('mean_total_entropy_correct', np.nan))
            vals['ti'].append(id_struct.get('mean_total_entropy_incorrect', np.nan))
            vals['ec'].append(id_struct.get('mean_epistemic_correct', np.nan))
            vals['ei'].append(id_struct.get('mean_epistemic_incorrect', np.nan))
        else:
            for k, ks in keys.items():
                v = np.nan
                for key in ks:
                    if key in d:
                        v = np.mean(d[key]); break
                vals[k].append(v)

    x = np.arange(len(methods))
    w = 0.2
    fig, ax = plt.subplots(figsize=(10,6))
    ax.bar(x-1.5*w, vals['tc'], w, label='Total Correct')
    ax.bar(x-0.5*w, vals['ti'], w, label='Total Incorrect')
    ax.bar(x+0.5*w, vals['ec'], w, label='Epistemic Correct')
    ax.bar(x+1.5*w, vals['ei'], w, label='Epistemic Incorrect')

    ax.set_xticks(x)
    ax.set_xticklabels([labels[m] for m in methods])  # apply remapped labels

    ax.set_yscale('log')
    ax.set_ylim(ylim)
    ax.set_ylabel('Entropy')
    ax.legend()

    fig.tight_layout()
    fig.savefig("imdb_correct_vs_incorrect.png", dpi=300, bbox_inches='tight')
    plt.show()


# --------------------------
# Plot 5: Rearranged Correct/Incorrect Comparison
# --------------------------
def plot_rearranged_ci(all_data, ylim):
    methods = ['MAP','DE','HMC','SMC']
    # updated legend labels
    labels = {
        'MAP': 'MAP',
        'DE': 'DE',
        'HMC': r'HMC$_\parallel$',
        'SMC': r'SMC$_\parallel$'
    }

    keys = {
        'tc':['replicate_correct_entropy','replicate_correct_total_entropy'],
        'ti':['replicate_incorrect_entropy','replicate_incorrect_total_entropy'],
        'ec':['replicate_correct_epistemic_entropy'],
        'ei':['replicate_incorrect_epistemic_entropy']
    }
    vals = {k:[] for k in keys}
    for m in methods:
        d = all_data[m]
        if m in ('HMC','SMC') and 'ID' in d:
            id_struct = d['ID']
            vals['tc'].append(id_struct.get('mean_total_entropy_correct', np.nan))
            vals['ti'].append(id_struct.get('mean_total_entropy_incorrect', np.nan))
            vals['ec'].append(id_struct.get('mean_epistemic_correct', np.nan))
            vals['ei'].append(id_struct.get('mean_epistemic_incorrect', np.nan))
        else:
            for k, ks in keys.items():
                v = np.nan
                for key in ks:
                    if key in d:
                        v = np.mean(d[key]); break
                vals[k].append(v)

    groups = ['Correct','Incorrect']
    x = np.arange(2)
    w = 0.15
    offsets = [(i-2)*w for i in range(len(methods))]

    fig, axes = plt.subplots(1,2,figsize=(12,5))

    # Total entropy panel
    ax = axes[0]
    for i, m in enumerate(methods):
        ax.bar(x+offsets[i], [vals['tc'][i], vals['ti'][i]], w,
               label=labels[m])
    ax.set_xticks(x); ax.set_xticklabels(groups)
    ax.set_ylabel('Total Entropy')
    #ax.set_title('Total: Correct vs Incorrect') 

    # Epistemic entropy panel
    ax = axes[1]
    for i, m in enumerate(methods):
        ax.bar(x+offsets[i], [vals['ec'][i], vals['ei'][i]], w,
               label=labels[m])
    ax.set_xticks(x); ax.set_xticklabels(groups)
    ax.set_ylabel('Epistemic Entropy')
    #ax.set_title('Epistemic: Correct vs Incorrect')

    # legend inside the right subplot   
    axes[1].legend(loc='upper left')

    fig.tight_layout(rect=[0,0,1,1])
    fig.savefig("imdb_rearranged_correct_incorrect.png", dpi=300, bbox_inches='tight')
    plt.show()


# --------------------------
# Plot 6: Group-Method Comparison
# --------------------------
def plot_group_methods(all_data):
    methods = ['MAP','DE','HMC','SMC']
    # updated legend labels
    labels = {
        'MAP': 'MAP',
        'DE': 'DE',
        'HMC': r'HMC$_\parallel$',
        'SMC': r'SMC$_\parallel$'
    }

    groups = ['Meta','Full Meta','Reviews','Full reviews','Lipsum', 'All IDs']

    tot_vals = {}
    epi_vals = {}
    for m in methods:
        if m == 'MAP':
            vals = get_ood_values(m, all_data[m])
            tot = vals[:5] + [vals[5]]
            epi = [0] * 6
        else:
            tot_list, ale, epi_list, aT, aA, aE = get_ood_values(m, all_data[m])
            tot = tot_list + [aT]
            epi = epi_list + [aE]
        tot_vals[m] = tot
        epi_vals[m] = epi

    x = np.arange(len(groups))
    width = 0.15
    offsets = [(i - 2) * width for i in range(len(methods))]

    fig, axes = plt.subplots(1, 2, figsize=(14, 6))

    # Total Entropy panel
    ax = axes[0]
    for i, m in enumerate(methods):
        ax.bar(x + offsets[i], tot_vals[m], width,
               label=labels[m])
    ax.set_xticks(x); ax.set_xticklabels(groups, rotation=0, ha='center')
    ax.set_ylabel('Total Entropy')
    #ax.set_title('Total Entropy by Group and Method')

    # Epistemic Entropy panel
    ax = axes[1]
    for i, m in enumerate(methods):
        ax.bar(x + offsets[i], epi_vals[m], width,
               label=labels[m])
    ax.set_xticks(x); ax.set_xticklabels(groups, rotation=0, ha='center')
    ax.set_ylabel('Epistemic Entropy')
    #ax.set_title('Epistemic Entropy by Group and Method')

    # legend inside the right subplot
    axes[1].legend(loc='upper left')

    fig.tight_layout()
    fig.savefig("imdb_group_methods_comparison.png", dpi=300, bbox_inches='tight')
    plt.show()


# --------------------------
# NEW: Plot 7: Full-Digit OOD Entropy (0–9, Perturbed, White Noise, All IDs)
# --------------------------
def get_full_ood_values(method, d):
    if method == 'MAP':
        tot = [d.get(f'total_entropy_class{s}_mean',
                                           np.array([np.nan])) for s in ['0','1']]
        for key in ['meta','full_meta','reviews','full_reviews','lipsum']:
            tot.append(np.nanmean(d.get(f'replicate_total_entropy_od_{key}', [np.nan])))
        return tot, [0]*len(tot), [0]*len(tot)
    
    if method == 'DE':
        tot = [np.nanmean( d.get(f'replicate_total_entropy_in_class{s}',
                                           np.array([np.nan])) ) for s in ['0','1']]
        epi = [np.nanmean( d.get(f'replicate_epistemic_in_class{s}',
                                           np.array([np.nan])) ) for s in ['0','1']]
        ale = [t - e for t,e in zip(tot, epi)]
        keys = ['meta','full_meta','reviews','full_reviews','lipsum']
        for key in keys:
            t = np.nanmean(d.get(f'replicate_total_entropy_od_{key}', np.array([np.nan])))
            e = np.nanmean(d.get(f'replicate_epistemic_od_{key}', np.array([np.nan])))
            tot.append(t); epi.append(e); ale.append(t - e)
        return tot, ale, epi
    
    # Other methods: fallback logic
    tot, ale, epi = [], [], []
    # digits 0–7
    t_arr = d.get('per_digit_total_entropy', None)
    e_arr = d.get('per_digit_epistemic_entropy', None)
    if t_arr is not None and e_arr is not None:
        for i in range(2):
            t = np.mean(t_arr[i]); e = np.mean(e_arr[i])
            tot.append(t); epi.append(e); ale.append(t - e)
    else:
        tot.extend([np.nan]*2); epi.extend([np.nan]*2); ale.extend([np.nan]*2)
    # meta, full meta, reviews, full reviews, lipsum
    if method == 'SMC':
        # SMC nested OOD dict
        o = d.get('OOD', {})
        mapping = [('meta', 'meta'), ('full_meta', 'full_meta'),
                   ('reviews', 'reviews'), ('full_reviews', 'full_reviews'), ('lipsum', 'lipsum')]
        for key, blk_key in mapping:
            blk = o.get(blk_key, {})
            t = blk.get('mean_total_entropy', np.nan)
            e = blk.get('mean_epistemic', np.nan)
            tot.append(t); epi.append(e); ale.append(t - e)
    elif method == 'HMC':
        o = d.get('OOD', {})
        mapping = [('meta', 'meta'), ('full_meta', 'full_meta'),
                   ('reviews', 'reviews'), ('full_reviews', 'full_reviews'), ('lipsum', 'lipsum')]
        for key, blk_key in mapping:
            blk = o.get(blk_key, {})
            t = blk.get('mean_total_entropy', np.nan)
            e = blk.get('mean_epistemic', np.nan)
            tot.append(t); epi.append(e); ale.append(t - e)
    else:
        # unknown method
        tot.extend([np.nan]*5); epi.extend([np.nan]*5); ale.extend([np.nan]*5)

    return tot, ale, epi


def plot_full_digit_ood(m, d, ylim):
    groups = ['Negative', 'Positive'] + ['Meta','Full Meta','Reviews','Full reviews', 'Lipsum','All IDs']
    tot, ale, epi = get_full_ood_values(m, d)
    # for MAP: ale/epi zero; for others, append All-ID from get_ood_values
    if m == 'MAP':
        allid = np.nanmean([d.get(f'total_entropy_class{s}_mean', np.array([np.nan])) for s in ['0','1']])
        tot.append(allid); ale.append(0.0); epi.append(0.0)
    else:
        _, _, _, aT, _, aE = get_ood_values(m, d)
        tot.append(aT); ale.append(aT - aE); epi.append(aE)

    x = np.arange(len(groups))
    fig, ax = plt.subplots(figsize=(12,6))
    if m == 'MAP':
        ax.bar(x, tot, 0.6, label='Total')
    else:
        width = 0.25
        ax.bar(x-width, tot, width, label='Total')
        ax.bar(x, ale, width, label='Aleatoric')
        ax.bar(x+width, epi, width, label='Epistemic')
    ax.set_xticks(x); ax.set_xticklabels(groups, rotation=0, ha='center')
    ax.set_yscale('log')
    ax.set_ylim(ylim); 
    ax.set_ylabel('Entropy'); ax.legend()
    fig.tight_layout(); fig.savefig(f"imdb_full_ood_entropy_{m}.png",dpi=300,bbox_inches='tight'); plt.show()


# --------------------------
# NEW: Plot 7: Group-Method Comparison with ID Correct vs Incorrect
# --------------------------
def plot_group_methods_ci(all_data):
    methods = ['MAP','DE','HMC','SMC']
    labels = {
        'MAP': 'MAP',
        'DE': 'DE',
        'HMC': r'HMC$_\parallel$',
        'SMC': r'SMC$_\parallel$'
    }
    groups = ['Meta','Full Meta','Reviews','Full reviews','Lipsum','ID Cor.','ID Inc.']

    tot_vals = {}
    epi_vals = {}
    for m in methods:
        vals = get_ood_values(m, all_data[m])
        tot_list, ale, epi_list, _, _, _ = get_ood_values(m, all_data[m])
        if m == 'MAP':
            base_tot = vals[:5]
            id_corr = np.mean(all_data[m].get('replicate_correct_entropy', np.nan))
            id_inc = np.mean(all_data[m].get('replicate_incorrect_entropy', np.nan))
            tot = base_tot + [id_corr, id_inc]
            epi = [0]*5 + [np.mean(all_data[m].get('replicate_correct_epistemic_entropy', np.nan)),
                          np.mean(all_data[m].get('replicate_incorrect_epistemic_entropy', np.nan))]
        elif m == 'DE':
            # your OOD totals/epis are already in tot_list, epi_list
            id_corr = np.mean(all_data[m].get('replicate_correct_total_entropy', np.nan))
            id_inc  = np.mean(all_data[m].get('replicate_incorrect_total_entropy', np.nan))
            ep_corr = np.mean(all_data[m].get('replicate_correct_epistemic_entropy', np.nan))
            ep_inc  = np.mean(all_data[m].get('replicate_incorrect_epistemic_entropy', np.nan))
            tot = tot_list + [id_corr, id_inc]
            epi = epi_list + [ep_corr, ep_inc]
        elif m in ('HMC','SMC'):
            id_s = all_data[m]['ID']
            tot = tot_list + [id_s.get('mean_total_entropy_correct', np.nan), id_s.get('mean_total_entropy_incorrect', np.nan)]
            epi = epi_list + [id_s.get('mean_epistemic_correct', np.nan), id_s.get('mean_epistemic_incorrect', np.nan)]
        tot_vals[m] = tot
        epi_vals[m] = epi

    x = np.arange(len(groups))
    width = 0.2
    offsets = [(i - 1.5) * width for i in range(len(methods))]

    fig, axes = plt.subplots(1, 2, figsize=(14,6))
    # Total Entropy panel
    ax = axes[0]
    for i, m in enumerate(methods):
        ax.bar(x + offsets[i], tot_vals[m], width, label=labels[m])
    ax.set_xticks(x)
    ax.set_xticklabels(groups, rotation=0, ha='center')
    ax.set_ylabel('Total Entropy')
    ax.set_yscale('log')

    # Epistemic Entropy panel
    ax = axes[1]
    for i, m in enumerate(methods):
        ax.bar(x + offsets[i], epi_vals[m], width, label=labels[m])
    ax.set_xticks(x)
    ax.set_xticklabels(groups, rotation=0, ha='center')
    ax.set_ylabel('Epistemic Entropy')
    ax.set_yscale('log')
    axes[1].legend(loc='upper left')

    fig.tight_layout()
    fig.savefig("imdb_group_methods_comparison_ci.png", dpi=300, bbox_inches='tight')
    plt.show()


# --------------------------
# main()
# --------------------------
def main():
    methods = ['MAP','DE','HMC','SMC']
    all_data = {m: load_metrics(m) for m in methods}

    # compute common ylim for 1–6
    ood_vals = []
    for m in methods:
        vals = get_ood_values(m, all_data[m])
        ood_vals.extend(vals[:5] if m=='MAP' else vals[0])
    ylim = (max(min(ood_vals), 1e-10)*0.05, np.nanmax(ood_vals)*3)

    for m in methods: plot_ood(m, all_data[m], ylim)

    # Compute ylim for correct/incorrect entropy plots
    ci_vals = []
    for m in methods:
        d = all_data[m]
        if m in ('HMC', 'SMC') and 'ID' in d:
            id_s = d['ID']
            ci_vals += [
                id_s.get('mean_total_entropy_correct', np.nan),
                id_s.get('mean_total_entropy_incorrect', np.nan),
                id_s.get('mean_epistemic_correct', np.nan),
                id_s.get('mean_epistemic_incorrect', np.nan)
            ]
        else:
            for key in ['replicate_correct_total_entropy', 'replicate_incorrect_total_entropy',
                        'replicate_correct_epistemic_entropy', 'replicate_incorrect_epistemic_entropy']:
                if key in d:
                    ci_vals.append(np.mean(d[key]))

    # Clean values
    ci_vals = [v for v in ci_vals if v > 0 and not np.isnan(v)]
    ci_ylim = (min(ci_vals) * 0.9, max(ci_vals) * 3)
    plot_ci(all_data, ci_ylim)
    plot_rearranged_ci(all_data, ylim)
    plot_group_methods(all_data)
    plot_group_methods_ci(all_data)

        # --- Print in-domain CI LaTeX table to terminal ---
    keys = {
        'tc':['replicate_correct_entropy','replicate_correct_total_entropy'],
        'ti':['replicate_incorrect_entropy','replicate_incorrect_total_entropy'],
        'ec':['replicate_correct_epistemic_entropy'],
        'ei':['replicate_incorrect_epistemic_entropy']
    }
    vals = {k: [] for k in keys}
    for m in methods:
        d = all_data[m]
        if m in ('HMC','SMC') and 'ID' in d:
            id_s = d['ID']
            vals['tc'].append(id_s.get('mean_total_entropy_correct',   np.nan))
            vals['ti'].append(id_s.get('mean_total_entropy_incorrect', np.nan))
            vals['ec'].append(id_s.get('mean_epistemic_correct',       np.nan))
            vals['ei'].append(id_s.get('mean_epistemic_incorrect',     np.nan))
        else:
            for k, ks in keys.items():
                v = np.nan
                for key in ks:
                    if key in d:
                        v = np.mean(d[key])
                        break
                vals[k].append(v)

    labels = {
        'MAP':    'MAP',
        'DE':     'DE',
        'HMC':    r'HMC$_\parallel$',
        'SMC':    r'SMC$_\parallel$'
    }

        # Assemble and print (scientific notation)
    print(r"\begin{table}[H]")
    print(r"\centering")
    print(r"\caption{Comparison in ID domain among SMC$_\parallel$ ($N=10$), HMC$_{\parallel}$ ($N$ chains), DE ($N$ models) and MAP, with fixed number of leapfrog $L=1$ and $s=$, on IMDB (5 realizations and $\pm$ s.e. in entropy).}")
    print(r"\begin{tabular}{l|cc|cc}")
    print(r"\toprule")
    print(r"Method & \multicolumn{2}{c}{$H_{\mathsf{ep}}$} & \multicolumn{2}{c}{$H_{\mathsf{tot}}$} \\")
    print(r"\cmidrule{2-3} \cmidrule{4-5}")
    print(r" & cor. & inc. & cor. & inc. \\")
    print(r"\midrule")
    for i, m in enumerate(methods):
        print(f"{labels[m]} & {vals['ec'][i]:.3e} & {vals['ei'][i]:.3e} & {vals['tc'][i]:.3e} & {vals['ti'][i]:.3e} \\\\")
        if m != methods[-1]:
            print(r"\midrule")
    print(r"\bottomrule")
    print(r"\end{tabular}")
    print(r"\end{table}")


    # compute ylim for full-digit
    full_vals = []
    for m in methods:
        tot, ale, epi = get_full_ood_values(m, all_data[m])
        full_vals.extend(tot + (ale if m!='MAP' else []) + (epi if m!='MAP' else []))
    full_ylim = (max(min(full_vals), 1e-10)*0.05, np.nanmax(full_vals)*3)

    for m in methods:
        plot_full_digit_ood(m, all_data[m], full_ylim)

        # --- Print all‐domains LaTeX table to terminal ---
    groups = ['Negative','Positive'] + ['Meta','Full Meta','Reviews','Full reviews','Lipsum','All ID']
    methods = ['MAP','DE','HMC','SMC']

    # gather values
    full_vals = {}
    for m in methods:
        # first 7 entries: class0&1, meta, full meta, reviews, full reviews, lipsum
        tot12, ale12, epi12 = get_full_ood_values(m, all_data[m])
        # now append the “All ID” triple
        if m == 'MAP':
            # get_ood_values returns [tot8, tot9, tot_pert, tot_white, allid]
            tmp = get_ood_values(m, all_data[m])
            allid = tmp[-1]
            aT, aA, aE = allid, allid, 0.0
        else:
            # returns tot4, ale4, epi4, aT, aA, aE
            _, _, _, aT, aA, aE = get_ood_values(m, all_data[m])
        full_vals[m] = (
            tot12 + [aT],
            ale12 + [aA],
            epi12 + [aE]
        )

    # header
    print(r"\begin{table}[H]")
    print(r"\centering")
    print(r"\scriptsize")
    print(r"\caption{Comparison in all domains among SMC$_\parallel$ ($N=10$), HMC$_{\parallel}$ ($N$ chains), DE ($N$ models) and MAP, with fixed number of leapfrog $L=1$ and $s=0,$, on IMDB (5 realizations and $\pm$ s.e. in entropy).}")
    print(r"\begin{tabular}{l|c|c|c|c|c|c|c|c|c|c|}")
    print(r"\toprule")
    print(r"Group & \multicolumn{1}{c|}{MAP} "
          r"& \multicolumn{3}{c|}{DE} "
          r"& \multicolumn{3}{c|}{HMC} "
          r"& \multicolumn{3}{c|}{SMC} \\")
    print(r"\cmidrule(rl){2-2} \cmidrule(rl){3-5} \cmidrule(rl){6-8} \cmidrule(rl){9-11}")
    print(r" & $H_{\mathsf{tot}}$ & $H_{\mathsf{tot}}$ & $H_{\mathsf{al}}$ & $H_{\mathsf{ep}}$ "
          r"& $H_{\mathsf{tot}}$ & $H_{\mathsf{al}}$ & $H_{\mathsf{ep}}$ "
          r"& $H_{\mathsf{tot}}$ & $H_{\mathsf{al}}$ & $H_{\mathsf{ep}}$ \\")
    print(r"\midrule")

    # rows
    for i, grp in enumerate(groups):
        entries = [grp]
        for m in methods:
            tot, ale, epi = full_vals[m]
            if m == 'MAP':
                # just H_tot
                entries.append(f"{tot[i]:.3e}")
            else:
                entries += [f"{tot[i]:.3e}", f"{ale[i]:.3e}", f"{epi[i]:.3e}"]
        print(" & ".join(entries) + r" \\")
        print(r"\midrule")

    print(r"\bottomrule")
    print(r"\end{tabular}")
    print(r"\end{table}")



if __name__ == '__main__':
    main()
