"""
Baseline Graph Generators for Diversity Comparison.

This script generates two baseline sets:
1. ER-mix: Erdős-Rényi graphs with varying edge probabilities
2. Generator-mix: Mixture of different random graph models

Usage:
    # Single run
    python run_baselines.py --n_vertices 16
    python run_baselines.py --n_vertices 64 --k 100
    
    # Run all node counts (N=16,32,64,128,256,512,1024,2048 with k=100)
    python run_baselines.py --all-nodes
    
    # Run all k values (k=1024,2048,4096,8192 with N=16)
    python run_baselines.py --all-ks
    
    # Use memory-efficient evaluation (recommended for large N or k)
    python run_baselines.py --n_vertices 1024 --large-scale-optim
    python run_baselines.py --all-nodes --large-scale-optim
    python run_baselines.py --all-ks --large-scale-optim
"""

import numpy as np
import networkx as nx
import pickle
import json
import argparse
import os
import gc
from pathlib import Path
from dataclasses import dataclass
from joblib import Parallel, delayed
from tqdm import tqdm

from evaluate.distances import pairwise_distance_matrix, _netlsd_signature, _gcm_orca_like_theirs
from evaluate.diversity import average_diversity, summarize

from evaluate.distances_optim import (
    compute_portrait_single,
    compute_portraits_batch,
    preprocess_portraits_for_divergence,
    portrait_divergence_fast,
)


# Output directory
OUTPUT_DIR = Path("data/baselines")

# ORCA path for GCD computation
ORCA_PATH = "evaluate/orca/orca"

# Number of parallel jobs for non-chunked operations
N_JOBS = -1

# Random seed
SEED = 42

# Default number of graphs to generate
K_DEFAULT = 100

# Node counts to run (for --all-nodes mode)
ALL_NODE_COUNTS = [16, 32, 64, 128, 256, 512, 1024, 2048]

ALL_K_VALUES = [1024, 2048, 4096, 8192]
K_SCALING_N_VERTICES = 16



METRICS_TO_COMPUTE = ["gcd", "netlsd_heat", "netlsd_wave", "portrait_div"]

import multiprocessing
_CPU_COUNT = multiprocessing.cpu_count()


GCD_CHUNK_SIZE = min(2 * _CPU_COUNT, 32)  

GCD_PARALLEL_WORKERS = _CPU_COUNT

PORTRAIT_CHUNK_SIZE = min(4 * _CPU_COUNT, 64)



ER_PROBABILITIES = [1/16, 1/8, 1/4, 1/2, 3/4, 7/8, 15/16]



PA_CONFIGS = []
for m in [1, 2, 4]:
    for alpha in [m/2, m, 2*m]:
        PA_CONFIGS.append({"m": m, "alpha": alpha})

HK_CONFIGS = []
for m in [2, 4]:
    for p in [0.5, 1.0]:
        HK_CONFIGS.append({"m": m, "p": p})

PL_CONFIGS = [{"gamma": g} for g in [2, 2.5, 3, 4]]

SBM_CONFIGS = []
S_values = [1/16, 1/8, 1/4, 1/2]
for s in S_values:
    SBM_CONFIGS.append({"r": 2, "p": 2*s, "q": s})
    SBM_CONFIGS.append({"r": 2, "p": s, "q": 2*s})
SBM_CONFIGS.append({"r": 3, "p": 1/2, "q": 1/4})
SBM_CONFIGS.append({"r": 3, "p": 1/5, "q": 2/5})



def compute_embedding_single_cpu(G: nx.Graph, metric: str, orca_path: str) -> np.ndarray:
    if metric == "netlsd_heat":
        return _netlsd_signature(G, "heat")
    elif metric == "netlsd_wave":
        return _netlsd_signature(G, "wave")
    elif metric == "gcd":
        orca_prefix = os.path.dirname(os.path.abspath(orca_path))
        gcm = _gcm_orca_like_theirs(G, orca_prefix=orca_prefix, graphlet_size=4)
        return gcm[np.triu_indices_from(gcm, k=1)]
    elif metric == "portrait_div":
        return compute_portrait_single(G)
    else:
        raise ValueError(f"Unknown metric: {metric}")


def _check_gpu_available():
    import torch
    if not torch.cuda.is_available():
        raise RuntimeError(
            "GPU required for --large-scale-optim but CUDA is not available. "
            "Please run on a machine with a CUDA-capable GPU or remove the --large-scale-optim flag."
        )


