import csv
from pathlib import Path
from typing import List, Tuple, Optional
import os
from absl import app, flags

import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.patches import Rectangle, ConnectionPatch
#mpl.rcParams['text.usetex'] = True # for latex fonts

# ----------------------------
# Global visualization settings
# ----------------------------
# Global text scale multiplier (1.0 means default size)


# ----------------------------
# absl flags
# ----------------------------
FLAGS = flags.FLAGS
flags.DEFINE_float("text_scale_cifar100LT", 2.0, "Global text scale multiplier (1.0 = default size)")
flags.DEFINE_float("text_scale_cifar10LT", 2.5, "Global text scale multiplier (1.0 = default size)")
flags.DEFINE_string("font", "computer-modern", "Global font choice: 'default' or 'computer-modern'")
flags.DEFINE_string("output_dir_distribution", "results_distribution", "Output directory")
flags.DEFINE_string("output_dir_ablation", "results_ablation", "Output directory")
flags.DEFINE_bool("write_value_file_name", False, "Write value on file name")
flags.DEFINE_bool("generate_log_scale", False, "Generate log scale versions of plots")

# Common figure size and bar widths
FIG_SIZE = (14, 5)
SMALL_BAR_WIDTH = 0.2
BIG_BAR_MULTIPLIER = 3.0

# Color map used across plots
COLORS = {
    "DATA": "C7",
    "I-CFM": "C0",
    "OT-CFM": "C2",
    "OURS": "C3",
}


def apply_text_scale(scale: float) -> None:
    """Apply a global text scale by adjusting the base font size.

    Most other sizes (axes labels, ticks, legend) are specified relatively
    (e.g., 'medium', 'large'), so scaling the base font size propagates.
    """
    base_font_size = mpl.rcParamsDefault.get("font.size", mpl.rcParams.get("font.size", 10.0))
    mpl.rcParams["font.size"] = base_font_size * scale


def apply_font_choice(choice: str) -> None:
    """Apply a global font choice for all figures.

    - "default": keep matplotlib defaults
    - "computer-modern": use Computer Modern family and mathtext cm
    """
    normalized = (choice or "").strip().lower()
    if normalized in ("computer-modern", "computer modern", "cm", "tex"):
        # Prefer Computer Modern; gracefully fall back if unavailable
        mpl.rcParams["font.family"] = [
            "Computer Modern",
            "Computer Modern Serif",
            "CMU Serif",
            "DejaVu Serif",
            "Times New Roman",
        ]
        mpl.rcParams["mathtext.fontset"] = "cm"
        mpl.rcParams["axes.unicode_minus"] = False
    else:
        # Default: do not override matplotlib defaults
        pass


def set_common_xticks(ax, x_vals: List[int]) -> None:
    """Set x-ticks sparsely for large class counts to reduce clutter."""
    n = len(x_vals)
    if n > 30:
        tick_step = 10
        ticks = list(range(min(x_vals), max(x_vals) + 1, tick_step))
        ax.set_xticks(ticks)
    else:
        ax.set_xticks(x_vals)


def save_svg(fig, output_path: Path) -> None:
    output_path.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(output_path, format="svg")


def compute_normalized_difference(generated: List[float], data: List[float]) -> List[float]:
    """Compute (generated - data) / data safely, returning 0 when data == 0."""
    out: List[float] = []
    for gv, dv in zip(generated, data):
        if dv == 0:
            out.append(0.0)
        else:
            out.append((gv - dv) / dv)
    return out


def compute_total_normalized_difference(icfm: List[float], otcfm: List[float], ours: List[float], data_ratio: List[float]) -> Tuple[float, float, float]:
    """Compute total normalized difference for each model."""
    icfm_nd = compute_normalized_difference(icfm, data_ratio)
    otcfm_nd = compute_normalized_difference(otcfm, data_ratio)
    ours_nd = compute_normalized_difference(ours, data_ratio)
    
    total_icfm = sum(abs(x) for x in icfm_nd)
    total_otcfm = sum(abs(x) for x in otcfm_nd)
    total_ours = sum(abs(x) for x in ours_nd)
    
    return total_icfm, total_otcfm, total_ours


def compute_average_nll(nll_icfm: List[float], nll_otcfm: List[float], nll_ours: List[float]) -> Tuple[float, float, float]:
    """Compute average NLL for each model."""
    avg_icfm = sum(nll_icfm) / len(nll_icfm)
    avg_otcfm = sum(nll_otcfm) / len(nll_otcfm)
    avg_ours = sum(nll_ours) / len(nll_ours)
    
    return avg_icfm, avg_otcfm, avg_ours


def load_class_ratios(csv_path: Path) -> Tuple[List[int], List[float], List[float], List[float], List[float]]:
    """Load class indices and ratio columns from the CSV file.

    Expected columns: 'class_id', 'icfm_ratio', 'otcfm_ratio', 'ours_ratio', 'data_ratio'.
    """
    class_ids: List[int] = []
    icfm_ratios: List[float] = []
    otcfm_ratios: List[float] = []
    ours_ratios: List[float] = []
    data_ratios: List[float] = []

    with csv_path.open(newline="", encoding="utf-8") as f:
        reader = csv.DictReader(f)
        for row in reader:
            if not row.get("class_id"):
                continue
            try:
                class_ids.append(int(row["class_id"]))
                icfm_ratios.append(float(row["icfm_ratio"]))
                otcfm_ratios.append(float(row["otcfm_ratio"]))
                ours_ratios.append(float(row["ours_ratio"]))
                data_ratios.append(float(row["data_ratio"]))
            except (KeyError, ValueError):
                # Skip malformed rows
                continue

    return class_ids, icfm_ratios, otcfm_ratios, ours_ratios, data_ratios


def load_classwise_nll(csv_path: Path) -> Tuple[List[int], List[float], List[float], List[float]]:
    """Load class indices and NLL columns from the CSV file.

    Accepts either 'class_id' or 'class' as the class index column.
    Requires: 'nll_icfm', 'nll_otcfm', 'nll_ours'. Ignores other columns.
    """
    class_ids: List[int] = []
    nll_icfm: List[float] = []
    nll_otcfm: List[float] = []
    nll_ours: List[float] = []

    with csv_path.open(newline="", encoding="utf-8") as f:
        reader = csv.DictReader(f)
        # Determine class column name
        fieldnames = reader.fieldnames or []
        class_col = "class_id" if "class_id" in fieldnames else ("class" if "class" in fieldnames else None)
        if class_col is None:
            # No valid class column; return empty
            return class_ids, nll_icfm, nll_otcfm, nll_ours

        for row in reader:
            try:
                if not row.get(class_col):
                    continue
                class_ids.append(int(row[class_col]))
                nll_icfm.append(float(row["nll_icfm"]))
                nll_otcfm.append(float(row["nll_otcfm"]))
                nll_ours.append(float(row["nll_ours"]))
            except (KeyError, ValueError):
                # Skip malformed rows
                continue

    return class_ids, nll_icfm, nll_otcfm, nll_ours


