import networkx as nx
import numpy as np
import time
import uuid
import copy
from typing import List, Dict, Optional, Tuple
from pydantic import BaseModel, Field
from sentence_transformers import SentenceTransformer
from kt_gen.knowledge_graph.utils.pydantic_models import EmbeddingStore, TraversalStep
from kt_gen.knowledge_graph.utils.utils_fgw import (
    compute_embeddings, 
    compute_structure_distances, 
    compute_fgw, 
    compute_gw, 
    compute_similarity_weighted_structure_distances
)


class FGWConfig(BaseModel):
    """Configuration for FGW traversal algorithms"""
    alpha: float = 0.5
    max_depth: int = 3
    sim_threshold: float = 0.8
    fgw_threshold: float = 0.1
    initial_prob: float = 0.8
    initial_prob_new: float = 0.1
    lr: float = 0.8
    structure_distance: str = "base"  # "base" or "similarity_weighted"
    compact_fusion: bool = True
    enable_logging: bool = False


def _log_if_enabled(message: str, enable_logging: bool, level: str = "INFO"):
    """Optional logging function"""
    if enable_logging:
        print(f"[{level}] {message}")


def _create_fused_text_compact(node_ids: List[str], G: nx.DiGraph, compact_fusion: bool = True) -> str:
    """Create fused text preserving section structure if compact_fusion is enabled"""
    if not compact_fusion:
        return " | ".join([G.nodes[n]["text"] for n in node_ids])
    
    # Group by section for compact fusion
    texts_by_section = {}
    for node_id in node_ids:
        node_text = G.nodes[node_id]["text"]
        
        if node_id.startswith("section_"):
            section_num = node_id.split("_")[1]
        else:
            section_num = "unknown"
        
        if section_num not in texts_by_section:
            texts_by_section[section_num] = []
        
        texts_by_section[section_num].append((node_id, node_text))
    
    # Sort by section and create structured text
    fused_texts = []
    for section_num in sorted(texts_by_section.keys(), key=lambda x: int(x) if x.isdigit() else float('inf')):
        section_items = texts_by_section[section_num]
        
        # Sort items within section by paragraph and sentence
        section_items.sort(key=lambda item: (
            int(item[0].split("_")[3]) if len(item[0].split("_")) > 3 and item[0].split("_")[3].isdigit() else float('inf'),
            int(item[0].split("_")[5]) if len(item[0].split("_")) > 5 and item[0].split("_")[5].isdigit() else float('inf')
        ))
        
        section_text = " ".join([item[1] for item in section_items])
        fused_texts.append(f"From section_{section_num}: {section_text}")
    
    return "\n----\n".join(fused_texts)


def _clean_graph_dtypes(G: nx.DiGraph):
    """Convert numpy dtypes to Python native types for serialization"""
    for u, v, data in G.edges(data=True):
        for key, value in data.items():
            if isinstance(value, (np.float32, np.float64)):
                data[key] = float(value)
    
    for n, data in G.nodes(data=True):
        for key, value in data.items():
            if isinstance(value, (np.float32, np.float64)):
                data[key] = float(value)


