#!/usr/bin/env python3
"""
Plot and Table Generation for Bayesian NN Results

This script loads the aggregated MAT files for each method (MAP, DE, HMC, SMC) and generates:
1. OOD entropy bar plots + LaTeX tables (saved as PNG)
2. All-domain group comparison plots + LaTeX tables (saved as PNG)
3. In-domain correct/incorrect breakdown plots + LaTeX tables (saved as PNG)
4. Combined OOD predictive entropy subplots (MAP/DE/HMC/SMC) (saved as PNG)
5. Rearranged correct/incorrect comparison plots (saved as PNG)
6. Group-method comparison plots (saved as PNG)
"""
import os, re
import numpy as np
import scipy.io as sio
from scipy.io.matlab import mat_struct
import matplotlib.pyplot as plt
import glob

# ——— choose which “P” file to load
P = 8#1 

# --------------------------
# Helpers to convert MATLAB structs
# --------------------------
def _todict(matobj):
    d = {}
    for f in matobj._fieldnames:
        d[f] = _check(getattr(matobj, f))
    return d


def _check(e):
    if isinstance(e, mat_struct):
        return _todict(e)
    if isinstance(e, np.ndarray) and e.dtype == object:
        return np.array([_check(x) for x in e])
    return e
# --------------------------
# Load metrics for each method
# --------------------------
def load_metrics(method):
    # MAP and DE: exactly as before
    if method in ('MAP','DE'):
        files = {
            'MAP': 'BayesianNN_CIFAR_MAP_metrics.mat',
            'DE':  'BayesianNN_CIFAR_DE_metrics.mat'
        }
        raw = sio.loadmat(files[method], squeeze_me=True, struct_as_record=False)
        return {k: _check(v) for k,v in raw.items() if not k.startswith('__')}

    # HMC or SMC: load the P={P} file you just set
    pattern = {
        'HMC': f'phmc_aggregated_results_P{P}.mat',
        'SMC': f'psmc_aggregated_results_P{P}.mat'
    }[method]
    candidates = glob.glob(pattern)
    if not candidates:
        raise FileNotFoundError(f"No file matching {pattern!r} found.")
    raw = sio.loadmat(candidates[0], squeeze_me=True, struct_as_record=False)
    flat = {k: _check(v) for k,v in raw.items() if not k.startswith('__')}

    # rebuild the nested dict structure your downstream code expects
    d = {}
    # — in-domain totals
    d['total_entropy_inID'] = flat.get('mean_tot_ent', np.nan)
    d['epistemic_inID']     = flat.get('mean_epi',     np.nan)

    # — correct/incorrect breakdown
    d['ID'] = {
        'mean_total_entropy_correct':   flat.get('mean_tot_corr', np.nan),
        'mean_total_entropy_incorrect': flat.get('mean_tot_inc',  np.nan),
        'mean_epistemic_correct':       flat.get('mean_epi_corr', np.nan),
        'mean_epistemic_incorrect':     flat.get('mean_epi_inc',  np.nan),
    }

    # — OOD fields
    if method == 'SMC':
        ood = {}
        for k in ['close','corrupt','far']:
            ood[k] = {
                'mean_total_entropy': flat.get(f'ood_{k}_mean_tot', np.nan),
                'mean_epistemic':     flat.get(f'ood_{k}_mean_epi', np.nan)
            }
        d['OOD'] = ood
    else:  # HMC
        ood = {}
        for k in ['close','corrupt','far']:
            ood[k] = {
                'mean_total_entropy': flat.get(f'ood_{k}_mean_tot', np.nan),
                'mean_epistemic':     flat.get(f'ood_{k}_mean_epi', np.nan)
            }
        d['OOD'] = ood
    
    #d['per_digit_total_entropy']   = flat.get('mean_class_tot_ent',   np.full(2, np.nan))
    #d['per_digit_epistemic_entropy'] = flat.get('mean_class_epi_ent',   np.full(2, np.nan))

    return d