def _get_torch_device():
    import torch
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")


def compute_gcd_embedding_single(G: nx.Graph, orca_path: str) -> np.ndarray:
    orca_prefix = os.path.dirname(os.path.abspath(orca_path))
    gcm = _gcm_orca_like_theirs(G, orca_prefix=orca_prefix, graphlet_size=4)
    return gcm[np.triu_indices_from(gcm, k=1)]


def compute_gcd_embeddings_parallel(
    graphs: list[nx.Graph], 
    orca_path: str,
    n_workers: int = GCD_PARALLEL_WORKERS,
) -> list[np.ndarray]:

    orca_prefix = os.path.dirname(os.path.abspath(orca_path))
    
    embeddings = Parallel(n_jobs=n_workers)(
        delayed(compute_gcd_embedding_single)(G, orca_path) 
        for G in graphs
    )
    
    return embeddings


def compute_portraits_parallel(
    graphs: list[nx.Graph],
    n_workers: int = _CPU_COUNT,
) -> list[np.ndarray]:

    portraits = Parallel(n_jobs=n_workers, prefer="threads")(
        delayed(compute_portrait_single)(G) for G in graphs
    )
    return portraits


def compute_embedding_single_gpu(G: nx.Graph, metric: str, orca_path: str, device) -> np.ndarray:

    import torch
    from evaluate.distances_gpu import netlsd_heat, netlsd_wave, nx_to_adjacency_tensor
    
    if metric == "netlsd_heat":
        A = nx_to_adjacency_tensor(G, device)
        embedding = netlsd_heat(A, G)  
        return embedding.cpu().numpy()
    elif metric == "netlsd_wave":
        A = nx_to_adjacency_tensor(G, device)
        embedding = netlsd_wave(A, G) 
        return embedding.cpu().numpy()
    elif metric == "gcd":
        orca_prefix = os.path.dirname(os.path.abspath(orca_path))
        gcm = _gcm_orca_like_theirs(G, orca_prefix=orca_prefix, graphlet_size=4)
        return gcm[np.triu_indices_from(gcm, k=1)]
    elif metric == "portrait_div":
        return compute_portrait_single(G)
    else:
        raise ValueError(f"Unknown metric: {metric}")


def compute_distance_matrix_from_embeddings(embeddings: list[np.ndarray]) -> np.ndarray:
    k = len(embeddings)
    D = np.zeros((k, k), dtype=float)
    for i in range(k):
        for j in range(i + 1, k):
            d = np.linalg.norm(embeddings[i] - embeddings[j])
            D[i, j] = D[j, i] = d
    return D


def compute_distance_matrix_from_portraits(
    portraits: list[np.ndarray],
    verbose: bool = True,
) -> np.ndarray:
    k = len(portraits)
    
    if verbose:
        print(f"    Preprocessing {k} portraits for fast divergence...")
    
    distributions, _ = preprocess_portraits_for_divergence(portraits, verbose=False)
    
    if verbose:
        print(f"    Computing {k*(k-1)//2} pairwise divergences...")
    
    D = np.zeros((k, k), dtype=float)
    for i in range(k):
        for j in range(i + 1, k):
            d = portrait_divergence_fast(distributions[i], distributions[j])
            D[i, j] = D[j, i] = d
    
    return D


def generate_er(n: int, p: float, seed: int) -> nx.Graph:
    return nx.erdos_renyi_graph(n, p, seed=seed)


def generate_preferential_attachment(n: int, m: int, alpha: float, seed: int) -> nx.Graph:

    rng = np.random.default_rng(seed)
    
    G = nx.complete_graph(m + 1)
    
    for new_node in range(m + 1, n):
        G.add_node(new_node)
        
        degrees = np.array([G.degree(i) for i in range(new_node)])
        probs = degrees + alpha
        probs = probs / probs.sum()
        
        targets = rng.choice(new_node, size=min(m, new_node), replace=False, p=probs)
        
        for target in targets:
            G.add_edge(new_node, target)
    
    return G


