import pickle
import numpy as np
import networkx as nx
from pathlib import Path
from tqdm import tqdm, trange
from joblib import Parallel, delayed
import os
import matplotlib.pyplot as plt
from typing import Literal
from collections import defaultdict
from dataclasses import dataclass, field
import heapq

SPEED_OPTIM = True

from evaluate.distances_optim import (
    pairwise_distance_matrix,
    DistanceName,
    _netlsd_signature,
    _gcm_orca_like_theirs,
    compute_gcd_embeddings_parallel,
    _netlsd_signature_batch,
    _get_device,
    set_device,
    compute_portraits_batch,
    preprocess_portraits_for_divergence,
    portrait_divergence_fast,
)

from evaluate.diversity import average_diversity, summarize
from netrd.distance.portrait_divergence import portrait_divergence, portrait

DEFAULT_TEMPERATURES = [
    0.01,
    0.1,
    0.3,
    0.5,
    0.7,
    1.0,
    1.5,
    3.0,
    5.0,
]

DEFAULT_ORCA_PATH = "evaluate/orca/orca"
EPSILON = 1e-6


@dataclass
class SelectionConfig:
    k_select: int = 100
    metric: DistanceName = "netlsd_wave"
    objective: Literal["energy", "average"] = "average"
    strategy: Literal["default", "iterative_survival"] = "default"
    total_budget: int | None = None
    temperatures: list[float] = field(default_factory=lambda: DEFAULT_TEMPERATURES.copy())
    include_rigid: bool = True
    start_mode: Literal["anchors", "random", "dissimilar"] = "dissimilar"
    sampling_budget: int = 100000
    n_top_matrices: int = 1000
    use_continuous_temp: bool = True
    temp_range: tuple[float, float] = (0.01, 5.0)
    rigid_prob: float = 0.1
    orca_path: str = DEFAULT_ORCA_PATH
    n_jobs: int = -1
    random_seed: int = 42
    verbose: bool = True


@dataclass
class SelectionResult:
    graphs: list[nx.Graph]
    diversity: float
    energy: float
    temp_contributions: dict[float, int]
    edge_mean: float
    edge_std: float
    edge_min: int
    edge_max: int
    pool_size: int
    selected_indices: list[int]


def to_numpy(x):
    if hasattr(x, "cpu"):
        return x.detach().cpu().numpy()
    if isinstance(x, list):
        return np.array(x)
    return x


def safe_logit(p):
    p = np.clip(p, EPSILON, 1.0 - EPSILON)
    return np.log(p / (1.0 - p))


def safe_sigmoid(x):
    return 1.0 / (1.0 + np.exp(-np.clip(x, -500, 500)))


def find_most_dissimilar_pair(
    embeddings: np.ndarray, verbose: bool = True
) -> tuple[int, int]:
    N = embeddings.shape[0]
    if verbose:
        print(f"  Finding most dissimilar pair (O(N) approximation over {N} graphs)...")
    centroid = embeddings.mean(axis=0)
    dists_to_centroid = np.linalg.norm(embeddings - centroid, axis=1)
    idx_a = np.argmax(dists_to_centroid)
    if verbose:
        print(f"    Point A (farthest from centroid): index {idx_a}, dist = {dists_to_centroid[idx_a]:.6f}")
    dists_to_a = np.linalg.norm(embeddings - embeddings[idx_a], axis=1)
    idx_b = np.argmax(dists_to_a)
    if verbose:
        print(f"    Point B (farthest from A): index {idx_b}, dist = {dists_to_a[idx_b]:.6f}")
    return idx_a, idx_b


def adj_to_nx(adj: np.ndarray) -> nx.Graph:
    rows, cols = np.where(adj > 0.5)
    edges = zip(rows.tolist(), cols.tolist())
    G = nx.Graph()
    G.add_nodes_from(range(adj.shape[0]))
    G.add_edges_from(edges)
    return G


def sample_bernoulli(probs: np.ndarray, temp: float = 1.0) -> np.ndarray:
    probs = to_numpy(probs).copy()
    probs = (probs + probs.T) / 2.0
    np.fill_diagonal(probs, 0.0)
    if abs(temp - 1.0) > 1e-4:
        logits = safe_logit(probs)
        scaled_logits = logits / temp
        probs = safe_sigmoid(scaled_logits)
    noise = np.random.rand(*probs.shape)
    adj = (noise < probs).astype(float)
    triu = np.triu(adj, k=1)
    adj = triu + triu.T
    return adj


def sample_rigid_topk(probs: np.ndarray) -> np.ndarray:
    probs = to_numpy(probs).copy()
    probs = (probs + probs.T) / 2.0
    np.fill_diagonal(probs, 0.0)
    mask = np.triu(np.ones_like(probs), k=1).astype(bool)
    p_upper = probs[mask]
    k = int(np.round(np.sum(p_upper)))
    if k == 0:
        return np.zeros_like(probs)
    flat_indices = np.argsort(p_upper)[-k:]
    adj_upper = np.zeros_like(p_upper)
    adj_upper[flat_indices] = 1.0
    adj = np.zeros_like(probs)
    adj[mask] = adj_upper
    adj = adj + adj.T
    return adj