def recursive_traversal_with_scores_fusion_fgw_base(
    start_node: str,
    question: str,
    embeddings: EmbeddingStore,
    G: nx.DiGraph,
    model: SentenceTransformer,
    alpha: float = 0.5,
    max_depth: int = 3,
    sim_threshold: float = 0.8,
    fgw_threshold: float = 0.2,
    enable_logging: bool = False
) -> List[TraversalStep]:
    """Basic FGW traversal with simple fusion"""
    
    path: List[TraversalStep] = [TraversalStep(node_id=start_node, similarity=None)]
    current = start_node
    query_vec = model.encode([question], normalize_embeddings=True)[0]
    embedding_current_graph = compute_embeddings([G], model=model)
    dist_current_graph = compute_structure_distances([G])

    for depth in range(max_depth):
        children = list(G.successors(current))
        if not children:
            break

        child_vecs = [embeddings.vectors[c] for c in children]
        sims = np.dot(child_vecs, query_vec)
        max_sim = np.max(sims)
        similar_nodes = [i for i, sim in enumerate(sims) if sim > sim_threshold * max_sim]

        if len(similar_nodes) > 1:
            G_fus = G.copy()
            new_node_id = f"fused_{current}_{depth}_{uuid.uuid4().hex[:8]}"
            fused_texts = " | ".join([G.nodes[children[i]]["text"] for i in similar_nodes])
            G_fus.add_node(new_node_id, text=fused_texts, type="fused")

            # Reconnect edges
            for i in similar_nodes:
                for gc in G.successors(children[i]):
                    G_fus.add_edge(new_node_id, gc)
                G_fus.remove_edge(current, children[i])
                G_fus.remove_node(children[i])

            G_fus.add_edge(current, new_node_id)

            # Calculate FGW distance
            F2 = compute_embeddings([G_fus], model=model)
            D2 = compute_structure_distances([G_fus])
            fgw_dist, _ = compute_fgw(embedding_current_graph[0], dist_current_graph[0], F2[0], D2[0], alpha=alpha)
            
            _log_if_enabled(f"FGW distance: {fgw_dist}", enable_logging)

            if fgw_dist < fgw_threshold:
                # Accept fusion
                fused_text = " | ".join([G.nodes[children[i]]["text"] for i in similar_nodes])
                new_node_id = f"fused_{current}_{depth}_{uuid.uuid4().hex[:8]}"
                G.add_node(new_node_id, text=fused_text, type="fused")
                embeddings.vectors[new_node_id] = model.encode([fused_text], normalize_embeddings=True)[0]

                for i in similar_nodes:
                    for gc in G.successors(children[i]):
                        G.add_edge(new_node_id, gc)
                    G.remove_edge(current, children[i])
                    G.remove_node(children[i])

                G.add_edge(current, new_node_id)
                current = new_node_id
                path.append(TraversalStep(node_id=current, similarity=float(np.mean([sims[i] for i in similar_nodes]))))
                
                # Update embeddings for next iteration
                embedding_current_graph = compute_embeddings([G], model=model)
                dist_current_graph = compute_structure_distances([G])
                continue

        # No fusion, follow best child
        best_idx = int(np.argmax(sims))
        current = children[best_idx]
        sim_score = float(sims[best_idx])
        path.append(TraversalStep(node_id=current, similarity=sim_score))

    _clean_graph_dtypes(G)
    return path


def recursive_traversal_with_scores_fusion_fgw_enhanced(
    start_node: str,
    question: str,
    embeddings: EmbeddingStore,
    G: nx.DiGraph,
    model: SentenceTransformer,
    max_depth: int = 3,
    sim_threshold: float = 0.8,
    fgw_threshold: float = 0.1,
    alpha: float = 0.5,
    accelerated: bool = True,
    enable_logging: bool = False
) -> List[TraversalStep]:
    """Enhanced FGW traversal with edge addition"""
    
    path: List[TraversalStep] = [TraversalStep(node_id=start_node, similarity=None)]
    current = start_node
    query_vec = model.encode([question], normalize_embeddings=True)[0]
    distance_restante = fgw_threshold
    embeddings_current_graph = compute_embeddings([G], model=model)
    dist_current_graph = compute_structure_distances([G])
    
    for depth in range(max_depth):
        children = list(G.successors(current))
        if not children:
            break

        child_vecs = [embeddings.vectors[c] for c in children]
        sims = np.dot(child_vecs, query_vec)

        sorted_children = sorted(zip(children, sims), key=lambda x: x[1], reverse=True)
        sorted_children_relevant = [c for c in sorted_children if c[1] > sim_threshold * np.max(sims)]
        
        _log_if_enabled(f"Similarity for {current}: {[f'{c[0]}: {c[1]:.3f}' for c in sorted_children_relevant]}", enable_logging)

        # Greedy fusion attempt
        candidate_set = []
        G_fus = G.copy()

        for child_id, sim_score in sorted_children_relevant:
            tmp_set = candidate_set + [child_id]
            tmp_texts = [G.nodes[n]["text"] for n in tmp_set]
            tmp_fused_text = " | ".join(tmp_texts)
            tmp_node_id = f"fused_{current}_{depth}_{uuid.uuid4().hex[:8]}"

            G_fus.add_node(tmp_node_id, text=tmp_fused_text, type="fused")

            for gc in G.successors(child_id):
                G_fus.add_edge(tmp_node_id, gc)
            
            if G_fus.has_edge(current, child_id):
                G_fus.remove_edge(current, child_id)
            if child_id in G_fus.nodes:
                G_fus.remove_node(child_id)

            G_fus.add_edge(current, tmp_node_id)

            # Recalculate embeddings and distances
            Gs = [G_fus]
            feats = compute_embeddings(Gs, model=model)
            dists = compute_structure_distances(Gs)
            
            if accelerated:
                fgw_dist, _ = compute_gw(dist_current_graph[0], dists[0])
            else:
                fgw_dist, _ = compute_fgw(embeddings_current_graph[0], dist_current_graph[0], feats[0], dists[0], alpha=alpha)
            
            _log_if_enabled(f"FGW dist: {fgw_dist}", enable_logging)
            
            if fgw_dist <= distance_restante:
                candidate_set = tmp_set
                _log_if_enabled(f"Fusion accepted for {tmp_set} with FGW distance: {fgw_dist}", enable_logging)
                distance_restante = fgw_threshold - fgw_dist
            else:
                break

        if len(candidate_set) > 1:
            fused_text = " | ".join([G.nodes[n]["text"] for n in candidate_set])
            new_node_id = f"fused_{current}_{depth}_{uuid.uuid4().hex[:8]}"
            G.add_node(new_node_id, text=fused_text, type="fused")
            embeddings.vectors[new_node_id] = model.encode([fused_text], normalize_embeddings=True)[0]

            for c in candidate_set:
                for gc in G.successors(c):
                    G.add_edge(new_node_id, gc)
                G.remove_edge(current, c)
                G.remove_node(c)

            G.add_edge(current, new_node_id)
            current = new_node_id
            path.append(TraversalStep(node_id=current, similarity=float(np.mean([sim for _, sim in sorted_children_relevant[:len(candidate_set)]]))))
        else:
            best_idx = int(np.argmax(sims))
            current = children[best_idx]
            path.append(TraversalStep(node_id=current, similarity=float(sims[best_idx])))

    _clean_graph_dtypes(G)
    return path




