#!/usr/bin/env python3
"""
TDDFT Results Analysis and Visualization

This script analyzes time-dependent DFT (TDDFT) excitation energies from ML-XC functionals
and generates publication-quality comparison plots against reference SCAN functional.

Purpose:
    Compare TDDFT excitation energies calculated with learned functionals (DEI-XC pipeline)
    against the SCAN reference functional (SCAN-repo pipeline) to evaluate how well the
    ML-XC functional reproduces excited-state properties.

Loss Configurations Analyzed:
    The script handles multiple loss function combinations including:
    - E: Energy only
    - ρ (P): Density
    - F: Forces
    - ORG: Orbital rotation gradient (∇)
    - ORH: Orbital rotation Hessian (H)
    - Etot: Total energy
    - Exc: XC energy
    - Vxc: XC potential

Metrics:
    - MAE: Mean absolute error vs SCAN-repo (eV)
    - Relative MAE: Mean absolute relative deviation (%)
    - Per-state analysis: Breakdown by excitation state index
    - Per-checkpoint analysis: Individual model comparisons

Input:
    CSV file with TDDFT results containing columns:
    - pipeline: 'DEI-XC' or 'SCAN-repo'
    - energy_eV: Excitation energy
    - checkpoint_id: Model checkpoint identifier
    - molecule_idx, state_index: Molecular system and excited state
    - dynamic_loss.relative_weights.*: Loss function weights

Output:
    Generates plots in specified output directory:
    - mae_by_checkpoint_overall_*.png: Overall MAE by checkpoint
    - mae_by_checkpoint_per_state_*.png: MAE broken down by state
    - mae_by_config_overall_*.png: MAE averaged over random seeds
    - mae_by_config_overall_relative_*.png: Relative MAE (%)
    - mae_by_config_per_state_*.png: Per-state MAE by configuration

TDDFT Context:
    TDDFT calculations probe the orbital derivatives of the learned functional through
    linear response theory. Good TDDFT accuracy indicates the functional correctly captures
    the orbital energy structure and XC kernel, beyond just ground-state properties.
"""

import argparse
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scienceplots  # noqa
import seaborn as sns

import wandb

# Plot styling configuration
plt.style.use(['science', 'grid'])
plt.rcParams.update({'font.size': 16})

# Set seaborn colorblind palette
sns.set_theme(style='whitegrid')
colors = sns.color_palette('deep')

# Map loss weight column names to human-readable labels
WEIGHT_COLS_MAP = {
    'dynamic_loss.relative_weights.density': 'ρ',
    'dynamic_loss.relative_weights.forces': 'F',
    'dynamic_loss.relative_weights.orbital_rotation_gradient': '∇',
    'dynamic_loss.relative_weights.orbital_rotation_hessian': 'H',
    'dynamic_loss.relative_weights.total_energy': 'E',
    'dynamic_loss.relative_weights.xc_energy': 'Exc',
    'dynamic_loss.relative_weights.xc_potential': 'Vxc',
}

# B3LYP checkpoint IDs to include in analysis
B3LYP_CHECKPOINTS = {
    'EGXC_b3lyp_qm7': {
        'baseline': 58,
        'gradient': 61,
        'grad_and_hess': 76,
    },
    'NNmGGA_b3lyp_qm7': {
        'baseline': 64,
        'grad': 10, #58, 10,
        'grad_and_hessian': 11,
    },
    'Skala_mGGA_b3lyp_qm7': {
        'baseline': 6,
        'gradient': 18,
        'grad_and_hessian': 21,
    },
}


def parse_args():
    parser = argparse.ArgumentParser(description="Plot TDDFT MAE analysis.")
    parser.add_argument(
        "--input",
        type=str,
        default="./results_tddft_XCdiff_scan_qm5_QM5_n5_df_all_checkpoints_combined-2.csv",
        help="Path to the combined CSV file containing TDDFT results.",
    )
    parser.add_argument(
        "--output-dir",
        type=str,
        default="./plots",
        help="Directory to save generated plots.",
    )
    parser.add_argument(
        "--wandb-project",
        type=str,
        default="tddft-analysis",
        help="WandB project name.",
    )
    parser.add_argument(
        "--wandb-entity",
        type=str,
        default=None,
        help="WandB entity/username.",
    )
    parser.add_argument(
        "--wandb",
        action="store_true",
        help="Enable WandB logging.",
    )
    return parser.parse_args()


def get_config_signature(row, weight_cols):
    """
    Create a unique signature for the dynamic loss weights.

    Args:
        row: DataFrame row containing weight columns
        weight_cols: List of column names containing loss weights

    Returns:
        String signature representing the configuration
    """
    weights_dict = {col: row[col] for col in weight_cols}
    # Sort by key to ensure consistency across runs
    signature = str(sorted(weights_dict.items()))
    return signature