def stratified_sample(
    prob_matrices: list[np.ndarray],
    temperatures: list[float],
    total_budget: int | None,
    include_rigid: bool = True,
    seed: int = 42,
    verbose: bool = True,
) -> tuple[list[nx.Graph], list[float], list[int]]:
    rng = np.random.default_rng(seed)
    n_matrices = len(prob_matrices)
    n_strategies = len(temperatures) + (1 if include_rigid else 0)

    if total_budget is None:
        budget_per_strategy = n_matrices
        if verbose:
            print(f"  Budget allocation: UNLIMITED")
            print(f"    Strategies: {n_strategies} ({len(temperatures)} temps + {'rigid' if include_rigid else 'no rigid'})")
            print(f"    Samples per strategy: {n_matrices} (all matrices)")
            print(f"    Total samples: {n_strategies * n_matrices}")
    else:
        budget_per_strategy = total_budget // n_strategies
        if verbose:
            print(f"  Budget allocation:")
            print(f"    Total budget: {total_budget}")
            print(f"    Strategies: {n_strategies} ({len(temperatures)} temps + {'rigid' if include_rigid else 'no rigid'})")
            print(f"    Per strategy: {budget_per_strategy}")

    graphs = []
    temps = []
    source_indices = []

    for temp in temperatures:
        if total_budget is None:
            indices_to_sample = list(range(n_matrices))
        else:
            if budget_per_strategy >= n_matrices:
                n_full_passes = budget_per_strategy // n_matrices
                remainder = budget_per_strategy % n_matrices
                indices_to_sample = list(range(n_matrices)) * n_full_passes
                if remainder > 0:
                    indices_to_sample += rng.choice(n_matrices, size=remainder, replace=False).tolist()
            else:
                indices_to_sample = rng.choice(n_matrices, size=budget_per_strategy, replace=False).tolist()

        for idx in indices_to_sample:
            adj = sample_bernoulli(prob_matrices[idx], temp)
            graphs.append(adj_to_nx(adj))
            temps.append(temp)
            source_indices.append(idx)

    if include_rigid:
        if total_budget is None:
            indices_to_sample = list(range(n_matrices))
        else:
            if budget_per_strategy >= n_matrices:
                indices_to_sample = list(range(n_matrices))[:budget_per_strategy]
            else:
                indices_to_sample = rng.choice(n_matrices, size=budget_per_strategy, replace=False).tolist()

        for idx in indices_to_sample:
            adj = sample_rigid_topk(prob_matrices[idx])
            graphs.append(adj_to_nx(adj))
            temps.append(np.inf)
            source_indices.append(idx)

    return graphs, temps, source_indices


def compute_portrait_safe(graph: nx.Graph) -> np.ndarray:
    N = graph.number_of_nodes()
    if N == 0:
        return np.array([[1.0]])
    try:
        dia = nx.diameter(graph)
    except nx.NetworkXError:
        dia = N

    B = np.zeros((dia + 1, N))
    max_path = 1
    adj = graph.adj

    for starting_node in graph.nodes():
        nodes_visited = {starting_node: 0}
        search_queue = [starting_node]
        d = 1

        while search_queue:
            next_depth = []
            extend = next_depth.extend
            for n in search_queue:
                l = [i for i in adj[n] if i not in nodes_visited]
                extend(l)
                for j in l:
                    nodes_visited[j] = d
            search_queue = next_depth
            d += 1

        node_distances = nodes_visited.values()
        if len(node_distances) == 0:
            continue
        max_node_distances = max(node_distances)
        curr_max_path = max_node_distances
        if curr_max_path > max_path:
            max_path = curr_max_path

        dict_distribution = dict.fromkeys(node_distances, 0)
        for d in node_distances:
            dict_distribution[d] += 1

        for shell, count in dict_distribution.items():
            if shell <= dia and count < N:
                B[shell][count] += 1

        max_shell = dia
        while max_shell > max_node_distances:
            B[max_shell][0] += 1
            max_shell -= 1

    return B[: max_path + 1, :]


def compute_portraits(
    graphs: list[nx.Graph],
    n_jobs: int = -1,
    verbose: bool = True,
) -> list[np.ndarray]:
    return compute_portraits_batch(graphs, n_jobs=n_jobs, verbose=verbose)


def compute_portrait_distance(B1: np.ndarray, B2: np.ndarray) -> float:
    return float(portrait_divergence(B1, B2))