def reward_distribution(
    edges_similarity: Dict[Tuple[str, str], float],
    reward: float=1.0,
    lr: float = 0.8,
    tau_pos: float = 1.0,      
    tau_neg: float = 1.0,       
    near_margin: float = 0.2,   
    neg_margin: float = 0.2,     
    neg_frac: float = 0.15,     
) -> Optional[Dict[Tuple[str, str], float]]:
    """
reward distribution function
    """
    if not edges_similarity:
        return None

    sims = np.array(list(edges_similarity.values()), dtype=float)
    edges = list(edges_similarity.keys())

    mean = sims.mean()
    std = sims.std()
    eps = 1e-12
    if (mean == 0 and sims.sum() == 0) or np.allclose(std, 0.0):
        # toutes les similarités sont (quasi) identiques -> répartir uniformément
        base_inc = (reward / len(sims)) * lr
        return {edge: base_inc for edge in edges}

    # z-scores
    z = (sims - mean) / (std + eps)

    # --- Pool positif (softmax centrée) ---
    pos_scores = np.clip(z + near_margin, a_min=0.0, a_max=None)
    if pos_scores.sum() > 0:
        pos_weights = np.exp(pos_scores / max(tau_pos, eps))
        pos_weights /= pos_weights.sum()
    else:
        pos_weights = np.ones_like(pos_scores) / len(pos_scores)

    pos_alloc = reward * pos_weights  # somme = reward

    # --- Pool négatif (pénalités pour z très bas) ---
    neg_mask = z < neg_margin
    neg_total = neg_frac * reward
    neg_delta = np.zeros_like(sims)

    if neg_mask.any() and neg_total > 0:
        neg_scores = np.abs(z[neg_mask])
        neg_weights = np.exp(neg_scores / max(tau_neg, eps))
        neg_weights /= neg_weights.sum()
        neg_delta[neg_mask] = -neg_total * neg_weights

    # Deltas finaux
    deltas = lr * (pos_alloc + neg_delta)
    updated_edges = {edge: float(delta) for edge, delta in zip(edges, deltas)}
    # Retour au format { (node1,node2): delta }
    print(f"updated_edges: {updated_edges}")
    return updated_edges