def plot_distribution(
    class_ids: List[int],
    icfm: List[float],
    otcfm: List[float],
    ours: List[float],
    data_ratio: List[float],
    output_path: Path,
    log_scale: bool = False,
) -> None:
    """Plot bars where data distribution bar is 3x wider and contains three side-by-side bars.

    - X-axis: class index
    - Y-axis: ratio values (log scale if log_scale=True)
    - Colors: icfm=C0, otcfm=C2, data=C7, ours=C3
    """
    n = len(class_ids)
    if not (len(icfm) == len(otcfm) == len(ours) == len(data_ratio) == n):
        raise ValueError("Input arrays must have the same length.")

    # Base widths: small bars (icfm/otcfm/ours) have width w, data bars are 3w.
    small_w = SMALL_BAR_WIDTH
    big_w = small_w * BIG_BAR_MULTIPLIER

    x = class_ids

    fig, ax = plt.subplots(figsize=FIG_SIZE)

    # Draw the wide background bars for the real data distribution first
    ax.bar(
        x,
        data_ratio,
        width=big_w,
        color=COLORS["DATA"],
        alpha=0.5,
        align="center",
        edgecolor="none",
        label="Data",
        zorder=1,
    )

    # Three narrower bars inside the wide bar, placed side-by-side
    # Centers at x - w, x, x + w so their total span matches the big bar width (3w)
    ax.bar(
        [xi - small_w for xi in x],
        icfm,
        width=small_w,
        color=COLORS["I-CFM"],
        align="center",
        edgecolor="none",
        label="I-CFM",
        zorder=2,
    )
    ax.bar(
        x,
        otcfm,
        width=small_w,
        color=COLORS["OT-CFM"],
        align="center",
        edgecolor="none",
        label="OT-CFM",
        zorder=2,
    )
    ax.bar(
        [xi + small_w for xi in x],
        ours,
        width=small_w,
        color=COLORS["OURS"],
        align="center",
        edgecolor="none",
        label="Ours",
        zorder=2,
    )

    ax.set_xlabel("Class index")
    if log_scale:
        ax.set_ylabel("Ratio (log scale)")
        ax.set_yscale('log')
    else:
        ax.set_ylabel("Ratio")
    ax.set_xlim(min(x) - 0.5, max(x) + 0.5)

    set_common_xticks(ax, x)

    ax.legend(ncol=4, frameon=False, loc="upper right")
    ax.grid(True, axis="y", linestyle=":", alpha=0.4)
    fig.tight_layout(pad=0.1)

    save_svg(fig, output_path)
    plt.close(fig)


def generate_class_ratio_svgs(
    csv_path: Path,
    output_dir: Path,
    name_suffix: str,
) -> None:
    """Generate class-ratio SVGs (distribution, normdiff, head/tail) from a CSV.

    - csv_path: input CSV containing class-wise ratios
    - output_dir: directory where SVGs will be written
    - name_suffix: string appended to filenames to distinguish datasets
    """
    if not csv_path.exists():
        return

    class_ids, icfm, otcfm, ours, data_ratio = load_class_ratios(csv_path)

    # Compute total normalized differences for filename
    total_icfm, total_otcfm, total_ours = compute_total_normalized_difference(icfm, otcfm, ours, data_ratio)
    norm_suffix = f"_i{total_icfm:.2f}_o{total_otcfm:.2f}_r{total_ours:.2f}" if FLAGS.write_value_file_name else ""

    # Distribution
    out_dist = output_dir / f"class_ratio_dist{name_suffix}.svg"
    plot_distribution(class_ids, icfm, otcfm, ours, data_ratio, out_dist)

    # Distribution (log scale)
    if FLAGS.generate_log_scale:
        out_dist_log = output_dir / f"class_ratio_dist_log{name_suffix}.svg"
        plot_distribution(class_ids, icfm, otcfm, ours, data_ratio, out_dist_log, log_scale=True)

    # Normalized difference (all classes)
    out_nd = output_dir / f"class_ratio_normdiff{norm_suffix}{name_suffix}.svg"
    plot_normalized_difference(class_ids, icfm, otcfm, ours, data_ratio, out_nd)

    # Normalized difference (all classes, log scale)
    if FLAGS.generate_log_scale:
        out_nd_log = output_dir / f"class_ratio_normdiff_log{norm_suffix}{name_suffix}.svg"
        plot_normalized_difference(class_ids, icfm, otcfm, ours, data_ratio, out_nd_log, log_scale=True)

    # Tail and head views (robust for small class counts as well)
    out_tail = output_dir / f"class_ratio_normdiff_tail{norm_suffix}{name_suffix}.svg"
    plot_normalized_difference(class_ids, icfm, otcfm, ours, data_ratio, out_tail, tail=40)

    out_head = output_dir / f"class_ratio_normdiff_head{norm_suffix}{name_suffix}.svg"
    plot_normalized_difference(class_ids, icfm, otcfm, ours, data_ratio, out_head, head=40)

    # Tail and head views (log scale)
    if FLAGS.generate_log_scale:
        out_tail_log = output_dir / f"class_ratio_normdiff_tail_log{norm_suffix}{name_suffix}.svg"
        plot_normalized_difference(class_ids, icfm, otcfm, ours, data_ratio, out_tail_log, tail=40, log_scale=True)

        out_head_log = output_dir / f"class_ratio_normdiff_head_log{norm_suffix}{name_suffix}.svg"
        plot_normalized_difference(class_ids, icfm, otcfm, ours, data_ratio, out_head_log, head=40, log_scale=True)
    
    # Combined ratio normalized difference chart
    out_combined = output_dir / f"class_ratio_normdiff_combined{norm_suffix}{name_suffix}.svg"
    plot_ratio_normdiff_combined(class_ids, icfm, otcfm, ours, data_ratio, out_combined)
    
    # Normalized difference with zoomed insets (5-25, 60-80)
    out_zoom = output_dir / f"class_ratio_normdiff_zoom{norm_suffix}{name_suffix}.svg"
    plot_normalized_difference_with_zoom(class_ids, icfm, otcfm, ours, data_ratio, out_zoom)


