#!/usr/bin/env python3
"""
Plot and Table Generation for Bayesian NN Results, with HMC-GS integration

This script loads the aggregated MAT files for each method (MAP, DE, HMC, SMC)
plus the HMC-GS gold standard from HMC_results.txt, and generates:
1. OOD entropy bar plots + LaTeX tables (saved as PNG)
2. All-domain group comparison plots + LaTeX tables (saved as PNG)
3. In-domain correct/incorrect breakdown plots + LaTeX tables (saved as PNG)
4. Combined OOD predictive entropy subplots (MAP/DE/HMC/SMC + HMC-GS) (saved as PNG)
5. Rearranged correct/incorrect comparison plots (saved as PNG)
6. Group-method comparison plots (saved as PNG)
"""
import os, re
import numpy as np
import scipy.io as sio
from scipy.io.matlab import mat_struct
import matplotlib.pyplot as plt
import glob

# ——— choose which “P” file to load
P = 1 

# --------------------------
# Helpers to convert MATLAB structs
# --------------------------
def _todict(matobj):
    d = {}
    for f in matobj._fieldnames:
        d[f] = _check(getattr(matobj, f))
    return d


def _check(e):
    if isinstance(e, mat_struct):
        return _todict(e)
    if isinstance(e, np.ndarray) and e.dtype == object:
        return np.array([_check(x) for x in e])
    return e

# --------------------------
# Parse HMC-GS text file, including in-domain and correct/incorrect metrics
# --------------------------
def parse_gs_hmc_file(txt_path):
    text = open(txt_path).read()
    def grab(key):
        m = re.search(rf"{key}:\s*([0-9\.]+|nan)", text)
        return float(m.group(1)) if m and m.group(1) != 'nan' else np.nan

    d = {}
    # overall in-domain
    for key in ['ID_tot', 'ID_epi']:
        d[key.lower()] = [grab(key)]
    # correct/incorrect breakdown
    for key in ['ID_corr_tot', 'ID_corr_epi', 'ID_inc_tot', 'ID_inc_epi']:
        d[key.lower()] = [grab(key)]
    # OOD metrics for digits 0–7
    for i in range(0,8):
        for suffix in ['tot','epi']:
            d[f'ood_{i}_{suffix}'] = [grab(f'ID_digit{i}_{suffix}')]
    # OOD metrics for 8,9, perturbed, white
    for key in ['OOD_8_tot', 'OOD_9_tot', 'OOD_perturbed_tot', 'OOD_white_tot',
                'OOD_8_epi', 'OOD_9_epi', 'OOD_perturbed_epi', 'OOD_white_epi',
                'OOD_all_tot', 'OOD_all_epi']:
        d[key.lower()] = [grab(key)]
    return d

# --------------------------
# Load metrics for each method
# --------------------------
def load_metrics(method):
    if method == 'HMC-GS':
        return parse_gs_hmc_file('HMC_results.txt')

    # MAP and DE: exactly as before
    if method in ('MAP','DE'):
        files = {
            'MAP': 'BayesianNN_MNIST_MAP_metrics.mat',
            'DE':  'BayesianNN_MNIST_de_metrics.mat'
        }
        raw = sio.loadmat(files[method], squeeze_me=True, struct_as_record=False)
        return {k: _check(v) for k,v in raw.items() if not k.startswith('__')}

    # HMC or SMC: load the P={P} file you just set
    pattern = {
        'HMC': f'phmc_aggregated_results_P{P}.mat',
        'SMC': f'psmc_aggregated_results_P{P}.mat'
    }[method]
    candidates = glob.glob(pattern)
    if not candidates:
        raise FileNotFoundError(f"No file matching {pattern!r} found.")
    raw = sio.loadmat(candidates[0], squeeze_me=True, struct_as_record=False)
    flat = {k: _check(v) for k,v in raw.items() if not k.startswith('__')}

    # rebuild the nested dict structure your downstream code expects
    d = {}
    # — in-domain totals
    d['total_entropy_inID'] = flat.get('mean_tot_ent', np.nan)
    d['epistemic_inID']     = flat.get('mean_epi',     np.nan)

    # — correct/incorrect breakdown
    d['ID'] = {
        'mean_total_entropy_correct':   flat.get('mean_tot_corr', np.nan),
        'mean_total_entropy_incorrect': flat.get('mean_tot_inc',  np.nan),
        'mean_epistemic_correct':       flat.get('mean_epi_corr', np.nan),
        'mean_epistemic_incorrect':     flat.get('mean_epi_inc',  np.nan),
    }

    # — OOD fields
    if method == 'SMC':
        ood = {}
        for k in ['digit8','digit9','perturbed','white_noise']:
            ood[k] = {
                'mean_total_entropy': flat.get(f'ood_{k}_mean_tot', np.nan),
                'mean_epistemic':     flat.get(f'ood_{k}_mean_epi', np.nan)
            }
        d['OOD'] = ood
    else:  # HMC
        ood = {}
        for k in ['digit8','digit9','perturbed','white_noise']:
            ood[k] = {
                'mean_total_entropy': flat.get(f'ood_{k}_mean_tot', np.nan),
                'mean_epistemic':     flat.get(f'ood_{k}_mean_epi', np.nan)
            }
        d['OOD'] = ood
    
    d['per_digit_total_entropy']   = flat.get('mean_digit_tot_ent',   np.full(8, np.nan))
    d['per_digit_epistemic_entropy'] = flat.get('mean_digit_epi_ent',   np.full(8, np.nan))

    return d