def test_fusion(tmp_set, candidate_set, G, G_fus, map_index_node, 
                embedding_current_graph, dists_current_graph, 
                model, current, depth, uuid_fus,
                initial_prob, alpha, structure_distance, compact_fusion=False):
    """Optimized fusion testing with minimal recomputation"""
    if len(tmp_set) <= 1:
        return 0.0, embedding_current_graph, dists_current_graph, G_fus, map_index_node
    new_embeddings_fus = copy.deepcopy(embedding_current_graph)
    new_distance_struct = copy.deepcopy(dists_current_graph)
    G_fus_copy = copy.deepcopy(G_fus)
    map_index_node_copy = copy.deepcopy(map_index_node)
    
    # Create fused text
    fused_text = _create_fused_text_compact(tmp_set, G, compact_fusion)
    tmp_node_id = f"fused_{current}_{depth}_{uuid_fus}"

    # Remove old nodes from embeddings & distances efficiently
    indices_to_remove = sorted([map_index_node_copy[n] for n in tmp_set], reverse=True)
    for index_to_del in indices_to_remove:
        new_embeddings_fus[0] = np.delete(new_embeddings_fus[0], index_to_del, axis=0)
        new_distance_struct[0] = np.delete(new_distance_struct[0], index_to_del, axis=0)
        new_distance_struct[0] = np.delete(new_distance_struct[0], index_to_del, axis=1)
        
        # Update indices
        for n2, idx in map_index_node_copy.items():
            if idx > index_to_del:
                map_index_node_copy[n2] -= 1

    # Add new fused node
    if tmp_node_id in G_fus_copy.nodes:
        G_fus_copy.nodes[tmp_node_id]["text"] = fused_text
    else:
        G_fus_copy.add_node(tmp_node_id, text=fused_text, type="fused")
        new_embedding = model.encode([fused_text], normalize_embeddings=True)[0].reshape(1, -1)
        new_embeddings_fus[0] = np.append(new_embeddings_fus[0], new_embedding, axis=0)
        
        # Add distance matrix rows/columns
        num_nodes = new_distance_struct[0].shape[0]
        new_distance_struct[0] = np.append(
            new_distance_struct[0],
            np.full((num_nodes, 1), np.exp(2) * 1.5),
            axis=1
        )
        new_distance_struct[0] = np.append(
            new_distance_struct[0],
            np.full((1, new_distance_struct[0].shape[1]), np.exp(2) * 1.5),
            axis=0
        )
        new_distance_struct[0][-1, -1] = 0.0

    # Add edges for fused node
    for n in tmp_set:
        for gc in G.successors(n):
            G_fus_copy.add_edge(
                tmp_node_id, gc,
                prob=G.edges[n, gc].get("prob", initial_prob),
                structure_distance=G.edges[n, gc].get("structure_distance", 1.0)
            )
            if gc in map_index_node_copy:
                new_distance_struct[0][-1, map_index_node_copy[gc]] = G.edges[n, gc].get("structure_distance", 1.0)

    # Add edge from current to fused node
    new_prob = np.mean([G.edges[current, gc].get("prob", initial_prob) for gc in tmp_set])
    new_structure_distance = np.mean([G.edges[current, gc].get("structure_distance", 1.0) for gc in tmp_set])
    G_fus_copy.add_edge(current, tmp_node_id, prob=new_prob, structure_distance=new_structure_distance)
    
    if current in map_index_node_copy:
        new_distance_struct[0][map_index_node_copy[current], -1] = new_structure_distance

    # Remove old nodes
    for child_id in tmp_set:
        if G_fus_copy.has_edge(current, child_id):
            G_fus_copy.remove_edge(current, child_id)
        if G_fus_copy.has_node(child_id):
            G_fus_copy.remove_node(child_id)

    # Calculate distances
    if structure_distance != "similarity_weighted":
        dists = compute_structure_distances([G_fus_copy])
    else:
        dists = new_distance_struct

    # Calculate FGW distance
    fgw_dist, _ = compute_fgw(
        embedding_current_graph[0], dists_current_graph[0],
        new_embeddings_fus[0], dists[0], alpha=alpha
    )

    return fgw_dist, new_embeddings_fus, dists, G_fus_copy, map_index_node_copy


def dichotomie_fusion(sorted_children_proba, candidate_set, candidate_vecs, 
                      G, G_fus, map_index_node,
                      embedding_current_graph, dists_current_graph, 
                      model, current, depth, uuid_fus,
                      initial_prob, alpha, structure_distance, 
                      distance_restante, compact_fusion=False, enable_logging=False):
    """Binary search for optimal fusion set"""
    
    left, right = 1, len(sorted_children_proba)
    best_k = 0
    best_candidate_set = candidate_set
    best_candidate_vecs = candidate_vecs[:]
    best_embedding = embedding_current_graph
    best_distance = dists_current_graph
    best_G_fus = G_fus

    while left <= right:
        _log_if_enabled(f"Binary search iteration: left={left}, right={right}", enable_logging)
        mid = (left + right) // 2
        tmp_set = candidate_set + [cid for cid, _ in sorted_children_proba[:mid]]
        tmp_vecs = candidate_vecs + [s for _, s in sorted_children_proba[:mid]]
        _log_if_enabled(f"Testing fusion with set: {tmp_set}", enable_logging)

        fgw_dist, new_emb, new_dist, new_G_fus, map_index_node_copy = test_fusion(
            tmp_set, candidate_set, G, G_fus, map_index_node,
            embedding_current_graph, dists_current_graph,
            model, current, depth, uuid_fus,
            initial_prob, alpha, structure_distance, compact_fusion
        )
        _log_if_enabled(f"FGW distance for set {tmp_set}: {fgw_dist}", enable_logging)

        if fgw_dist <= distance_restante:
            best_k = mid
            best_candidate_set = tmp_set
            best_candidate_vecs = tmp_vecs
            best_embedding = new_emb
            best_distance = new_dist
            best_G_fus = new_G_fus
            left = mid + 1
        else:
            right = mid - 1
    
    _log_if_enabled(f"Optimal set containing: {best_candidate_set} with FGW distance: {fgw_dist:.3f}", enable_logging)
    _log_if_enabled(f"Total candidate set: {sorted_children_proba}", enable_logging)

    return best_candidate_set, best_candidate_vecs, best_embedding, best_distance, best_G_fus, map_index_node_copy