def plot_normalized_difference(
    class_ids: List[int],
    icfm: List[float],
    otcfm: List[float],
    ours: List[float],
    data_ratio: List[float],
    output_path: Path,
    head: Optional[int] = None,
    tail: Optional[int] = None,
    log_scale: bool = False,
) -> None:
    """Plot normalized differences (generated - data) / data for ICFM, OT-CFM, and Ours.

    - Data bars are NOT visualized here; only the three generated series are shown.
    - Colors: icfm=C0, otcfm=C2, ours=C3.
    - By default, plots all classes (original order). If `head` or `tail` is provided,
      the classes are first sorted by class index, then the corresponding slice is plotted.
    - Y-axis: log scale if log_scale=True
    """
    n = len(class_ids)
    if not (len(icfm) == len(otcfm) == len(ours) == len(data_ratio) == n):
        raise ValueError("Input arrays must have the same length.")

    if head is not None and tail is not None:
        raise ValueError("Provide only one of head or tail, not both.")

    x = class_ids
    icfm_vals = icfm
    otcfm_vals = otcfm
    ours_vals = ours
    data_vals = data_ratio

    if head is not None or tail is not None:
        combined = list(zip(class_ids, icfm, otcfm, ours, data_ratio))
        combined.sort(key=lambda t: t[0])
        if head is not None:
            sliced = combined[: head if head <= len(combined) else len(combined)]
        else:
            # tail is not None here
            sliced = combined[-tail:] if tail < len(combined) else combined

        x = [t[0] for t in sliced]
        icfm_vals = [t[1] for t in sliced]
        otcfm_vals = [t[2] for t in sliced]
        ours_vals = [t[3] for t in sliced]
        data_vals = [t[4] for t in sliced]

    icfm_nd = compute_normalized_difference(icfm_vals, data_vals)
    otcfm_nd = compute_normalized_difference(otcfm_vals, data_vals)
    ours_nd = compute_normalized_difference(ours_vals, data_vals)

    small_w = SMALL_BAR_WIDTH

    fig, ax = plt.subplots(figsize=FIG_SIZE)

    ax.bar([xi - small_w for xi in x], icfm_nd, width=small_w, color="C0", align="center", edgecolor="none", label="I-CFM")
    ax.bar(x, otcfm_nd, width=small_w, color="C2", align="center", edgecolor="none", label="OT-CFM")
    ax.bar([xi + small_w for xi in x], ours_nd, width=small_w, color="C3", align="center", edgecolor="none", label="Ours")

    ax.axhline(0, color="black", linewidth=0.8, alpha=0.6)
    ax.set_xlabel("Class index")
    if log_scale:
        ax.set_ylabel("Normalized difference (log scale)")
        ax.set_yscale('log')
    else:
        ax.set_ylabel("Normalized difference")

    set_common_xticks(ax, x)

    ax.legend(ncol=3, frameon=False, loc="upper left")
    ax.grid(True, axis="y", linestyle=":", alpha=0.4)
    ax.margins(x=0.01)  # Minimize X-axis margins
    fig.tight_layout(pad=0.1)

    save_svg(fig, output_path)
    plt.close(fig)


def plot_normalized_difference_with_zoom(
    class_ids: List[int],
    icfm: List[float],
    otcfm: List[float],
    ours: List[float],
    data_ratio: List[float],
    output_path: Path,
    zoom_ranges: List[Tuple[int, int]] = [(5, 25), (60, 80)],
    log_scale: bool = False,
) -> None:
    """Plot normalized differences with small zoomed insets for specific class ranges.

    - Creates main plot with all classes
    - Adds small zoomed subplots for specified ranges
    - Colors: icfm=C0, otcfm=C2, ours=C3
    """
    n = len(class_ids)
    if not (len(icfm) == len(otcfm) == len(ours) == len(data_ratio) == n):
        raise ValueError("Input arrays must have the same length.")

    # Calculate normalized differences for all data
    icfm_nd = compute_normalized_difference(icfm, data_ratio)
    otcfm_nd = compute_normalized_difference(otcfm, data_ratio)
    ours_nd = compute_normalized_difference(ours, data_ratio)

    small_w = SMALL_BAR_WIDTH
    x = class_ids

    # Create main figure
    fig, ax = plt.subplots(figsize=FIG_SIZE)

    # Main plot - all classes
    ax.bar([xi - small_w for xi in x], icfm_nd, width=small_w, color="C0", align="center", edgecolor="none", label="I-CFM")
    ax.bar(x, otcfm_nd, width=small_w, color="C2", align="center", edgecolor="none", label="OT-CFM")
    ax.bar([xi + small_w for xi in x], ours_nd, width=small_w, color="C3", align="center", edgecolor="none", label="Ours")

    ax.axhline(0, color="black", linewidth=0.8, alpha=0.6)
    ax.set_xlabel("Class index")
    if log_scale:
        ax.set_ylabel("Normalized difference (log scale)")
        ax.set_yscale('log')
    else:
        ax.set_ylabel("Normalized difference")

    set_common_xticks(ax, x)
    ax.legend(ncol=3, frameon=False, loc="upper left")
    ax.grid(True, axis="y", linestyle=":", alpha=0.4)
    ax.margins(x=0.01)

    # Add zoomed insets for each range
    for i, (start_idx, end_idx) in enumerate(zoom_ranges):
        # Draw rectangle on the main plot to show the zoomed region
        if i == 0:
            y0, y1 = -1, 3
        else:
            y0, y1 = -1, 13
        rect = Rectangle(
            (start_idx - 0.5, y0),
            (end_idx - start_idx) + 1,
            (y1 - y0),
            fill=False,
            edgecolor="black",
            linewidth=1.2,
            linestyle="--",
            alpha=0.9,
        )
        ax.add_patch(rect)

        # Create inset axes (1.5x larger, moved down and left)
        if i == 0:
            # First zoom: moved down and left
            inset_ax = fig.add_axes([0.12, 0.38, 0.375, 0.45])  # [left, bottom, width, height] - moved down and left
        else:
            # Second zoom: moved down and left
            inset_ax = fig.add_axes([0.55, 0.38, 0.375, 0.45])  # moved down and left
        
        # Draw a dashed connector from the rectangle (top center) to the inset (bottom center)
        x_mid = (start_idx + end_idx) / 2.0
        connector = ConnectionPatch(
            xyA=(x_mid, y1), coordsA=ax.transData,
            xyB=(0.5, 0.0), coordsB=inset_ax.transAxes,
            linestyle="--", linewidth=1.0, color="black", arrowstyle='-'
        )
        fig.add_artist(connector)

        # Filter data for this zoom range
        zoom_data = []
        for j, class_id in enumerate(class_ids):
            if start_idx <= class_id <= end_idx:
                zoom_data.append((class_id, icfm_nd[j], otcfm_nd[j], ours_nd[j]))
        
        if not zoom_data:
            continue
            
        # Sort by class_id
        zoom_data.sort(key=lambda x: x[0])
        
        zoom_x = [item[0] for item in zoom_data]
        zoom_icfm = [item[1] for item in zoom_data]
        zoom_otcfm = [item[2] for item in zoom_data]
        zoom_ours = [item[3] for item in zoom_data]

        # Plot zoomed data
        inset_ax.bar([xi - small_w for xi in zoom_x], zoom_icfm, width=small_w, color="C0", align="center", edgecolor="none")
        inset_ax.bar(zoom_x, zoom_otcfm, width=small_w, color="C2", align="center", edgecolor="none")
        inset_ax.bar([xi + small_w for xi in zoom_x], zoom_ours, width=small_w, color="C3", align="center", edgecolor="none")

        inset_ax.axhline(0, color="black", linewidth=0.5, alpha=0.6)
        inset_ax.set_xlim(start_idx - 0.5, end_idx + 0.5)
        
        # Set x-ticks for zoom range (reduce frequency to accommodate larger font)
        tick_step = max(1, (end_idx - start_idx) // 5)
        inset_ax.set_xticks(range(start_idx, end_idx + 1, tick_step))
        
        # Set y-axis range for zoomed insets (different ranges for each subplot)
        if i == 0:
            # First zoom (5-25): -1 to 3
            inset_ax.set_ylim(-1, 3)
        else:
            # Second zoom (60-80): -1 to 15
            inset_ax.set_ylim(-1, 13)
        
        # Style the inset
        inset_ax.grid(True, axis="y", linestyle=":", alpha=0.3)
        inset_ax.locator_params(axis='y', nbins=4)
        inset_ax.tick_params(labelsize=20)
        
        # Add range label
        inset_ax.text(0.5, 0.95, f'Classes {start_idx}-{end_idx}', 
                     transform=inset_ax.transAxes, ha='center', va='top', 
                     fontsize=18, bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8))

    fig.tight_layout(pad=0.1)
    save_svg(fig, output_path)
    plt.close(fig)