def format_config_label(config_str, weight_map):
    """
    Convert configuration string to human-readable label.

    Args:
        config_str: Raw configuration string from get_config_signature
        weight_map: Dictionary mapping column names to readable labels

    Returns:
        Formatted string with short labels and non-zero weights only
    """
    # Remove brackets and quotes
    label = config_str.replace('[', '').replace(']', '').replace('"', '').replace("'", '')

    # Remove parentheses and commas at end of entries
    label = label.replace('(', '').replace(')', '')
    label = label + ', '  # Add trailing comma for consistent parsing

    # Remove zero-valued weights
    for weight_col in weight_map.keys():
        label = label.replace(f'{weight_col}:0.0,', '')
        label = label.replace(f'{weight_col}, 0.0,', '')

    # Replace full column names with short labels
    for full_name, short_name in weight_map.items():
        label = label.replace(full_name, short_name)

    # Clean up formatting
    label = label.replace(' ', '')  # Remove whitespace
    label = label.replace(',', ' ')  # Commas to spaces
    label = label.replace('.0 ', ' ')  # Remove .0 decimals
    label = label.strip()

    return label


def preprocess_data(df):
    """
    Pivot TDDFT data to compare DEI-XC vs SCAN-repo and calculate error metrics.

    Args:
        df: DataFrame with TDDFT results from both pipelines

    Returns:
        Pivoted DataFrame with absolute and relative errors
    """
    print(f'\n=== Data Preprocessing ===')
    print(f'Loaded {len(df)} rows')

    # Filter for ML functional (DEI-XC) and reference (SCAN-repo) pipelines
    relevant_pipelines = ['DEI-XC', 'SCAN-repo']
    print(f'Pipelines found: {df["pipeline"].unique()}')
    df = df[df['pipeline'].isin(relevant_pipelines)].copy()

    # Identify dynamic loss weight columns for configuration grouping
    weight_cols = [
        col for col in df.columns if col.startswith('dynamic_loss.relative_weights')
    ]

    # Print dataset statistics
    print(f'\nDataset statistics:')
    print(f'  Molecules: {len(df["molecule_idx"].unique())}')
    print(f'  Excited states: {len(df["state_index"].unique())}')
    print(f'  Methods: {len(df["method"].unique())}')
    print(f'  Checkpoints: {len(df["checkpoint_id"].unique())}')
    print(f'  Random seeds: {len(df["run_seed"].unique())}')

    # Rename the specific config_path because of a mistake in the naming convention.
    df["config_path"] = df["config_path"].replace(
        "evaluations/tddft/b3lyp/NNmGGA/grad/NNmGGA2_b3lyp_qm7_58/checkpoint/NNmGGA2_b3lyp_qm7_58.yaml",
        "evaluations/tddft/b3lyp/NNmGGA/grad/NNmGGA2_b3lyp_qm7_10/checkpoint/NNmGGA2_b3lyp_qm7_10.yaml"
    )
    print("Config paths: ", df["config_path"].unique())

    # Define index columns for pivoting
    # Each unique combination identifies one calculation
    index_cols = [
        'checkpoint_id',
        'molecule_idx',
        'state_index',
        'method',
        'n_atoms',
        'elements',
        'run_seed',
        'config_path',
    ] + weight_cols

    # Pivot to get DEI-XC and SCAN-repo energies as separate columns
    print(f'\nPivoting data...')
    data_pivoted = df.pivot_table(
        index=index_cols, columns='pipeline', values='energy_eV'
    ).reset_index()
    print(f'Pivoted to {len(data_pivoted)} rows')

    # Calculate error metrics
    if 'DEI-XC' in data_pivoted.columns and 'SCAN-repo' in data_pivoted.columns:
        # Absolute error in eV
        data_pivoted['abs_error'] = (data_pivoted['DEI-XC'] - data_pivoted['SCAN-repo']).abs()
        data_pivoted['error'] = data_pivoted['DEI-XC'] - data_pivoted['SCAN-repo']

        # Relative error as percentage
        data_pivoted['rel_error'] = (
            data_pivoted['abs_error'] / data_pivoted['SCAN-repo'].abs()
        ) * 100.0
    else:
        raise ValueError(
            'Input CSV must contain both "DEI-XC" and "SCAN-repo" pipelines.'
        )

    # Create configuration identifier from loss weights
    data_pivoted['config_id'] = data_pivoted.apply(
        lambda row: get_config_signature(row, weight_cols), axis=1
    )

    # Map configurations to short names for plotting
    unique_configs = data_pivoted['config_id'].unique()
    config_map = {cfg: f'Cfg {i}' for i, cfg in enumerate(unique_configs)}
    data_pivoted['config_short'] = data_pivoted['config_id'].map(config_map)

    return data_pivoted