def greedy_maxsum_incremental_portrait(
    portraits: list[np.ndarray],
    k: int,
    objective: Literal["energy", "average"],
    start_indices: list[int] | None = None,
    n_jobs: int = -1,
    verbose: bool = True,
    precomputed_distributions: list[np.ndarray] | None = None,
) -> list[int]:
    N = len(portraits)
    if k > N:
        raise ValueError(f"Cannot select {k} graphs from {N} candidates")

    use_fast_divergence = precomputed_distributions is not None

    if use_fast_divergence:
        def compute_divergence(idx: int, selected_idx: int) -> tuple[int, float]:
            d = portrait_divergence_fast(
                precomputed_distributions[idx],
                precomputed_distributions[selected_idx]
            )
            return idx, d
    else:
        def compute_divergence(idx: int, selected_idx: int) -> tuple[int, float]:
            d = compute_portrait_distance(portraits[idx], portraits[selected_idx])
            return idx, d

    cumulative_scores = np.zeros(N)
    available_mask = np.ones(N, dtype=bool)
    selected = []

    if start_indices is None:
        start_indices = [0]

    for start_idx in start_indices:
        selected.append(start_idx)
        available_mask[start_idx] = False
        available_indices = np.where(available_mask)[0]

        if use_fast_divergence:
            results = [compute_divergence(idx, start_idx) for idx in available_indices]
        else:
            with Parallel(n_jobs=n_jobs, prefer="threads") as parallel:
                results = parallel(
                    delayed(compute_divergence)(idx, start_idx)
                    for idx in available_indices
                )

        for idx, dist in results:
            if objective == "energy":
                cumulative_scores[idx] += -1.0 / (dist + EPSILON)
            else:
                cumulative_scores[idx] += dist

    cumulative_scores[~available_mask] = -np.inf
    remaining = k - len(selected)
    if remaining <= 0:
        return selected[:k]

    iterator = trange(remaining, desc="  Greedy selection") if verbose else range(remaining)

    for _ in iterator:
        best_idx = np.argmax(cumulative_scores)
        if cumulative_scores[best_idx] == -np.inf:
            break
        selected.append(best_idx)
        available_mask[best_idx] = False
        cumulative_scores[best_idx] = -np.inf
        available_indices = np.where(available_mask)[0]
        if len(available_indices) == 0:
            break

        if use_fast_divergence:
            results = [compute_divergence(idx, best_idx) for idx in available_indices]
        else:
            with Parallel(n_jobs=n_jobs, prefer="threads") as parallel:
                results = parallel(
                    delayed(compute_divergence)(idx, best_idx)
                    for idx in available_indices
                )

        for idx, dist in results:
            if objective == "energy":
                cumulative_scores[idx] += -1.0 / (dist + EPSILON)
            else:
                cumulative_scores[idx] += dist

    return selected


def find_most_dissimilar_pair_portrait(
    portraits: list[np.ndarray],
    n_jobs: int = -1,
    verbose: bool = True,
    precomputed_distributions: list[np.ndarray] | None = None,
) -> tuple[int, int]:
    N = len(portraits)
    if verbose:
        print(f"  Finding most dissimilar pair (O(N) approximation over {N} graphs)...")
    if N <= 2:
        return 0, min(1, N - 1)

    use_fast_divergence = precomputed_distributions is not None
    start_idx = 0

    if use_fast_divergence:
        dists_to_start = [
            portrait_divergence_fast(precomputed_distributions[start_idx], precomputed_distributions[i])
            for i in range(N)
        ]
    else:
        with Parallel(n_jobs=n_jobs, prefer="threads") as parallel:
            dists_to_start = parallel(
                delayed(compute_portrait_distance)(portraits[start_idx], portraits[i])
                for i in range(N)
            )

    idx_a = int(np.argmax(dists_to_start))
    if verbose:
        print(f"    Point A (farthest from start): index {idx_a}, dist = {dists_to_start[idx_a]:.6f}")

    if use_fast_divergence:
        dists_to_a = [
            portrait_divergence_fast(precomputed_distributions[idx_a], precomputed_distributions[i])
            for i in range(N)
        ]
    else:
        with Parallel(n_jobs=n_jobs, prefer="threads") as parallel:
            dists_to_a = parallel(
                delayed(compute_portrait_distance)(portraits[idx_a], portraits[i])
                for i in range(N)
            )

    idx_b = int(np.argmax(dists_to_a))
    if verbose:
        print(f"    Point B (farthest from A): index {idx_b}, dist = {dists_to_a[idx_b]:.6f}")
    return idx_a, idx_b


def compute_embeddings(
    graphs: list[nx.Graph],
    metric: DistanceName,
    orca_path: str,
    n_jobs: int = -1,
    verbose: bool = True,
) -> np.ndarray:
    if metric == "portrait_div":
        raise ValueError("portrait_div should use incremental greedy selection")

    if verbose:
        print(f"  Computing {metric} embeddings for {len(graphs)} graphs...")
        device = _get_device()
        print(f"  Using device: {device}")

    if metric in ("netlsd_heat", "netlsd_wave"):
        kernel = "heat" if metric == "netlsd_heat" else "wave"
        embeddings = _netlsd_signature_batch(graphs, kernel)
        return np.array(embeddings)
    elif metric == "gcd":
        embeddings = compute_gcd_embeddings_parallel(graphs, orca_path, verbose=verbose)
        return np.array(embeddings)
    else:
        raise ValueError(f"Unknown metric: {metric}")