def generate_holme_kim(n: int, m: int, p: float, seed: int) -> nx.Graph:
    rng = np.random.default_rng(seed)
    
    G = nx.complete_graph(m + 1)
    
    for new_node in range(m + 1, n):
        G.add_node(new_node)
        targets = set()
        
        degrees = np.array([G.degree(i) for i in range(new_node)])
        if degrees.sum() == 0:
            probs = np.ones(new_node) / new_node
        else:
            probs = degrees / degrees.sum()
        
        first_target = rng.choice(new_node, p=probs)
        targets.add(first_target)
        G.add_edge(new_node, first_target)
        
        for _ in range(m - 1):
            if len(targets) >= new_node:
                break
                
            if rng.random() < p:
                last_target = list(targets)[-1]
                neighbors = [nb for nb in G.neighbors(last_target) if nb not in targets and nb != new_node]
                
                if neighbors:
                    next_target = rng.choice(neighbors)
                    targets.add(next_target)
                    G.add_edge(new_node, next_target)
                    continue
            
            available = [i for i in range(new_node) if i not in targets]
            if not available:
                break
                
            degrees = np.array([G.degree(i) for i in available])
            if degrees.sum() == 0:
                probs = np.ones(len(available)) / len(available)
            else:
                probs = degrees / degrees.sum()
            
            next_target = available[rng.choice(len(available), p=probs)]
            targets.add(next_target)
            G.add_edge(new_node, next_target)
    
    return G


def generate_power_law_cluster(n: int, gamma: float, seed: int) -> nx.Graph:
    rng = np.random.default_rng(seed)
    
    weights = (rng.pareto(gamma - 1, size=n) + 1)

    target_edges = n * (n - 1) / 4
    scale = np.sqrt(2 * target_edges / (weights.sum() ** 2 / weights.mean()))
    weights = weights * scale
    
    max_weight = np.sqrt(weights.sum())
    weights = np.minimum(weights, max_weight)
    
    G = nx.Graph()
    G.add_nodes_from(range(n))
    
    weight_sum = weights.sum()
    
    for i in range(n):
        for j in range(i + 1, n):
            p_ij = min(1.0, weights[i] * weights[j] / weight_sum)
            if rng.random() < p_ij:
                G.add_edge(i, j)
    
    return G