# --------------------------
# Extract OOD + All-ID summary
# --------------------------
def get_ood_values(m, d):
    if m == 'MAP':
        tot = [np.mean(d.get(f'replicate_total_entropy_od_{s}', np.nan))
               for s in ['close','corrupt','far']]
        #allid = np.nanmean([np.nanmean(d.get(f'total_entropy_class{s}_mean', np.nan)) for s in ['0','1']])
        allid = np.nanmean(d.get('total_entropy_mean', np.nan))
        return tot + [allid]

    if m == 'DE':
        tot = [np.mean(d[f'replicate_total_entropy_od_{s}'])
               for s in ['close','corrupt','far']]
        epi = [np.mean(d[f'replicate_epistemic_od_{s}'])
               for s in ['close','corrupt','far']]
        ale = [t - e for t,e in zip(tot, epi)]
        aT = np.nanmean( d.get(f'replicate_total_entropy', np.array([np.nan])))
        aE = np.nanmean( d.get(f'replicate_epistemic_entropy', np.array([np.nan])))
        #aT = np.nanmean([np.nanmean( d.get(f'replicate_total_entropy_in_class{s}',
        #                                   np.array([np.nan])) ) for s in ['0','1']])
        #aE = np.nanmean([np.nanmean( d.get(f'replicate_epistemic_in_class{s}',
        #                                   np.array([np.nan])) ) for s in ['0','1']])
        #if np.isnan(aE):
        #    aE = 0.5 * (np.mean(d.get('replicate_correct_epistemic_entropy',np.nan)) +
        #                np.mean(d.get('replicate_incorrect_epistemic_entropy',np.nan)))
        return tot, ale, epi, aT, aT - aE, aE

    if m == 'HMC':
        o = d['OOD']
        tot = [o['close']['mean_total_entropy'],
               o['corrupt']['mean_total_entropy'],
               o['far']['mean_total_entropy']]
        epi = [o['close']['mean_epistemic'],
               o['corrupt']['mean_epistemic'],
               o['far']['mean_epistemic']]
        ale = [t - e for t, e in zip(tot, epi)]
        aT = d.get('total_entropy_inID', np.nan)
        aE = d.get('epistemic_inID', np.nan)
        return tot, ale, epi, aT, aT - aE, aE

    if m == 'SMC':
        o = d['OOD']
        tot = [o['close']['mean_total_entropy'],
               o['corrupt']['mean_total_entropy'],
               o['far']['mean_total_entropy']]
        epi = [o['close']['mean_epistemic'],
               o['corrupt']['mean_epistemic'],
               o['far']['mean_epistemic']]
        ale = [t - e for t,e in zip(tot, epi)]
        aT = d.get('total_entropy_inID', np.nan)
        aE = d.get('epistemic_inID', np.nan)
        return tot, ale, epi, aT, aT - aE, aE

    return [], [], [], np.nan, np.nan, np.nan

# --------------------------
# Plot 1: OOD Entropy + save PNG
# --------------------------
def plot_ood(m, d, ylim):
    grp=['close','corrupt','far']
    fig, ax = plt.subplots()
    if m=='MAP':
        tot = get_ood_values(m,d)
        ax.bar(range(3), tot[:3], 0.5, label='Total')
    else:
        tot, ale, epi, *_ = get_ood_values(m,d)
        x = np.arange(3)
        ax.bar(x-0.25, tot, 0.25, label='Total')
        ax.bar(x,      ale, 0.25, label='Aleatoric')
        ax.bar(x+0.25, epi, 0.25, label='Epistemic')
    ax.set_xticks(range(3)); ax.set_xticklabels(grp)
    ax.set_yscale('log')
    ax.set_ylim(ylim)
    ax.set_ylabel('Entropy'); ax.legend()
    fig.tight_layout(); fig.savefig(f"cifar_ood_entropy_{m}.png",dpi=300,bbox_inches='tight'); plt.show()


