import pickle
import numpy as np
import networkx as nx
from pathlib import Path
from tqdm import tqdm, trange
from joblib import Parallel, delayed
import os
import gc
from dataclasses import dataclass, field
from typing import Literal
from collections import defaultdict

from postprocess_new import (
    SelectionConfig,
    SelectionResult,
    DEFAULT_TEMPERATURES,
    DEFAULT_ORCA_PATH,
    EPSILON,
    SPEED_OPTIM,
    to_numpy,
    safe_logit,
    safe_sigmoid,
    sample_bernoulli,
    sample_rigid_topk,
    adj_to_nx,
    evaluate_selection,
)

if SPEED_OPTIM:
    from evaluate.distances_optim import (
        _netlsd_signature_batch,
        compute_gcd_embeddings_parallel,
        compute_portraits_batch,
        preprocess_portraits_for_divergence,
        portrait_divergence_fast,
        _get_device,
    )
else:
    from evaluate.distances import (
        _netlsd_signature,
        _gcm_orca_like_theirs,
    )
    from netrd.distance.portrait_divergence import portrait


class ChunkedProbLoader:
    def __init__(self, chunks_meta_path: str):
        with open(chunks_meta_path, "rb") as f:
            meta = pickle.load(f)
        self.chunk_paths = [Path(p) for p in meta["chunk_paths"]]
        self.total_count = meta["total_count"]
        self.n_vertices = meta["n_vertices"]
        self._chunk_start_indices = []
        self._chunk_sizes = []
        current_idx = 0
        for chunk_path in self.chunk_paths:
            with open(chunk_path, "rb") as f:
                chunk = pickle.load(f)
            chunk_size = len(chunk)
            self._chunk_start_indices.append(current_idx)
            self._chunk_sizes.append(chunk_size)
            current_idx += chunk_size
        self._cached_chunk_idx = None
        self._cached_chunk_data = None

    def __len__(self):
        return self.total_count

    def _find_chunk_for_index(self, idx: int) -> tuple[int, int]:
        for chunk_idx, (start, size) in enumerate(zip(self._chunk_start_indices, self._chunk_sizes)):
            if start <= idx < start + size:
                return chunk_idx, idx - start
        raise IndexError(f"Index {idx} out of range [0, {self.total_count})")

    def __getitem__(self, idx: int) -> np.ndarray:
        if idx < 0:
            idx = self.total_count + idx
        if idx < 0 or idx >= self.total_count:
            raise IndexError(f"Index {idx} out of range [0, {self.total_count})")
        chunk_idx, offset = self._find_chunk_for_index(idx)
        if self._cached_chunk_idx != chunk_idx:
            with open(self.chunk_paths[chunk_idx], "rb") as f:
                self._cached_chunk_data = pickle.load(f)
            self._cached_chunk_idx = chunk_idx
        return self._cached_chunk_data[offset]

    def get_batch(self, indices: list[int]) -> list[np.ndarray]:
        idx_with_orig = [(idx, i) for i, idx in enumerate(indices)]
        idx_with_orig.sort(key=lambda x: self._find_chunk_for_index(x[0])[0])
        results = [None] * len(indices)
        for idx, orig_pos in idx_with_orig:
            results[orig_pos] = self[idx]
        return results


def load_prob_matrices(prob_input: list[np.ndarray] | str) -> list[np.ndarray] | ChunkedProbLoader:
    if isinstance(prob_input, str):
        print(f"Loading probability matrices from chunks: {prob_input}")
        return ChunkedProbLoader(prob_input)
    else:
        return prob_input


@dataclass
class SampleMetadata:
    matrix_idx: int
    temperature: float
    random_seed: int

    def sample(self, prob_matrix: np.ndarray) -> np.ndarray:
        np.random.seed(self.random_seed)
        if self.temperature == np.inf:
            return sample_rigid_topk(prob_matrix)
        else:
            return sample_bernoulli(prob_matrix, self.temperature)