def plot_mae(
    data,
    x_col,
    hue_col,
    title,
    filename,
    output_dir,
    use_wandb,
    errorbar=('ci', 95),
    write_values_over_bars=True,
    relative_error=False,
    use_log_scale=True,
):
    """
    Create bar plot comparing TDDFT MAE across configurations.

    Args:
        data: DataFrame with error metrics
        x_col: Column name for x-axis (e.g., 'checkpoint_id', 'config_human_readable')
        hue_col: Column name for grouping bars (e.g., 'state_index'), or None
        title: Plot title
        filename: Output filename
        output_dir: Directory to save plot
        use_wandb: Whether to log to Weights & Biases
        errorbar: Error bar style - ('ci', 95) for 95% confidence interval,
                  'sd' for standard deviation, 'se' for standard error
        write_values_over_bars: Whether to annotate bars with values
        relative_error: If True, plot relative error (%), else absolute error (eV)
        use_log_scale: Whether to use log scale for y-axis (default: True)
    """
    fig, ax = plt.subplots(figsize=(14, 8))

    # Select metric and label
    if relative_error:
        y_col = 'rel_error'
        y_label = 'Relative MAE vs SCAN-repo (%)'
    else:
        y_col = 'abs_error'
        y_label = 'MAE vs SCAN-repo (eV)'

    # Create bar plot with improved styling
    plot_kwargs = {
        'data': data,
        'x': x_col,
        'y': y_col,
        'hue': hue_col,
        'errorbar': errorbar,
        'capsize': 1.5,
        'err_kws': {'linewidth': 1.0},
        'alpha': 1.0,
        'ax': ax,
        'edgecolor': 'black',
        'linewidth': 0.5,
    }

    # Only add palette if hue is specified
    if hue_col is not None:
        plot_kwargs['palette'] = 'deep'
    else:
        plot_kwargs['color'] = colors[0]

    sns.barplot(**plot_kwargs)

    # Apply log scale if requested
    if use_log_scale:
        ax.set_yscale('log')

    # Formatting
    ax.set_title(title, fontsize=18, pad=15)
    ax.set_ylabel(y_label, fontsize=16)
    ax.set_xlabel(None)
    ax.tick_params(axis='x', rotation=45, labelsize=14)

    # Disable LaTeX rendering for x-tick labels (they contain Unicode characters)
    for label in ax.get_xticklabels():
        label.set_usetex(False)
    plt.setp(ax.get_xticklabels(), ha='right')

    # Improved tick positioning (from tabel_visualization.py)
    ax.yaxis.set_tick_params(pad=18)
    for label in ax.get_yticklabels():
        label.set_ha('left')

    # Cleaner grid styling (from tabel_visualization.py)
    ax.grid(True, axis='y', alpha=0.8, linewidth=0.5, zorder=0)
    ax.grid(False, axis='x')
    ax.set_axisbelow(True)

    # Annotate bars with values (optional)
    if write_values_over_bars and not use_log_scale:
        for container in ax.containers:
            if hasattr(container, 'datavalues'):
                for bar, value in zip(container, container.datavalues):
                    height = bar.get_height()
                    if height > 0:  # Only annotate positive values
                        ax.annotate(
                            f'{height:.2f}',
                            xy=(bar.get_x() + bar.get_width() / 2, height),
                            xytext=(0, 5),
                            textcoords='offset points',
                            ha='center',
                            va='bottom',
                            fontsize=10,
                            color='black',
                        )

    # Configure legend
    if hue_col:
        ax.legend(
            title=hue_col.replace('_', ' ').title(),
            frameon=False,
            loc='upper right',
        )

    plt.tight_layout()

    # Save plot
    save_path = output_dir / filename
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f'Saved: {save_path}')

    # Log to WandB if enabled
    if use_wandb:
        wandb.log({filename.replace('.png', ''): wandb.Image(str(save_path))})

    plt.close()