def plot_nll(
    class_ids: List[int],
    nll_icfm: List[float],
    nll_otcfm: List[float],
    nll_ours: List[float],
    output_path: Path,
    log_scale: bool = False,
) -> None:
    """Plot NLL per class for ICFM, OT-CFM, and Ours as grouped bars.

    X-axis: class index, Y-axis: NLL (log scale if log_scale=True).
    Colors follow the same mapping as other plots.
    """
    n = len(class_ids)
    if not (len(nll_icfm) == len(nll_otcfm) == len(nll_ours) == n):
        raise ValueError("Input arrays must have the same length.")

    small_w = SMALL_BAR_WIDTH
    x = class_ids

    fig, ax = plt.subplots(figsize=FIG_SIZE)

    ax.bar([xi - small_w for xi in x], nll_icfm, width=small_w, color=COLORS["I-CFM"], align="center", edgecolor="none", label="I-CFM")
    ax.bar(x, nll_otcfm, width=small_w, color=COLORS["OT-CFM"], align="center", edgecolor="none", label="OT-CFM")
    ax.bar([xi + small_w for xi in x], nll_ours, width=small_w, color=COLORS["OURS"], align="center", edgecolor="none", label="Ours")

    ax.set_xlabel("Class index")
    if log_scale:
        ax.set_ylabel("NLL (log scale)")
        ax.set_yscale('log')
    else:
        ax.set_ylabel("NLL")
        # Set Y-axis lower bound to the global minimum across all NLL series
        y_min = min(min(nll_icfm), min(nll_otcfm), min(nll_ours))
        # For CIFAR-10-LT figure, extend a bit more downward for clearer separation
        name_lower = str(output_path.name).lower()
        pad = 0.001
        if "cifar10lt" in name_lower or "cifar10" in name_lower:
            pad = 0.01  # slightly larger downward padding for CIFAR-10-LT
        ax.set_ylim(bottom=y_min - pad)

    set_common_xticks(ax, x)

    ax.legend(ncol=3, frameon=False, loc="upper right")
    ax.grid(True, axis="y", linestyle=":", alpha=0.4)
    # Reduce left/right x-axis margins so bars start/end closer to edges
    ax.margins(x=0.01)
    fig.tight_layout(pad=0.1)

    save_svg(fig, output_path)
    plt.close(fig)


def plot_ratio_normdiff_combined(
    class_ids: List[int],
    icfm: List[float],
    otcfm: List[float],
    ours: List[float],
    data_ratio: List[float],
    output_path: Path,
) -> None:
    """Plot combined ratio normalized difference chart with classwise on left and mean on right."""
    # Calculate normalized differences
    icfm_nd = compute_normalized_difference(icfm, data_ratio)
    otcfm_nd = compute_normalized_difference(otcfm, data_ratio)
    ours_nd = compute_normalized_difference(ours, data_ratio)
    
    # Calculate mean normalized differences
    mean_icfm = sum(abs(x) for x in icfm_nd) / len(icfm_nd)
    mean_otcfm = sum(abs(x) for x in otcfm_nd) / len(otcfm_nd)
    mean_ours = sum(abs(x) for x in ours_nd) / len(ours_nd)
    
    # Calculate Y-axis range to match both subplots
    all_values = icfm_nd + otcfm_nd + ours_nd + [mean_icfm, mean_otcfm, mean_ours]
    y_min = min(all_values)
    y_max = max(all_values)
    # Add small symmetric margins to avoid touching bounds
    y_range = y_max - y_min
    y_pad = max(0.005, 0.05 * y_range) if y_range > 0 else 0.02
    
    # Create figure with subplots using GridSpec for custom width ratios
    # Make right subplot wider for CIFAR-10-LT outputs
    fig = plt.figure(figsize=(12, 5))
    is_cifar10 = "cifar10" in str(output_path.name).lower()
    width_ratios = [15, 2] if is_cifar10 else [16, 1]
    gs = fig.add_gridspec(1, 2, width_ratios=width_ratios)
    # Reduce horizontal space between subplots for CIFAR-10-LT
    #if is_cifar10:
    #    gs.update(wspace=0.02)
    
    # Left subplot: classwise normalized difference
    ax1 = fig.add_subplot(gs[0, 0])
    small_w = SMALL_BAR_WIDTH
    x = class_ids
    
    ax1.bar([xi - small_w for xi in x], icfm_nd, width=small_w, color="C0", align="center", edgecolor="none", label="I-CFM")
    ax1.bar(x, otcfm_nd, width=small_w, color="C2", align="center", edgecolor="none", label="OT-CFM")
    ax1.bar([xi + small_w for xi in x], ours_nd, width=small_w, color="C3", align="center", edgecolor="none", label="Ours")
    
    ax1.axhline(0, color="black", linewidth=0.8, alpha=0.6)
    ax1.set_xlabel("Class index")
    ax1.set_ylabel("Normalized difference")
    ax1.set_ylim(y_min - y_pad, y_max + y_pad)
    set_common_xticks(ax1, x)
    ax1.legend(ncol=3, frameon=False, loc="upper left")
    ax1.grid(True, axis="y", linestyle=":", alpha=0.4)
    ax1.margins(x=0.01)
    
    # Right subplot: mean normalized difference
    ax2 = fig.add_subplot(gs[0, 1])
    x_positions = [0, 0.1, 0.2]
    means = [mean_icfm, mean_otcfm, mean_ours]
    colors = [COLORS["I-CFM"], COLORS["OT-CFM"], COLORS["OURS"]]
    
    bars = ax2.bar(x_positions, means, color=colors, width=0.1)
    ax2.set_xlim(-0.05, 0.25)
    ax2.margins(x=0.05)
    ax2.set_ylabel("")
    ax2.set_yticks([])
    ax2.set_xticks([])
    ax2.set_xlabel("")
    ax2.set_ylim(y_min - y_pad, y_max + y_pad)
    
    # Add value labels on bars
    for bar, mean in zip(bars, means):
        height = bar.get_height()
        ax2.text(bar.get_x() + bar.get_width()/2., height + height*1.0,
               f'{mean:.2f}', ha='center', va='bottom', fontsize='small', rotation=90)
    
    # Add 'Mean' label at the bottom
    ax2.text(0.5, -0.025, 'Abs\nmean', ha='center', va='top', transform=ax2.transAxes, 
            fontsize='medium')
    
    plt.tight_layout(pad=0.05)
    save_svg(fig, output_path)
    plt.close(fig)


