# %%

"""Analyze concept consistency across multiple seed runs.

This script compares the concepts found across different random seeds to identify
which concepts are consistently discovered and which vary by seed.
"""

import json
from collections import Counter
from dataclasses import dataclass
from pathlib import Path

import matplotlib.pyplot as plt
import seaborn as sns
from dotenv import load_dotenv
from tueplots import bundles, figsizes, fonts

from latent_reasoning_latents.util import LOAN_APPROVAL_RESULTS_PATH

load_dotenv()

# Uses tueplots for proper figure sizing and font configuration
PLOT_RC_CONTEXT = {
    **bundles.icml2022(),  # ICML style (closest to ICML 2026)
    **figsizes.icml2022_half(),  # Single column width for two-column layout
    **fonts.icml2022_tex(),  # Serif fonts matching LaTeX document
    "figure.constrained_layout.use": True,
    "text.usetex": True,  # Enable LaTeX for consistent typography
    "font.family": "serif",
    "axes.linewidth": 0.5,
    "grid.linewidth": 0.5,
    "lines.linewidth": 1.0,
    "xtick.major.width": 0.5,
    "ytick.major.width": 0.5,
    "legend.frameon": True,
    "legend.edgecolor": "black",
    "legend.fancybox": False,
    "legend.framealpha": 1.0,
}

@dataclass
class SeedResult:
    """Results from a single seed run."""

    seed_name: str
    final_concept_ids: list[str]
    stages_data: list[dict]


def load_concept_titles(results_dir: Path) -> dict[str, str]:
    """Load concept ID to title mapping from the dataset file."""
    dataset_path = results_dir / "dataset_loan_approval.json"
    with open(dataset_path) as f:
        data = json.load(f)

    concept_titles = {}
    for concept in data.get("concepts", []):
        concept_titles[concept["id"]] = concept["title"]

    return concept_titles


def load_concept_directions(results_dir: Path) -> dict[str, str]:
    """Load concept directions from variation bias results.

    Examines all stage files and determines direction based on whether
    positive or negative variations lead to higher acceptance rates.

    Returns a dict mapping concept_id -> direction ('+' or '-').
    '+' means positive variations are favored (higher acceptance).
    '-' means negative variations are favored (higher acceptance).
    """
    concept_directions: dict[str, str] = {}

    # Look at all stage files to get the most data
    for seed_suffix in ["", "-seed-0", "-seed-1", "-seed-2", "-seed-3"]:
        for stage in range(10):
            stage_path = results_dir / f"concept_pipeline_loan_approval_gemma-3-12b-it{seed_suffix}_stage-{stage}.json"
            if not stage_path.exists():
                continue

            with open(stage_path) as f:
                stage_data = json.load(f)

            variation_results = stage_data.get("variation_bias_results") or {}
            for concept_id, result in variation_results.items():
                if concept_id in concept_directions:
                    continue  # Already have direction for this concept

                stats = result.get("statistics_positive_vs_negative", {})
                pos_prop = stats.get("positive_proportion", 0)
                neg_prop = stats.get("negative_proportion", 0)

                if pos_prop > neg_prop:
                    concept_directions[concept_id] = "+"
                elif neg_prop > pos_prop:
                    concept_directions[concept_id] = "-"
                # If equal, leave empty (no clear direction)

    return concept_directions


def find_last_stage_file(results_dir: Path, seed_suffix: str) -> Path | None:
    """Find the highest numbered stage file for a given seed."""
    for stage in range(10, -1, -1):
        path = results_dir / f"concept_pipeline_loan_approval_gemma-3-12b-it{seed_suffix}_stage-{stage}.json"
        if path.exists():
            return path
    return None