def compute_embeddings_chunked(
    prob_matrices: list[np.ndarray] | ChunkedProbLoader,
    metadata_list: list[SampleMetadata],
    metric: str,
    orca_path: str,
    chunk_size: int,
    n_jobs: int = -1,
    verbose: bool = True,
) -> np.ndarray:
    n_samples = len(metadata_list)
    all_embeddings = []
    n_chunks = (n_samples + chunk_size - 1) // chunk_size

    if verbose:
        print(f"    Computing {metric} embeddings in {n_chunks} chunks of size {chunk_size}...")

    chunk_iter = range(0, n_samples, chunk_size)
    if verbose:
        chunk_iter = tqdm(chunk_iter, desc="    Chunks", total=n_chunks)

    for chunk_start in chunk_iter:
        chunk_end = min(chunk_start + chunk_size, n_samples)
        chunk_metadata = metadata_list[chunk_start:chunk_end]
        chunk_graphs = []
        for meta in chunk_metadata:
            adj = meta.sample(prob_matrices[meta.matrix_idx])
            chunk_graphs.append(adj_to_nx(adj))

        if SPEED_OPTIM:
            if metric in ("netlsd_heat", "netlsd_wave"):
                kernel = "heat" if metric == "netlsd_heat" else "wave"
                chunk_emb = _netlsd_signature_batch(chunk_graphs, kernel)
                chunk_emb = np.array(chunk_emb)
            elif metric == "gcd":
                chunk_emb = compute_gcd_embeddings_parallel(
                    chunk_graphs, orca_path, verbose=False
                )
                chunk_emb = np.array(chunk_emb)
            else:
                raise ValueError(f"Unknown metric: {metric}")
        else:
            chunk_emb = []
            for g in chunk_graphs:
                if metric == "netlsd_heat":
                    emb = _netlsd_signature(g, "heat")
                elif metric == "netlsd_wave":
                    emb = _netlsd_signature(g, "wave")
                elif metric == "gcd":
                    orca_prefix = os.path.dirname(os.path.abspath(orca_path))
                    gcm = _gcm_orca_like_theirs(g, orca_prefix=orca_prefix, graphlet_size=4)
                    emb = gcm[np.triu_indices_from(gcm, k=1)]
                else:
                    raise ValueError(f"Unknown metric: {metric}")
                chunk_emb.append(emb)
            chunk_emb = np.array(chunk_emb)

        all_embeddings.append(chunk_emb)
        del chunk_graphs
        gc.collect()

    return np.vstack(all_embeddings)


def compute_portraits_chunked(
    prob_matrices: list[np.ndarray] | ChunkedProbLoader,
    metadata_list: list[SampleMetadata],
    chunk_size: int,
    n_jobs: int = -1,
    verbose: bool = True,
) -> list[np.ndarray]:
    n_samples = len(metadata_list)
    all_portraits = []
    n_chunks = (n_samples + chunk_size - 1) // chunk_size

    if verbose:
        print(f"    Computing portraits in {n_chunks} chunks of size {chunk_size}...")

    chunk_iter = range(0, n_samples, chunk_size)
    if verbose:
        chunk_iter = tqdm(chunk_iter, desc="    Chunks", total=n_chunks)

    for chunk_start in chunk_iter:
        chunk_end = min(chunk_start + chunk_size, n_samples)
        chunk_metadata = metadata_list[chunk_start:chunk_end]
        chunk_graphs = []
        for meta in chunk_metadata:
            adj = meta.sample(prob_matrices[meta.matrix_idx])
            chunk_graphs.append(adj_to_nx(adj))

        if SPEED_OPTIM:
            chunk_portraits = compute_portraits_batch(chunk_graphs, n_jobs=n_jobs, verbose=False)
        else:
            from postprocess_new import compute_portrait_safe
            chunk_portraits = [compute_portrait_safe(g) for g in chunk_graphs]

        all_portraits.extend(chunk_portraits)
        del chunk_graphs
        gc.collect()

    return all_portraits


def greedy_maxsum_embeddings(
    embeddings: np.ndarray,
    k: int,
    objective: Literal["energy", "average"],
    start_indices: list[int] | None = None,
    verbose: bool = True,
) -> list[int]:
    N = embeddings.shape[0]
    if k > N:
        raise ValueError(f"Cannot select {k} graphs from {N} candidates")

    selected = []
    selected_mask = np.zeros(N, dtype=bool)
    sum_scores = np.zeros(N)

    if start_indices is None:
        start_indices = [0]

    for idx in start_indices:
        selected.append(idx)
        selected_mask[idx] = True
        dists = np.linalg.norm(embeddings - embeddings[idx], axis=1)
        if objective == "energy":
            with np.errstate(divide='ignore'):
                score_update = -1.0 / dists
                score_update[~np.isfinite(score_update)] = 0
            sum_scores += score_update
        else:
            sum_scores += dists

    sum_scores[selected_mask] = -np.inf
    remaining = k - len(selected)
    if remaining <= 0:
        return selected[:k]

    iterator = (
        tqdm(range(remaining), desc="    Selecting") if verbose else range(remaining)
    )

    for _ in iterator:
        best = np.argmax(sum_scores)
        if sum_scores[best] == -np.inf:
            break
        selected.append(best)
        selected_mask[best] = True
        new_dists = np.linalg.norm(embeddings - embeddings[best], axis=1)
        if objective == "energy":
            with np.errstate(divide='ignore'):
                score_update = -1.0 / new_dists
                score_update[~np.isfinite(score_update)] = 0
            sum_scores += score_update
        else:
            sum_scores += new_dists
        sum_scores[selected_mask] = -np.inf

    return selected