def greedy_maxsum(
    embeddings: np.ndarray,
    k: int,
    objective: Literal["energy", "average"],
    start_indices: list[int] | None = None,
    verbose: bool = True,
) -> list[int]:
    N = embeddings.shape[0]
    if k > N:
        raise ValueError(f"Cannot select {k} graphs from {N} candidates")

    selected = []
    selected_mask = np.zeros(N, dtype=bool)
    sum_scores = np.zeros(N)

    if start_indices is None:
        start_indices = [0]

    for idx in start_indices:
        selected.append(idx)
        selected_mask[idx] = True
        dists = np.linalg.norm(embeddings - embeddings[idx], axis=1)
        if objective == "energy":
            with np.errstate(divide='ignore'):
                score_update = -1.0 / dists
                score_update[~np.isfinite(score_update)] = 0
            sum_scores += score_update
        else:
            sum_scores += dists

    sum_scores[selected_mask] = -np.inf
    remaining = k - len(selected)
    if remaining <= 0:
        return selected[:k]

    iterator = (
        tqdm(range(remaining), desc="  Selecting") if verbose else range(remaining)
    )

    for _ in iterator:
        best = np.argmax(sum_scores)
        if sum_scores[best] == -np.inf:
            break
        selected.append(best)
        selected_mask[best] = True
        new_dists = np.linalg.norm(embeddings - embeddings[best], axis=1)
        if objective == "energy":
            with np.errstate(divide='ignore'):
                score_update = -1.0 / new_dists
                score_update[~np.isfinite(score_update)] = 0
            sum_scores += score_update
        else:
            sum_scores += new_dists
        sum_scores[selected_mask] = -np.inf

    return selected


def evaluate_selection(
    graphs: list[nx.Graph],
    metric: DistanceName,
    orca_path: str,
    n_jobs: int,
) -> tuple[float, float]:
    workers = Parallel(n_jobs=n_jobs)
    D = pairwise_distance_matrix(
        graphs, distance=metric, workers=workers, orca_path=orca_path
    )
    return average_diversity(D), summarize(D).energy