# --------------------------
# Extract OOD + All-ID summary
# --------------------------
def get_ood_values(m, d):
    if m == 'MAP':
        tot = [np.mean(d.get(f'replicate_total_entropy_ood_{s}', np.nan))
               for s in ['8','9','perturbed','whitenoise']]
        allid = np.nanmean(d.get('per_digit_total_entropy', np.array([np.nan])))
        return tot + [allid]

    if m == 'HMC-GS':
        # use lowercase keys from parse_gs_hmc_file
        tot = [d['ood_8_tot'][0], d['ood_9_tot'][0],
               d['ood_perturbed_tot'][0], d['ood_white_tot'][0]]
        epi = [d['ood_8_epi'][0], d['ood_9_epi'][0],
               d['ood_perturbed_epi'][0], d['ood_white_epi'][0]]
        ale = [t - e for t,e in zip(tot, epi)]
        aT = d['ood_all_tot'][0]
        aE = d['ood_all_epi'][0]
        return tot, ale, epi, aT, aT - aE, aE

    if m == 'DE':
        tot = [np.mean(d[f'replicate_total_entropy_ood_{s}'])
               for s in ['8','9','perturbed','whitenoise']]
        epi = [np.mean(d[f'replicate_epistemic_ood_{s}'])
               for s in ['8','9','perturbed','whitenoise']]
        ale = [t - e for t,e in zip(tot, epi)]
        aT = np.nanmean(d.get('per_digit_total_entropy', np.array([np.nan])))
        aE = np.nanmean(d.get('epistemic_inID', np.nan))
        if np.isnan(aE):
            aE = 0.5 * (np.mean(d.get('replicate_epistemic_in_correct',np.nan)) +
                        np.mean(d.get('replicate_epistemic_in_incorrect',np.nan)))
        return tot, ale, epi, aT, aT - aE, aE

    if m == 'HMC':
        o = d['OOD']
        tot = [o['digit8']['mean_total_entropy'],
               o['digit9']['mean_total_entropy'],
               o['perturbed']['mean_total_entropy'],
               o['white_noise']['mean_total_entropy']]
        epi = [o['digit8']['mean_epistemic'],
               o['digit9']['mean_epistemic'],
               o['perturbed']['mean_epistemic'],
               o['white_noise']['mean_epistemic']]
        ale = [t - e for t, e in zip(tot, epi)]
        aT = d.get('total_entropy_inID', np.nan)
        aE = d.get('epistemic_inID', np.nan)
        return tot, ale, epi, aT, aT - aE, aE

    if m == 'SMC':
        o = d['OOD']
        tot = [o['digit8']['mean_total_entropy'],
               o['digit9']['mean_total_entropy'],
               o['perturbed']['mean_total_entropy'],
               o['white_noise']['mean_total_entropy']]
        epi = [o['digit8']['mean_epistemic'],
               o['digit9']['mean_epistemic'],
               o['perturbed']['mean_epistemic'],
               o['white_noise']['mean_epistemic']]
        ale = [t - e for t,e in zip(tot, epi)]
        aT = d.get('total_entropy_inID', np.nan)
        aE = d.get('epistemic_inID', np.nan)
        return tot, ale, epi, aT, aT - aE, aE

    return [], [], [], np.nan, np.nan, np.nan