def generate_sbm(n: int, r: int, p: float, q: float, seed: int) -> nx.Graph:
    rng = np.random.default_rng(seed)
    
    block_sizes = [n // r] * r
    for i in range(n % r):
        block_sizes[i] += 1
    
    sizes = block_sizes
    
    probs = np.full((r, r), q)
    np.fill_diagonal(probs, p)
    
    G = nx.stochastic_block_model(sizes, probs, seed=int(rng.integers(0, 2**31)))
    
    for node in G.nodes():
        G.nodes[node].clear()
    
    return G

def generate_er_mix(n: int, k: int, seed: int = SEED) -> tuple[list[nx.Graph], list[dict]]:

    rng = np.random.default_rng(seed)
    
    graphs = []
    metadata = []
    
    n_probs = len(ER_PROBABILITIES)
    graphs_per_prob = k // n_probs
    remainder = k % n_probs
    
    graph_idx = 0
    for i, p in enumerate(ER_PROBABILITIES):
        n_graphs = graphs_per_prob + (1 if i < remainder else 0)
        
        for j in range(n_graphs):
            g_seed = int(rng.integers(0, 2**31))
            G = generate_er(n, p, g_seed)
            graphs.append(G)
            metadata.append({
                "generator": "erdos_renyi",
                "params": {"p": p},
                "seed": g_seed,
                "index": graph_idx,
            })
            graph_idx += 1
    
    indices = rng.permutation(len(graphs))
    graphs = [graphs[i] for i in indices]
    metadata = [metadata[i] for i in indices]
    
    return graphs, metadata


def generate_er_mix_iter(n: int, k: int, seed: int = SEED):

    rng = np.random.default_rng(seed)
    
    n_probs = len(ER_PROBABILITIES)
    graphs_per_prob = k // n_probs
    remainder = k % n_probs
    
    assignments = []
    graph_idx = 0
    for i, p in enumerate(ER_PROBABILITIES):
        n_graphs = graphs_per_prob + (1 if i < remainder else 0)
        for j in range(n_graphs):
            g_seed = int(rng.integers(0, 2**31))
            assignments.append((p, g_seed, graph_idx))
            graph_idx += 1
    
    indices = rng.permutation(len(assignments))
    assignments = [assignments[i] for i in indices]
    
    for p, g_seed, idx in assignments:
        G = generate_er(n, p, g_seed)
        metadata = {
            "generator": "erdos_renyi",
            "params": {"p": p},
            "seed": g_seed,
            "index": idx,
        }
        yield G, metadata


def generate_generator_mix(n: int, k: int, seed: int = SEED) -> tuple[list[nx.Graph], list[dict]]:

    rng = np.random.default_rng(seed)
    
    all_configs = []
    
    # ER configs
    for p in ER_PROBABILITIES:
        all_configs.append(("erdos_renyi", {"p": p}))
    
    # PA configs
    for cfg in PA_CONFIGS:
        all_configs.append(("preferential_attachment", cfg.copy()))
    
    # Holme-Kim configs
    for cfg in HK_CONFIGS:
        all_configs.append(("holme_kim", cfg.copy()))
    
    # Power-law configs
    for cfg in PL_CONFIGS:
        all_configs.append(("power_law", cfg.copy()))
    
    # SBM configs
    for cfg in SBM_CONFIGS:
        all_configs.append(("sbm", cfg.copy()))
    
    print(f"  Total generator configurations: {len(all_configs)}")
    
    graphs = []
    metadata = []
    
    for i in range(k):
        config_idx = rng.integers(0, len(all_configs))
        gen_name, params = all_configs[config_idx]
        
        g_seed = int(rng.integers(0, 2**31))
        
        if gen_name == "erdos_renyi":
            G = generate_er(n, params["p"], g_seed)
        elif gen_name == "preferential_attachment":
            G = generate_preferential_attachment(n, params["m"], params["alpha"], g_seed)
        elif gen_name == "holme_kim":
            G = generate_holme_kim(n, params["m"], params["p"], g_seed)
        elif gen_name == "power_law":
            G = generate_power_law_cluster(n, params["gamma"], g_seed)
        elif gen_name == "sbm":
            G = generate_sbm(n, params["r"], params["p"], params["q"], g_seed)
        else:
            raise ValueError(f"Unknown generator: {gen_name}")
        
        graphs.append(G)
        metadata.append({
            "generator": gen_name,
            "params": params,
            "seed": g_seed,
            "index": i,
        })
    
    return graphs, metadata


def generate_generator_mix_iter(n: int, k: int, seed: int = SEED):
    rng = np.random.default_rng(seed)
    
    all_configs = []
    
    for p in ER_PROBABILITIES:
        all_configs.append(("erdos_renyi", {"p": p}))
    for cfg in PA_CONFIGS:
        all_configs.append(("preferential_attachment", cfg.copy()))
    for cfg in HK_CONFIGS:
        all_configs.append(("holme_kim", cfg.copy()))
    for cfg in PL_CONFIGS:
        all_configs.append(("power_law", cfg.copy()))
    for cfg in SBM_CONFIGS:
        all_configs.append(("sbm", cfg.copy()))
    
    for i in range(k):
        config_idx = rng.integers(0, len(all_configs))
        gen_name, params = all_configs[config_idx]
        g_seed = int(rng.integers(0, 2**31))
        
        if gen_name == "erdos_renyi":
            G = generate_er(n, params["p"], g_seed)
        elif gen_name == "preferential_attachment":
            G = generate_preferential_attachment(n, params["m"], params["alpha"], g_seed)
        elif gen_name == "holme_kim":
            G = generate_holme_kim(n, params["m"], params["p"], g_seed)
        elif gen_name == "power_law":
            G = generate_power_law_cluster(n, params["gamma"], g_seed)
        elif gen_name == "sbm":
            G = generate_sbm(n, params["r"], params["p"], params["q"], g_seed)
        else:
            raise ValueError(f"Unknown generator: {gen_name}")
        
        metadata = {
            "generator": gen_name,
            "params": params,
            "seed": g_seed,
            "index": i,
        }
        yield G, metadata

def evaluate_graphs(
    graphs: list[nx.Graph],
    orca_path: str = ORCA_PATH,
    n_jobs: int = N_JOBS,
) -> dict[str, dict[str, float]]:

    results = {}
    
    workers = Parallel(n_jobs=n_jobs)
    
    for metric in METRICS_TO_COMPUTE:
        print(f"  Computing {metric}...")
        
        if metric == "portrait_div":
            portraits = compute_portraits_batch(graphs, n_jobs=n_jobs, verbose=False)
            D = compute_distance_matrix_from_portraits(portraits, verbose=True)
        else:
            D = pairwise_distance_matrix(
                graphs, 
                distance=metric, 
                workers=workers, 
                orca_path=orca_path
            )
        
        diversity = average_diversity(D)
        scores = summarize(D)
        
        results[metric] = {
            "diversity": diversity,
            "energy": scores.energy,
        }
        
        print(f"    diversity={diversity:.6f}, energy={scores.energy:.6f}")
    
    return results


def evaluate_graphs_large_scale(
    generator_func,
    n_vertices: int,
    k: int,
    seed: int,
    orca_path: str = ORCA_PATH,
    use_gpu: bool = True,
) -> dict[str, dict[str, float]]:

    import torch
    
    if use_gpu:
        _check_gpu_available()
        device = _get_torch_device()
        print(f"  Using device: {device}")
    else:
        device = torch.device("cpu")
        print(f"  Using device: CPU")
    
    metrics = METRICS_TO_COMPUTE
    has_gcd = "gcd" in metrics
    has_portrait = "portrait_div" in metrics
    netlsd_metrics = [m for m in metrics if m not in ("gcd", "portrait_div")]
    
    print(f"  Generating {k} graphs and computing embeddings...")
    print(f"  Metrics: {metrics}")
    if has_gcd:
        print(f"  GCD chunk size: {GCD_CHUNK_SIZE}, parallel workers: {GCD_PARALLEL_WORKERS}")
    if has_portrait:
        print(f"  Portrait chunk size: {PORTRAIT_CHUNK_SIZE}")
    
    embeddings = {metric: [] for metric in netlsd_metrics}
    if has_gcd:
        embeddings["gcd"] = []
    if has_portrait:
        portraits = []  
    
    if has_gcd:
        gcd_chunk = []  
    
    if has_portrait:
        portrait_chunk = []
    
    from evaluate.distances_gpu import netlsd_heat, netlsd_wave, nx_to_adjacency_tensor
    
    pbar = tqdm(generator_func(n_vertices, k, seed), total=k, desc="  Generating")
    for graph, metadata in pbar:
        for metric in netlsd_metrics:
            if metric == "netlsd_heat":
                A = nx_to_adjacency_tensor(graph, device)
                emb = netlsd_heat(A, graph).cpu().numpy()
            elif metric == "netlsd_wave":
                A = nx_to_adjacency_tensor(graph, device)
                emb = netlsd_wave(A, graph).cpu().numpy()
            embeddings[metric].append(emb)
        
        if has_gcd:
            gcd_chunk.append(graph)
            
            if len(gcd_chunk) >= GCD_CHUNK_SIZE:
                pbar.set_description(f"  Processing GCD chunk ({len(gcd_chunk)} graphs)")
                chunk_embeddings = compute_gcd_embeddings_parallel(
                    gcd_chunk, orca_path, n_workers=GCD_PARALLEL_WORKERS
                )
                embeddings["gcd"].extend(chunk_embeddings)
                
                gcd_chunk = []
                gc.collect()
                pbar.set_description("  Generating")
        
        if has_portrait:
            portrait_chunk.append(graph)
            
            if len(portrait_chunk) >= PORTRAIT_CHUNK_SIZE:
                pbar.set_description(f"  Computing portraits ({len(portrait_chunk)} graphs)")
                chunk_portraits = compute_portraits_parallel(
                    portrait_chunk, n_workers=_CPU_COUNT
                )
                portraits.extend(chunk_portraits)
                
                portrait_chunk = []
                gc.collect()
                pbar.set_description("  Generating")
        
        if not has_gcd and not has_portrait:
            del graph
        del metadata
    
    if has_gcd and len(gcd_chunk) > 0:
        print(f"  Processing final GCD chunk ({len(gcd_chunk)} graphs)...")
        chunk_embeddings = compute_gcd_embeddings_parallel(
            gcd_chunk, orca_path, n_workers=GCD_PARALLEL_WORKERS
        )
        embeddings["gcd"].extend(chunk_embeddings)
        gcd_chunk = []
    
    if has_portrait and len(portrait_chunk) > 0:
        print(f"  Processing final portrait chunk ({len(portrait_chunk)} graphs)...")
        chunk_portraits = compute_portraits_parallel(
            portrait_chunk, n_workers=_CPU_COUNT
        )
        portraits.extend(chunk_portraits)
        portrait_chunk = []
    
    gc.collect()
    if use_gpu:
        torch.cuda.empty_cache()
    
    results = {}
    
    for metric in list(embeddings.keys()):
        print(f"  Computing {metric} distances from {len(embeddings[metric])} embeddings...")
        D = compute_distance_matrix_from_embeddings(embeddings[metric])
        
        diversity = average_diversity(D)
        scores = summarize(D)
        
        results[metric] = {
            "diversity": diversity,
            "energy": scores.energy,
        }
        
        print(f"    diversity={diversity:.6f}, energy={scores.energy:.6f}")
        
        del D
        del embeddings[metric]
        gc.collect()
    
    if has_portrait:
        print(f"  Computing portrait_div distances from {len(portraits)} portraits...")
        D = compute_distance_matrix_from_portraits(portraits, verbose=True)
        
        diversity = average_diversity(D)
        scores = summarize(D)
        
        results["portrait_div"] = {
            "diversity": diversity,
            "energy": scores.energy,
        }
        
        print(f"    diversity={diversity:.6f}, energy={scores.energy:.6f}")
        
        del D
        del portraits
        gc.collect()
    
    return results


def run_baselines(n_vertices: int, k: int = K_DEFAULT, seed: int = SEED, large_scale_optim: bool = False):
    
    print(f"\n{'=' * 70}")
    print(f"GENERATING BASELINES FOR N={n_vertices}")
    print(f"{'=' * 70}")
    print(f"  k={k} graphs per baseline")
    print(f"  seed={seed}")
    print(f"  large_scale_optim={large_scale_optim}")
    print(f"  metrics={METRICS_TO_COMPUTE}")
    
    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
    
    k_suffix = f"_k{k}" if k != K_DEFAULT else ""
    
    print(f"\n{'=' * 40}")
    print("ER-MIX BASELINE")
    print(f"{'=' * 40}")
    print(f"  Probabilities: {[f'{p:.4f}' for p in ER_PROBABILITIES]}")
    
    if large_scale_optim:
        er_results = evaluate_graphs_large_scale(
            generate_er_mix_iter,
            n_vertices,
            k,
            seed,
            ORCA_PATH,
            use_gpu=True,
        )
        er_meta = []  # No metadata in large scale mode
    else:
        er_graphs, er_meta = generate_er_mix(n_vertices, k, seed)
        print(f"  Generated {len(er_graphs)} graphs")
        
        edges = [g.number_of_edges() for g in er_graphs]
        print(f"  Edge stats: mean={np.mean(edges):.1f}, std={np.std(edges):.1f}, "
              f"range=[{min(edges)}, {max(edges)}]")
        
        print(f"\n  Evaluating...")
        er_results = evaluate_graphs(er_graphs)
        
        er_graphs_path = OUTPUT_DIR / f"baseline_ER_mix_{n_vertices}{k_suffix}.pkl"
        with open(er_graphs_path, "wb") as f:
            pickle.dump(er_graphs, f)
        print(f"\n  Saved graphs to {er_graphs_path}")
        
        del er_graphs
        gc.collect()
    
    er_meta_path = OUTPUT_DIR / f"baseline_ER_mix_{n_vertices}{k_suffix}_meta.json"
    er_full_meta = {
        "n_vertices": n_vertices,
        "k": k,
        "seed": seed,
        "probabilities": ER_PROBABILITIES,
        "results": er_results,
        "large_scale_optim": large_scale_optim,
        "metrics_computed": METRICS_TO_COMPUTE,
    }
    if not large_scale_optim:
        er_full_meta["graphs_metadata"] = er_meta
    with open(er_meta_path, "w") as f:
        json.dump(er_full_meta, f, indent=2)
    

    print(f"\n{'=' * 40}")
    print("GENERATOR-MIX BASELINE")
    print(f"{'=' * 40}")
    
    if large_scale_optim:
        gen_results = evaluate_graphs_large_scale(
            generate_generator_mix_iter,
            n_vertices,
            k,
            seed,
            ORCA_PATH,
            use_gpu=True,
        )
        gen_meta = []
    else:
        gen_graphs, gen_meta = generate_generator_mix(n_vertices, k, seed)
        print(f"  Generated {len(gen_graphs)} graphs")
        
        gen_counts = {}
        for m in gen_meta:
            gen_name = m["generator"]
            gen_counts[gen_name] = gen_counts.get(gen_name, 0) + 1
        print(f"  Generator distribution: {gen_counts}")
        
        edges = [g.number_of_edges() for g in gen_graphs]
        print(f"  Edge stats: mean={np.mean(edges):.1f}, std={np.std(edges):.1f}, "
              f"range=[{min(edges)}, {max(edges)}]")
        
        print(f"\n  Evaluating...")
        gen_results = evaluate_graphs(gen_graphs)
        
        gen_graphs_path = OUTPUT_DIR / f"baseline_generator_mix_{n_vertices}{k_suffix}.pkl"
        with open(gen_graphs_path, "wb") as f:
            pickle.dump(gen_graphs, f)
        print(f"\n  Saved graphs to {gen_graphs_path}")
        
        del gen_graphs
        gc.collect()
    
    gen_meta_path = OUTPUT_DIR / f"baseline_generator_mix_{n_vertices}{k_suffix}_meta.json"
    gen_full_meta = {
        "n_vertices": n_vertices,
        "k": k,
        "seed": seed,
        "generators": {
            "erdos_renyi": {"configs": len(ER_PROBABILITIES)},
            "preferential_attachment": {"configs": len(PA_CONFIGS)},
            "holme_kim": {"configs": len(HK_CONFIGS)},
            "power_law": {"configs": len(PL_CONFIGS)},
            "sbm": {"configs": len(SBM_CONFIGS)},
        },
        "results": gen_results,
        "large_scale_optim": large_scale_optim,
        "metrics_computed": METRICS_TO_COMPUTE,
    }
    if not large_scale_optim:
        gen_full_meta["graphs_metadata"] = [{k: (v if not isinstance(v, float) else round(v, 6)) 
                            for k, v in m.items()} for m in gen_meta]
    with open(gen_meta_path, "w") as f:
        json.dump(gen_full_meta, f, indent=2)
    

    print(f"\n{'=' * 70}")
    print(f"SUMMARY FOR N={n_vertices}")
    print(f"{'=' * 70}")
    
    print(f"\n{'Metric':<15} {'ER-mix div':>12} {'Gen-mix div':>12}")
    print("-" * 42)
    for metric in METRICS_TO_COMPUTE:
        er_div = er_results[metric]["diversity"]
        gen_div = gen_results[metric]["diversity"]
        print(f"{metric:<15} {er_div:>12.6f} {gen_div:>12.6f}")
    
    return {
        "er_mix": {"results": er_results},
        "generator_mix": {"results": gen_results},
    }


def main():
    parser = argparse.ArgumentParser(description="Generate baseline graph sets")
    parser.add_argument("--n_vertices", type=int, help="Number of vertices")
    parser.add_argument("--k", type=int, default=K_DEFAULT, help="Number of graphs to generate")
    parser.add_argument("--seed", type=int, default=SEED, help="Random seed")
    parser.add_argument("--all-nodes", action="store_true", 
                        help=f"Run all node counts {ALL_NODE_COUNTS} with k={K_DEFAULT}")
    parser.add_argument("--all-ks", action="store_true",
                        help=f"Run all k values {ALL_K_VALUES} with N={K_SCALING_N_VERTICES}")
    parser.add_argument("--large-scale-optim", action="store_true",
                        help="Use memory-efficient evaluation with GPU acceleration (recommended for large N or k)")
    args = parser.parse_args()
    
    if args.all_nodes:
        print(f"Running all node counts: {ALL_NODE_COUNTS}")
        print(f"Fixed k={K_DEFAULT}")
        print(f"Large-scale optimization: {args.large_scale_optim}")
        for n in ALL_NODE_COUNTS:
            run_baselines(n, K_DEFAULT, args.seed, args.large_scale_optim)
    elif args.all_ks:
        print(f"Running all k values: {ALL_K_VALUES}")
        print(f"Fixed N={K_SCALING_N_VERTICES}")
        print(f"Large-scale optimization: {args.large_scale_optim}")
        for k in ALL_K_VALUES:
            run_baselines(K_SCALING_N_VERTICES, k, args.seed, args.large_scale_optim)
    elif args.n_vertices:
        run_baselines(args.n_vertices, args.k, args.seed, args.large_scale_optim)
    else:
        parser.print_help()


if __name__ == "__main__":
    main()
