# -*- coding: utf-8 -*-
"""
Torch-based FGW traversal (genetic optimized) matching the original implementation.
"""

import time
import uuid
import numpy as np
import torch
import networkx as nx
from typing import List, Dict, Optional, Tuple
from pydantic import BaseModel, Field
from sentence_transformers import SentenceTransformer

from kt_gen.knowledge_graph.utils.pydantic_models import EmbeddingStore, TraversalStep
from kt_gen.knowledge_graph.utils.utils_fgw_torch import (
    compute_embeddings,
    compute_structure_distances,
    compute_similarity_weighted_structure_distances,
    compute_fgw,
    dichotomie_fusion,
)

# ===== Config (unchanged)
class FGWConfig(BaseModel):
    alpha: float = 0.5
    max_depth: int = 3
    sim_threshold: float = 0.8
    fgw_threshold: float = 0.1
    initial_prob: float = 0.8
    initial_prob_new: float = 0.1
    lr: float = 0.8
    structure_distance: str = "base"  # "base" or "similarity_weighted"
    compact_fusion: bool = True
    enable_logging: bool = False


def _log_if_enabled(message: str, enable_logging: bool, level: str = "INFO"):
    if enable_logging:
        print(f"[{level}] {message}")


def _create_fused_text_compact(node_ids: List[str], G: nx.DiGraph, compact_fusion: bool = True) -> str:
    if not compact_fusion:
        return " | ".join([G.nodes[n]["text"] for n in node_ids])

    texts_by_section = {}
    for node_id in node_ids:
        node_text = G.nodes[node_id]["text"]
        if node_id.startswith("section_"):
            section_num = node_id.split("_")[1]
        else:
            section_num = "unknown"
        texts_by_section.setdefault(section_num, []).append((node_id, node_text))

    def _safe_int(x):
        try:
            return int(x)
        except Exception:
            return 10**9

    fused_texts = []
    for section_num in sorted(texts_by_section.keys(), key=_safe_int):
        section_items = texts_by_section[section_num]
        section_items.sort(key=lambda item: (
            _safe_int(item[0].split("_")[3] if len(item[0].split("_")) > 3 else 10**9),
            _safe_int(item[0].split("_")[5] if len(item[0].split("_")) > 5 else 10**9),
        ))
        section_text = " ".join([item[1] for item in section_items])
        fused_texts.append(f"From section_{section_num}: {section_text}")
    return "\n----\n".join(fused_texts)


def _clean_graph_dtypes(G: nx.DiGraph):
    # Convert numpy/torch float types to Python float for safe serialization
    for u, v, data in G.edges(data=True):
        for key, value in list(data.items()):
            if isinstance(value, (np.float32, np.float64)):
                data[key] = float(value)
            elif isinstance(value, torch.Tensor) and value.numel() == 1:
                data[key] = float(value.item())
    for n, data in G.nodes(data=True):
        for key, value in list(data.items()):
            if isinstance(value, (np.float32, np.float64)):
                data[key] = float(value)
            elif isinstance(value, torch.Tensor) and value.numel() == 1:
                data[key] = float(value.item())


def reward_distribution(edges_similarity: Dict[Tuple[str, str], float], reward: float, lr: float = 0.8) -> Optional[Dict[Tuple[str, str], float]]:
    if not edges_similarity:
        return None
    vals = list(edges_similarity.values())
    mean_similarity = float(np.mean(vals))
    if mean_similarity == 0:
        return None
    updated_edges = {}
    for edge, similarity in edges_similarity.items():
        if similarity >= mean_similarity:
            change = lr * (similarity - mean_similarity) * reward
            updated_edges[edge] = float(change)
    return updated_edges