# --------------------------
# Plot 3: Correct vs Incorrect + LaTeX Table + save PNG
# --------------------------
def plot_ci(all_data, ylim):
    methods = ['MAP','DE','HMC','SMC']
    # remapped tick labels
    labels = {
        'MAP': 'MAP',
        'DE': 'DE',
        'HMC': r'HMC$_\parallel$',
        'SMC': r'SMC$_\parallel$'
    }

    keys = {
        'tc':['replicate_correct_entropy','replicate_correct_total_entropyt'],
        'ti':['replicate_incorrect_entropy','replicate_incorrect_total_entropy'],
        'ec':['replicate_correct_epistemic_entropy','replicate_correct_epistemic_entropyt'],
        'ei':['replicate_incorrect_epistemic_entropy']
    }
    vals = {k:[] for k in keys}
    for m in methods:
        d = all_data[m]
        if m in ('HMC','SMC') and 'ID' in d:
            id_struct = d['ID']
            vals['tc'].append(id_struct.get('mean_total_entropy_correct', np.nan))
            vals['ti'].append(id_struct.get('mean_total_entropy_incorrect', np.nan))
            vals['ec'].append(id_struct.get('mean_epistemic_correct', np.nan))
            vals['ei'].append(id_struct.get('mean_epistemic_incorrect', np.nan))
        else:
            for k, ks in keys.items():
                v = np.nan
                for key in ks:
                    if key in d:
                        v = np.mean(d[key]); break
                vals[k].append(v)

    x = np.arange(len(methods))
    w = 0.2
    fig, ax = plt.subplots(figsize=(10,6))
    ax.bar(x-1.5*w, vals['tc'], w, label='Total Correct')
    ax.bar(x-0.5*w, vals['ti'], w, label='Total Incorrect')
    ax.bar(x+0.5*w, vals['ec'], w, label='Epistemic Correct')
    ax.bar(x+1.5*w, vals['ei'], w, label='Epistemic Incorrect')

    ax.set_xticks(x)
    ax.set_xticklabels([labels[m] for m in methods])  # apply remapped labels

    ax.set_yscale('log')
    ax.set_ylim(ylim)
    ax.set_ylabel('Entropy')
    ax.legend()

    fig.tight_layout()
    fig.savefig("cifar_correct_vs_incorrect.png", dpi=300, bbox_inches='tight')
    plt.show()


# --------------------------
# Plot 5: Rearranged Correct/Incorrect Comparison
# --------------------------
def plot_rearranged_ci(all_data, ylim):
    methods = ['MAP','DE','HMC','SMC']
    # updated legend labels
    labels = {
        'MAP': 'MAP',
        'DE': 'DE',
        'HMC': r'HMC$_\parallel$',
        'SMC': r'SMC$_\parallel$'
    }

    keys = {
        'tc':['replicate_correct_total_entropyt','replicate_correct_entropy'],
        'ti':['replicate_incorrect_total_entropy','replicate_incorrect_entropy'],
        'ec':['replicate_correct_epistemic_entropy'],
        'ei':['replicate_incorrect_epistemic_entropy']
    }
    vals = {k:[] for k in keys}
    for m in methods:
        d = all_data[m]
        if m in ('HMC','SMC') and 'ID' in d:
            id_struct = d['ID']
            vals['tc'].append(id_struct.get('mean_total_entropy_correct', np.nan))
            vals['ti'].append(id_struct.get('mean_total_entropy_incorrect', np.nan))
            vals['ec'].append(id_struct.get('mean_epistemic_correct', np.nan))
            vals['ei'].append(id_struct.get('mean_epistemic_incorrect', np.nan))
        else:
            for k, ks in keys.items():
                v = np.nan
                for key in ks:
                    if key in d:
                        v = np.mean(d[key]); break
                vals[k].append(v)

    groups = ['Correct','Incorrect']
    x = np.arange(2)
    w = 0.15
    offsets = [(i-2)*w for i in range(len(methods))]

    fig, axes = plt.subplots(1,2,figsize=(12,5))

    # Total entropy panel
    ax = axes[0]
    for i, m in enumerate(methods):
        ax.bar(x+offsets[i], [vals['tc'][i], vals['ti'][i]], w,
               label=labels[m])
    ax.set_xticks(x); ax.set_xticklabels(groups)
    ax.set_ylabel('Total Entropy')
    #ax.set_title('Total: Correct vs Incorrect') 

    # Epistemic entropy panel
    ax = axes[1]
    for i, m in enumerate(methods):
        ax.bar(x+offsets[i], [vals['ec'][i], vals['ei'][i]], w,
               label=labels[m])
    ax.set_xticks(x); ax.set_xticklabels(groups)
    ax.set_ylabel('Epistemic Entropy')
    #ax.set_title('Epistemic: Correct vs Incorrect')

    # legend inside the right subplot   
    axes[1].legend(loc='upper left')

    fig.tight_layout(rect=[0,0,1,1])
    fig.savefig("cifar_rearranged_correct_incorrect.png", dpi=300, bbox_inches='tight')
    plt.show()