# --------------------------
# Plot 1: OOD Entropy + save PNG
# --------------------------
def plot_ood(m, d, ylim):
    grp=['Digit 8','Digit 9','Perturbed','White Noise']
    fig, ax = plt.subplots()
    if m=='MAP':
        tot = get_ood_values(m,d)
        ax.bar(range(4), tot[:4], 0.5, label='Total')
    else:
        tot, ale, epi, *_ = get_ood_values(m,d)
        x = np.arange(4)
        ax.bar(x-0.25, tot, 0.25, label='Total')
        ax.bar(x,      ale, 0.25, label='Aleatoric')
        ax.bar(x+0.25, epi, 0.25, label='Epistemic')
    ax.set_xticks(range(4)); ax.set_xticklabels(grp)
    ax.set_ylim(ylim); ax.set_ylabel('Entropy'); ax.legend()
    fig.tight_layout(); fig.savefig(f"ood_entropy_{m}.png",dpi=300,bbox_inches='tight'); plt.show()


# --------------------------
# Plot 2: All-Domain Group Comparison
# --------------------------
def plot_grp(m, d, ylim):
    # unchanged
    pass

# --------------------------
# Plot 3: Correct vs Incorrect + LaTeX Table + save PNG
# --------------------------
def plot_ci(all_data, ylim):
    methods = ['MAP','DE','HMC','SMC','HMC-GS']
    # remapped tick labels
    labels = {
        'MAP': 'MAP',
        'DE': 'DE',
        'HMC': r'HMC$_\parallel$',
        'SMC': r'SMC$_\parallel$',
        'HMC-GS': 'HMC-GS'
    }

    keys = {
        'tc':['mean_total_entropy_correct','replicate_total_entropy_in_correct',
              'replicate_total_entropy_in_correct','replicate_total_entropy_in_correct'],
        'ti':['mean_total_entropy_incorrect','replicate_total_entropy_in_incorrect',
              'replicate_total_entropy_in_incorrect','replicate_total_entropy_in_incorrect'],
        'ec':['mean_epistemic_in_correct','replicate_epistemic_in_correct',
              'replicate_epistemic_in_correct'],
        'ei':['mean_epistemic_inincorrect','replicate_epistemic_in_incorrect',
              'replicate_epistemic_in_incorrect']
    }
    vals = {k:[] for k in keys}
    for m in methods:
        d = all_data[m]
        if m in ('HMC','SMC') and 'ID' in d:
            id_struct = d['ID']
            vals['tc'].append(id_struct.get('mean_total_entropy_correct', np.nan))
            vals['ti'].append(id_struct.get('mean_total_entropy_incorrect', np.nan))
            vals['ec'].append(id_struct.get('mean_epistemic_correct', np.nan))
            vals['ei'].append(id_struct.get('mean_epistemic_incorrect', np.nan))
        elif m == 'HMC-GS':
            vals['tc'].append(np.nanmean(d.get('id_corr_tot', np.nan)))
            vals['ti'].append(np.nanmean(d.get('id_inc_tot', np.nan)))
            vals['ec'].append(np.nanmean(d.get('id_corr_epi', np.nan)))
            vals['ei'].append(np.nanmean(d.get('id_inc_epi', np.nan)))
        else:
            for k, ks in keys.items():
                v = np.nan
                for key in ks:
                    if key in d:
                        v = np.mean(d[key]); break
                vals[k].append(v)

    x = np.arange(len(methods))
    w = 0.2
    fig, ax = plt.subplots(figsize=(10,6))
    ax.bar(x-1.5*w, vals['tc'], w, label='Total Correct')
    ax.bar(x-0.5*w, vals['ti'], w, label='Total Incorrect')
    ax.bar(x+0.5*w, vals['ec'], w, label='Epistemic Correct')
    ax.bar(x+1.5*w, vals['ei'], w, label='Epistemic Incorrect')

    ax.set_xticks(x)
    ax.set_xticklabels([labels[m] for m in methods])  # apply remapped labels

    ax.set_ylim(ylim)
    ax.set_ylabel('Entropy')
    ax.legend()

    fig.tight_layout()
    fig.savefig("correct_vs_incorrect.png", dpi=300, bbox_inches='tight')
    plt.show()