def plot_nll_combined(
    class_ids: List[int],
    nll_icfm: List[float],
    nll_otcfm: List[float],
    nll_ours: List[float],
    output_path: Path,
) -> None:
    """Plot combined NLL chart with classwise NLL on left and average NLL on right."""
    # Calculate y-axis range to match classwise NLL
    y_min = min(min(nll_icfm), min(nll_otcfm), min(nll_ours))
    y_max = max(max(nll_icfm), max(nll_otcfm), max(nll_ours))
    
    # Calculate averages
    avg_icfm = sum(nll_icfm) / len(nll_icfm)
    avg_otcfm = sum(nll_otcfm) / len(nll_otcfm)
    avg_ours = sum(nll_ours) / len(nll_ours)
    
    # Create figure with subplots using GridSpec for custom width ratios
    fig = plt.figure(figsize=(12, 5))
    gs = fig.add_gridspec(1, 2, width_ratios=[16, 1])  # 4:1 ratio
    
    # Left subplot: classwise NLL (wider)
    ax1 = fig.add_subplot(gs[0, 0])
    small_w = SMALL_BAR_WIDTH
    x = class_ids
    
    ax1.bar([xi - small_w for xi in x], nll_icfm, width=small_w, color=COLORS["I-CFM"], align="center", edgecolor="none", label="I-CFM")
    ax1.bar(x, nll_otcfm, width=small_w, color=COLORS["OT-CFM"], align="center", edgecolor="none", label="OT-CFM")
    ax1.bar([xi + small_w for xi in x], nll_ours, width=small_w, color=COLORS["OURS"], align="center", edgecolor="none", label="Ours")
    
    ax1.set_xlabel("Class index")
    ax1.set_ylabel("NLL")
    ax1.set_ylim(bottom=y_min-0.001)
    set_common_xticks(ax1, x)
    ax1.legend(ncol=3, frameon=False, loc="upper right")
    ax1.grid(True, axis="y", linestyle=":", alpha=0.4)
    
    # Minimize X-axis margins
    ax1.margins(x=0.01)
    
    # Right subplot: average NLL (much narrower)
    ax2 = fig.add_subplot(gs[0, 1])
    x_positions = [0, 0.1, 0.2]
    averages = [avg_icfm, avg_otcfm, avg_ours]
    colors = [COLORS["I-CFM"], COLORS["OT-CFM"], COLORS["OURS"]]
    
    bars = ax2.bar(x_positions, averages, color=colors, width=0.1)
    ax2.set_xlim(-0.05, 0.25)
    ax2.margins(x=0.05)
    ax2.set_ylabel("")
    ax2.set_yticks([])
    ax2.set_xticks([])
    ax2.set_xlabel("")
    ax2.set_ylim(y_min-0.001, y_max+0.001)
    
    # Add value labels on bars
    for bar, avg in zip(bars, averages):
        height = bar.get_height()
        ax2.text(bar.get_x() + bar.get_width()/2., height + height*0.05,
               f'{avg:.2f}', ha='center', va='bottom', fontsize='small', rotation=90)
    
    # Add 'Mean' label at the bottom
    ax2.text(0.5, -0.025, 'Mean', ha='center', va='top', transform=ax2.transAxes, 
            fontsize='medium')
    
    plt.tight_layout(pad=0.05)
    save_svg(fig, output_path)
    plt.close(fig)


def plot_average_nll(
    nll_icfm: List[float],
    nll_otcfm: List[float],
    nll_ours: List[float],
    output_path: Path,
) -> None:
    """Plot average NLL comparison for ICFM, OT-CFM, and Ours.

    Simple bar chart showing average NLL for each model.
    Y-axis range matches classwise NLL for consistency.
    """
    avg_icfm = sum(nll_icfm) / len(nll_icfm)
    avg_otcfm = sum(nll_otcfm) / len(nll_otcfm)
    avg_ours = sum(nll_ours) / len(nll_ours)
    
    # Calculate y-axis range to match classwise NLL
    y_min = min(min(nll_icfm), min(nll_otcfm), min(nll_ours))
    y_max = max(max(nll_icfm), max(nll_otcfm), max(nll_ours))
    
    # Position bars closer together
    x_positions = [0, 0.1, 0.2]  # Closer spacing than default [0, 1, 2]
    averages = [avg_icfm, avg_otcfm, avg_ours]
    colors = [COLORS["I-CFM"], COLORS["OT-CFM"], COLORS["OURS"]]
    
    fig, ax = plt.subplots(figsize=(2, 5))  # Even narrower width
    
    bars = ax.bar(x_positions, averages, color=colors, alpha=0.8, width=0.1)
    
    # Set X-axis range to match the bar positions
    ax.set_xlim(-0.05, 0.25)
    
    # Reduce margins around bars
    ax.margins(x=0.05)
    
    # Remove title
    # ax.set_title("Average NLL Comparison")  # Removed
    
    # Remove x-axis labels (hide model names)
    ax.set_xticks([])
    ax.set_xlabel("")
    
    # Remove y-axis label and ticks
    ax.set_ylabel("")
    ax.set_yticks([])
    
    # Set y-axis range to match classwise NLL
    ax.set_ylim(y_min-0.001, y_max+0.001)
    
    # Remove grid
    # ax.grid(True, axis="y", linestyle=":", alpha=0.4)  # Removed
    
    # Add value labels on bars
    for bar, avg in zip(bars, averages):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height + height*0.01,
               f'{avg:.2f}', ha='center', va='bottom', fontsize='medium', rotation=90)
    
    fig.tight_layout(pad=0.1)
    save_svg(fig, output_path)
    plt.close(fig)