# --------------------------
# Plot 6: Group-Method Comparison
# --------------------------
def plot_group_methods(all_data):
    methods = ['MAP','DE','HMC','SMC']
    # updated legend labels
    labels = {
        'MAP': 'MAP',
        'DE': 'DE',
        'HMC': r'HMC$_\parallel$',
        'SMC': r'SMC$_\parallel$'
    }

    groups = ['close','corrupt','far', 'All IDs']

    tot_vals = {}
    epi_vals = {}
    for m in methods:
        if m == 'MAP':
            vals = get_ood_values(m, all_data[m])
            tot = vals[:4] #+ [vals[3]]
            epi = [0] * 4
        else:
            tot_list, ale, epi_list, aT, aA, aE = get_ood_values(m, all_data[m])
            tot = tot_list + [aT]
            epi = epi_list + [aE]
        tot_vals[m] = tot
        epi_vals[m] = epi

    x = np.arange(len(groups))
    width = 0.15
    offsets = [(i - 2) * width for i in range(len(methods))]

    fig, axes = plt.subplots(1, 2, figsize=(14, 6))

    # Total Entropy panel
    ax = axes[0]
    for i, m in enumerate(methods):
        ax.bar(x + offsets[i], tot_vals[m], width,
               label=labels[m])
    ax.set_xticks(x); ax.set_xticklabels(groups, rotation=0, ha='center')
    ax.set_ylabel('Total Entropy')
    #ax.set_title('Total Entropy by Group and Method')

    # Epistemic Entropy panel
    ax = axes[1]
    for i, m in enumerate(methods):
        ax.bar(x + offsets[i], epi_vals[m], width,
               label=labels[m])
    ax.set_xticks(x); ax.set_xticklabels(groups, rotation=0, ha='center')
    ax.set_ylabel('Epistemic Entropy')
    #ax.set_title('Epistemic Entropy by Group and Method')

    # legend inside the right subplot
    axes[1].legend(loc='upper left')

    fig.tight_layout()
    fig.savefig("cifar_group_methods_comparison.png", dpi=300, bbox_inches='tight')
    plt.show()


# --------------------------
# NEW: Plot 7: Full-Digit OOD Entropy (0–9, Perturbed, White Noise, All IDs)
# --------------------------
def get_full_ood_values(method, d):
    if method == 'MAP':
        tot = []#[d.get(f'total_entropy_class{s}_mean',
        #                                   np.array([np.nan])) for s in ['0','1']]
        for key in ['close','corrupt','far']:
            tot.append(np.nanmean(d.get(f'replicate_total_entropy_od_{key}', [np.nan])))
        return tot, [0]*len(tot), [0]*len(tot)
    
    if method == 'DE':
        tot = []#[np.nanmean( d.get(f'replicate_total_entropy_in_class{s}',
              #                             np.array([np.nan])) ) for s in ['0','1']]
        epi = []#[np.nanmean( d.get(f'replicate_epistemic_in_class{s}',
              #                             np.array([np.nan])) ) for s in ['0','1']]
        ale = [t - e for t,e in zip(tot, epi)]
        keys = ['close','corrupt','far']
        for key in keys:
            t = np.nanmean(d.get(f'replicate_total_entropy_od_{key}', np.array([np.nan])))
            e = np.nanmean(d.get(f'replicate_epistemic_od_{key}', np.array([np.nan])))
            tot.append(t); epi.append(e); ale.append(t - e)
        return tot, ale, epi
    
    # Other methods: fallback logic
    tot, ale, epi = [], [], []
    # digits 0–7
    #t_arr = d.get('per_digit_total_entropy', None)
    #e_arr = d.get('per_digit_epistemic_entropy', None)
    #if t_arr is not None and e_arr is not None:
    #    for i in range(2):
    #        t = np.mean(t_arr[i]); e = np.mean(e_arr[i])
    #        tot.append(t); epi.append(e); ale.append(t - e)
    #else:
    #    tot.extend([np.nan]*2); epi.extend([np.nan]*2); ale.extend([np.nan]*2)
    # meta, full meta, reviews, full reviews, lipsum
    if method == 'SMC':
        # SMC nested OOD dict
        o = d.get('OOD', {})
        mapping = [('close', 'close'), ('corrupt', 'corrupt'),('far', 'far')]
        for key, blk_key in mapping:
            blk = o.get(blk_key, {})
            t = blk.get('mean_total_entropy', np.nan)
            e = blk.get('mean_epistemic', np.nan)
            tot.append(t); epi.append(e); ale.append(t - e)
    elif method == 'HMC':
        o = d.get('OOD', {})
        mapping = [('close', 'close'), ('corrupt', 'corrupt'),('far', 'far')]
        for key, blk_key in mapping:
            blk = o.get(blk_key, {})
            t = blk.get('mean_total_entropy', np.nan)
            e = blk.get('mean_epistemic', np.nan)
            tot.append(t); epi.append(e); ale.append(t - e)
    else:
        # unknown method
        tot.extend([np.nan]*3); epi.extend([np.nan]*3); ale.extend([np.nan]*3)

    return tot, ale, epi