# --------------------------
# Plot 5: Rearranged Correct/Incorrect Comparison
# --------------------------
def plot_rearranged_ci(all_data, ylim):
    methods = ['MAP','DE','HMC','SMC','HMC-GS']
    # updated legend labels
    labels = {
        'MAP': 'MAP',
        'DE': 'DE',
        'HMC': r'HMC$_\parallel$',
        'SMC': r'SMC$_\parallel$',
        'HMC-GS': r'HMC-GS'
    }

    keys = {
        'tc':['mean_total_entropy_correct','replicate_total_entropy_in_correct'],
        'ti':['mean_total_entropy_incorrect','replicate_total_entropy_in_incorrect'],
        'ec':['mean_epistemic_in_correct','replicate_epistemic_in_correct'],
        'ei':['mean_epistemic_inincorrect','replicate_epistemic_in_incorrect']
    }
    vals = {k:[] for k in keys}
    for m in methods:
        d = all_data[m]
        if m in ('HMC','SMC') and 'ID' in d:
            id_struct = d['ID']
            vals['tc'].append(id_struct.get('mean_total_entropy_correct', np.nan))
            vals['ti'].append(id_struct.get('mean_total_entropy_incorrect', np.nan))
            vals['ec'].append(id_struct.get('mean_epistemic_correct', np.nan))
            vals['ei'].append(id_struct.get('mean_epistemic_incorrect', np.nan))
        elif m == 'HMC-GS':
            vals['tc'].append(np.nanmean(d.get('id_corr_tot', np.nan)))
            vals['ti'].append(np.nanmean(d.get('id_inc_tot', np.nan)))
            vals['ec'].append(np.nanmean(d.get('id_corr_epi', np.nan)))
            vals['ei'].append(np.nanmean(d.get('id_inc_epi', np.nan)))
        else:
            for k, ks in keys.items():
                v = np.nan
                for key in ks:
                    if key in d:
                        v = np.mean(d[key]); break
                vals[k].append(v)

    groups = ['Correct','Incorrect']
    x = np.arange(2)
    w = 0.15
    offsets = [(i-2)*w for i in range(len(methods))]

    fig, axes = plt.subplots(1,2,figsize=(12,5))

    # Total entropy panel
    ax = axes[0]
    for i, m in enumerate(methods):
        ax.bar(x+offsets[i], [vals['tc'][i], vals['ti'][i]], w,
               label=labels[m])
    ax.set_xticks(x); ax.set_xticklabels(groups)
    ax.set_ylabel('Total Entropy')
    #ax.set_title('Total: Correct vs Incorrect') 

    # Epistemic entropy panel
    ax = axes[1]
    for i, m in enumerate(methods):
        ax.bar(x+offsets[i], [vals['ec'][i], vals['ei'][i]], w,
               label=labels[m])
    ax.set_xticks(x); ax.set_xticklabels(groups)
    ax.set_ylabel('Epistemic Entropy')
    #ax.set_title('Epistemic: Correct vs Incorrect')

    # legend inside the right subplot   
    axes[1].legend(loc='upper left')

    fig.tight_layout(rect=[0,0,1,1])
    fig.savefig("rearranged_correct_incorrect.png", dpi=300, bbox_inches='tight')
    plt.show()