def greedy_maxsum_portraits(
    portraits: list[np.ndarray],
    k: int,
    objective: Literal["energy", "average"],
    start_indices: list[int] | None = None,
    n_jobs: int = -1,
    verbose: bool = True,
    precomputed_distributions: list[np.ndarray] | None = None,
) -> list[int]:
    N = len(portraits)
    if k > N:
        raise ValueError(f"Cannot select {k} graphs from {N} candidates")

    use_fast_divergence = SPEED_OPTIM and precomputed_distributions is not None

    if use_fast_divergence:
        def compute_divergence(idx: int, selected_idx: int) -> tuple[int, float]:
            d = portrait_divergence_fast(
                precomputed_distributions[idx],
                precomputed_distributions[selected_idx]
            )
            return idx, d
    else:
        from postprocess_new import compute_portrait_distance
        def compute_divergence(idx: int, selected_idx: int) -> tuple[int, float]:
            d = compute_portrait_distance(portraits[idx], portraits[selected_idx])
            return idx, d

    cumulative_scores = np.zeros(N)
    available_mask = np.ones(N, dtype=bool)
    selected = []

    if start_indices is None:
        start_indices = [0]

    for start_idx in start_indices:
        selected.append(start_idx)
        available_mask[start_idx] = False
        available_indices = np.where(available_mask)[0]
        if use_fast_divergence:
            results = [compute_divergence(idx, start_idx) for idx in available_indices]
        else:
            with Parallel(n_jobs=n_jobs, prefer="threads") as parallel:
                results = parallel(
                    delayed(compute_divergence)(idx, start_idx)
                    for idx in available_indices
                )
        for idx, dist in results:
            if objective == "energy":
                cumulative_scores[idx] += -1.0 / (dist + EPSILON)
            else:
                cumulative_scores[idx] += dist

    cumulative_scores[~available_mask] = -np.inf
    remaining = k - len(selected)
    if remaining <= 0:
        return selected[:k]

    iterator = trange(remaining, desc="    Selecting") if verbose else range(remaining)

    for _ in iterator:
        best_idx = np.argmax(cumulative_scores)
        if cumulative_scores[best_idx] == -np.inf:
            break
        selected.append(best_idx)
        available_mask[best_idx] = False
        cumulative_scores[best_idx] = -np.inf
        available_indices = np.where(available_mask)[0]
        if len(available_indices) == 0:
            break
        if use_fast_divergence:
            results = [compute_divergence(idx, best_idx) for idx in available_indices]
        else:
            with Parallel(n_jobs=n_jobs, prefer="threads") as parallel:
                results = parallel(
                    delayed(compute_divergence)(idx, best_idx)
                    for idx in available_indices
                )
        for idx, dist in results:
            if objective == "energy":
                cumulative_scores[idx] += -1.0 / (dist + EPSILON)
            else:
                cumulative_scores[idx] += dist

    return selected


def find_most_dissimilar_pair_embeddings(
    embeddings: np.ndarray,
    verbose: bool = True,
) -> tuple[int, int]:
    N = embeddings.shape[0]
    if verbose:
        print(f"    Finding most dissimilar pair (O(N) over {N} embeddings)...")
    centroid = embeddings.mean(axis=0)
    dists_to_centroid = np.linalg.norm(embeddings - centroid, axis=1)
    idx_a = np.argmax(dists_to_centroid)
    dists_to_a = np.linalg.norm(embeddings - embeddings[idx_a], axis=1)
    idx_b = np.argmax(dists_to_a)
    if verbose:
        print(f"      Point A: {idx_a}, Point B: {idx_b}, dist={dists_to_a[idx_b]:.6f}")
    return idx_a, idx_b