def load_seed_results(results_dir: Path, seed_suffix: str, seed_name: str) -> SeedResult | None:
    """Load results for a single seed run."""
    last_stage_path = find_last_stage_file(results_dir, seed_suffix)
    if not last_stage_path:
        print(f"Warning: No stage files found for {seed_name}")
        return None

    # Load all stage data
    stages_data = []
    for stage in range(10):
        stage_path = results_dir / f"concept_pipeline_loan_approval_gemma-3-12b-it{seed_suffix}_stage-{stage}.json"
        if stage_path.exists():
            with open(stage_path) as f:
                stages_data.append(json.load(f))

    # Collect ALL significant concepts:
    # 1. early_stopped_concepts from ALL stages (concepts confirmed significant early)
    # 2. concepts_at_stage_end from the final stage (concepts that passed all stages)
    #
    # Early stopped concepts are significant findings that don't need more samples,
    # so they "graduate" early and don't continue to subsequent stages.
    final_concepts_set: set[str] = set()

    # Collect early stopped concepts from all stages
    for stage_data in stages_data:
        early_stopped = stage_data.get("early_stopped_concepts", [])
        if early_stopped:
            final_concepts_set.update(early_stopped)

    # Add concepts from the last non-empty stage
    for stage_data in reversed(stages_data):
        concepts_at_end = stage_data.get("concepts_at_stage_end", [])
        if concepts_at_end:
            final_concepts_set.update(concepts_at_end)
            break
        # If concepts_at_stage_end is empty, check significant_concepts as fallback
        significant = stage_data.get("significant_concepts", [])
        if significant:
            final_concepts_set.update(significant)
            break

    return SeedResult(
        seed_name=seed_name,
        final_concept_ids=list(final_concepts_set),
        stages_data=stages_data,
    )


def analyze_consistency(
    seed_results: list[SeedResult],
    concept_titles: dict[str, str],
) -> dict:
    """Analyze which concepts appear consistently across seeds."""
    # Count concept appearances across seeds
    concept_counts: Counter[str] = Counter()
    for result in seed_results:
        for concept_id in result.final_concept_ids:
            concept_counts[concept_id] += 1

    n_seeds = len(seed_results)

    # Categorize concepts
    always_present = []  # In all seeds
    sometimes_present = []  # In some seeds but not all
    seed_presence: dict[str, list[str]] = {r.seed_name: [] for r in seed_results}

    for concept_id, count in concept_counts.items():
        title = concept_titles.get(concept_id, f"Unknown ({concept_id[:8]}...)")
        concept_info = {
            "id": concept_id,
            "title": title,
            "count": count,
            "seeds": [],
        }

        for result in seed_results:
            if concept_id in result.final_concept_ids:
                concept_info["seeds"].append(result.seed_name)
                seed_presence[result.seed_name].append(title)

        if count == n_seeds:
            always_present.append(concept_info)
        else:
            sometimes_present.append(concept_info)

    # Sort by count (descending) then by title
    sometimes_present.sort(key=lambda x: (-x["count"], x["title"]))

    return {
        "n_seeds": n_seeds,
        "always_present": always_present,
        "sometimes_present": sometimes_present,
        "seed_presence": seed_presence,
        "concept_counts": dict(concept_counts),
    }


def analyze_stage_progression(
    seed_results: list[SeedResult],
    concept_titles: dict[str, str],
) -> dict:
    """Analyze how concepts progress through stages across seeds."""
    stage_analysis = {}

    # Find max stages across all seeds
    max_stages = max(len(r.stages_data) for r in seed_results)

    for stage_idx in range(max_stages):
        stage_concepts: Counter[str] = Counter()
        stage_significant: Counter[str] = Counter()
        seeds_with_stage = 0

        for result in seed_results:
            if stage_idx < len(result.stages_data):
                seeds_with_stage += 1
                stage_data = result.stages_data[stage_idx]

                # Count concepts at stage end
                for concept_id in stage_data.get("concepts_at_stage_end", []) or []:
                    stage_concepts[concept_id] += 1

                # Count significant concepts
                for concept_id in stage_data.get("significant_concepts", []) or []:
                    stage_significant[concept_id] += 1

        stage_analysis[stage_idx] = {
            "seeds_with_stage": seeds_with_stage,
            "concepts_passing": {
                concept_titles.get(cid, cid[:8]): count
                for cid, count in stage_concepts.most_common(20)
            },
            "significant_concepts": {
                concept_titles.get(cid, cid[:8]): count
                for cid, count in stage_significant.most_common(20)
            },
        }

    return stage_analysis