def recursive_traversal_with_scores_fusion_fgw_genetic(
    start_node: str,
    question: str,
    embeddings: EmbeddingStore,
    G: nx.DiGraph,
    model: SentenceTransformer,
    max_depth: int = 3,
    sim_threshold: float = 0.8,
    fgw_threshold: float = 0.1,
    alpha: float = 0.5,
    initial_prob: float = 0.8,
    initial_prob_new: float = 0.1,
    lr: float = 0.8,
    structure_distance: str = "base",
    enable_logging: bool = False
) -> List[TraversalStep]:
    """Genetic algorithm-based FGW traversal with probability updates"""

    path: List[TraversalStep] = [TraversalStep(node_id=start_node, similarity=None)]
    time_start = time.time()
    current = start_node
    query_vec = model.encode([question], normalize_embeddings=True)[0]
    distance_restante = fgw_threshold
    embedding_current_graph = compute_embeddings([G])
    
    if structure_distance == "similarity_weighted":
        dists_current_graph = compute_similarity_weighted_structure_distances([G], embedding_current_graph)
    else:
        dists_current_graph = compute_structure_distances([G])
    
    temps_wasted = 0
    
    for depth in range(max_depth):
        children = list(G.successors(current))
        if not children:
            break

        child_vecs = [embeddings.vectors[c] for c in children]
        sims = np.dot(child_vecs, query_vec)

        sorted_children = sorted(zip(children, sims), key=lambda x: x[1], reverse=True)
        sorted_children_relevant = [c for c in sorted_children if c[1] > sim_threshold * np.max(sims)]
        sorted_children_proba = [c for c in sorted_children if c[1] * G.edges[current, c[0]].get("prob", 1.0) > sim_threshold * np.max(sims)]
        
        _log_if_enabled(f"Relevant children: {[f'{c[0]}: {c[1]:.3f}' for c in sorted_children_relevant]}", enable_logging)
        _log_if_enabled(f"Probability children: {[f'{c[0]}: {c[1]:.3f}' for c in sorted_children_proba]}", enable_logging)
        
        candidate_set = []
        candidate_vecs = []
        G_fus = G.copy()
        _log_if_enabled(f"Preprocessing time: {time.time() - time_start:.3f} seconds", enable_logging)

        for child_id, sim_score in sorted_children_proba:
            tmp_set = candidate_set + [child_id]
            tmp_texts = [G.nodes[n]["text"] for n in tmp_set]
            tmp_fused_text = " | ".join(tmp_texts)
            tmp_node_id = f"fused_{current}_{depth}_{uuid.uuid4().hex[:8]}"

            if tmp_node_id in G.nodes:
                G_fus.nodes[tmp_node_id]["text"] = tmp_fused_text
            else:
                G_fus.add_node(tmp_node_id, text=tmp_fused_text, type="fused")
            
            for gc in G.successors(child_id):
                G_fus.add_edge(tmp_node_id, gc, 
                              prob=G.edges[child_id, gc].get("prob", initial_prob_new), 
                              structure_distance=G.edges[child_id, gc].get("structure_distance", 1.0))
            
            new_prob = np.mean([G.edges[current, c].get("prob", initial_prob_new) for c in tmp_set])
            new_structure_distance = np.mean([G.edges[current, c].get("structure_distance", 1.0) for c in tmp_set])
            G_fus.add_edge(current, tmp_node_id, prob=new_prob, structure_distance=new_structure_distance)
            G_fus.remove_edge(current, child_id)
            G_fus.remove_node(child_id)
            
            time_start_calc = time.time()
            Gs = [G_fus]
            feats = compute_embeddings(Gs, model=model)
            temps_wasted += time.time() - time_start_calc
            
            if structure_distance == "similarity_weighted":
                time_start_calc = time.time()
                dists = compute_similarity_weighted_structure_distances(Gs, feats)
                temps_wasted += time.time() - time_start_calc
            else:
                dists = compute_structure_distances(Gs)
            
            fgw_dist, _ = compute_fgw(embedding_current_graph[0], dists_current_graph[0], feats[0], dists[0], alpha=alpha)

            if fgw_dist <= distance_restante:
                candidate_set = tmp_set
                candidate_vecs.append(sim_score)
                _log_if_enabled(f"Fusion accepted for {tmp_set} with FGW distance: {fgw_dist}", enable_logging)
                distance_restante = fgw_threshold - fgw_dist
            else:
                break

        if len(candidate_set) > 1:
            fused_text = " | ".join([G.nodes[n]["text"] for n in candidate_set])
            new_node_id = f"fused_{current}_{depth}_{uuid.uuid4().hex[:8]}"
            G.add_node(new_node_id, text=fused_text, type="fused")
            embeddings.vectors[new_node_id] = model.encode([fused_text], normalize_embeddings=True)[0]

            new_prob = np.mean([G.edges[current, c].get("prob", initial_prob_new) for c in candidate_set])
            new_structure_distance = np.mean([G.edges[current, c].get("structure_distance", 1.0) for c in candidate_set])
            G.add_edge(current, new_node_id, prob=new_prob, structure_distance=new_structure_distance)
            
            for c in candidate_set:
                for gc in G.successors(c):
                    G.add_edge(new_node_id, gc, 
                              prob=G.edges[c, gc].get("prob", initial_prob_new), 
                              structure_distance=G.edges[c, gc].get("structure_distance", 1.0))
                G.remove_edge(current, c)
                G.remove_node(c)

            current = new_node_id
            path.append(TraversalStep(node_id=current, similarity=float(np.mean(candidate_vecs))))

        else:
            best_idx = int(np.argmax(sims))
            current = children[best_idx]
            path.append(TraversalStep(node_id=current, similarity=float(sims[best_idx])))

    _clean_graph_dtypes(G)
    _log_if_enabled(f"Total wasted time in calculations: {temps_wasted:.3f} seconds", enable_logging)
    return path