def find_most_dissimilar_pair_portraits(
    portraits: list[np.ndarray],
    n_jobs: int = -1,
    verbose: bool = True,
    precomputed_distributions: list[np.ndarray] | None = None,
) -> tuple[int, int]:
    N = len(portraits)
    if verbose:
        print(f"    Finding most dissimilar pair (O(N) over {N} portraits)...")
    if N <= 2:
        return 0, min(1, N - 1)

    use_fast_divergence = SPEED_OPTIM and precomputed_distributions is not None
    start_idx = 0

    if use_fast_divergence:
        dists_to_start = [
            portrait_divergence_fast(precomputed_distributions[start_idx], precomputed_distributions[i])
            for i in range(N)
        ]
    else:
        from postprocess_new import compute_portrait_distance
        with Parallel(n_jobs=n_jobs, prefer="threads") as parallel:
            dists_to_start = parallel(
                delayed(compute_portrait_distance)(portraits[start_idx], portraits[i])
                for i in range(N)
            )

    idx_a = int(np.argmax(dists_to_start))

    if use_fast_divergence:
        dists_to_a = [
            portrait_divergence_fast(precomputed_distributions[idx_a], precomputed_distributions[i])
            for i in range(N)
        ]
    else:
        with Parallel(n_jobs=n_jobs, prefer="threads") as parallel:
            dists_to_a = parallel(
                delayed(compute_portrait_distance)(portraits[idx_a], portraits[i])
                for i in range(N)
            )

    idx_b = int(np.argmax(dists_to_a))
    if verbose:
        print(f"      Point A: {idx_a}, Point B: {idx_b}, dist={dists_to_a[idx_b]:.6f}")
    return idx_a, idx_b