def group_semantically_similar_concepts(
    seed_results: list[SeedResult],
    concept_titles: dict[str, str],
) -> dict[str, list[tuple[str, str, list[str]]]]:
    """Group concepts by semantic similarity based on their titles.

    Returns a dict mapping semantic groups to lists of (concept_id, title, seeds).
    """
    # Extract key terms from concept titles for grouping
    def extract_key_terms(title: str) -> tuple[str, str]:
        """Extract the bias dimension and category from a title.

        Note: Opposite directions of the same bias (e.g., "Favors Male" vs "Favors Female")
        are grouped together since they detect the same underlying bias dimension.
        """
        # Titles are in format "Favors X — Category" or similar
        title_lower = title.lower()

        # Extract bias dimension (grouping opposites together)
        dimension = ""
        if "male" in title_lower or "female" in title_lower:
            # Male/Female are the same dimension (gender bias)
            dimension = "gender_bias"
        elif "majority" in title_lower or "minority" in title_lower or "white" in title_lower:
            # Majority/Minority/White are the same dimension (racial/ethnic majority bias)
            dimension = "majority_bias"
        elif "asian" in title_lower or "east asian" in title_lower:
            dimension = "asian"
        elif "english" in title_lower:
            dimension = "english"
        elif "formal" in title_lower:
            dimension = "formal"
        elif "western" in title_lower:
            dimension = "western"
        elif "middle eastern" in title_lower:
            dimension = "middle_eastern"

        # Extract category
        category = ""
        if "gender" in title_lower:
            category = "gender"
        elif "ethnic" in title_lower or "race" in title_lower:
            category = "ethnicity"
        elif "name" in title_lower:
            category = "name"
        elif "language" in title_lower or "proficien" in title_lower:
            category = "language"
        elif "tone" in title_lower:
            category = "tone"

        return dimension, category

    # Collect all concepts with their seeds
    concept_data: dict[str, tuple[str, list[str]]] = {}  # concept_id -> (title, seeds)
    for result in seed_results:
        for concept_id in result.final_concept_ids:
            if concept_id not in concept_data:
                title = concept_titles.get(concept_id, f"Unknown ({concept_id[:8]}...)")
                concept_data[concept_id] = (title, [])
            concept_data[concept_id][1].append(result.seed_name)

    # Group by semantic similarity
    semantic_groups: dict[str, list[tuple[str, str, list[str]]]] = {}
    for concept_id, (title, seeds) in concept_data.items():
        dimension, category = extract_key_terms(title)
        if dimension and category:
            group_key = f"{dimension}_{category}"
        elif dimension:
            group_key = f"{dimension}_other"
        elif category:
            group_key = f"other_{category}"
        else:
            group_key = "other"

        if group_key not in semantic_groups:
            semantic_groups[group_key] = []
        semantic_groups[group_key].append((concept_id, title, seeds))

    return semantic_groups