def plot_full_digit_ood(m, d, ylim):
    groups = ['close','corrupt','far','All IDs']#['Negative', 'Positive'] + ['close','corrupt','far','All IDs']
    tot, ale, epi = get_full_ood_values(m, d)
    # for MAP: ale/epi zero; for others, append All-ID from get_ood_values
    if m == 'MAP':
        allid = np.nanmean(d.get('total_entropy_mean', np.nan))
        tot.append(allid); ale.append(0.0); epi.append(0.0)
    else:
        _, _, _, aT, _, aE = get_ood_values(m, d)
        tot.append(aT); ale.append(aT - aE); epi.append(aE)

    x = np.arange(len(groups))
    fig, ax = plt.subplots(figsize=(12,6))
    if m == 'MAP':
        ax.bar(x, tot, 0.6, label='Total')
    else:
        width = 0.25
        ax.bar(x-width, tot, width, label='Total')
        ax.bar(x, ale, width, label='Aleatoric')
        ax.bar(x+width, epi, width, label='Epistemic')
    ax.set_xticks(x); ax.set_xticklabels(groups, rotation=0, ha='center')
    ax.set_yscale('log')
    ax.set_ylim(ylim); 
    ax.set_ylabel('Entropy'); ax.legend()
    fig.tight_layout(); fig.savefig(f"cifar_full_ood_entropy_{m}.png",dpi=300,bbox_inches='tight'); plt.show()


# --------------------------
# Plot 8: Group-Method Comparison with Correct/Incorrect ID
# --------------------------
def plot_group_methods_ci(all_data):
    methods = ['MAP','DE','HMC','SMC']
    labels = {
        'MAP': 'MAP',
        'DE': 'DE',
        'HMC': r'HMC$_\parallel$',
        'SMC': r'SMC$_\parallel$'
    }
    groups = ['Close','Corrupt','Far','ID Cor.','ID Inc.']

    tot_vals = {}
    epi_vals = {}
    for m in methods:
        d = all_data[m]
        # OOD totals & epistemic
        if m == 'MAP':
            ood_tot = get_ood_values(m, d)[:3]
            ood_epi = [0,0,0]
        else:
            ood_tot, _, ood_epi, *_ = get_ood_values(m, d)

        # In-domain correct vs incorrect
        if m in ('HMC','SMC') and 'ID' in d:
            tc = d['ID']['mean_total_entropy_correct']
            ti = d['ID']['mean_total_entropy_incorrect']
            ec = d['ID']['mean_epistemic_correct']
            ei = d['ID']['mean_epistemic_incorrect']
        elif  m == 'MAP':
            tc = np.nanmean(d.get('replicate_correct_entropy',      np.nan))
            ti = np.nanmean(d.get('replicate_incorrect_entropy',     np.nan))
            ec = np.nanmean(d.get('replicate_correct_epistemic_entropy',   np.nan))
            ei = np.nanmean(d.get('replicate_incorrect_epistemic_entropy', np.nan))
        else:
            tc = np.nanmean(d.get('replicate_correct_total_entropyt',      np.nan))
            ti = np.nanmean(d.get('replicate_incorrect_total_entropy',     np.nan))
            ec = np.nanmean(d.get('replicate_correct_epistemic_entropy',   np.nan))
            ei = np.nanmean(d.get('replicate_incorrect_epistemic_entropy', np.nan))

        tot_vals[m] = ood_tot + [tc, ti]
        epi_vals[m] = ood_epi + [ec, ei]

    x = np.arange(len(groups))
    width = 0.15
    offsets = [(i - 1.5) * width for i in range(len(methods))]

    fig, axes = plt.subplots(1, 2, figsize=(14, 6))

    # Total Entropy
    ax = axes[0]
    for i, m in enumerate(methods):
        ax.bar(x + offsets[i], tot_vals[m], width, label=labels[m])
    ax.set_xticks(x)
    ax.set_xticklabels(groups, rotation=0, ha='center')
    ax.set_ylabel('Total Entropy')

    # Epistemic Entropy
    ax = axes[1]
    for i, m in enumerate(methods):
        ax.bar(x + offsets[i], epi_vals[m], width, label=labels[m])
    ax.set_xticks(x)
    ax.set_xticklabels(groups, rotation=0, ha='center')
    ax.set_ylabel('Epistemic Entropy')
    axes[1].legend(loc='upper left')

    fig.tight_layout()
    fig.savefig("cifar_group_methods_ci.png", dpi=300, bbox_inches='tight')
    plt.show()