def run_selection_large(
    prob_matrices: list[np.ndarray] | str,
    config: SelectionConfig,
) -> SelectionResult:
    verbose = config.verbose
    rng = np.random.default_rng(config.random_seed)
    use_portrait = config.metric == "portrait_div"
    prob_loader = load_prob_matrices(prob_matrices)
    n_matrices = len(prob_loader)
    first_matrix = prob_loader[0]
    n_vertices = first_matrix.shape[0]
    k_final = config.k_select
    chunk_size = getattr(config, 'chunk_size', 100)

    total_budget = config.sampling_budget
    phase1_budget = min(n_matrices, total_budget // 10)
    phase2_budget = total_budget - phase1_budget
    n_top_matrices = min(config.n_top_matrices, n_matrices, phase1_budget)
    samples_per_matrix = phase2_budget // n_top_matrices if n_top_matrices > 0 else 0

    if verbose:
        print(f"\n{'=' * 60}")
        print("MEMORY-EFFICIENT SELECTION FOR LARGE GRAPHS")
        print(f"{'=' * 60}")
        print(f"  Metric: {config.metric}")
        print(f"  Total budget: {total_budget}")
        print(f"  Phase 1 (matrix selection): {phase1_budget} samples")
        print(f"  Phase 2 (focused sampling): {phase2_budget} samples")
        print(f"  Top matrices to select: {n_top_matrices}")
        print(f"  Samples per selected matrix: {samples_per_matrix}")
        print(f"  Chunk size: {chunk_size}")
        print(f"  Total probability matrices: {n_matrices}")
        print(f"  Graph size: {n_vertices} vertices")
        print(f"  Using chunked loader: {isinstance(prob_loader, ChunkedProbLoader)}")
        print(f"  SPEED_OPTIM: {SPEED_OPTIM}")
        print()

    if verbose:
        print(f"{'=' * 40}")
        print("PHASE 1: MATRIX DIVERSITY DISCOVERY")
        print(f"{'=' * 40}")

    matrices_to_sample = min(n_matrices, phase1_budget)
    phase1_metadata = []
    for idx in range(matrices_to_sample):
        seed = int(rng.integers(0, 2**31))
        phase1_metadata.append(SampleMetadata(
            matrix_idx=idx,
            temperature=1.0,
            random_seed=seed,
        ))

    if verbose:
        print(f"  Sampling {len(phase1_metadata)} representative graphs...")

    if use_portrait:
        phase1_portraits = compute_portraits_chunked(
            prob_loader, phase1_metadata, chunk_size, config.n_jobs, verbose
        )
        phase1_distributions = None
        if SPEED_OPTIM:
            if verbose:
                print(f"    Preprocessing portraits for fast divergence...")
            phase1_distributions, _ = preprocess_portraits_for_divergence(
                phase1_portraits, verbose=False
            )
        if config.start_mode == "dissimilar":
            idx1, idx2 = find_most_dissimilar_pair_portraits(
                phase1_portraits, config.n_jobs, verbose,
                precomputed_distributions=phase1_distributions
            )
            start_indices = [idx1, idx2]
        else:
            start_indices = [0]
        if verbose:
            print(f"  Running greedy selection for {n_top_matrices} diverse matrices...")
        selected_matrix_indices = greedy_maxsum_portraits(
            phase1_portraits,
            n_top_matrices,
            config.objective,
            start_indices=start_indices,
            n_jobs=config.n_jobs,
            verbose=verbose,
            precomputed_distributions=phase1_distributions,
        )
        del phase1_portraits
        if phase1_distributions is not None:
            del phase1_distributions
        gc.collect()
    else:
        phase1_embeddings = compute_embeddings_chunked(
            prob_loader, phase1_metadata, config.metric,
            config.orca_path, chunk_size, config.n_jobs, verbose
        )
        if verbose:
            print(f"    Embedding shape: {phase1_embeddings.shape}")
        if config.start_mode == "dissimilar":
            idx1, idx2 = find_most_dissimilar_pair_embeddings(phase1_embeddings, verbose)
            start_indices = [idx1, idx2]
        else:
            start_indices = [0]
        if verbose:
            print(f"  Running greedy selection for {n_top_matrices} diverse matrices...")
        selected_matrix_indices = greedy_maxsum_embeddings(
            phase1_embeddings,
            n_top_matrices,
            config.objective,
            start_indices=start_indices,
            verbose=verbose,
        )
        del phase1_embeddings
        gc.collect()

    top_matrix_ids = [phase1_metadata[i].matrix_idx for i in selected_matrix_indices]
    if verbose:
        print(f"  Selected {len(top_matrix_ids)} diverse matrices")
        selected_means = [np.mean(prob_loader[i]) for i in top_matrix_ids]
        print(f"  Selected matrix mean probabilities: min={min(selected_means):.3f}, "
              f"max={max(selected_means):.3f}, mean={np.mean(selected_means):.3f}")

    if verbose:
        print(f"\n{'=' * 40}")
        print("PHASE 2: FOCUSED SAMPLING")
        print(f"{'=' * 40}")

    phase2_metadata = []
    use_continuous = config.use_continuous_temp
    if use_continuous:
        temp_min, temp_max = config.temp_range
        log_temp_min, log_temp_max = np.log(temp_min), np.log(temp_max)
        rigid_prob = config.rigid_prob if config.include_rigid else 0.0

    for matrix_idx in top_matrix_ids:
        for _ in range(samples_per_matrix):
            seed = int(rng.integers(0, 2**31))
            if use_continuous:
                if rng.random() < rigid_prob:
                    temp = np.inf
                else:
                    log_temp = rng.uniform(log_temp_min, log_temp_max)
                    temp = np.exp(log_temp)
            else:
                temp_options = config.temperatures.copy()
                if config.include_rigid:
                    temp_options.append(np.inf)
                temp = temp_options[rng.integers(0, len(temp_options))]
            phase2_metadata.append(SampleMetadata(
                matrix_idx=matrix_idx,
                temperature=temp,
                random_seed=seed,
            ))

    if verbose:
        print(f"  Created {len(phase2_metadata)} sampling tasks...")

    if use_portrait:
        if verbose:
            print(f"  Computing portraits for phase 2 samples...")
        phase2_portraits = compute_portraits_chunked(
            prob_loader, phase2_metadata, chunk_size, config.n_jobs, verbose
        )
        phase2_distributions = None
        if SPEED_OPTIM:
            if verbose:
                print(f"    Preprocessing portraits for fast divergence...")
            phase2_distributions, _ = preprocess_portraits_for_divergence(
                phase2_portraits, verbose=False
            )
    else:
        if verbose:
            print(f"  Computing embeddings for phase 2 samples...")
        phase2_embeddings = compute_embeddings_chunked(
            prob_loader, phase2_metadata, config.metric,
            config.orca_path, chunk_size, config.n_jobs, verbose
        )
        if verbose:
            print(f"    Embedding shape: {phase2_embeddings.shape}")

    if verbose:
        print(f"\n{'=' * 40}")
        print("FINAL SELECTION")
        print(f"{'=' * 40}")
        print(f"  Running greedy selection for final {k_final} graphs...")

    if use_portrait:
        if config.start_mode == "dissimilar":
            idx1, idx2 = find_most_dissimilar_pair_portraits(
                phase2_portraits, config.n_jobs, verbose,
                precomputed_distributions=phase2_distributions
            )
            start_indices = [idx1, idx2]
        else:
            start_indices = [0]
        final_indices = greedy_maxsum_portraits(
            phase2_portraits,
            k_final,
            config.objective,
            start_indices=start_indices,
            n_jobs=config.n_jobs,
            verbose=verbose,
            precomputed_distributions=phase2_distributions,
        )
        del phase2_portraits
        if phase2_distributions is not None:
            del phase2_distributions
        gc.collect()
    else:
        if config.start_mode == "dissimilar":
            idx1, idx2 = find_most_dissimilar_pair_embeddings(phase2_embeddings, verbose)
            start_indices = [idx1, idx2]
        else:
            start_indices = [0]
        final_indices = greedy_maxsum_embeddings(
            phase2_embeddings,
            k_final,
            config.objective,
            start_indices=start_indices,
            verbose=verbose,
        )
        del phase2_embeddings
        gc.collect()

    if verbose:
        print(f"\n{'=' * 40}")
        print("RECONSTRUCTING SELECTED GRAPHS")
        print(f"{'=' * 40}")
        print(f"  Regenerating {len(final_indices)} selected graphs...")

    selected_metadata = [phase2_metadata[i] for i in final_indices]
    final_graphs = []
    final_temps = []
    final_sources = []

    for meta in tqdm(selected_metadata, desc="    Reconstructing", disable=not verbose):
        adj = meta.sample(prob_loader[meta.matrix_idx])
        final_graphs.append(adj_to_nx(adj))
        final_temps.append(meta.temperature)
        final_sources.append(meta.matrix_idx)

    if verbose:
        print(f"\n{'=' * 40}")
        print("EVALUATION")
        print(f"{'=' * 40}")

    final_diversity, final_energy = evaluate_selection(
        final_graphs, config.metric, config.orca_path, config.n_jobs
    )

    final_temp_contributions = defaultdict(int)
    use_continuous = config.use_continuous_temp

    for t in final_temps:
        if t == np.inf:
            final_temp_contributions[np.inf] += 1
        elif use_continuous:
            if t < 0.1:
                final_temp_contributions["<0.1"] += 1
            elif t < 0.5:
                final_temp_contributions["0.1-0.5"] += 1
            elif t < 1.0:
                final_temp_contributions["0.5-1.0"] += 1
            elif t < 2.0:
                final_temp_contributions["1.0-2.0"] += 1
            else:
                final_temp_contributions["2.0-5.0"] += 1
        else:
            final_temp_contributions[t] += 1

    unique_sources = len(set(final_sources))
    edges = [g.number_of_edges() for g in final_graphs]

    result = SelectionResult(
        graphs=final_graphs,
        diversity=final_diversity,
        energy=final_energy,
        temp_contributions=dict(final_temp_contributions),
        edge_mean=float(np.mean(edges)),
        edge_std=float(np.std(edges)),
        edge_min=int(min(edges)),
        edge_max=int(max(edges)),
        pool_size=len(phase2_metadata),
        selected_indices=final_indices,
    )

    if verbose:
        print(f"\nFinal diversity: {final_diversity:.6f}")
        print(f"Final energy: {final_energy:.6f}")
        print(f"Unique source matrices in final selection: {unique_sources}/{k_final}")
        print(f"\nTemperature contributions:")
        if use_continuous:
            bucket_order = ["<0.1", "0.1-0.5", "0.5-1.0", "1.0-2.0", "2.0-5.0", np.inf]
            for bucket in bucket_order:
                if bucket in final_temp_contributions:
                    label = "rigid" if bucket == np.inf else bucket
                    print(f"  {label}: {final_temp_contributions[bucket]}")
        else:
            for t in sorted(
                final_temp_contributions.keys(),
                key=lambda x: (x == np.inf, x if x != np.inf else 0),
            ):
                label = "rigid" if t == np.inf else f"{t:.2f}"
                print(f"  {label}: {final_temp_contributions[t]}")
        print(f"\nEdge statistics:")
        print(f"  Mean: {result.edge_mean:.1f}, Std: {result.edge_std:.1f}")
        print(f"  Range: [{result.edge_min}, {result.edge_max}]")

    return result


if __name__ == "__main__":
    print("This module is meant to be imported, not run directly.")
    print("Use run_large_graphs.py for large graph experiments.")