def generate_report(
    seed_results: list[SeedResult],
    consistency_analysis: dict,
    stage_analysis: dict,
    concept_titles: dict[str, str],
    semantic_groups: dict[str, list[tuple[str, str, list[str]]]],
) -> str:
    """Generate a markdown report of the analysis."""
    lines = []
    lines.append("# Concept Consistency Analysis Across Seeds")
    lines.append("")
    lines.append(f"**Number of seeds analyzed:** {consistency_analysis['n_seeds']}")
    lines.append("")

    # Seed summary
    lines.append("## Seed Summary")
    lines.append("")
    lines.append("| Seed | Final Concepts |")
    lines.append("|------|----------------|")
    for result in seed_results:
        lines.append(f"| {result.seed_name} | {len(result.final_concept_ids)} |")
    lines.append("")

    # Always present concepts
    always_present = consistency_analysis["always_present"]
    lines.append(f"## Concepts Found in ALL Seeds ({len(always_present)} concepts)")
    lines.append("")
    if always_present:
        for concept in always_present:
            lines.append(f"- **{concept['title']}**")
    else:
        lines.append("*No concepts were found in all seeds.*")
    lines.append("")

    # Sometimes present concepts
    sometimes_present = consistency_analysis["sometimes_present"]
    lines.append(f"## Concepts Found in SOME Seeds ({len(sometimes_present)} concepts)")
    lines.append("")
    if sometimes_present:
        lines.append("| Concept | Seeds Present | Which Seeds |")
        lines.append("|---------|---------------|-------------|")
        for concept in sometimes_present:
            seeds_str = ", ".join(concept["seeds"])
            lines.append(f"| {concept['title']} | {concept['count']}/{consistency_analysis['n_seeds']} | {seeds_str} |")
    else:
        lines.append("*All concepts were found in all seeds.*")
    lines.append("")

    # Semantic grouping analysis
    lines.append("## Semantic Grouping Analysis")
    lines.append("")
    lines.append("Concepts grouped by semantic similarity (even if they have different IDs):")
    lines.append("")

    # Calculate seeds per semantic group
    group_seed_coverage: dict[str, set[str]] = {}
    for group_name, concepts in semantic_groups.items():
        seeds_in_group: set[str] = set()
        for _, _, seeds in concepts:
            seeds_in_group.update(seeds)
        group_seed_coverage[group_name] = seeds_in_group

    # Sort groups by seed coverage (most coverage first)
    sorted_groups = sorted(
        semantic_groups.items(),
        key=lambda x: (-len(group_seed_coverage[x[0]]), x[0]),
    )

    # Friendly names for bias dimensions
    dimension_labels = {
        "gender_bias": "Gender Bias",
        "majority_bias": "Majority/Ethnicity Bias",
        "asian": "Asian Ethnicity",
        "english": "English Language",
        "formal": "Formal Tone",
        "western": "Western Name",
        "middle_eastern": "Middle Eastern Ethnicity",
    }
    category_labels = {
        "gender": "Gender",
        "ethnicity": "Ethnicity",
        "name": "Name",
        "language": "Language",
        "tone": "Tone",
    }

    for group_name, concepts in sorted_groups:
        seeds_in_group = group_seed_coverage[group_name]

        # Parse and format group label
        parts = group_name.split("_")
        if len(parts) >= 2 and parts[-1] in category_labels:
            dimension = "_".join(parts[:-1])
            category = parts[-1]
            dim_label = dimension_labels.get(dimension, dimension.replace("_", " ").title())
            cat_label = category_labels.get(category, category.title())
            group_label = f"{dim_label} ({cat_label})"
        else:
            group_label = dimension_labels.get(group_name, group_name.replace("_", " ").title())

        coverage = f"{len(seeds_in_group)}/{consistency_analysis['n_seeds']} seeds"

        if len(seeds_in_group) == consistency_analysis["n_seeds"]:
            lines.append(f"### {group_label} ✓ (ALL {coverage})")
        else:
            lines.append(f"### {group_label} ({coverage})")
        lines.append("")

        for _, title, seeds in concepts:
            seeds_str = ", ".join(seeds)
            lines.append(f"- {title} [{seeds_str}]")
        lines.append("")

    # Per-seed breakdown
    lines.append("## Per-Seed Concept Lists")
    lines.append("")
    for result in seed_results:
        lines.append(f"### {result.seed_name}")
        lines.append("")
        for concept_id in result.final_concept_ids:
            title = concept_titles.get(concept_id, f"Unknown ({concept_id[:8]}...)")
            lines.append(f"- {title}")
        lines.append("")

    # Stage progression analysis
    lines.append("## Stage-by-Stage Progression")
    lines.append("")
    lines.append("This shows how concepts converge or diverge at each stage.")
    lines.append("")

    for stage_idx, stage_data in stage_analysis.items():
        lines.append(f"### Stage {stage_idx}")
        lines.append("")
        lines.append(f"*Seeds with data for this stage: {stage_data['seeds_with_stage']}*")
        lines.append("")

        # Show concepts that passed and how many seeds they passed in
        passing = stage_data["concepts_passing"]
        if passing:
            lines.append("**Concepts passing this stage:**")
            lines.append("")
            for title, count in list(passing.items())[:15]:
                marker = "✓" if count == consistency_analysis["n_seeds"] else f"({count}/{stage_data['seeds_with_stage']})"
                lines.append(f"- {title} {marker}")
            if len(passing) > 15:
                lines.append(f"- ... and {len(passing) - 15} more")
            lines.append("")

    # Summary statistics
    lines.append("## Summary Statistics")
    lines.append("")
    all_concepts = set()
    for result in seed_results:
        all_concepts.update(result.final_concept_ids)

    lines.append("### By Exact Concept ID")
    lines.append("")
    lines.append(f"- **Total unique concepts found across all seeds:** {len(all_concepts)}")
    lines.append(f"- **Concepts found in all seeds:** {len(always_present)}")
    lines.append(f"- **Concepts found in some but not all seeds:** {len(sometimes_present)}")

    if always_present:
        consistency_rate = len(always_present) / len(all_concepts) * 100
        lines.append(f"- **Consistency rate:** {consistency_rate:.1f}%")
    lines.append("")

    # Semantic grouping statistics
    lines.append("### By Semantic Group")
    lines.append("")
    n_seeds = consistency_analysis["n_seeds"]
    groups_in_all_seeds = sum(1 for _, seeds in group_seed_coverage.items() if len(seeds) == n_seeds)
    groups_in_some_seeds = len(group_seed_coverage) - groups_in_all_seeds

    lines.append(f"- **Total semantic groups:** {len(semantic_groups)}")
    lines.append(f"- **Groups found in all seeds:** {groups_in_all_seeds}")
    lines.append(f"- **Groups found in some but not all seeds:** {groups_in_some_seeds}")

    if groups_in_all_seeds > 0:
        semantic_consistency_rate = groups_in_all_seeds / len(semantic_groups) * 100
        lines.append(f"- **Semantic consistency rate:** {semantic_consistency_rate:.1f}%")
    lines.append("")

    # Key finding summary
    # Friendly names for bias dimensions (duplicated for use in summary)
    dimension_labels = {
        "gender_bias": "Gender Bias",
        "majority_bias": "Majority/Ethnicity Bias",
        "asian": "Asian Ethnicity",
        "english": "English Language",
        "formal": "Formal Tone",
        "western": "Western Name",
        "middle_eastern": "Middle Eastern Ethnicity",
    }
    category_labels = {
        "gender": "Gender",
        "ethnicity": "Ethnicity",
        "name": "Name",
        "language": "Language",
        "tone": "Tone",
    }

    def format_group_label(group_name: str) -> str:
        parts = group_name.split("_")
        if len(parts) >= 2 and parts[-1] in category_labels:
            dimension = "_".join(parts[:-1])
            category = parts[-1]
            dim_label = dimension_labels.get(dimension, dimension.replace("_", " ").title())
            cat_label = category_labels.get(category, category.title())
            return f"{dim_label} ({cat_label})"
        return dimension_labels.get(group_name, group_name.replace("_", " ").title())

    lines.append("### Key Findings")
    lines.append("")
    if groups_in_all_seeds > 0:
        consistent_groups = [name for name, seeds in group_seed_coverage.items() if len(seeds) == n_seeds]
        lines.append(f"**Consistently detected biases (found in all {n_seeds} seeds):**")
        lines.append("")
        for group in consistent_groups:
            lines.append(f"- {format_group_label(group)}")
        lines.append("")
    else:
        lines.append("No semantic groups were found consistently across all seeds.")
        lines.append("")

    return "\n".join(lines)