# --------------------------
# main()
# --------------------------
def main():
    methods = ['MAP','DE','HMC','SMC']
    all_data = {m: load_metrics(m) for m in methods}

    # compute common ylim for 1–6
    ood_vals = []
    for m in methods:
        vals = get_ood_values(m, all_data[m])
        ood_vals.extend(vals[:3] if m=='MAP' else vals[0])
    ylim = (max(min(ood_vals), 1e-10)*0.0001, np.nanmax(ood_vals)*3)

    for m in methods: plot_ood(m, all_data[m], ylim)

    # Compute ylim for correct/incorrect entropy plots
    ci_vals = []
    for m in methods:
        d = all_data[m]
        if m in ('HMC', 'SMC') and 'ID' in d:
            id_s = d['ID']
            ci_vals += [
                id_s.get('mean_total_entropy_correct', np.nan),
                id_s.get('mean_total_entropy_incorrect', np.nan),
                id_s.get('mean_epistemic_correct', np.nan),
                id_s.get('mean_epistemic_incorrect', np.nan)
            ]
        else:
            for key in ['replicate_correct_total_entropyt', 'replicate_incorrect_total_entropy',
                        'replicate_correct_epistemic_entropy','replicate_correct_epistemic_entropyt', 'replicate_incorrect_epistemic_entropy']:
                if key in d:
                    ci_vals.append(np.mean(d[key]))

    # Clean values
    ci_vals = [v for v in ci_vals if v > 0 and not np.isnan(v)]
    ci_ylim = (min(ci_vals) * 0.1, max(ci_vals) * 3)
    plot_ci(all_data, ci_ylim)
    plot_rearranged_ci(all_data, ylim)
    plot_group_methods(all_data)
    # In main(), after plot_group_methods(all_data):
    plot_group_methods_ci(all_data)

        # --- Print in-domain CI LaTeX table to terminal ---
    keys = {
        'tc':['replicate_correct_total_entropyt'],
        'ti':['replicate_incorrect_total_entropy'],
        'ec':['replicate_correct_epistemic_entropy','replicate_correct_epistemic_entropyt'],
        'ei':['replicate_incorrect_epistemic_entropy']
    }
    vals = {k: [] for k in keys}
    for m in methods:
        d = all_data[m]
        if m in ('HMC','SMC') and 'ID' in d:
            id_s = d['ID']
            vals['tc'].append(id_s.get('mean_total_entropy_correct',   np.nan))
            vals['ti'].append(id_s.get('mean_total_entropy_incorrect', np.nan))
            vals['ec'].append(id_s.get('mean_epistemic_correct',       np.nan))
            vals['ei'].append(id_s.get('mean_epistemic_incorrect',     np.nan))
        else:
            for k, ks in keys.items():
                v = np.nan
                for key in ks:
                    if key in d:
                        v = np.mean(d[key])
                        break
                vals[k].append(v)

    labels = {
        'MAP':    'MAP',
        'DE':     'DE',
        'HMC':    r'HMC$_\parallel$',
        'SMC':    r'SMC$_\parallel$'
    }

        # Assemble and print (scientific notation)
    print(r"\begin{table}[H]")
    print(r"\centering")
    print(r"\caption{Comparison in ID domain among SMC$_\parallel$ ($N=10$), HMC$_{\parallel}$ ($N$ chains), DE ($N$ models) and MAP, with fixed number of leapfrog $L=1$ and $s=$, on CIFAR10 (5 realizations and $\pm$ s.e. in entropy).}")
    print(r"\begin{tabular}{l|cc|cc}")
    print(r"\toprule")
    print(r"Method & \multicolumn{2}{c}{$H_{\mathsf{ep}}$} & \multicolumn{2}{c}{$H_{\mathsf{tot}}$} \\")
    print(r"\cmidrule{2-3} \cmidrule{4-5}")
    print(r" & cor. & inc. & cor. & inc. \\")
    print(r"\midrule")
    for i, m in enumerate(methods):
        print(f"{labels[m]} & {vals['ec'][i]:.3e} & {vals['ei'][i]:.3e} & {vals['tc'][i]:.3e} & {vals['ti'][i]:.3e} \\\\")
        if m != methods[-1]:
            print(r"\midrule")
    print(r"\bottomrule")
    print(r"\end{tabular}")
    print(r"\end{table}")


    # compute ylim for full-digit
    full_vals = []
    for m in methods:
        tot, ale, epi = get_full_ood_values(m, all_data[m])
        full_vals.extend(tot + (ale if m!='MAP' else []) + (epi if m!='MAP' else []))
    full_ylim = (max(min(full_vals), 1e-10)*0.05, np.nanmax(full_vals)*3)

    for m in methods:
        plot_full_digit_ood(m, all_data[m], full_ylim)

        # --- Print all‐domains LaTeX table to terminal ---
    groups = ['close','corrupt','far','All IDs']#['Negative','Positive'] + ['close','corrupt','far','All IDs']
    methods = ['MAP','DE','HMC','SMC']

    # gather values
    full_vals = {}
    for m in methods:
        # first 7 entries: class0&1, meta, full meta, reviews, full reviews, lipsum
        tot12, ale12, epi12 = get_full_ood_values(m, all_data[m])
        # now append the “All ID” triple
        if m == 'MAP':
            # get_ood_values returns [tot8, tot9, tot_pert, tot_white, allid]
            tmp = get_ood_values(m, all_data[m])
            allid = tmp[-1]
            aT, aA, aE = allid, allid, 0.0
        else:
            # returns tot4, ale4, epi4, aT, aA, aE
            _, _, _, aT, aA, aE = get_ood_values(m, all_data[m])
        full_vals[m] = (
            tot12 + [aT],
            ale12 + [aA],
            epi12 + [aE]
        )

    # header
    print(r"\begin{table}[H]")
    print(r"\centering")
    print(r"\scriptsize")
    print(r"\caption{Comparison in all domains among SMC$_\parallel$ ($N=10$), HMC$_{\parallel}$ ($N$ chains), DE ($N$ models) and MAP, with fixed number of leapfrog $L=1$ and $s=0,$, on CIFAR10 (5 realizations and $\pm$ s.e. in entropy).}")
    print(r"\begin{tabular}{l|c|c|c|c|c|c|c|c|c|c|}")
    print(r"\toprule")
    print(r"Group & \multicolumn{1}{c|}{MAP} "
          r"& \multicolumn{3}{c|}{DE} "
          r"& \multicolumn{3}{c|}{HMC} "
          r"& \multicolumn{3}{c|}{SMC} \\")
    print(r"\cmidrule(rl){2-2} \cmidrule(rl){3-5} \cmidrule(rl){6-8} \cmidrule(rl){9-11}")
    print(r" & $H_{\mathsf{tot}}$ & $H_{\mathsf{tot}}$ & $H_{\mathsf{al}}$ & $H_{\mathsf{ep}}$ "
          r"& $H_{\mathsf{tot}}$ & $H_{\mathsf{al}}$ & $H_{\mathsf{ep}}$ "
          r"& $H_{\mathsf{tot}}$ & $H_{\mathsf{al}}$ & $H_{\mathsf{ep}}$ \\")
    print(r"\midrule")

    # rows
    for i, grp in enumerate(groups):
        entries = [grp]
        for m in methods:
            tot, ale, epi = full_vals[m]
            if m == 'MAP':
                # just H_tot
                entries.append(f"{tot[i]:.3e}")
            else:
                entries += [f"{tot[i]:.3e}", f"{ale[i]:.3e}", f"{epi[i]:.3e}"]
        print(" & ".join(entries) + r" \\")
        print(r"\midrule")

    print(r"\bottomrule")
    print(r"\end{tabular}")
    print(r"\end{table}")



if __name__ == '__main__':
    main()