def recursive_traversal_with_score_fusion_fgw_genetic_optimized(
    start_node: str,
    question: str,
    embeddings: EmbeddingStore,
    G: nx.DiGraph,
    model: SentenceTransformer,
    max_depth: int = 3,
    sim_threshold: float = 0.8,
    fgw_threshold: float = 0.1,
    alpha: float = 0.5,
    initial_prob: float = 0.8,
    initial_prob_new: float = 0.1,
    lr: float = 0.8,
    structure_distance: str = "base",
    compact_fusion: bool = True,
    enable_logging: bool = True,
    *,
    eps: float = 2e-2,
    sinkhorn_max_iter: int = 1000,
    sinkhorn_tol: float = 1e-9
) -> List[TraversalStep]:
    """
    Torch version strictly following the original 'genetic optimized' algorithm.
    """
    # Ensure all tensors are on the same device (GPU if available)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if enable_logging:
        print(f"🚀 Using device: {device}")
    
    path: List[TraversalStep] = [TraversalStep(node_id=start_node, similarity=None)]
    time_start = time.time()
    current = start_node


    query_vec = model.encode(
        [question],
        normalize_embeddings=True,
        convert_to_tensor=True
    )[0].to(dtype=torch.float64, device=device)  

    distance_restante = fgw_threshold

    # Graph-wide embeddings (torch) - ensure GPU usage
    embedding_current_graph = compute_embeddings([G], model=model, device=device)
    map_index_node = {n: i for i, n in enumerate(G.nodes)}

    # Structure distances according to mode
    if structure_distance == "similarity_weighted":
        dists_current_graph = compute_similarity_weighted_structure_distances([G], embedding_current_graph)
        # mirror original: write structure_distance on edges
        for u, v in G.edges():
            G.edges[u, v]["structure_distance"] = float(dists_current_graph[0][map_index_node[u], map_index_node[v]].item())
    else:
        dists_current_graph = compute_structure_distances([G])

    temps_wasted = 0.0

    for depth in range(max_depth):
        _log_if_enabled(f"Starting depth {depth} with node {current}", enable_logging)

        children = list(G.successors(current))
        if not children:
            break

        # Build child vectors from EmbeddingStore (may hold numpy); cast to torch64 on GPU
        child_vecs = []
        for c in children:
            v = embeddings.vectors[c]
            vt = torch.as_tensor(v, dtype=torch.float64, device=device)
            child_vecs.append(vt)
        child_mat = torch.stack(child_vecs, dim=0)
        sims = (child_mat @ query_vec)  # dot since normalized

        # sorting & relevant filters (keep exact logic/thresholding)
        sims_np = sims.cpu().numpy()
        sorted_children = sorted(zip(children, sims_np), key=lambda x: x[1], reverse=True)
        max_sim = float(np.max(sims_np)) if len(sims_np) > 0 else 0.0
        sorted_children_relevant = [c for c in sorted_children if c[1] > sim_threshold * max_sim]
        sorted_children_proba = [c for c in sorted_children if c[1] * G.edges[current, c[0]].get("prob", 1.0) > sim_threshold * max_sim]

        _log_if_enabled(f"Relevant children: {[f'{c[0]}: {c[1]:.3f}' for c in sorted_children_relevant]}", enable_logging)
        _log_if_enabled(f"Probability children: {[f'{c[0]}: {c[1]:.3f}' for c in sorted_children_proba]}", enable_logging)

        candidate_set: List[str] = []
        candidate_vecs: List[float] = []
        G_fus = G.copy()
        _log_if_enabled(f"Preprocessing time: {time.time() - time_start:.3f} seconds", enable_logging)

        uuid_fus = uuid.uuid4().hex[:8]
        new_embedding_fus_etape_precedente = embedding_current_graph.copy()
        new_distance_struct_etape_precedente = dists_current_graph.copy()

        # === Binary search on fusion set size (exactly as original) ===
        candidate_set, candidate_vecs, new_embedding_fus_etape_precedente, new_distance_struct_etape_precedente, G_fus, map_index_node_copy = dichotomie_fusion(
            sorted_children_proba, candidate_set, candidate_vecs,
            G, G_fus, map_index_node,
            embedding_current_graph, dists_current_graph,
            model, current, depth, uuid_fus,
            initial_prob, alpha, structure_distance,
            distance_restante, compact_fusion, enable_logging,
            eps=eps, sinkhorn_max_iter=sinkhorn_max_iter, sinkhorn_tol=sinkhorn_tol
        )

        map_index_node = map_index_node_copy

        if len(candidate_set) > 1:
            time_start_fusion = time.time()
            fused_text = _create_fused_text_compact(candidate_set, G, compact_fusion)
            new_node_id = f"fused_{current}_{depth}_{uuid.uuid4().hex[:8]}"
            G.add_node(new_node_id, text=fused_text, type="fused")
            # store embedding into EmbeddingStore
            emb_new = model.encode([fused_text], normalize_embeddings=True, convert_to_tensor=True).to(dtype=torch.float64)
            embeddings.vectors[new_node_id] = emb_new.cpu().numpy() 

            avg_prob = float(np.mean([G.edges[current, c].get("prob", initial_prob) for c in candidate_set]))
            avg_struct_dist = float(np.mean([G.edges[current, c].get("structure_distance", 1.0) for c in candidate_set]))
            G.add_edge(current, new_node_id, prob=avg_prob, structure_distance=avg_struct_dist)

            for c in candidate_set:
                for gc in G.successors(c):
                    G.add_edge(new_node_id, gc,
                               prob=G.edges[c, gc].get("prob", initial_prob),
                               structure_distance=G.edges[c, gc].get("structure_distance", 1.0))
                G.remove_edge(current, c)
                G.remove_node(c)
                _log_if_enabled(f"Removed edge {current} -> {c} and node {c}", enable_logging)

            # update mappings
            map_index_node[new_node_id] = len(G.nodes) - 1
            embedding_current_graph = new_embedding_fus_etape_precedente.copy()
            _log_if_enabled(f"Fusion time: {time.time() - time_start_fusion:.3f} seconds", enable_logging)

            if structure_distance == "similarity_weighted":
                dists_current_graph = new_distance_struct_etape_precedente.copy()
            else:
                dists_current_graph = compute_structure_distances([G])

            # Remaining children → grandchildren scan & reward
            remaining_children = sorted_children_proba[len(candidate_set):]
            all_grandchildren = []
            for c in remaining_children:
                grandchildren = list(G.successors(c[0]))
                all_grandchildren.extend([(c[0], gc) for gc in grandchildren])

            if all_grandchildren:
                grand_child_vecs = []
                for (parent, gc) in all_grandchildren:
                    v = embeddings.vectors[gc]
                    vt = torch.as_tensor(v, dtype=torch.float64, device=device)
                    grand_child_vecs.append(vt)
                if grand_child_vecs:
                    grand_child_mat = torch.stack(grand_child_vecs, dim=0)
                    grand_child_sims = (grand_child_mat @ query_vec).cpu().numpy()
                else:
                    grand_child_sims = np.array([])

                sorted_gc = sorted(zip(all_grandchildren, grand_child_sims), key=lambda x: x[1], reverse=True)
                max_gc = float(np.max(grand_child_sims)) if len(grand_child_sims) > 0 else 0.0
                sorted_gc_relevant = [c for c in sorted_gc if c[1] > sim_threshold * max_gc]
                sorted_gc_proba = [c for c in sorted_gc if c[1] * G.edges[c[0][0], c[0][1]].get("prob", 0.1) > sim_threshold * max_gc]

                edge_to_consider: Dict[Tuple[str, str], float] = {}
                for child in [gc[0][1] for gc in sorted_gc_relevant + sorted_gc_proba]:
                    sim = float((torch.as_tensor(embeddings.vectors[child], dtype=torch.float64, device=device) @ query_vec).item())
                    sim_transformed = float(np.exp(sim) / np.exp(1))
                    if not G.has_edge(new_node_id, child):
                        G.add_edge(new_node_id, child,
                                   prob=initial_prob_new,
                                   structure_distance=1.0 / (sim_transformed + 1e-6))
                    edge_to_consider[(new_node_id, child)] = sim_transformed

                rewarded_edges = reward_distribution(edge_to_consider, reward=5.0, lr=lr)
                if rewarded_edges:
                    for (u, v), change_prob in rewarded_edges.items():
                        current_prob = G.edges[u, v].get("prob", initial_prob_new)
                        G.edges[u, v]["prob"] = min(current_prob + change_prob, 1.0)
                        _log_if_enabled(f"Updated edge {u}->{v} probability by {change_prob:.3f}", enable_logging)

            current = new_node_id
            path.append(TraversalStep(node_id=current, similarity=float(np.mean(candidate_vecs)) if candidate_vecs else None))

        else:
            # No fusion → follow best child
            best_idx = int(np.argmax(sims_np))
            current = children[best_idx]
            path.append(TraversalStep(node_id=current, similarity=float(sims_np[best_idx])))

    _clean_graph_dtypes(G)
    _log_if_enabled(f"Total wasted time in calculations: {temps_wasted:.3f} seconds", enable_logging)
    return path



class FGWTraversal:
    def __init__(self, config: FGWConfig):
        self.config = config

    def genetic_traversal(self, start_node: str, question: str, embeddings: EmbeddingStore,
                          G: nx.DiGraph, model: SentenceTransformer) -> List[TraversalStep]:
        return recursive_traversal_with_score_fusion_fgw_genetic_optimized(
            start_node, question, embeddings, G, model,
            max_depth=self.config.max_depth, sim_threshold=self.config.sim_threshold,
            fgw_threshold=self.config.fgw_threshold, alpha=self.config.alpha,
            initial_prob=self.config.initial_prob, initial_prob_new=self.config.initial_prob_new,
            lr=self.config.lr, structure_distance=self.config.structure_distance,
            compact_fusion=self.config.compact_fusion, enable_logging=self.config.enable_logging
        )