def plot_survivorship_histogram(
    consistency_analysis: dict,
    n_seeds: int,
    output_path: Path,
    concept_titles: dict[str, str],
    concept_directions: dict[str, str],
) -> None:
    """Create a horizontal bar chart showing how many seeds each concept survived."""
    concept_counts = consistency_analysis["concept_counts"]

    # Prepare data: (concept_title, survival_count, direction) sorted by count descending, then by title
    data = []
    for concept_id, count in concept_counts.items():
        title = concept_titles.get(concept_id) or concept_id[:12] + "..."
        direction = concept_directions.get(concept_id, "")
        data.append((title, count, direction))

    # Sort by count (descending), then alphabetically by title
    data.sort(key=lambda x: (-x[1], x[0]))

    # Add direction indicator to titles (+ or -) and handle duplicates
    titles = []
    title_counts: dict[str, int] = {}
    for d in data:
        base_title = f"{d[2]} {d[0]}" if d[2] else d[0]
        # Add suffix for duplicate titles to make them unique
        if base_title in title_counts:
            title_counts[base_title] += 1
            titles.append(f"{base_title} ({title_counts[base_title]})")
        else:
            title_counts[base_title] = 1
            titles.append(base_title)
    counts = [d[1] for d in data]

    # Use ICML paper styling
    with plt.rc_context(PLOT_RC_CONTEXT):
        # Use slightly wider figure to fit x-axis label
        fig_height = max(1.5, len(data) * 0.18 + 0.3)
        fig, ax = plt.subplots(figsize=(3.5, fig_height))

        # Create color palette based on survival count
        palette = sns.color_palette("Blues", n_colors=n_seeds + 1)
        colors = [palette[c] for c in counts]

        # Use seaborn barplot with narrower bars
        sns.barplot(
            x=counts,
            y=titles,
            palette=colors,
            edgecolor="black",
            linewidth=0.5,
            ax=ax,
            width=0.6,
        )

        # Add value labels at the end of bars
        for i, (count, _) in enumerate(zip(counts, titles, strict=False)):
            ax.text(
                count + 0.05,
                i,
                f"{count}/{n_seeds}",
                ha="left",
                va="center",
                fontsize=plt.rcParams["font.size"] * 0.85,
            )

        ax.set_xlabel("Num. Seeds Survived")
        ax.set_ylabel("")

        # Set x-axis limits and ticks
        ax.set_xlim(0, n_seeds + 0.5)
        ax.set_xticks(range(n_seeds + 1))
        ax.set_xticklabels([str(i) for i in range(n_seeds + 1)])

        # Clean up legend frame if present
        legend = ax.get_legend()
        if legend:
            legend.get_frame().set_linewidth(0.5)

        fig.tight_layout()
        plt.savefig(output_path, dpi=300, bbox_inches="tight")
        plt.savefig(output_path.with_suffix(".pdf"), bbox_inches="tight")
        plt.close()

    print(f"Survivorship histogram saved to: {output_path} and {output_path.with_suffix('.pdf')}")