def run_iterative_survival_selection(
    prob_matrices: list[np.ndarray],
    config: SelectionConfig,
) -> SelectionResult:
    verbose = config.verbose
    rng = np.random.default_rng(config.random_seed)
    use_portrait = config.metric == "portrait_div"

    n_matrices = len(prob_matrices)
    k_final = config.k_select

    total_budget = config.sampling_budget
    phase1_budget = min(n_matrices, total_budget // 10)
    phase2_budget = total_budget - phase1_budget
    n_top_matrices = min(config.n_top_matrices, n_matrices, phase1_budget)
    samples_per_matrix = phase2_budget // n_top_matrices if n_top_matrices > 0 else 0

    if verbose:
        print(f"\n{'=' * 60}")
        print("EMBEDDING-GUIDED MATRIX SELECTION")
        print(f"{'=' * 60}")
        print(f"  Total budget: {total_budget}")
        print(f"  Phase 1 (matrix selection): {phase1_budget} samples")
        print(f"  Phase 2 (focused sampling): {phase2_budget} samples")
        print(f"  Top matrices to select: {n_top_matrices}")
        print(f"  Samples per selected matrix: {samples_per_matrix}")
        print(f"  Total probability matrices: {n_matrices}")
        print(f"  Using portrait (incremental): {use_portrait}")
        print()

    if verbose:
        print(f"{'=' * 40}")
        print("PHASE 1: MATRIX DIVERSITY DISCOVERY")
        print(f"{'=' * 40}")
        print(f"  Sampling one graph per matrix at T=1.0...")

    phase1_graphs = []
    phase1_sources = []
    matrices_to_sample = min(n_matrices, phase1_budget)
    matrix_indices = list(range(matrices_to_sample))

    for idx in tqdm(matrix_indices, desc="  Phase 1 sampling", disable=not verbose):
        adj = sample_bernoulli(prob_matrices[idx], temp=1.0)
        phase1_graphs.append(adj_to_nx(adj))
        phase1_sources.append(idx)

    if verbose:
        print(f"  Sampled {len(phase1_graphs)} representative graphs")

    if use_portrait:
        if verbose:
            print(f"  Computing portrait matrices...")
        phase1_portraits = compute_portraits(phase1_graphs, config.n_jobs, verbose=False)

        phase1_distributions = None
        if verbose:
            print(f"  Preprocessing portraits for fast divergence...")
        phase1_distributions, _ = preprocess_portraits_for_divergence(phase1_portraits, verbose=False)

        if verbose:
            print(f"  Finding most dissimilar pair...")

        if config.start_mode == "dissimilar":
            idx1, idx2 = find_most_dissimilar_pair_portrait(
                phase1_portraits, config.n_jobs, verbose=False,
                precomputed_distributions=phase1_distributions
            )
            start_indices = [idx1, idx2]
        else:
            start_indices = [0]

        if verbose:
            print(f"  Running incremental greedy selection for {n_top_matrices} matrices...")

        selected_matrix_indices = greedy_maxsum_incremental_portrait(
            phase1_portraits,
            n_top_matrices,
            config.objective,
            start_indices=start_indices,
            n_jobs=config.n_jobs,
            verbose=verbose,
            precomputed_distributions=phase1_distributions,
        )
    else:
        if verbose:
            print(f"  Computing embeddings...")
        phase1_embeddings = compute_embeddings(
            phase1_graphs, config.metric, config.orca_path, config.n_jobs, verbose=False
        )
        if verbose:
            print(f"  Embedding shape: {phase1_embeddings.shape}")

        if config.start_mode == "dissimilar":
            idx1, idx2 = find_most_dissimilar_pair(phase1_embeddings, verbose=False)
            start_indices = [idx1, idx2]
        else:
            start_indices = [0]

        selected_matrix_indices = greedy_maxsum(
            phase1_embeddings,
            n_top_matrices,
            config.objective,
            start_indices=start_indices,
            verbose=False,
        )

    top_matrix_ids = [phase1_sources[i] for i in selected_matrix_indices]

    if verbose:
        print(f"  Selected {len(top_matrix_ids)} diverse matrices")
        selected_means = [np.mean(prob_matrices[i]) for i in top_matrix_ids]
        print(f"  Selected matrix mean probabilities: min={min(selected_means):.3f}, "
              f"max={max(selected_means):.3f}, mean={np.mean(selected_means):.3f}")

    if verbose:
        print(f"\n{'=' * 40}")
        print("PHASE 2: FOCUSED SAMPLING")
        print(f"{'=' * 40}")
        print(f"  Sampling {samples_per_matrix} graphs from each of {len(top_matrix_ids)} matrices...")

    phase2_graphs = [phase1_graphs[i] for i in selected_matrix_indices]
    phase2_temps = [1.0] * len(phase2_graphs)
    phase2_sources = top_matrix_ids.copy()

    if use_portrait:
        phase2_portraits = [phase1_portraits[i] for i in selected_matrix_indices]
    else:
        phase2_embeddings = phase1_embeddings[selected_matrix_indices].copy()

    use_continuous = config.use_continuous_temp
    if use_continuous:
        temp_min, temp_max = config.temp_range
        log_temp_min, log_temp_max = np.log(temp_min), np.log(temp_max)
        rigid_prob = config.rigid_prob if config.include_rigid else 0.0

    new_graphs = []
    new_temps = []
    new_sources = []

    for matrix_idx in tqdm(top_matrix_ids, desc="  Phase 2 sampling", disable=not verbose):
        prob_matrix = prob_matrices[matrix_idx]
        for _ in range(samples_per_matrix - 1):
            if use_continuous:
                if rng.random() < rigid_prob:
                    temp = np.inf
                    adj = sample_rigid_topk(prob_matrix)
                else:
                    log_temp = rng.uniform(log_temp_min, log_temp_max)
                    temp = np.exp(log_temp)
                    adj = sample_bernoulli(prob_matrix, temp)
            else:
                temp_options = config.temperatures.copy()
                if config.include_rigid:
                    temp_options.append(np.inf)
                temp = temp_options[rng.integers(0, len(temp_options))]
                if temp == np.inf:
                    adj = sample_rigid_topk(prob_matrix)
                else:
                    adj = sample_bernoulli(prob_matrix, temp)

            new_graphs.append(adj_to_nx(adj))
            new_temps.append(temp)
            new_sources.append(matrix_idx)

    if verbose:
        print(f"  Sampled {len(new_graphs)} additional graphs")

    if len(new_graphs) > 0:
        if use_portrait:
            if verbose:
                print(f"  Computing portraits for new graphs...")
            new_portraits = compute_portraits(new_graphs, config.n_jobs, verbose=verbose)
            phase2_graphs.extend(new_graphs)
            phase2_portraits.extend(new_portraits)
        else:
            if verbose:
                print(f"  Computing embeddings for new graphs...")
            new_embeddings = compute_embeddings(
                new_graphs, config.metric, config.orca_path, config.n_jobs, verbose=verbose
            )
            phase2_graphs.extend(new_graphs)
            phase2_embeddings = np.vstack([phase2_embeddings, new_embeddings])

        phase2_temps.extend(new_temps)
        phase2_sources.extend(new_sources)

    if verbose:
        print(f"  Total phase 2 pool: {len(phase2_graphs)} graphs")

    phase2_distributions = None
    if use_portrait:
        if verbose:
            print(f"  Preprocessing portraits for fast divergence...")
        phase2_distributions, _ = preprocess_portraits_for_divergence(phase2_portraits, verbose=verbose)

    if verbose:
        print(f"\n{'=' * 40}")
        print("FINAL SELECTION")
        print(f"{'=' * 40}")
        print(f"  Running greedy selection for final {k_final} graphs...")

    if use_portrait:
        if config.start_mode == "dissimilar":
            idx1, idx2 = find_most_dissimilar_pair_portrait(
                phase2_portraits, config.n_jobs, verbose=False,
                precomputed_distributions=phase2_distributions
            )
            start_indices = [idx1, idx2]
        else:
            start_indices = [0]

        final_indices = greedy_maxsum_incremental_portrait(
            phase2_portraits,
            k_final,
            config.objective,
            start_indices=start_indices,
            n_jobs=config.n_jobs,
            verbose=verbose,
            precomputed_distributions=phase2_distributions,
        )
    else:
        if config.start_mode == "dissimilar":
            idx1, idx2 = find_most_dissimilar_pair(phase2_embeddings, verbose=False)
            start_indices = [idx1, idx2]
        else:
            start_indices = [0]

        final_indices = greedy_maxsum(
            phase2_embeddings,
            k_final,
            config.objective,
            start_indices=start_indices,
            verbose=verbose,
        )

    final_graphs = [phase2_graphs[i] for i in final_indices]
    final_temps = [phase2_temps[i] for i in final_indices]
    final_sources = [phase2_sources[i] for i in final_indices]

    if verbose:
        print(f"\n{'=' * 40}")
        print("EVALUATION")
        print(f"{'=' * 40}")

    final_diversity, final_energy = evaluate_selection(
        final_graphs, config.metric, config.orca_path, config.n_jobs
    )

    final_temp_contributions = defaultdict(int)
    for t in final_temps:
        if t == np.inf:
            final_temp_contributions[np.inf] += 1
        elif use_continuous:
            if t < 0.1:
                final_temp_contributions["<0.1"] += 1
            elif t < 0.5:
                final_temp_contributions["0.1-0.5"] += 1
            elif t < 1.0:
                final_temp_contributions["0.5-1.0"] += 1
            elif t < 2.0:
                final_temp_contributions["1.0-2.0"] += 1
            else:
                final_temp_contributions["2.0-5.0"] += 1
        else:
            final_temp_contributions[t] += 1

    unique_sources = len(set(final_sources))
    edges = [g.number_of_edges() for g in final_graphs]

    result = SelectionResult(
        graphs=final_graphs,
        diversity=final_diversity,
        energy=final_energy,
        temp_contributions=dict(final_temp_contributions),
        edge_mean=float(np.mean(edges)),
        edge_std=float(np.std(edges)),
        edge_min=int(min(edges)),
        edge_max=int(max(edges)),
        pool_size=len(phase2_graphs),
        selected_indices=final_indices,
    )

    if verbose:
        print(f"\nFinal diversity: {final_diversity:.6f}")
        print(f"Final energy: {final_energy:.6f}")
        print(f"Total budget used: {phase1_budget + len(new_graphs)}")
        print(f"Unique source matrices in final selection: {unique_sources}/{k_final}")
        print(f"\nTemperature contributions (final selection):")
        if use_continuous:
            bucket_order = ["<0.1", "0.1-0.5", "0.5-1.0", "1.0-2.0", "2.0-5.0", np.inf]
            for bucket in bucket_order:
                if bucket in final_temp_contributions:
                    label = "rigid" if bucket == np.inf else bucket
                    print(f"  {label}: {final_temp_contributions[bucket]}")
        else:
            for t in sorted(
                final_temp_contributions.keys(),
                key=lambda x: (x == np.inf, x if x != np.inf else 0),
            ):
                label = "rigid" if t == np.inf else f"{t:.2f}"
                print(f"  {label}: {final_temp_contributions[t]}")
        print(f"\nEdge statistics:")
        print(f"  Mean: {result.edge_mean:.1f}, Std: {result.edge_std:.1f}")
        print(f"  Range: [{result.edge_min}, {result.edge_max}]")

    return result


def plot_temperature_contributions(
    temps: list[float],
    selected_indices: list[int],
    output_path: str,
):
    temp_counts = defaultdict(int)
    anchor_count = 0
    for idx in selected_indices:
        t = temps[idx]
        if np.isnan(t):
            anchor_count += 1
        else:
            temp_counts[t] += 1

    regular_temps = sorted([t for t in temp_counts.keys() if t != np.inf])

    fig, ax = plt.subplots(figsize=(12, 6))

    x_labels = []
    counts = []
    colors = []

    if anchor_count > 0:
        x_labels.append("anchors")
        counts.append(anchor_count)
        colors.append("gold")

    for t in regular_temps:
        x_labels.append(f"{t:.2f}")
        counts.append(temp_counts[t])
        colors.append("steelblue")

    if np.inf in temp_counts:
        x_labels.append("rigid")
        counts.append(temp_counts[np.inf])
        colors.append("coral")

    bars = ax.bar(range(len(x_labels)), counts, color=colors, alpha=0.8, edgecolor="black")
    ax.set_xticks(range(len(x_labels)))
    ax.set_xticklabels(x_labels, rotation=45, ha="right")
    ax.set_xlabel("Temperature")
    ax.set_ylabel("Number of Selected Graphs")
    ax.set_title(f"Temperature Contributions to Final Selection (k={len(selected_indices)})")
    ax.grid(axis="y", alpha=0.3)

    for bar, count in zip(bars, counts):
        if count > 0:
            ax.text(
                bar.get_x() + bar.get_width() / 2,
                bar.get_height() + 0.5,
                str(count),
                ha="center",
                va="bottom",
                fontsize=9,
            )

    plt.tight_layout()
    Path(output_path).parent.mkdir(parents=True, exist_ok=True)
    plt.savefig(output_path, dpi=150, bbox_inches="tight")
    plt.close()


def run_selection(
    prob_matrices: list[np.ndarray],
    config: SelectionConfig | None = None,
) -> SelectionResult:
    if config is None:
        config = SelectionConfig()

    verbose = config.verbose
    use_portrait = config.metric == "portrait_div"
    prob_matrices = [to_numpy(x) for x in prob_matrices]

    if config.strategy == "iterative_survival":
        if verbose:
            print(f"Using ITERATIVE SURVIVAL strategy")
        return run_iterative_survival_selection(prob_matrices, config)

    if verbose:
        print(f"Running selection: metric={config.metric}, objective={config.objective}, k={config.k_select}")
        print(f"Using portrait (incremental): {use_portrait}")

    if verbose:
        print(f"\nPhase 1: Stratified temperature sampling...")

    graphs, temps, source_indices = stratified_sample(
        prob_matrices=prob_matrices,
        temperatures=config.temperatures,
        total_budget=config.total_budget,
        include_rigid=config.include_rigid,
        seed=config.random_seed,
        verbose=verbose,
    )

    pool_size = len(graphs)
    if verbose:
        print(f"  Total graphs sampled: {pool_size}")

    if verbose:
        print(f"\nPhase 2: Computing {'portraits' if use_portrait else 'embeddings'}...")

    precomputed_distributions = None

    if use_portrait:
        portraits = compute_portraits(graphs, config.n_jobs, verbose)
        if verbose:
            print(f"  Preprocessing portraits for fast divergence...")
        precomputed_distributions, _ = preprocess_portraits_for_divergence(portraits, verbose=verbose)
    else:
        embeddings = compute_embeddings(
            graphs, config.metric, config.orca_path, config.n_jobs, verbose
        )
        if verbose:
            print(f"  Embedding shape: {embeddings.shape}")

    if verbose:
        print(f"\n  Start mode: {config.start_mode}")

    rng = np.random.default_rng(config.random_seed + 1)

    if config.start_mode == "anchors":
        n_nodes = prob_matrices[0].shape[0]
        empty_graph = nx.Graph()
        empty_graph.add_nodes_from(range(n_nodes))
        full_graph = nx.complete_graph(n_nodes)
        anchor_graphs = [empty_graph, full_graph]

        if use_portrait:
            anchor_portraits = compute_portraits(anchor_graphs, config.n_jobs, verbose=False)
            portraits = anchor_portraits + portraits
            graphs = anchor_graphs + graphs
            precomputed_distributions, _ = preprocess_portraits_for_divergence(portraits, verbose=False)
        else:
            anchor_embeddings = compute_embeddings(
                anchor_graphs, config.metric, config.orca_path, config.n_jobs, verbose=False
            )
            graphs = anchor_graphs + graphs
            embeddings = np.vstack([anchor_embeddings, embeddings])

        temps = [np.nan, np.nan] + temps
        source_indices = [-1, -1] + source_indices

        n_anchors = 2
        shuffle_indices = rng.permutation(len(graphs) - n_anchors) + n_anchors
        shuffle_indices = np.concatenate([[0, 1], shuffle_indices])

        graphs = [graphs[i] for i in shuffle_indices]
        temps = [temps[i] for i in shuffle_indices]
        source_indices = [source_indices[i] for i in shuffle_indices]

        if use_portrait:
            portraits = [portraits[i] for i in shuffle_indices]
            precomputed_distributions = [precomputed_distributions[i] for i in shuffle_indices]
        else:
            embeddings = embeddings[shuffle_indices]

        start_indices = [0, 1]

    elif config.start_mode == "random":
        shuffle_indices = rng.permutation(len(graphs))
        graphs = [graphs[i] for i in shuffle_indices]
        temps = [temps[i] for i in shuffle_indices]
        source_indices = [source_indices[i] for i in shuffle_indices]

        if use_portrait:
            portraits = [portraits[i] for i in shuffle_indices]
            if precomputed_distributions is not None:
                precomputed_distributions = [precomputed_distributions[i] for i in shuffle_indices]
        else:
            embeddings = embeddings[shuffle_indices]

        start_indices = [0]

    elif config.start_mode == "dissimilar":
        if use_portrait:
            idx1, idx2 = find_most_dissimilar_pair_portrait(
                portraits, config.n_jobs, verbose=verbose,
                precomputed_distributions=precomputed_distributions
            )
        else:
            idx1, idx2 = find_most_dissimilar_pair(embeddings, verbose=verbose)

        other_indices = [i for i in range(len(graphs)) if i not in (idx1, idx2)]
        rng.shuffle(other_indices)
        new_order = [idx1, idx2] + other_indices

        graphs = [graphs[i] for i in new_order]
        temps = [temps[i] for i in new_order]
        source_indices = [source_indices[i] for i in new_order]

        if use_portrait:
            portraits = [portraits[i] for i in new_order]
            if precomputed_distributions is not None:
                precomputed_distributions = [precomputed_distributions[i] for i in new_order]
        else:
            embeddings = embeddings[new_order]

        start_indices = [0, 1]
    else:
        raise ValueError(f"Unknown start_mode: {config.start_mode}")

    if verbose:
        print(f"\nPhase 3: Greedy max-sum selection (k={config.k_select})...")

    if use_portrait:
        selected_indices = greedy_maxsum_incremental_portrait(
            portraits,
            config.k_select,
            config.objective,
            start_indices=start_indices,
            n_jobs=config.n_jobs,
            verbose=verbose,
            precomputed_distributions=precomputed_distributions,
        )
    else:
        selected_indices = greedy_maxsum(
            embeddings,
            config.k_select,
            config.objective,
            start_indices=start_indices,
            verbose=verbose,
        )

    if verbose:
        print(f"\nPhase 4: Evaluation...")

    selected_graphs = [graphs[i] for i in selected_indices]
    diversity, energy = evaluate_selection(
        selected_graphs, config.metric, config.orca_path, config.n_jobs
    )

    temp_contributions = defaultdict(int)
    for idx in selected_indices:
        t = temps[idx]
        if not np.isnan(t):
            temp_contributions[t] += 1

    edges = [g.number_of_edges() for g in selected_graphs]

    result = SelectionResult(
        graphs=selected_graphs,
        diversity=diversity,
        energy=energy,
        temp_contributions=dict(temp_contributions),
        edge_mean=float(np.mean(edges)),
        edge_std=float(np.std(edges)),
        edge_min=int(min(edges)),
        edge_max=int(max(edges)),
        pool_size=pool_size,
        selected_indices=selected_indices,
    )

    if verbose:
        print(f"\n{'=' * 60}")
        print("RESULTS")
        print(f"{'=' * 60}")
        print(f"Average pairwise diversity: {diversity:.6f}")
        print(f"Energy: {energy:.6f}")
        print(f"Pool size before selection: {pool_size}")
        print(f"\nEdge statistics:")
        print(f"  Mean: {result.edge_mean:.1f}, Std: {result.edge_std:.1f}")
        print(f"  Range: [{result.edge_min}, {result.edge_max}]")

    return result


INPUT_PROBS_PATH = "data/diverse_probs_128.pkl"
OUTPUT_DIR = "data/selected"
PLOTS_DIR = "training_visuals"
K_SELECT = 100
METRIC: DistanceName = "netlsd_wave"
OBJECTIVE: Literal["energy", "average"] = "average"
TOTAL_BUDGET = None
TEMPERATURES = DEFAULT_TEMPERATURES.copy()
INCLUDE_RIGID = True
START_MODE: Literal["anchors", "random", "dissimilar"] = "dissimilar"
ORCA_PATH = DEFAULT_ORCA_PATH
N_JOBS = -1
RANDOM_SEED = 42


def main():
    print("=" * 60)
    print("MULTI-TEMPERATURE SAMPLING FOR DIVERSE GRAPH SELECTION")
    print("=" * 60)
    print(f"Metric: {METRIC}")
    print(f"Objective: {OBJECTIVE}")
    print(f"Target K: {K_SELECT}")
    print(f"Total budget: {TOTAL_BUDGET if TOTAL_BUDGET else 'UNLIMITED'}")
    print(f"Temperatures: {TEMPERATURES}")
    print(f"Include rigid: {INCLUDE_RIGID}")
    print(f"Start mode: {START_MODE}")

    print(f"\nLoading probability matrices from {INPUT_PROBS_PATH}...")
    if not Path(INPUT_PROBS_PATH).exists():
        print(f"ERROR: Could not find {INPUT_PROBS_PATH}")
        return

    with open(INPUT_PROBS_PATH, "rb") as f:
        data = pickle.load(f)

    prob_matrices = [to_numpy(x) for x in data]
    print(f"Loaded {len(prob_matrices)} probability matrices, shape {prob_matrices[0].shape}")

    config = SelectionConfig(
        k_select=K_SELECT,
        metric=METRIC,
        objective=OBJECTIVE,
        total_budget=TOTAL_BUDGET,
        temperatures=TEMPERATURES,
        include_rigid=INCLUDE_RIGID,
        start_mode=START_MODE,
        orca_path=ORCA_PATH,
        n_jobs=N_JOBS,
        random_seed=RANDOM_SEED,
        verbose=True,
    )

    result = run_selection(prob_matrices, config)

    Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
    output_path = f"{OUTPUT_DIR}/multitemp_{METRIC}_{OBJECTIVE}.pkl"
    with open(output_path, "wb") as f:
        pickle.dump(result.graphs, f)
    print(f"\nSaved {len(result.graphs)} graphs to {output_path}")

    metadata = {
        "metric": METRIC,
        "objective": OBJECTIVE,
        "k": K_SELECT,
        "budget": TOTAL_BUDGET,
        "temperatures": TEMPERATURES,
        "include_rigid": INCLUDE_RIGID,
        "diversity": result.diversity,
        "energy": result.energy,
        "temp_contributions": result.temp_contributions,
        "selected_indices": result.selected_indices,
    }
    meta_path = f"{OUTPUT_DIR}/multitemp_{METRIC}_{OBJECTIVE}_meta.pkl"
    with open(meta_path, "wb") as f:
        pickle.dump(metadata, f)

    print("\nDone!")


if __name__ == "__main__":
    main()