# --------------------------
# Plot 6: Group-Method Comparison
# --------------------------
def plot_group_methods(all_data):
    methods = ['MAP','DE','HMC','SMC','HMC-GS']
    # updated legend labels
    labels = {
        'MAP': 'MAP',
        'DE': 'DE',
        'HMC': r'HMC$_\parallel$',
        'SMC': r'SMC$_\parallel$',
        'HMC-GS': r'HMC-GS'
    }

    groups = ['Digit 8','Digit 9','Perturbed','White Noise','All IDs']

    tot_vals = {}
    epi_vals = {}
    for m in methods:
        if m == 'MAP':
            vals = get_ood_values(m, all_data[m])
            tot = vals[:4] + [vals[4]]
            epi = [0] * 5
        else:
            tot_list, ale, epi_list, aT, aA, aE = get_ood_values(m, all_data[m])
            tot = tot_list + [aT]
            epi = epi_list + [aE]
        tot_vals[m] = tot
        epi_vals[m] = epi

    x = np.arange(len(groups))
    width = 0.15
    offsets = [(i - 2) * width for i in range(len(methods))]

    fig, axes = plt.subplots(1, 2, figsize=(14, 6))

    # Total Entropy panel
    ax = axes[0]
    for i, m in enumerate(methods):
        ax.bar(x + offsets[i], tot_vals[m], width,
               label=labels[m])
    ax.set_xticks(x); ax.set_xticklabels(groups, rotation=0, ha='center')
    ax.set_ylabel('Total Entropy')
    #ax.set_title('Total Entropy by Group and Method')

    # Epistemic Entropy panel
    ax = axes[1]
    for i, m in enumerate(methods):
        ax.bar(x + offsets[i], epi_vals[m], width,
               label=labels[m])
    ax.set_xticks(x); ax.set_xticklabels(groups, rotation=0, ha='center')
    ax.set_ylabel('Epistemic Entropy')
    #ax.set_title('Epistemic Entropy by Group and Method')

    # legend inside the right subplot
    axes[1].legend(loc='upper left')

    fig.tight_layout()
    fig.savefig("group_methods_comparison.png", dpi=300, bbox_inches='tight')
    plt.show()


# --------------------------
# NEW: Plot 7: Full-Digit OOD Entropy (0–9, Perturbed, White Noise, All IDs)
# --------------------------
def get_full_ood_values(method, d):
    if method == 'MAP':
        t_arr = d.get('per_digit_total_entropy', None)
        tot = [np.mean(t_arr[i]) for i in range(8)] if t_arr is not None else [np.nan]*8
        for key in ['8','9','perturbed','whitenoise']:
            tot.append(np.nanmean(d.get(f'replicate_total_entropy_ood_{key}', [np.nan])))
        return tot, [0]*len(tot), [0]*len(tot)

    if method == 'HMC-GS':
        # reuse parse output
        tot = [d[f'ood_{i}_tot'][0] for i in range(0,8)]
        epi = [d[f'ood_{i}_epi'][0] for i in range(0,8)]
        ale = [t - e for t,e in zip(tot, epi)]
        # append 8,9, perturbed, white
        for suffix in [('8','tot'),('9','tot'),('perturbed','tot'),('white','tot')]:
            tot.append(d[f'ood_{suffix[0]}_{suffix[1]}'][0])
        for suffix in [('8','epi'),('9','epi'),('perturbed','epi'),('white','epi')]:
            epi.append(d[f'ood_{suffix[0]}_{suffix[1]}'][0])
        ale.extend([t - e for t,e in zip(tot[-4:], epi[-4:])])
        return tot, ale, epi

    # Other methods: fallback logic
    tot, ale, epi = [], [], []
    # digits 0–7
    t_arr = d.get('per_digit_total_entropy', None)
    e_arr = d.get('per_digit_epistemic_entropy', None)
    if t_arr is not None and e_arr is not None:
        for i in range(8):
            t = np.mean(t_arr[i]); e = np.mean(e_arr[i])
            tot.append(t); epi.append(e); ale.append(t - e)
    else:
        tot.extend([np.nan]*8); epi.extend([np.nan]*8); ale.extend([np.nan]*8)
    # digits 8,9, perturbed, whitenoise
    if method == 'DE':
        keys = ['8', '9', 'perturbed', 'whitenoise']
        for key in keys:
            t = np.nanmean(d.get(f'replicate_total_entropy_ood_{key}', np.array([np.nan])))
            e = np.nanmean(d.get(f'replicate_epistemic_ood_{key}', np.array([np.nan])))
            tot.append(t); epi.append(e); ale.append(t - e)
    elif method == 'SMC':
        # SMC nested OOD dict
        o = d.get('OOD', {})
        mapping = [('digit8', 'digit8'), ('digit9', 'digit9'),
                   ('perturbed', 'perturbed'), ('white_noise', 'white_noise')]
        for key, blk_key in mapping:
            blk = o.get(blk_key, {})
            t = blk.get('mean_total_entropy', np.nan)
            e = blk.get('mean_epistemic', np.nan)
            tot.append(t); epi.append(e); ale.append(t - e)
    elif method == 'HMC':
        o = d.get('OOD', {})
        mapping = [('digit8', 'digit8'), ('digit9', 'digit9'),
                   ('perturbed', 'perturbed'), ('white_noise', 'white_noise')]
        for key, blk_key in mapping:
            blk = o.get(blk_key, {})
            t = blk.get('mean_total_entropy', np.nan)
            e = blk.get('mean_epistemic', np.nan)
            tot.append(t); epi.append(e); ale.append(t - e)
    else:
        # unknown method
        tot.extend([np.nan]*4); epi.extend([np.nan]*4); ale.extend([np.nan]*4)

    return tot, ale, epi