def recursive_traversal_with_score_fusion_fgw_genetic_optimized(
    start_node: str,
    question: str,
    embeddings: EmbeddingStore,
    G: nx.DiGraph,
    model: SentenceTransformer,
    max_depth: int = 3,
    sim_threshold: float = 0.8,
    fgw_threshold: float = 0.1,
    alpha: float = 0.5,
    initial_prob: float = 0.8,
    initial_prob_new: float = 0.1,
    lr: float = 0.8,
    structure_distance: str = "base",
    compact_fusion: bool = True,
    enable_logging: bool = False
) -> List[TraversalStep]:
    """Optimized genetic algorithm-based FGW traversal"""

    path: List[TraversalStep] = [TraversalStep(node_id=start_node, similarity=None)]
    time_start = time.time()
    current = start_node
    query_vec = model.encode([question], normalize_embeddings=True)[0]
    distance_restante = fgw_threshold
    embedding_current_graph = compute_embeddings([G])
    map_index_node = {n: i for i, n in enumerate(G.nodes)}
    
    if structure_distance == "similarity_weighted":
        dists_current_graph = compute_similarity_weighted_structure_distances([G], embedding_current_graph)
        for edge in G.edges(data=True):
            G.edges[edge[0], edge[1]]["structure_distance"] = dists_current_graph[0][map_index_node[edge[0]], map_index_node[edge[1]]]
    else:
        dists_current_graph = compute_structure_distances([G])
    
    temps_wasted = 0
    
    for depth in range(max_depth):
        _log_if_enabled(f"Starting depth {depth} with node {current}", enable_logging)
        
        children = list(G.successors(current))
        if not children:
            break

        child_vecs = [embeddings.vectors[c] for c in children]
        sims = np.dot(child_vecs, query_vec)

        sorted_children = sorted(zip(children, sims), key=lambda x: x[1], reverse=True)
        sorted_children_relevant = [c for c in sorted_children if c[1] > sim_threshold * np.max(sims)]
        sorted_children_proba = [c for c in sorted_children if c[1] * G.edges[current, c[0]].get("prob", 1.0) > sim_threshold * np.max(sims)]
        prob = [G.edges[current, c[0]].get("prob", 1.0) for c in sorted_children_proba]
        _log_if_enabled(f"Relevant children: {[f'{c[0]}: {c[1]:.3f}' for c in sorted_children_relevant]}", enable_logging)
        _log_if_enabled(f"Probability children: {[f'{c[0]}: {c[1]:.3f}' for c in sorted_children_proba]}", enable_logging)
        _log_if_enabled(f"Probabilities: {[f'{p:.3f}' for p in prob]}", enable_logging)
        candidate_set = []
        candidate_vecs = []
        G_fus = G.copy()
        _log_if_enabled(f"Preprocessing time: {time.time() - time_start:.3f} seconds", enable_logging)

        uuid_fus = uuid.uuid4().hex[:8]
        new_embedding_fus_etape_precedente = embedding_current_graph.copy()
        new_distance_struct_etape_precedente = dists_current_graph.copy()

        candidate_set, candidate_vecs, new_embedding_fus_etape_precedente, new_distance_struct_etape_precedente, G_fus, map_index_node_copy = dichotomie_fusion(
            sorted_children_proba, candidate_set, candidate_vecs,
            G, G_fus, map_index_node,
            embedding_current_graph, dists_current_graph,
            model, current, depth, uuid_fus,
            initial_prob, alpha, structure_distance,
            distance_restante, compact_fusion, enable_logging
        )

        map_index_node = map_index_node_copy

        if len(candidate_set) > 1:
            time_start_fusion = time.time()
            fused_text = _create_fused_text_compact(candidate_set, G, compact_fusion)
            new_node_id = f"fused_{current}_{depth}_{uuid.uuid4().hex[:8]}"
            G.add_node(new_node_id, text=fused_text, type="fused")
            embeddings.vectors[new_node_id] = model.encode([fused_text], normalize_embeddings=True)[0]

            avg_prob = np.mean([G.edges[current, c].get("prob", initial_prob) for c in candidate_set])
            avg_struct_dist = np.mean([G.edges[current, c].get("structure_distance", 1.0) for c in candidate_set])
            G.add_edge(current, new_node_id, prob=avg_prob, structure_distance=avg_struct_dist)
            
            for c in candidate_set:
                for gc in G.successors(c):
                    G.add_edge(new_node_id, gc, 
                              prob=G.edges[c, gc].get("prob", initial_prob), 
                              structure_distance=G.edges[c, gc].get("structure_distance", 1.0))
                G.remove_edge(current, c)
                G.remove_node(c)
                _log_if_enabled(f"Removed edge {current} -> {c} and node {c}", enable_logging)

            # Update mappings
            map_index_node[new_node_id] = len(G.nodes) - 1
            embedding_current_graph = new_embedding_fus_etape_precedente.copy()
            _log_if_enabled(f"Fusion time: {time.time() - time_start_fusion:.3f} seconds", enable_logging)
            
            if structure_distance == "similarity_weighted":
                dists_current_graph = new_distance_struct_etape_precedente.copy()
            else:
                dists_current_graph = compute_structure_distances([G])

            # Handle edge probability updates for remaining children
            remaining_children = sorted_children_proba[len(candidate_set):]
            _log_if_enabled(f"Remaining children after fusion: {[c[0] for c in remaining_children]}", enable_logging)
            all_grandchildren = []
            
            for c in remaining_children:
                grandchildren = list(G.successors(c[0]))
                all_grandchildren.extend([(c[0], gc) for gc in grandchildren])
            print(f"all_grandchildren: {all_grandchildren}")

            if all_grandchildren:
                grand_child_vecs = [embeddings.vectors[gc[1]] for gc in all_grandchildren]
                grand_child_sims = np.dot(grand_child_vecs, query_vec)
                sorted_gc = sorted(zip(all_grandchildren, grand_child_sims), key=lambda x: x[1], reverse=True)
                sorted_gc_relevant = [c for c in sorted_gc if c[1] > sim_threshold * np.max(grand_child_sims)]
                sorted_gc_proba = [c for c in sorted_gc if c[1] * G.edges[c[0][0], c[0][1]].get("prob", 0.1) > sim_threshold * np.max(grand_child_sims)]

                # Apply edge updates with reward distribution
                edge_to_consider = {}
                for child in [gc[0][1] for gc in sorted_gc_relevant + sorted_gc_proba]:
                    sim = np.dot(embeddings.vectors[child], query_vec)
                    sim_transformed = np.exp(sim) / np.exp(1)
                    
                    if not G.has_edge(new_node_id, child):
                        G.add_edge(new_node_id, child, 
                                  prob=initial_prob_new, 
                                  structure_distance=1.0 / (sim_transformed + 1e-6))
                    
                    edge_to_consider[(new_node_id, child)] = sim_transformed
                print(f"edge_to_consider: {edge_to_consider}")

                rewarded_edges = reward_distribution(edge_to_consider, reward=1, lr=lr)
                if rewarded_edges:
                    for (u, v), change_prob in rewarded_edges.items():
                        current_prob = G.edges[u, v].get("prob", initial_prob_new)
                        G.edges[u, v]["prob"] = min(current_prob + change_prob, 1.0)
                        _log_if_enabled(f"Updated edge {u}->{v} probability by {change_prob:.3f}", enable_logging)

            current = new_node_id
            path.append(TraversalStep(node_id=current, similarity=float(np.mean(candidate_vecs))))

        else:
            best_idx = int(np.argmax(sims))

            current = children[best_idx]

            # Handle edge probability updates for remaining children
            remaining_children = sorted_children_proba[len(candidate_set):]
            _log_if_enabled(f"Remaining children after fusion: {[c[0] for c in remaining_children]}", enable_logging)
            all_grandchildren = []
            
            for c in remaining_children:
                grandchildren = list(G.successors(c[0]))
                all_grandchildren.extend([(c[0], gc) for gc in grandchildren])
            print(f"all_grandchildren: {all_grandchildren}")

            if all_grandchildren:
                grand_child_vecs = [embeddings.vectors[gc[1]] for gc in all_grandchildren]
                grand_child_sims = np.dot(grand_child_vecs, query_vec)
                sorted_gc = sorted(zip(all_grandchildren, grand_child_sims), key=lambda x: x[1], reverse=True)
                sorted_gc_relevant = [c for c in sorted_gc if c[1] > sim_threshold * np.max(grand_child_sims)]
                sorted_gc_proba = [c for c in sorted_gc if c[1] * G.edges[c[0][0], c[0][1]].get("prob", 0.1) > sim_threshold * np.max(grand_child_sims)]

                # Apply edge updates with reward distribution
                edge_to_consider = {}
                for child in [gc[0][1] for gc in sorted_gc_relevant + sorted_gc_proba]:
                    sim = np.dot(embeddings.vectors[child], query_vec)
                    sim_transformed = np.exp(sim) / np.exp(1)
                    
                    if not G.has_edge(current, child):
                        G.add_edge(current, child,
                                  prob=initial_prob_new,
                                  structure_distance=1.0 / (sim_transformed + 1e-6))

                    edge_to_consider[(current, child)] = sim_transformed
                print(f"edge_to_consider: {edge_to_consider}")

                rewarded_edges = reward_distribution(edge_to_consider, reward=1, lr=lr)
                if rewarded_edges:
                    for (u, v), change_prob in rewarded_edges.items():
                        current_prob = G.edges[u, v].get("prob", initial_prob_new)
                        G.edges[u, v]["prob"] = min(current_prob + change_prob, 1.0)
                        _log_if_enabled(f"Updated edge {u}->{v} probability by {change_prob:.3f}", enable_logging)

            path.append(TraversalStep(node_id=current, similarity=float(sims[best_idx])))

    _clean_graph_dtypes(G)
    _log_if_enabled(f"Total wasted time in calculations: {temps_wasted:.3f} seconds", enable_logging)
    return path