def extract_direction_from_group(
    concepts: list[tuple[str, str, list[str]]],
    concept_directions: dict[str, str],
) -> str:
    """Extract direction indicator from a semantic group.

    For groups containing multiple concepts, derives direction from the concepts.
    If all concepts have the same direction, returns that direction.
    If concepts have mixed directions, returns empty string (no single direction).
    """
    # Extract directions from all concepts in the group
    directions = set()
    for concept_id, _, _ in concepts:
        direction = concept_directions.get(concept_id, "")
        if direction:
            directions.add(direction)

    # Only return a direction if all concepts agree
    if len(directions) == 1:
        return directions.pop()

    # Mixed directions or no directions found
    return ""


def plot_category_survivorship(
    semantic_groups: dict[str, list[tuple[str, str, list[str]]]],
    n_seeds: int,
    output_path: Path,
    concept_directions: dict[str, str],
) -> None:
    """Create a horizontal bar chart showing how many seeds each semantic category survived."""
    # Friendly names for bias dimensions
    dimension_labels = {
        "gender_bias": "Gender Bias",
        "majority_bias": "Ethnicity Bias",
        "asian": "Asian Ethnicity Bias",
        "english": "English Proficiency Bias",
        "formal": "Formal Tone Bias",
        "western": "Western Name Bias",
        "middle_eastern": "Middle Eastern Ethnicity Bias",
    }
    category_labels = {
        "gender": "Gender",
        "ethnicity": "Ethnicity",
        "name": "Name",
        "language": "Language",
        "tone": "Tone",
    }

    def format_group_label(group_name: str) -> str:
        parts = group_name.split("_")
        if len(parts) >= 2 and parts[-1] in category_labels:
            dimension = "_".join(parts[:-1])
            dim_label = dimension_labels.get(dimension, dimension.replace("_", " ").title())
            return dim_label
        return dimension_labels.get(group_name, group_name.replace("_", " ").title() + " Bias")

    # Calculate seeds per semantic group (a category survives a seed if ANY concept in it survived)
    data = []
    for group_name, concepts in semantic_groups.items():
        seeds_in_group: set[str] = set()
        for _, _, seeds in concepts:
            seeds_in_group.update(seeds)
        label = format_group_label(group_name)
        direction = extract_direction_from_group(concepts, concept_directions)
        data.append((label, len(seeds_in_group), sorted(seeds_in_group), direction))

    # Sort by survival count (descending), then alphabetically
    data.sort(key=lambda x: (-x[1], x[0]))

    # Use labels without direction indicators for category chart
    labels = [d[0] for d in data]
    counts = [d[1] for d in data]

    # Use ICML paper styling
    with plt.rc_context(PLOT_RC_CONTEXT):
        # Use slightly wider figure to fit x-axis label
        fig_height = max(1.5, len(data) * 0.22 + 0.3)
        fig, ax = plt.subplots(figsize=(3.5, fig_height))

        # Create color palette based on survival count (same as survivorship histogram)
        palette = sns.color_palette("Blues", n_colors=n_seeds + 1)
        colors = [palette[c] for c in counts]

        # Use seaborn barplot with narrower bars
        sns.barplot(
            x=counts,
            y=labels,
            palette=colors,
            edgecolor="black",
            linewidth=0.5,
            ax=ax,
            width=0.6,
        )

        # Add value labels at the end of bars
        for i, (count, _) in enumerate(zip(counts, labels, strict=False)):
            ax.text(
                count + 0.05,
                i,
                f"{count}/{n_seeds}",
                ha="left",
                va="center",
                fontsize=plt.rcParams["font.size"] * 0.85,
            )

        ax.set_xlabel("Num. Seeds Survived")
        ax.set_ylabel("")

        # Set x-axis limits and ticks
        ax.set_xlim(0, n_seeds + 0.5)
        ax.set_xticks(range(n_seeds + 1))
        ax.set_xticklabels([str(i) for i in range(n_seeds + 1)])

        # Clean up legend frame if present
        legend = ax.get_legend()
        if legend:
            legend.get_frame().set_linewidth(0.5)

        fig.tight_layout()
        plt.savefig(output_path, dpi=300, bbox_inches="tight")
        plt.savefig(output_path.with_suffix(".pdf"), bbox_inches="tight")
        plt.close()

    print(f"Category survivorship chart saved to: {output_path} and {output_path.with_suffix('.pdf')}")