def plot_full_digit_ood(m, d, ylim):
    groups = [f'Digit {i}' for i in range(10)] + ['Perturbed','White Noise','All IDs']
    tot, ale, epi = get_full_ood_values(m, d)
    # for MAP: ale/epi zero; for others, append All-ID from get_ood_values
    if m == 'MAP':
        allid = np.nanmean(d.get('per_digit_total_entropy', [np.nan]))
        tot.append(allid); ale.append(0.0); epi.append(0.0)
    else:
        _, _, _, aT, _, aE = get_ood_values(m, d)
        tot.append(aT); ale.append(aT - aE); epi.append(aE)

    x = np.arange(len(groups))
    fig, ax = plt.subplots(figsize=(12,6))
    if m == 'MAP':
        ax.bar(x, tot, 0.6, label='Total')
    else:
        width = 0.25
        ax.bar(x-width, tot, width, label='Total')
        ax.bar(x, ale, width, label='Aleatoric')
        ax.bar(x+width, epi, width, label='Epistemic')
    ax.set_xticks(x); ax.set_xticklabels(groups, rotation=0, ha='center')
    ax.set_ylim(ylim); ax.set_ylabel('Entropy'); ax.legend()
    fig.tight_layout(); fig.savefig(f"full_ood_entropy_{m}.png",dpi=300,bbox_inches='tight'); plt.show()


# --------------------------
# Plot 8: Group-Method Comparison with In-Domain Correct vs Incorrect
# --------------------------
def plot_group_methods_ci(all_data):
    methods = ['MAP','DE','HMC','SMC','HMC-GS']
    labels = {
        'MAP': 'MAP',
        'DE': 'DE',
        'HMC': r'HMC$_\parallel$',
        'SMC': r'SMC$_\parallel$',
        'HMC-GS': 'HMC-GS'
    }
    groups = ['Digit 8','Digit 9','Perturbed','White Noise','ID Cor.','ID Inc.']

    # prepare storage
    tot_vals = {m: [] for m in methods}
    epi_vals = {m: [] for m in methods}

    for m in methods:
        d = all_data[m]
        # first four OOD groups
        if m == 'MAP':
            vals = get_ood_values(m, d)
            tot = vals[:4]
            epi = [0]*4
        else:
            tot_list, ale, epi_list, *_ = get_ood_values(m, d)
            tot = tot_list
            epi = epi_list
        # in-domain correct/incorrect
        if m in ('HMC','SMC'):
            id_s = d['ID']
            tot += [id_s.get('mean_total_entropy_correct', np.nan), id_s.get('mean_total_entropy_incorrect', np.nan)]
            epi += [id_s.get('mean_epistemic_correct',   np.nan), id_s.get('mean_epistemic_incorrect',   np.nan)]
        elif m == 'HMC-GS':
            tot += [np.nanmean(d.get('id_corr_tot', np.nan)), np.nanmean(d.get('id_inc_tot', np.nan))]
            epi += [np.nanmean(d.get('id_corr_epi', np.nan)), np.nanmean(d.get('id_inc_epi', np.nan))]
        else:  # MAP, DE
            # correct
            tc = np.mean(d.get('replicate_total_entropy_in_correct', np.nan))
            ec = np.mean(d.get('replicate_epistemic_in_correct',     np.nan))
            # incorrect
            ti = np.mean(d.get('replicate_total_entropy_in_incorrect', np.nan))
            ei = np.mean(d.get('replicate_epistemic_in_incorrect',     np.nan))
            tot += [tc, ti]
            epi += [ec, ei]
        tot_vals[m] = tot
        epi_vals[m] = epi

    x = np.arange(len(groups))
    width = 0.15
    offsets = [(i - 2) * width for i in range(len(methods))]

    fig, axes = plt.subplots(1, 2, figsize=(14,6))
    # total panel
    ax = axes[0]
    for i, m in enumerate(methods):
        ax.bar(x + offsets[i], tot_vals[m], width, label=labels[m])
    ax.set_xticks(x); ax.set_xticklabels(groups, rotation=0, ha='center')
    ax.set_ylabel('Total Entropy')

    # epistemic panel
    ax = axes[1]
    for i, m in enumerate(methods):
        ax.bar(x + offsets[i], epi_vals[m], width, label=labels[m])
    ax.set_xticks(x); ax.set_xticklabels(groups, rotation=0, ha='center')
    ax.set_ylabel('Epistemic Entropy')
    axes[1].legend(loc='upper left')

    fig.tight_layout()
    fig.savefig("group_methods_comparison_ci.png", dpi=300, bbox_inches='tight')
    plt.show()