# Convenience class for new API (optional)
class FGWTraversal:
    """Main class for FGW-based graph traversal with fusion capabilities"""
    
    def __init__(self, config: FGWConfig):
        self.config = config
    
    def base_traversal(self, start_node: str, question: str, embeddings: EmbeddingStore,
                      G: nx.DiGraph, model: SentenceTransformer) -> List[TraversalStep]:
        return recursive_traversal_with_scores_fusion_fgw_base(
            start_node, question, embeddings, G, model,
            alpha=self.config.alpha, max_depth=self.config.max_depth,
            sim_threshold=self.config.sim_threshold, fgw_threshold=self.config.fgw_threshold,
            enable_logging=self.config.enable_logging
        )
    
    def enhanced_traversal(self, start_node: str, question: str, embeddings: EmbeddingStore,
                          G: nx.DiGraph, model: SentenceTransformer) -> List[TraversalStep]:
        return recursive_traversal_with_scores_fusion_fgw_enhanced(
            start_node, question, embeddings, G, model,
            max_depth=self.config.max_depth, sim_threshold=self.config.sim_threshold,
            fgw_threshold=self.config.fgw_threshold, alpha=self.config.alpha,
            enable_logging=self.config.enable_logging
        )
    
    def genetic_traversal(self, start_node: str, question: str, embeddings: EmbeddingStore,
                         G: nx.DiGraph, model: SentenceTransformer) -> List[TraversalStep]:
        return recursive_traversal_with_score_fusion_fgw_genetic_optimized(
            start_node, question, embeddings, G, model,
            max_depth=self.config.max_depth, sim_threshold=self.config.sim_threshold,
            fgw_threshold=self.config.fgw_threshold, alpha=self.config.alpha,
            initial_prob=self.config.initial_prob, initial_prob_new=self.config.initial_prob_new,
            lr=self.config.lr, structure_distance=self.config.structure_distance,
            compact_fusion=self.config.compact_fusion, enable_logging=self.config.enable_logging
        )