def plot_nll_normalized_difference(
    class_ids: List[int],
    nll_icfm: List[float],
    nll_otcfm: List[float],
    nll_ours: List[float],
    output_path: Path,
    log_scale: bool = False,
) -> None:
    """Plot classwise normalized difference (nll_ours - nll_icfm) / nll_icfm.

    Single series bar plot, centered at class index. Uses color for Ours.
    Y-axis: log scale if log_scale=True
    """
    n = len(class_ids)
    if not (len(nll_icfm) == len(nll_otcfm) == len(nll_ours) == n):
        raise ValueError("Input arrays must have the same length.")

    nd_ours_vs_icfm = compute_normalized_difference(nll_ours, nll_icfm)
    nd_ours_vs_otcfm = compute_normalized_difference(nll_ours, nll_otcfm)

    small_w = SMALL_BAR_WIDTH
    x = class_ids

    fig, ax = plt.subplots(figsize=FIG_SIZE)

    # Two side-by-side bars per class: dark gray for Ours vs I-CFM, light olive for Ours vs OT-CFM
    ax.bar([xi - (small_w/2) for xi in x], nd_ours_vs_icfm, width=small_w, color="#4E9B96", align="center", edgecolor="none", label="Diff. from I-CFM")
    ax.bar([xi + (small_w/2) for xi in x], nd_ours_vs_otcfm, width=small_w, color="#B96478", align="center", edgecolor="none", label="Diff. from OT-CFM")

    ax.axhline(0, color="black", linewidth=0.8, alpha=0.6)
    ax.set_xlabel("Class index")
    if log_scale:
        ax.set_ylabel("Normalized NLL Difference (log scale)")
        ax.set_yscale('log')
    else:
        ax.set_ylabel("Normalized NLL Diff.")
        # Expand positive Y-range for better headroom on positive values
        y_min = min(min(nd_ours_vs_icfm), min(nd_ours_vs_otcfm))
        y_max = max(max(nd_ours_vs_icfm), max(nd_ours_vs_otcfm))
        pad_positive = 0.01
        pad_negative = 0.001
        ax.set_ylim(bottom=y_min - pad_negative, top=y_max + pad_positive)

    set_common_xticks(ax, x)

    # Set legend position based on dataset: CIFAR-10-LT uses lower left, CIFAR-100-LT uses upper right
    output_name = str(output_path.name).lower()
    if "cifar10" in output_name:
        legend_loc = "lower left"
    else:
        legend_loc = "upper right"
    
    ax.legend(ncol=2, frameon=False, loc=legend_loc)
    ax.grid(True, axis="y", linestyle=":", alpha=0.4)
    ax.margins(x=0.01)  # Minimize X-axis margins
    fig.tight_layout(pad=0.1)

    save_svg(fig, output_path)
    plt.close(fig)


def load_ablation_k_data(csv_path: Path) -> Tuple[List[float], List[float], List[float], float, float]:
    """Load ablation k data from CSV file.
    
    Returns: k_values, cifar10_scores, cifar100_scores, uotcfm_cifar10, uotcfm_cifar100
    """
    k_values = []
    cifar10_scores = []
    cifar100_scores = []
    uotcfm_cifar10 = 0.0
    uotcfm_cifar100 = 0.0
    
    with csv_path.open(newline="", encoding="utf-8") as f:
        reader = csv.DictReader(f)
        for row in reader:
            if row.get("Dataset") == "CIFAR-10-LT":
                # Extract k values and scores
                for col in row:
                    if col.startswith("k="):
                        k_val = float(col.split("=")[1])
                        if k_val not in k_values:  # Avoid duplicates
                            k_values.append(k_val)
                        cifar10_scores.append(float(row[col]))
                # UOT-CFM  
                uotcfm_cifar10 = float(row["UOT-CFM"])
            elif row.get("Dataset") == "CIFAR-100-LT":
                for col in row:
                    if col.startswith("k="):
                        cifar100_scores.append(float(row[col]))
                # UOT-CFM  
                uotcfm_cifar100 = float(row["UOT-CFM"])
    
    return k_values, cifar10_scores, cifar100_scores, uotcfm_cifar10, uotcfm_cifar100


def load_ablation_comb_data(csv_path: Path) -> Tuple[List[float], List[float], List[float], List[float], List[float], List[float], List[float]]:
    """Load ablation combination data from CSV file.
    
    CSV structure: tau\k,k=6.0,k=8.0,k=10.0,k=16.0,k=6.0,k=8.0,k=10.0,k=16.0
    First 4 k values are CIFAR-10-LT, last 4 k values are CIFAR-100-LT
    
    Returns: k_values, tau2_cifar10, tau4_cifar10, tau6_cifar10, tau2_cifar100, tau4_cifar100, tau6_cifar100
    """
    k_values = [6.0, 8.0, 10.0, 16.0]  # Fixed k values from the CSV structure
    tau2_cifar10 = []
    tau4_cifar10 = []
    tau6_cifar10 = []
    tau2_cifar100 = []
    tau4_cifar100 = []
    tau6_cifar100 = []
    
    with csv_path.open(newline="", encoding="utf-8") as f:
        # Read raw lines to handle duplicate column names properly
        lines = f.readlines()
        
        for i, line in enumerate(lines[1:], 1):
            values = line.strip().split(',')
            tau_val = values[0]
            
            if tau_val == "2.0":
                # CIFAR-10-LT: indices 1-4, CIFAR-100-LT: indices 5-8
                tau2_cifar10 = [float(values[1]), float(values[2]), float(values[3]), float(values[4])]
                tau2_cifar100 = [float(values[5]), float(values[6]), float(values[7]), float(values[8])]
            elif tau_val == "4.0":
                tau4_cifar10 = [float(values[1]), float(values[2]), float(values[3]), float(values[4])]
                tau4_cifar100 = [float(values[5]), float(values[6]), float(values[7]), float(values[8])]
            elif tau_val == "6.0":
                tau6_cifar10 = [float(values[1]), float(values[2]), float(values[3]), float(values[4])]
                tau6_cifar100 = [float(values[5]), float(values[6]), float(values[7]), float(values[8])]
    
    return k_values, tau2_cifar10, tau4_cifar10, tau6_cifar10, tau2_cifar100, tau4_cifar100, tau6_cifar100