# --------------------------
# main()
# --------------------------
def main():
    methods = ['MAP','DE','HMC','SMC','HMC-GS']
    all_data = {m: load_metrics(m) for m in methods}

    # compute common ylim for 1–6
    ood_vals = []
    for m in methods:
        vals = get_ood_values(m, all_data[m])
        ood_vals.extend(vals[:4] if m=='MAP' else vals[0])
    ylim = (0, max(ood_vals)*1.1)

    for m in methods: plot_ood(m, all_data[m], ylim)
    for m in methods: plot_grp(m, all_data[m], ylim)
    plot_ci(all_data, ylim)
    plot_rearranged_ci(all_data, ylim)
    plot_group_methods(all_data)
    # NEW: add group-method plot including ID correct/incorrect
    plot_group_methods_ci(all_data)

        # --- Print in-domain CI LaTeX table to terminal ---
    keys = {
        'tc': ['mean_total_entropy_correct',   'replicate_total_entropy_in_correct'],
        'ti': ['mean_total_entropy_incorrect', 'replicate_total_entropy_in_incorrect'],
        'ec': ['mean_epistemic_in_correct',    'replicate_epistemic_in_correct'],
        'ei': ['mean_epistemic_inincorrect',   'replicate_epistemic_in_incorrect']
    }
    vals = {k: [] for k in keys}
    for m in methods:
        d = all_data[m]
        if m in ('HMC','SMC') and 'ID' in d:
            id_s = d['ID']
            vals['tc'].append(id_s.get('mean_total_entropy_correct',   np.nan))
            vals['ti'].append(id_s.get('mean_total_entropy_incorrect', np.nan))
            vals['ec'].append(id_s.get('mean_epistemic_correct',       np.nan))
            vals['ei'].append(id_s.get('mean_epistemic_incorrect',     np.nan))
        elif m == 'HMC-GS':
            vals['tc'].append(np.nanmean(d.get('id_corr_tot', np.nan)))
            vals['ti'].append(np.nanmean(d.get('id_inc_tot',  np.nan)))
            vals['ec'].append(np.nanmean(d.get('id_corr_epi', np.nan)))
            vals['ei'].append(np.nanmean(d.get('id_inc_epi',  np.nan)))
        else:
            for k, ks in keys.items():
                v = np.nan
                for key in ks:
                    if key in d:
                        v = np.mean(d[key])
                        break
                vals[k].append(v)

    labels = {
        'MAP':    'MAP',
        'DE':     'DE',
        'HMC':    r'HMC$_\parallel$',
        'SMC':    r'SMC$_\parallel$',
        'HMC-GS': 'HMC-GS'
    }

        # Assemble and print (scientific notation)
    print(r"\begin{table}[H]")
    print(r"\centering")
    print(r"\caption{Comparison in ID domain among SMC$_\parallel$ ($N=10$), HMC$_{\parallel}$ ($N$ chains), HMC-GS ($2.0e+4$ samples), DE ($N$ models) and MAP, with fixed number of leapfrog $L=1$ and $s=\frac{1}{4}$, on MNIST (5 realizations and $\pm$ s.e. in entropy).}")
    print(r"\begin{tabular}{l|cc|cc}")
    print(r"\toprule")
    print(r"Method & \multicolumn{2}{c}{$H_{\mathsf{ep}}$} & \multicolumn{2}{c}{$H_{\mathsf{tot}}$} \\")
    print(r"\cmidrule{2-3} \cmidrule{4-5}")
    print(r" & cor. & inc. & cor. & inc. \\")
    print(r"\midrule")
    for i, m in enumerate(methods):
        print(f"{labels[m]} & {vals['ec'][i]:.3e} & {vals['ei'][i]:.3e} & {vals['tc'][i]:.3e} & {vals['ti'][i]:.3e} \\\\")
        if m != methods[-1]:
            print(r"\midrule")
    print(r"\bottomrule")
    print(r"\end{tabular}")
    print(r"\end{table}")


    # compute ylim for full-digit
    full_vals = []
    for m in methods:
        tot, ale, epi = get_full_ood_values(m, all_data[m])
        full_vals.extend(tot + (ale if m!='MAP' else []) + (epi if m!='MAP' else []))
    full_ylim = (0, max(full_vals)*1.1)

    for m in methods:
        plot_full_digit_ood(m, all_data[m], full_ylim)

        # --- Print all‐domains LaTeX table to terminal ---
    groups = [f'Digit {i}' for i in range(8)] + ['Digit 8','Digit 9','Perturbed','White Noise','All ID']
    methods = ['MAP','DE','HMC','SMC','HMC-GS']

    # gather values
    full_vals = {}
    for m in methods:
        # first 12 entries: digits 0–7, 8, 9, perturbed, white
        tot12, ale12, epi12 = get_full_ood_values(m, all_data[m])
        # now append the “All ID” triple
        if m == 'MAP':
            # get_ood_values returns [tot8, tot9, tot_pert, tot_white, allid]
            tmp = get_ood_values(m, all_data[m])
            allid = tmp[-1]
            aT, aA, aE = allid, allid, 0.0
        else:
            # returns tot4, ale4, epi4, aT, aA, aE
            _, _, _, aT, aA, aE = get_ood_values(m, all_data[m])
        full_vals[m] = (
            tot12 + [aT],
            ale12 + [aA],
            epi12 + [aE]
        )

    # header
    print(r"\begin{table}[H]")
    print(r"\centering")
    print(r"\scriptsize")
    print(r"\caption{Comparison in all domains among SMC$_\parallel$ ($N=10$), HMC$_{\parallel}$ ($N$ chains), HMC-GS ($2.0e+4$ samples), DE ($N$ models) and MAP, with fixed number of leapfrog $L=1$ and $s=0,\frac{1}{4}$, on MNIST (5 realizations and $\pm$ s.e. in entropy).}")
    print(r"\begin{tabular}{l|c|c|c|c|c|c|c|c|c|c|c|c|c}")
    print(r"\toprule")
    print(r"Group & \multicolumn{1}{c|}{MAP} "
          r"& \multicolumn{3}{c|}{DE} "
          r"& \multicolumn{3}{c|}{HMC} "
          r"& \multicolumn{3}{c|}{SMC} "
          r"& \multicolumn{3}{c}{HMC-GS}\\")
    print(r"\cmidrule(rl){2-2} \cmidrule(rl){3-5} \cmidrule(rl){6-8} \cmidrule(rl){9-11} \cmidrule(rl){12-14}")
    print(r" & $H_{\mathsf{tot}}$ & $H_{\mathsf{tot}}$ & $H_{\mathsf{al}}$ & $H_{\mathsf{ep}}$ "
          r"& $H_{\mathsf{tot}}$ & $H_{\mathsf{al}}$ & $H_{\mathsf{ep}}$ "
          r"& $H_{\mathsf{tot}}$ & $H_{\mathsf{al}}$ & $H_{\mathsf{ep}}$ "
          r"& $H_{\mathsf{tot}}$ & $H_{\mathsf{al}}$ & $H_{\mathsf{ep}}$ \\")
    print(r"\midrule")

    # rows
    for i, grp in enumerate(groups):
        entries = [grp]
        for m in methods:
            tot, ale, epi = full_vals[m]
            if m == 'MAP':
                # just H_tot
                entries.append(f"{tot[i]:.3e}")
            else:
                entries += [f"{tot[i]:.3e}", f"{ale[i]:.3e}", f"{epi[i]:.3e}"]
        print(" & ".join(entries) + r" \\")
        print(r"\midrule")

    print(r"\bottomrule")
    print(r"\end{tabular}")
    print(r"\end{table}")



if __name__ == '__main__':
    main()