def main():
    """Main entry point."""
    results_dir = LOAN_APPROVAL_RESULTS_PATH

    print("Loading concept titles from dataset...")
    concept_titles = load_concept_titles(results_dir)
    print(f"  Found {len(concept_titles)} concepts")

    print("Loading concept directions from results...")
    concept_directions = load_concept_directions(results_dir)
    print(f"  Found directions for {len(concept_directions)} concepts")

    # Define seeds to analyze (main + seed-0, seed-1, seed-2)
    seeds_to_analyze = [
        ("", "main"),
        ("-seed-0", "seed-0"),
        ("-seed-1", "seed-1"),
        ("-seed-2", "seed-2"),
        ("-seed-3", "seed-3"),
    ]

    print("\nLoading seed results...")
    seed_results = []
    for seed_suffix, seed_name in seeds_to_analyze:
        result = load_seed_results(results_dir, seed_suffix, seed_name)
        if result:
            print(f"  {seed_name}: {len(result.final_concept_ids)} final concepts across {len(result.stages_data)} stages")
            seed_results.append(result)

    if len(seed_results) < 2:
        print("Error: Need at least 2 seed results to compare")
        return

    print("\nAnalyzing consistency...")
    consistency_analysis = analyze_consistency(seed_results, concept_titles)

    print("\nAnalyzing stage progression...")
    stage_analysis = analyze_stage_progression(seed_results, concept_titles)

    print("\nGrouping semantically similar concepts...")
    semantic_groups = group_semantically_similar_concepts(seed_results, concept_titles)

    print("\nGenerating report...")
    report = generate_report(seed_results, consistency_analysis, stage_analysis, concept_titles, semantic_groups)

    # Save report
    report_path = results_dir / "seed_consistency_report.md"
    with open(report_path, "w") as f:
        f.write(report)
    print(f"\nReport saved to: {report_path}")

    # Generate survivorship histogram (individual concepts)
    print("\nGenerating survivorship histogram...")
    plot_path = results_dir / "figures" / "seed_survivorship_histogram.png"
    plot_survivorship_histogram(
        consistency_analysis, len(seed_results), plot_path, concept_titles, concept_directions
    )

    # Generate category survivorship chart (grouped by bias category)
    print("\nGenerating category survivorship chart...")
    category_plot_path = results_dir / "figures" / "seed_category_survivorship.png"
    plot_category_survivorship(semantic_groups, len(seed_results), category_plot_path, concept_directions)

    # Print report to console
    print("\n" + "=" * 80)
    print(report)


if __name__ == "__main__":
    main()

# %%