def plot_ablation_k(
    k_values: List[float],
    cifar10_scores: List[float],
    cifar100_scores: List[float],
    uotcfm_cifar10: float,
    uotcfm_cifar100: float,
    output_path: Path,
) -> None:
    """Plot ablation study for k values with CIFAR-10-LT, CIFAR-100-LT, and UOT-CFM."""
    fig, ax = plt.subplots(figsize=(10, 8.5))
    
    # Colors: Ours (red), UOT-CFM (purple)
    ours_color = 'C3'  # Red
    uotcfm_color = 'C4'  # Purple
    
    # Plot Ours lines (k > 0)
    ax.plot(k_values, cifar10_scores, 'o-', color=ours_color, linewidth=2, markersize=18, 
            linestyle='-', label='Ours, CIFAR-10-LT')
    ax.plot(k_values, cifar100_scores, 's--', color=ours_color, linewidth=2, markersize=18, 
            linestyle='--', label='Ours, CIFAR-100-LT')
    
    # Plot UOT-CFM as two points at k=0 (one for each dataset)
    ax.scatter([0], [uotcfm_cifar10], color=uotcfm_color, s=300, marker='o', label='UOT-CFM, CIFAR-10-LT', zorder=5)
    ax.scatter([0], [uotcfm_cifar100], color=uotcfm_color, s=300, marker='s', label='UOT-CFM, CIFAR-100-LT', zorder=5)
    
    ax.set_xlabel('k value')
    ax.set_ylabel('FID Score')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Set x-axis to include 0 for UOT-CFM point
    ax.set_xlim(-0.5, max(k_values) + 0.5)
    
    # Add padding to y-axis (extend top by 4 units, keep bottom as is)
    all_values = cifar10_scores + cifar100_scores + [uotcfm_cifar10, uotcfm_cifar100]
    y_min = min(all_values)
    y_max = max(all_values)
    ax.set_ylim(y_min-1, y_max + 4.5)
    
    fig.tight_layout()
    save_svg(fig, output_path)
    plt.close(fig)


def plot_ablation_comb(
    k_values: List[float],
    tau2_cifar10: List[float],
    tau4_cifar10: List[float],
    tau6_cifar10: List[float],
    tau2_cifar100: List[float],
    tau4_cifar100: List[float],
    tau6_cifar100: List[float],
    output_path: Path,
) -> None:
    """Plot ablation study for tau and k combinations."""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 8.5))
    
    # Colors for different tau values
    tau2_color = 'C3'  # Red
    tau4_color = '#FF8C00'  # Orange
    tau6_color = '#FFD700'  # Yellow
    
    MARKER_SIZE = 16 
    # Plot CIFAR-10-LT lines (solid lines, circle markers) - Left subplot
    ax1.plot(k_values, tau2_cifar10, 'o-', color=tau2_color, linewidth=2, markersize=MARKER_SIZE, 
            linestyle='-', label='τ=2.0')
    ax1.plot(k_values, tau4_cifar10, 'o-', color=tau4_color, linewidth=2, markersize=MARKER_SIZE, 
            linestyle='-', label='τ=4.0')
    ax1.plot(k_values, tau6_cifar10, 'o-', color=tau6_color, linewidth=2, markersize=MARKER_SIZE, 
            linestyle='-', label='τ=6.0')
    
    # Plot CIFAR-100-LT lines (dashed lines, square markers) - Right subplot
    ax2.plot(k_values, tau2_cifar100, 's--', color=tau2_color, linewidth=2, markersize=MARKER_SIZE, 
            linestyle='--', label='τ=2.0')
    ax2.plot(k_values, tau4_cifar100, 's--', color=tau4_color, linewidth=2, markersize=MARKER_SIZE, 
            linestyle='--', label='τ=4.0')
    ax2.plot(k_values, tau6_cifar100, 's--', color=tau6_color, linewidth=2, markersize=MARKER_SIZE, 
            linestyle='--', label='τ=6.0')
    
    # Set labels and titles for subplots
    ax1.set_xlabel('k value')
    ax1.set_ylabel('FID Score')
    ax1.set_title('CIFAR-10-LT')
    ax1.grid(True, alpha=0.3)
    ax1.legend(loc='upper right')
    
    ax2.set_xlabel('k value')
    ax2.set_ylabel('')
    ax2.set_title('CIFAR-100-LT')
    ax2.grid(True, alpha=0.3)
    ax2.legend(loc='upper right')
    
    # Set independent y-axis ranges for each subplot
    # CIFAR-10-LT subplot
    cifar10_values = tau2_cifar10 + tau4_cifar10 + tau6_cifar10
    cifar10_min = min(cifar10_values)
    cifar10_max = max(cifar10_values)
    ax1.set_ylim(cifar10_min - 0.2, cifar10_max + 0.5)
    
    # CIFAR-100-LT subplot
    cifar100_values = tau2_cifar100 + tau4_cifar100 + tau6_cifar100
    cifar100_min = min(cifar100_values)
    cifar100_max = max(cifar100_values)
    ax2.set_ylim(cifar100_min - 1, cifar100_max + 2)
    
    fig.tight_layout()
    save_svg(fig, output_path)
    plt.close(fig)