def main():
    """Main analysis and plotting pipeline."""
    args = parse_args()
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    print(f'\n=== TDDFT Analysis Pipeline ===')
    print(f'Input: {args.input}')
    print(f'Output: {output_dir}')

    # Initialize WandB if requested
    use_wandb = args.wandb
    if use_wandb:
        wandb.init(
            project=args.wandb_project,
            entity=args.wandb_entity,
            config={'input_file': args.input},
        )
        print('WandB logging enabled')

    # Load raw data
    df_raw = pd.read_csv(args.input)
    print(f"Columns: {df_raw.columns}")

    # Preprocess and calculate errors
    data_all = preprocess_data(df_raw)

    # Analyze each TDDFT method separately (TDA, TDDFT, etc.)
    methods = ['TDDFT']  # Can extend to ['TDA', 'TDDFT'] if needed

    for method in methods:
        print(f'\n=== Processing Method: {method} ===')
        data = data_all[data_all['method'] == method].copy()
        print(f"Columns: {data.columns}")

        if data.empty:
            print(f'No data found for method {method}, skipping...')
            continue

        method_suffix = f'_{method}'

        # Create human-readable configuration labels
        data['config_human_readable'] = data['config_id'].apply(
            lambda cfg: format_config_label(cfg, WEIGHT_COLS_MAP)
        )

        # Create checkpoint labels with configuration and checkpoint ID
        # Extract ID from YAML filename in config_path (e.g., 'EGXC_b3lyp_qm7_58.yaml' -> '58')
        yaml_id = data['config_path'].str.split('/').str[-1].str.replace('.yaml', '').str.split('_').str[-1]
        data['yaml_id'] = yaml_id

        # Filter to only include B3LYP checkpoints specified in B3LYP_CHECKPOINTS
        # Also create reverse mapping from ID to loss configuration key
        valid_ids = set()
        id_to_loss_key = {}
        for model_name, model_checkpoints in B3LYP_CHECKPOINTS.items():
            for loss_key, id_val in model_checkpoints.items():
                valid_ids.add(str(id_val))
                id_to_loss_key[str(id_val)] = loss_key

        print(f'\n=== Filtering to B3LYP checkpoints ===')
        print(f'Valid checkpoint IDs: {sorted(valid_ids)}')
        print(f'IDs before: {sorted(data["yaml_id"].unique())}')
        print(f'ID to loss key mapping: {id_to_loss_key} {len(id_to_loss_key)}')
        print(f'Rows before filtering: {len(data)}')

        data = data[data['yaml_id'].isin(valid_ids)].copy()

        print(f'Rows after filtering: {len(data)}')
        print(f'{len(data["yaml_id"].unique())} unique checkpoints after filtering: {sorted(data["yaml_id"].unique())}')

        if data.empty:
            print(f'No data remains after filtering for B3LYP checkpoints, skipping...')
            continue

        # Map ID to loss configuration key
        data['loss_name'] = data['yaml_id'].map(id_to_loss_key)

        # Map loss configuration keys to descriptive names
        loss_key_to_description = {
            'baseline': 'E + rho',
            'grad': 'E + rho + grad',
            'gradient': 'E + rho + grad',
            'hess': 'E + rho + hess',
            'hessian': 'E + rho + hess',
            'grad_and_hess': 'E + rho + grad + hess',
            'grad_and_hessian': 'E + rho + grad + hess',
        }
        data['loss_description'] = data['loss_name'].map(loss_key_to_description)

        print(f'Loss names found: {sorted(data["loss_name"].unique())}')
        print(f'Loss descriptions: {sorted(data["loss_description"].unique())}')

        data['ckpt_human_readable'] = (
            data['config_human_readable'] +
            ' (' + data['yaml_id'] + ')'
        )

        # Sort by label length for cleaner visualization
        data = data.sort_values(by='ckpt_human_readable', key=lambda x: x.str.len())

        print(f'Data summary:')
        print(f'  Checkpoints: {len(data["checkpoint_id"].unique())}')
        print(f'  States: {len(data["state_index"].unique())}')
        print(f'  Molecules: {len(data["molecule_idx"].unique())}')
        print(f'  Total entries: {len(data)}')

        # # ===== Per-Checkpoint Analysis (Individual Runs) =====

        # # Overall MAE across all states and molecules
        # plot_mae(
        #     data=data,
        #     x_col='ckpt_human_readable',
        #     hue_col=None,
        #     title=f'{method} All-State MAE',
        #     filename=f'mae_by_checkpoint_overall{method_suffix}.png',
        #     output_dir=output_dir,
        #     use_wandb=use_wandb,
        #     write_values_over_bars=False,
        #     use_log_scale=True,
        # )

        # # MAE broken down by excited state index
        # plot_mae(
        #     data=data,
        #     x_col='ckpt_human_readable',
        #     hue_col='state_index',
        #     title=f'{method} Per-State MAE',
        #     filename=f'mae_by_checkpoint_per_state{method_suffix}.png',
        #     output_dir=output_dir,
        #     use_wandb=use_wandb,
        #     write_values_over_bars=False,
        #     use_log_scale=True,
        # )

        # # ===== Per-Configuration Analysis (Averaged over Random Seeds) =====

        # # Overall MAE
        # plot_mae(
        #     data=data,
        #     x_col='config_human_readable',
        #     hue_col=None,
        #     title=f'{method} All-State MAE (Avg over Seeds)',
        #     filename=f'mae_by_config_overall{method_suffix}.png',
        #     output_dir=output_dir,
        #     use_wandb=use_wandb,
        #     write_values_over_bars=False,
        #     use_log_scale=True,
        # )

        # # Overall relative MAE (percentage)
        # plot_mae(
        #     data=data,
        #     x_col='config_human_readable',
        #     hue_col=None,
        #     title=f'{method} All-State Relative MAE (Avg over Seeds)',
        #     filename=f'mae_by_config_overall_relative{method_suffix}.png',
        #     output_dir=output_dir,
        #     use_wandb=use_wandb,
        #     relative_error=True,
        #     write_values_over_bars=False,
        #     use_log_scale=True,
        # )

        # # Per-state MAE
        # plot_mae(
        #     data=data,
        #     x_col='config_human_readable',
        #     hue_col='state_index',
        #     title=f'{method} Per-State MAE (Avg over Seeds)',
        #     filename=f'mae_by_config_per_state{method_suffix}.png',
        #     output_dir=output_dir,
        #     use_wandb=use_wandb,
        #     write_values_over_bars=False,
        #     use_log_scale=True,
        # )

        # ===== Combined Figure: tabel_visualization.py style =====
        # 2 rows x 3 columns: Overall MAE + 5 per-state MAEs
        ERROR_STYLE = None # std, se, None
        print(f'\n=== Generating combined figure (tabel_visualization.py style) ===')
        print(f'Error style: {ERROR_STYLE}')

        # Extract model name from checkpoint_id
        def extract_model(checkpoint_id):
            """Extract model architecture name from checkpoint ID."""
            if checkpoint_id.startswith('EGXC'):
                return 'EGXC'
            elif checkpoint_id.startswith('NNmGGA'):
                return 'NNmGGA'
            elif checkpoint_id.startswith('Skala'):
                return 'Skala-mGGA'
            else:
                return 'Unknown'

        # Determine if checkpoint is _more or _less variant
        def get_variant(checkpoint_id):
            """Determine if checkpoint is _more or _less variant."""
            if '_less' in checkpoint_id:
                return 'less'
            elif '_more' in checkpoint_id:
                return 'more'
            else:
                return 'baseline'

        data['model'] = data['checkpoint_id'].apply(extract_model)
        data['variant'] = data['checkpoint_id'].apply(get_variant)

        # Filter out hessian-only configurations
        data = data[~data['loss_name'].isin(['hessian', 'hess'])].copy()

        # rename grad_and_hess to grad_and_hessian
        data['loss_name'] = data['loss_name'].replace('grad_and_hess', 'grad_and_hessian')
        # rename grad to gradient
        data['loss_name'] = data['loss_name'].replace('grad', 'gradient')

        # Print detailed breakdown of loss configurations
        print('\n=== Loss Configuration Breakdown ===')
        for model in data['model'].unique():
            model_data = data[data['model'] == model]
            print(f'\n{model}:')
            for checkpoint in model_data['checkpoint_id'].unique():
                ckpt_data = model_data[model_data['checkpoint_id'] == checkpoint].iloc[0]
                loss_name = ckpt_data['loss_name']
                loss_desc = ckpt_data['loss_description']

                print(f'  {checkpoint}:')
                print(f'    Loss name: {loss_name}')
                print(f'    Loss description: {loss_desc}')
        
        # Count unique loss configurations
        print('\n=== Summary ===')
        loss_name_counts = data.groupby(['model', 'loss_name'])['checkpoint_id'].nunique()
        print(f'\nCheckpoints per model+loss combination:')
        print(loss_name_counts)

        print(f'\nTotal unique loss names: {data["loss_name"].nunique()}')
        print(f'Loss names: {sorted(data["loss_name"].unique())}')

        # Map loss_name to LaTeX labels for plotting
        loss_name_to_latex = {
            'baseline': r'$E + \rho$',
            'gradient': r'$+ \nabla$',
            'grad_and_hessian': r'$+ \nabla + H$',
        }

        # Define loss order and colors
        # Yellow (colors[2]), Blue (colors[0]), Red (colors[1])
        loss_name_order = ['baseline', 'gradient', 'grad_and_hessian']
        loss_name_colors = {
            'baseline': colors[1],           # Yellow
            'gradient': colors[0],           # Blue
            'grad_and_hessian': colors[3],   # Red
        }
        model_order = ['EGXC', 'NNmGGA', 'Skala-mGGA']

        # Filter to only models/loss_names that exist
        available_models = sorted([m for m in model_order if m in data['model'].unique()])
        available_loss_names = [l for l in loss_name_order if l in data['loss_name'].unique()]

        print(f'\n  Available models: {available_models}')
        print(f'  Available loss names: {available_loss_names}')

        # --- Create figure with 2x3 subplots ---
        fig, axes = plt.subplots(
            2, 3, figsize=(14, 8), gridspec_kw={'hspace': 0.15, 'wspace': 0.30}
        )
        axes = axes.flatten()

        x = np.arange(len(available_models))
        width = 0.2  # Width for 3 bars
        base_offsets = [-width, 0, width]

        # Remove subplot titles, instead name y-axis.
        subplot_ylabels = ['Overall MAE (eV)'] + [f'State {i} MAE (eV)' for i in range(5)]

        # Store y-limits from first subplot to apply to all
        first_ylim = None

        for subplot_idx in range(6):
            ax = axes[subplot_idx]

            # Filter data for this subplot
            if subplot_idx == 0:
                # Overall: all states
                subplot_data = data.copy()
            else:
                # Per-state
                state_idx = subplot_idx - 1
                subplot_data = data[data['state_index'] == state_idx].copy()

            if subplot_data.empty:
                ax.text(0.5, 0.5, 'No data', ha='center', va='center', transform=ax.transAxes)
                ax.set_xticks([])
                ax.set_yticks([])
                continue

            # Aggregate: mean and std over random seeds for each model+loss+checkpoint combination
            # Don't average different checkpoints together
            agg_data = subplot_data.groupby(['model', 'loss_name', 'checkpoint_id', 'variant'])['abs_error'].agg(['mean', 'std']).reset_index()

            # Plot bars for each loss configuration
            for loss_idx, loss_name in enumerate(available_loss_names):
                color = loss_name_colors[loss_name]

                for model_idx, model in enumerate(available_models):
                    # Get data for this model+loss combination
                    subset = agg_data[(agg_data['model'] == model) & (agg_data['loss_name'] == loss_name)]

                    if subset.empty:
                        continue

                    # Check if we have multiple variants (_more/_less)
                    checkpoints = subset['checkpoint_id'].unique()

                    if len(checkpoints) == 1:
                        # Single checkpoint: use standard position
                        val = subset['mean'].values[0]
                        if ERROR_STYLE == 'se':
                            # Use standard error (std / sqrt(n)) for smaller error bars
                            n = subset.shape[0]
                            err = (subset['std'].values[0] / np.sqrt(n)) if n > 0 and not np.isnan(subset['std'].values[0]) else 0
                        elif ERROR_STYLE == 'std':
                            err = subset['std'].values[0] if not np.isnan(subset['std'].values[0]) else 0
                        else:
                            err = None
                        variant = subset['variant'].values[0]

                        # Bar
                        ax.bar(
                            model_idx + base_offsets[loss_idx],
                            val,
                            width,
                            label=loss_name_to_latex[loss_name] if (subplot_idx == 0 and model_idx == 0) else None,
                            color=color,
                            edgecolor='black',
                            linewidth=0.5,
                            zorder=3,
                        )

                        # Error bar
                        ax.errorbar(
                            model_idx + base_offsets[loss_idx],
                            val,
                            yerr=err,
                            fmt='none',
                            color='0.2',
                            capsize=1.5,
                            capthick=0.6,
                            linewidth=1,
                            zorder=4,
                        )
                    else:
                        # Multiple checkpoints: plot side by side
                        # Sort by variant: baseline first, then less, then more
                        variant_order = {'baseline': 0, 'less': 1, 'more': 2}
                        subset_sorted = subset.sort_values(by='variant', key=lambda x: x.map(variant_order))

                        sub_width = width / len(checkpoints)
                        for sub_idx, (_, row) in enumerate(subset_sorted.iterrows()):
                            val = row['mean']
                            # Use standard error (std / sqrt(n)) for smaller error bars
                            n = subset.shape[0]
                            if ERROR_STYLE == 'se':
                                err = (subset['std'].values[0] / np.sqrt(n)) if n > 0 and not np.isnan(subset['std'].values[0]) else 0
                            elif ERROR_STYLE == 'std':
                                err = subset['std'].values[0] if not np.isnan(subset['std'].values[0]) else 0
                            else:
                                err = None
                            variant = row['variant']

                            # Calculate position offset
                            offset = base_offsets[loss_idx] + (sub_idx - len(checkpoints)/2 + 0.5) * sub_width

                            # Bar
                            ax.bar(
                                model_idx + offset,
                                val,
                                sub_width,
                                label=loss_name_to_latex[loss_name] if (subplot_idx == 0 and model_idx == 0 and sub_idx == 0) else None,
                                color=color,
                                edgecolor='black',
                                linewidth=0.5,
                                zorder=3,
                            )

                            # Error bar
                            ax.errorbar(
                                model_idx + offset,
                                val,
                                yerr=err,
                                fmt='none',
                                color='0.2',
                                capsize=1.5,
                                capthick=0.6,
                                linewidth=1,
                                zorder=4,
                            )

            # Formatting
            ax.set_ylabel(subplot_ylabels[subplot_idx], fontsize=14)
            # Remove subplot title
            # ax.set_title(subplot_titles[subplot_idx], fontsize=15)
            ax.set_xticks(x)
            ax.set_xticklabels(available_models, fontsize=12)

            # Use normal (linear) scaling on the y-axis
            ax.set_yscale('linear')

            # Grid styling (from tabel_visualization.py)
            ax.grid(True, axis='y', alpha=0.8, linewidth=0.5, zorder=0)
            ax.grid(False, axis='x')
            ax.set_axisbelow(True)

            # Tick positioning (from tabel_visualization.py)
            ax.yaxis.set_tick_params(pad=18)
            for label in ax.get_yticklabels():
                label.set_ha('left')
                label.set_usetex(False)

            # Capture y-limits from first subplot and apply to all others
            if subplot_idx == 0:
                first_ylim = ax.get_ylim()
            elif first_ylim is not None:
                ax.set_ylim(first_ylim)

        # Create custom legend with all loss configurations
        from matplotlib.patches import Patch

        # Define legend labels and colors for available loss names
        legend_labels = [loss_name_to_latex[ln] for ln in available_loss_names]
        legend_colors = [loss_name_colors[ln] for ln in available_loss_names]

        # Create custom legend handles with the correct colors
        legend_handles = [
            Patch(facecolor=legend_colors[i], edgecolor='black', linewidth=0.5)
            for i in range(len(available_loss_names))
        ]

        fig.legend(
            legend_handles,
            legend_labels,
            loc='upper center',
            bbox_to_anchor=(0.5, 0.98),
            ncol=3,
            frameon=False,
            columnspacing=1.5,
            handletextpad=0.4,
            fontsize=14,
        )

        plt.subplots_adjust(top=0.92)

        # Save
        combined_filename = f'mae_combined_by_model_and_loss{method_suffix}.png'
        save_path = output_dir / combined_filename
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f'Saved: {save_path}')

        if use_wandb:
            wandb.log({combined_filename.replace('.png', ''): wandb.Image(str(save_path))})

        plt.close()

        #########################################
        # --- Create figure with 2x2 subplots ---
        fig, axes = plt.subplots(
            2, 2, figsize=(10, 8), gridspec_kw={'hspace': 0.15, 'wspace': 0.30}
        )
        axes = axes.flatten()

        x = np.arange(len(available_models))
        width = 0.2  # Width for 3 bars
        base_offsets = [-width, 0, width]

        # Remove subplot titles, instead name y-axis.
        subplot_ylabels = ['Overall MAE (eV)'] + [f'State {i} MAE (eV)' for i in range(5)]

        # Store y-limits from first subplot to apply to all
        first_ylim = None

        for subplot_idx in range(4):
            ax = axes[subplot_idx]

            # Filter data for this subplot
            if subplot_idx == 0:
                # Overall: all states
                subplot_data = data.copy()
            else:
                # Per-state
                state_idx = subplot_idx - 1
                subplot_data = data[data['state_index'] == state_idx].copy()

            if subplot_data.empty:
                ax.text(0.5, 0.5, 'No data', ha='center', va='center', transform=ax.transAxes)
                ax.set_xticks([])
                ax.set_yticks([])
                continue

            # Aggregate: mean and std over random seeds for each model+loss+checkpoint combination
            # Don't average different checkpoints together
            agg_data = subplot_data.groupby(['model', 'loss_name', 'checkpoint_id', 'variant'])['abs_error'].agg(['mean', 'std']).reset_index()

            # Plot bars for each loss configuration
            for loss_idx, loss_name in enumerate(available_loss_names):
                color = loss_name_colors[loss_name]

                for model_idx, model in enumerate(available_models):
                    # Get data for this model+loss combination
                    subset = agg_data[(agg_data['model'] == model) & (agg_data['loss_name'] == loss_name)]

                    if subset.empty:
                        continue

                    # Check if we have multiple variants (_more/_less)
                    checkpoints = subset['checkpoint_id'].unique()

                    if len(checkpoints) == 1:
                        # Single checkpoint: use standard position
                        val = subset['mean'].values[0]
                        if ERROR_STYLE == 'se':
                            # Use standard error (std / sqrt(n)) for smaller error bars
                            n = subset.shape[0]
                            err = (subset['std'].values[0] / np.sqrt(n)) if n > 0 and not np.isnan(subset['std'].values[0]) else 0
                        elif ERROR_STYLE == 'std':
                            err = subset['std'].values[0] if not np.isnan(subset['std'].values[0]) else 0
                        else:
                            err = None
                        variant = subset['variant'].values[0]

                        # Bar
                        ax.bar(
                            model_idx + base_offsets[loss_idx],
                            val,
                            width,
                            label=loss_name_to_latex[loss_name] if (subplot_idx == 0 and model_idx == 0) else None,
                            color=color,
                            edgecolor='black',
                            linewidth=0.5,
                            zorder=3,
                        )

                        # Error bar
                        ax.errorbar(
                            model_idx + base_offsets[loss_idx],
                            val,
                            yerr=err,
                            fmt='none',
                            color='0.2',
                            capsize=1.5,
                            capthick=0.6,
                            linewidth=1,
                            zorder=4,
                        )
                    else:
                        # Multiple checkpoints: plot side by side
                        # Sort by variant: baseline first, then less, then more
                        variant_order = {'baseline': 0, 'less': 1, 'more': 2}
                        subset_sorted = subset.sort_values(by='variant', key=lambda x: x.map(variant_order))

                        sub_width = width / len(checkpoints)
                        for sub_idx, (_, row) in enumerate(subset_sorted.iterrows()):
                            val = row['mean']
                            # Use standard error (std / sqrt(n)) for smaller error bars
                            n = subset.shape[0]
                            if ERROR_STYLE == 'se':
                                err = (subset['std'].values[0] / np.sqrt(n)) if n > 0 and not np.isnan(subset['std'].values[0]) else 0
                            elif ERROR_STYLE == 'std':
                                err = subset['std'].values[0] if not np.isnan(subset['std'].values[0]) else 0
                            else:
                                err = None
                            variant = row['variant']

                            # Calculate position offset
                            offset = base_offsets[loss_idx] + (sub_idx - len(checkpoints)/2 + 0.5) * sub_width

                            # Bar
                            ax.bar(
                                model_idx + offset,
                                val,
                                sub_width,
                                label=loss_name_to_latex[loss_name] if (subplot_idx == 0 and model_idx == 0 and sub_idx == 0) else None,
                                color=color,
                                edgecolor='black',
                                linewidth=0.5,
                                zorder=3,
                            )

                            # Error bar
                            ax.errorbar(
                                model_idx + offset,
                                val,
                                yerr=err,
                                fmt='none',
                                color='0.2',
                                capsize=1.5,
                                capthick=0.6,
                                linewidth=1,
                                zorder=4,
                            )

            # Formatting
            ax.set_ylabel(subplot_ylabels[subplot_idx], fontsize=14)
            # Remove subplot title
            # ax.set_title(subplot_titles[subplot_idx], fontsize=15)
            ax.set_xticks(x)
            ax.set_xticklabels(available_models, fontsize=12)

            # Use normal (linear) scaling on the y-axis
            ax.set_yscale('linear')

            # Grid styling (from tabel_visualization.py)
            ax.grid(True, axis='y', alpha=0.8, linewidth=0.5, zorder=0)
            ax.grid(False, axis='x')
            ax.set_axisbelow(True)

            # Tick positioning (from tabel_visualization.py)
            ax.yaxis.set_tick_params(pad=18)
            for label in ax.get_yticklabels():
                label.set_ha('left')
                label.set_usetex(False)

            # Capture y-limits from first subplot and apply to all others
            if subplot_idx == 0:
                first_ylim = ax.get_ylim()
            elif first_ylim is not None:
                ax.set_ylim(first_ylim)

        # Create custom legend with all loss configurations
        from matplotlib.patches import Patch

        # Define legend labels and colors for available loss names
        legend_labels = [loss_name_to_latex[ln] for ln in available_loss_names]
        legend_colors = [loss_name_colors[ln] for ln in available_loss_names]

        # Create custom legend handles with the correct colors
        legend_handles = [
            Patch(facecolor=legend_colors[i], edgecolor='black', linewidth=0.5)
            for i in range(len(available_loss_names))
        ]

        fig.legend(
            legend_handles,
            legend_labels,
            loc='upper center',
            bbox_to_anchor=(0.5, 0.98),
            ncol=3,
            frameon=False,
            columnspacing=1.5,
            handletextpad=0.4,
            fontsize=14,
        )

        plt.subplots_adjust(top=0.92)

        # Save
        combined_filename = f'mae_combined_by_model_and_loss_2x2{method_suffix}.png'
        save_path = output_dir / combined_filename
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f'Saved: {save_path}')

        if use_wandb:
            wandb.log({combined_filename.replace('.png', ''): wandb.Image(str(save_path))})

        plt.close()

    # Finalize WandB logging
    if use_wandb:
        wandb.finish()
        print('\nWandB logging completed')

    print(f'\n=== Analysis Complete ===')
    print(f'All plots saved to: {output_dir}')


if __name__ == '__main__':
    main()