def main(argv=None) -> None:
    here = Path(__file__).resolve().parent
    os.makedirs(os.path.join(here, FLAGS.output_dir_distribution), exist_ok=True)
    output_dir = Path(here) / FLAGS.output_dir_distribution

    # Apply dataset-specific text scales
    # CIFAR-100-LT figures
    apply_text_scale(FLAGS.text_scale_cifar100LT)
    # Apply global font choice
    apply_font_choice(FLAGS.font)

    # CIFAR-100-LT
    csv_c100 = here / "classified_gen_cifar100LT.csv"
    generate_class_ratio_svgs(csv_c100, output_dir, name_suffix="")

    # CIFAR-10-LT figures use different text scale
    apply_text_scale(FLAGS.text_scale_cifar10LT)

    # CIFAR-10-LT
    csv_c10 = here / "classified_gen_cifar10LT.csv"
    generate_class_ratio_svgs(csv_c10, output_dir, name_suffix="_cifar10LT")

    # NLL chart uses CIFAR-100-LT scale by default unless changed later
    apply_text_scale(FLAGS.text_scale_cifar100LT)

    # NLL chart
    nll_csv = here / "classwise_nll_cifar100LT.csv"
    if nll_csv.exists():
        nll_class_ids, nll_icfm, nll_otcfm, nll_ours = load_classwise_nll(nll_csv)
        # Align order if needed (ensure same class ordering)
        order = sorted(range(len(nll_class_ids)), key=lambda i: nll_class_ids[i])
        nll_class_ids = [nll_class_ids[i] for i in order]
        nll_icfm = [nll_icfm[i] for i in order]
        nll_otcfm = [nll_otcfm[i] for i in order]
        nll_ours = [nll_ours[i] for i in order]

        # Compute average NLL for filename
        avg_icfm, avg_otcfm, avg_ours = compute_average_nll(nll_icfm, nll_otcfm, nll_ours)
        nll_suffix = f"_i{avg_icfm:.2f}_o{avg_otcfm:.2f}_r{avg_ours:.2f}" if FLAGS.write_value_file_name else ""

        out_path_nll = output_dir / f"classwise_nll{nll_suffix}.svg"
        plot_nll(nll_class_ids, nll_icfm, nll_otcfm, nll_ours, out_path_nll)

        # NLL chart (log scale)
        if FLAGS.generate_log_scale:
            out_path_nll_log = output_dir / f"classwise_nll_log{nll_suffix}.svg"
            plot_nll(nll_class_ids, nll_icfm, nll_otcfm, nll_ours, out_path_nll_log, log_scale=True)

        # NLL normalized difference chart: (nll_ours - nll_icfm) / nll_icfm
        out_path_nll_nd = output_dir / f"classwise_nll_normdiff{nll_suffix}.svg"
        plot_nll_normalized_difference(nll_class_ids, nll_icfm, nll_otcfm, nll_ours, out_path_nll_nd)

        # NLL normalized difference chart (log scale)
        if FLAGS.generate_log_scale:
            out_path_nll_nd_log = output_dir / f"classwise_nll_normdiff_log{nll_suffix}.svg"
            plot_nll_normalized_difference(nll_class_ids, nll_icfm, nll_otcfm, nll_ours, out_path_nll_nd_log, log_scale=True)

        # Average NLL comparison chart
        out_path_avg_nll = output_dir / f"average_nll{nll_suffix}.svg"
        plot_average_nll(nll_icfm, nll_otcfm, nll_ours, out_path_avg_nll)
        
        # Combined NLL chart
        out_path_combined = output_dir / f"nll_combined{nll_suffix}.svg"
        plot_nll_combined(nll_class_ids, nll_icfm, nll_otcfm, nll_ours, out_path_combined)

    # CIFAR-10-LT NLL charts
    # Use CIFAR-10-LT text scale for rendering
    apply_text_scale(FLAGS.text_scale_cifar10LT)
    nll_csv_c10 = here / "classwise_nll_cifar10LT.csv"
    if nll_csv_c10.exists():
        nll10_class_ids, nll10_icfm, nll10_otcfm, nll10_ours = load_classwise_nll(nll_csv_c10)
        order10 = sorted(range(len(nll10_class_ids)), key=lambda i: nll10_class_ids[i])
        nll10_class_ids = [nll10_class_ids[i] for i in order10]
        nll10_icfm = [nll10_icfm[i] for i in order10]
        nll10_otcfm = [nll10_otcfm[i] for i in order10]
        nll10_ours = [nll10_ours[i] for i in order10]

        avg10_icfm, avg10_otcfm, avg10_ours = compute_average_nll(nll10_icfm, nll10_otcfm, nll10_ours)
        nll10_suffix_vals = f"_i{avg10_icfm:.2f}_o{avg10_otcfm:.2f}_r{avg10_ours:.2f}" if FLAGS.write_value_file_name else ""
        nll10_suffix = f"{nll10_suffix_vals}_cifar10LT"

        out10_path_nll = output_dir / f"classwise_nll{nll10_suffix}.svg"
        plot_nll(nll10_class_ids, nll10_icfm, nll10_otcfm, nll10_ours, out10_path_nll)

        if FLAGS.generate_log_scale:
            out10_path_nll_log = output_dir / f"classwise_nll_log{nll10_suffix}.svg"
            plot_nll(nll10_class_ids, nll10_icfm, nll10_otcfm, nll10_ours, out10_path_nll_log, log_scale=True)

        out10_path_nll_nd = output_dir / f"classwise_nll_normdiff{nll10_suffix}.svg"
        plot_nll_normalized_difference(nll10_class_ids, nll10_icfm, nll10_otcfm, nll10_ours, out10_path_nll_nd)

        if FLAGS.generate_log_scale:
            out10_path_nll_nd_log = output_dir / f"classwise_nll_normdiff_log{nll10_suffix}.svg"
            plot_nll_normalized_difference(nll10_class_ids, nll10_icfm, nll10_otcfm, nll10_ours, out10_path_nll_nd_log, log_scale=True)

        out10_path_avg_nll = output_dir / f"average_nll{nll10_suffix}.svg"
        plot_average_nll(nll10_icfm, nll10_otcfm, nll10_ours, out10_path_avg_nll)

        out10_path_combined = output_dir / f"nll_combined{nll10_suffix}.svg"
        plot_nll_combined(nll10_class_ids, nll10_icfm, nll10_otcfm, nll10_ours, out10_path_combined)

    # Ablation studies
    ablation_output_dir = Path(here) / FLAGS.output_dir_ablation
    os.makedirs(ablation_output_dir, exist_ok=True)
    
    # Ablation k study
    ablation_k_csv = here / "ablation_k_tau2.0.csv"
    if ablation_k_csv.exists():
        k_values, cifar10_scores, cifar100_scores, uotcfm_cifar10, uotcfm_cifar100 = load_ablation_k_data(ablation_k_csv)
        ablation_k_output = ablation_output_dir / "ablation_k_study.svg"
        plot_ablation_k(k_values, cifar10_scores, cifar100_scores, uotcfm_cifar10, uotcfm_cifar100, ablation_k_output)
    
    # Ablation combination study
    ablation_comb_csv = here / "ablation_comb_cifar10LT.csv"
    if ablation_comb_csv.exists():
        k_values, tau2_cifar10, tau4_cifar10, tau6_cifar10, tau2_cifar100, tau4_cifar100, tau6_cifar100 = load_ablation_comb_data(ablation_comb_csv)
        ablation_comb_output = ablation_output_dir / "ablation_combination_study.svg"
        plot_ablation_comb(k_values, tau2_cifar10, tau4_cifar10, tau6_cifar10, tau2_cifar100, tau4_cifar100, tau6_cifar100, ablation_comb_output)

    # Also show interactively if running in an environment with a display
    try:
        plt.show()
    except Exception:
        pass


if __name__ == "__main__":
    app.run(main)