import streamlit as st
import networkx as nx
import numpy as np
import json
import time
import os
from sentence_transformers import SentenceTransformer
from openai import OpenAI
import streamlit.components.v1 as components

# ==== Project-specific imports (CPU / NumPy) ====
from kt_gen.knowledge_graph.utils.pydantic_models import EmbeddingStore, TraversalStep
from kt_gen.knowledge_graph.utils.utils_fgw import (
    compute_embeddings as compute_embeddings_cpu,
    elementary_distance as elementary_distance_cpu,
    compute_similarity_weighted_structure_distances as compute_similarity_weighted_structure_distances_cpu,
    compute_structure_distances as compute_structure_distances_cpu,
)
from kt_gen.knowledge_graph.kg_fgw import (
    recursive_traversal_with_scores_fusion_fgw_base,
    recursive_traversal_with_scores_fusion_fgw_enhanced,
    recursive_traversal_with_scores_fusion_fgw_genetic,
    recursive_traversal_with_score_fusion_fgw_genetic_optimized,
)

# ==== Imports (GPU / PyTorch) ====
from kt_gen.knowledge_graph.utils.utils_fgw_torch import (
    compute_embeddings as compute_embeddings_torch,
    compute_similarity_weighted_structure_distances as compute_similarity_weighted_structure_distances_torch,
    compute_structure_distances as compute_structure_distances_torch,
)
from kt_gen.knowledge_graph.kg_fgw_torch import (
    recursive_traversal_with_score_fusion_fgw_genetic_optimized as recursive_traversal_fgw_genetic_optimized_torch,
)

# ==== Other utilities ====
from kt_gen.utils.streamlit.utils_st import (
    draw_graph, display_graph_pyvis, display_graph_distance, display_graph_distance_with_prob
)
from kt_gen.utils.llm.utils_llm import build_prompt, ask_llm
from kt_gen.knowledge_graph.kg_base_algo import (
    recursive_traversal_with_scores_fusion,
    recursive_traversal_with_scores_fusion_and_llm,
)

# ==== Config & ENV ====
try:
    from dotenv import load_dotenv, find_dotenv
    load_dotenv(find_dotenv())
except ImportError:
    pass

model_path = os.getenv("MODEL_PATH", "sentence-transformers/all-MiniLM-L6-v2")

# SSL Bypass (optional, WSL)
try:
    import requests
    from huggingface_hub import configure_http_backend
    def backend_factory() -> requests.Session:
        session = requests.Session()
        session.verify = False
        return session
    configure_http_backend(backend_factory=backend_factory)
except ImportError:
    pass

# (optional) HF offline modes
os.environ.setdefault("TRANSFORMERS_OFFLINE", "1")
os.environ.setdefault("HF_DATASETS_OFFLINE", "1")
os.environ.setdefault("HF_HUB_OFFLINE", "1")


# ================== UTILITIES ==================

@st.cache_resource
def load_sentence_transformer(model_path: str):
    return SentenceTransformer(model_path, trust_remote_code=False)

def load_graph_from_json(json_data: dict):
    G = nx.DiGraph()
    # Nodes
    for node in json_data["nodes"]:
        G.add_node(node["index"], **node)
    # Edges
    has_structure_distances = "structure_distances" in json_data
    if has_structure_distances:
        for source, target, dict_data in json_data["edges"]:
            G.add_edge(source, target, **dict_data)
    else:
        for source, target in json_data["edges"]:
            G.add_edge(source, target, prob=1.0)
    return G, has_structure_distances

def initialize_session_state():
    for k, v in {
        'G': None,
        'model': None,
        'last_uploaded_filename': None,
        'structure_distances': None,
        'compute_mode': 'CPU (NumPy/POT)',
    }.items():
        if k not in st.session_state:
            st.session_state[k] = v

def display_graph_info(G):
    col1, col2 = st.columns(2)
    with col1:
        st.metric("📊 Nodes", len(G.nodes))
    with col2:
        st.metric("🔗 Edges", len(G.edges))

def step_to_pair(step):
    """Compat: TraversalStep (pydantic) or (node_id, similarity)."""
    if isinstance(step, (tuple, list)) and len(step) == 2:
        return step[0], step[1]
    if hasattr(step, "node_id"):
        return step.node_id, getattr(step, "similarity", None)
    return str(step), None


# ================== UI ==================

st.set_page_config(page_title="🧠 Graph LLM Navigator — CPU & GPU", layout="wide", page_icon="🧠")

st.markdown("""
# 🧠 Exploring a Document Graph
**Modes:** CPU (NumPy/POT — all methods) **or** GPU (PyTorch — optimized genetic FGW)
""")

initialize_session_state()

# ---- LOAD GRAPH ----
with st.container():
    st.markdown("## 📁 Load Graph")

    uploaded_file = st.file_uploader("📤 Upload a JSON graph", type=["json"])

    if uploaded_file is not None:
        if st.session_state['last_uploaded_filename'] != uploaded_file.name:
            with st.spinner("⏳ Loading graph..."):
                try:
                    json_data = json.load(uploaded_file)
                    G, has_structure_distances = load_graph_from_json(json_data)

                    if st.session_state['model'] is None:
                        with st.spinner("🤖 Loading model..."):
                            st.session_state['model'] = load_sentence_transformer(model_path)

                    st.session_state['G'] = G
                    st.session_state['last_uploaded_filename'] = uploaded_file.name
                    st.session_state['structure_distances'] = has_structure_distances

                    st.success("✅ Graph successfully imported!")
                    display_graph_info(G)
                except Exception as e:
                    st.error(f"❌ Error while loading: {e}")
                    st.stop()
        else:
            G = st.session_state['G']
            model = st.session_state['model']
            st.success("✅ Graph already loaded")
            display_graph_info(G)
    else:
        st.info("💡 Upload a JSON file to get started.")
        st.stop()

# ---- VISUALIZATION ----
with st.expander("👁️ Visualize graph", expanded=False):
    if st.session_state['G'] is not None:
        html_path = draw_graph(st.session_state['G'])
        components.html(open(html_path, 'r', encoding='utf-8').read(), height=500, scrolling=True)

# ---- MODE CHOICE ----
st.markdown("### ⚙️ Compute mode")
compute_mode = st.selectbox(
    "🖥️ Choose mode:",
    ["CPU (NumPy/POT)", "🚀 GPU (PyTorch FGW génétique optimisée)"],
    index=0
)
st.session_state['compute_mode'] = compute_mode
if compute_mode.startswith("CPU"):
    st.info("🖥️ **CPU (NumPy/POT)**: all methods available.")
else:
    st.info("🚀 **GPU (PyTorch)**: optimized genetic FGW (stabilized Sinkhorn).")

# ---- STRUCTURAL DISTANCES ----
st.markdown("## 🔬 Structural distances")

col1, col2 = st.columns(2)
with col1:
    if st.button("📊 Compute similarity-weighted structural distances"):
        with st.spinner("⏳ Computing..."):
            G = st.session_state['G']
            model = st.session_state['model']

            if compute_mode.startswith("CPU"):
                embedding_struct = compute_embeddings_cpu([G], model)
                structure_distances = compute_similarity_weighted_structure_distances_cpu(
                    [G], embedding_struct, factor=1.5, similarity_power=1.0
                )
                base_structure_distances = compute_structure_distances_cpu([G], factor=1.5)
            else:
                embedding_struct = compute_embeddings_torch([G], model=model)
                structure_distances = compute_similarity_weighted_structure_distances_torch(
                    [G], embedding_struct, factor=1.5, similarity_power=1.0
                )
                base_structure_distances = compute_structure_distances_torch([G], factor=1.5)

            st.session_state['structure_distances'] = structure_distances
            st.success("✅ Distances computed!")

            with st.expander("📈 Weighted matrix"):
                for i, dist in enumerate(structure_distances):
                    M = dist.detach().cpu().numpy() if hasattr(dist, "detach") else dist
                    st.write(f"**Graph {i+1}:**")
                    st.dataframe(M, use_container_width=True)

            with st.expander("📊 Base matrix"):
                for i, dist in enumerate(base_structure_distances):
                    M = dist.detach().cpu().numpy() if hasattr(dist, "detach") else dist
                    st.write(f"**Graph {i+1}:**")
                    st.dataframe(M, use_container_width=True)

            # Update edge attribute
            for src, dst in G.edges:
                src_i = list(G.nodes).index(src)
                dst_i = list(G.nodes).index(dst)
                val = structure_distances[0][src_i, dst_i].item() if hasattr(structure_distances[0], "item") else structure_distances[0][src_i, dst_i]
                G.edges[src, dst]["structure_distance"] = float(val)

            st.markdown("### 🎯 Graph with distances")
            display_graph_distance(G)

with col2:
    if compute_mode.startswith("CPU"):
        if st.button("🧮 Elementary distance (CPU)"):
            G = st.session_state['G']
            model = st.session_state['model']
            embeddings = {n: model.encode(G.nodes[n]["text"], normalize_embeddings=True) for n in G.nodes}
            d = elementary_distance_cpu(G, EmbeddingStore(vectors=embeddings), model)
            st.success(f"📏 Elementary distance (CPU): {d:.4f}")
    else:
        st.info("ℹ️ In GPU mode, navigation uses FGW (Sinkhorn). Elementary distance not computed here.")

# ---- LLM NAVIGATION ----
st.markdown("## 🤖 LLM Navigation")

client = OpenAI(base_url="http://localhost:11434/v1", api_key="ollama")

question = st.text_input(
    "❓ Ask your question:",
    placeholder="e.g., Why don't fungi perform photosynthesis?",
)

# Method selector depending on mode
if compute_mode.startswith("CPU"):
    method_options = [
        "🚀 Kurisu-G² (CPU)",
    ]
    method_help = "Classic methods (NumPy/POT)"
else:
    method_options = [
        "🚀Kurisu-G² (PyTorch)",
    ]
    method_help = "GPU-accelerated method (PyTorch + Sinkhorn)"

method = st.selectbox("🛠️ Choose method", method_options, help=method_help)

# ---- PARAMS ----
st.markdown("### ⚙️ Parameters")

# common
c1, c2, c3 = st.columns(3)
with c1:
    sim_threshold = st.slider("🎯 Similarity threshold", 0.0, 1.0, 0.7, 0.01)
with c2:
    fgw_threshold = st.slider("🌊 FGW threshold", 0.0, 3.0, 1.0, 0.01)
with c3:
    alpha_fgw = st.slider("⚖️ FGW alpha", 0.0, 1.0, 0.5, 0.01)

# LLM params if needed
llm_answer = None
if "LLM" in method:
    llm_answer = st.text_input("🤖 LLM answer:", "")
    if not llm_answer:
        st.warning("⚠️ Enter an LLM answer to use this method.")

# structural distance type (for genetic FGW)
structure_dist = st.selectbox("📏 Structural distance", ["similarity_weighted","base"])

# Sinkhorn (GPU only)
if not compute_mode.startswith("CPU"):
    c4, c5, c6 = st.columns(3)
    with c4:
        eps = st.number_input("ε (entropy)", min_value=1e-4, max_value=1.0, value=2e-2, step=1e-3, format="%.5f")
    with c5:
        sinkhorn_max_iter = st.number_input("Sinkhorn iterations", min_value=0, max_value=20000, value=50, step=10)
    with c6:
        sinkhorn_tol = st.number_input("Sinkhorn tolerance", min_value=1e-14, max_value=1e-6, value=1e-12, step=1e-14, format="%.1e")

# ---- RUN ----
if question and st.button("🚀 Start navigation", type="primary"):
    with st.spinner(f"🔄 Navigating ({method})..."):
        G = st.session_state['G']
        model = st.session_state['model']

        # EmbeddingStore (numpy arrays)
        embeddings = {n: model.encode(G.nodes[n]["text"], normalize_embeddings=True) for n in G.nodes}
        embedding_store = EmbeddingStore(vectors=embeddings)

        t0 = time.time()
        try:
            # ===== CPU =====
            if compute_mode.startswith("CPU"):
                if method == "🔍 Méthode de fusion de base":
                    path = recursive_traversal_with_scores_fusion("document", question, embedding_store, G, model, threshold=sim_threshold)
                    show = display_graph_pyvis

                elif method == "🤖 Méthode de fusion avec LLM":
                    if not llm_answer:
                        st.error("Please provide the LLM answer.")
                        st.stop()
                    path = recursive_traversal_with_scores_fusion_and_llm(
                        "document", question, embedding_store, G, model,
                        llm_answer, threshold=sim_threshold, alpha=0.5, type_fuse="franchement"
                    )
                    show = display_graph_pyvis

                elif method == "🌊 Méthode FGW de base":
                    path = recursive_traversal_with_scores_fusion_fgw_base(
                        "document", question, embedding_store, G, model,
                        alpha=alpha_fgw, sim_threshold=sim_threshold, fgw_threshold=fgw_threshold
                    )
                    show = display_graph_pyvis

                elif method == "⚡ Méthode FGW améliorée":
                    path = recursive_traversal_with_scores_fusion_fgw_enhanced(
                        "document", question, embedding_store, G, model,
                        max_depth=3, sim_threshold=sim_threshold, fgw_threshold=fgw_threshold,
                        alpha=alpha_fgw, accelerated=False
                    )
                    show = display_graph_pyvis

                elif method == "🧬 Méthode FGW génétique":
                    path = recursive_traversal_with_scores_fusion_fgw_genetic(
                        "document", question, embedding_store, G, model,
                        max_depth=3, sim_threshold=sim_threshold, fgw_threshold=fgw_threshold,
                        alpha=alpha_fgw, structure_distance=structure_dist
                    )
                    show = display_graph_distance_with_prob

                elif method == "🚀 Kurisu-G² (CPU)":
                    path = recursive_traversal_with_score_fusion_fgw_genetic_optimized(
                        "document", question, embedding_store, G, model,
                        max_depth=3, sim_threshold=sim_threshold, fgw_threshold=fgw_threshold,
                        alpha=alpha_fgw, structure_distance=structure_dist, enable_logging=True
                    )
                    show = display_graph_distance_with_prob

                else:
                    st.error("Unknown CPU method.")
                    st.stop()

            # ===== GPU / PyTorch =====
            else:
                if method == "🚀Kurisu-G² (PyTorch)":
                    path = recursive_traversal_fgw_genetic_optimized_torch(
                        start_node="document",
                        question=question,
                        embeddings=embedding_store,
                        G=G,
                        model=model,  # reuse the same cached model
                        max_depth=3,
                        sim_threshold=sim_threshold,
                        fgw_threshold=fgw_threshold,
                        alpha=alpha_fgw,
                        structure_distance=structure_dist,
                        eps=float(eps),
                        sinkhorn_max_iter=int(sinkhorn_max_iter),
                        sinkhorn_tol=float(sinkhorn_tol),
                    )
                    show = display_graph_distance_with_prob
                else:
                    st.error("Unknown GPU method.")
                    st.stop()

            t1 = time.time()
            st.success(f"✅ Navigation completed in {t1 - t0:.2f} s")

            # Metrics
            if path:
                scores = []
                for step in path:
                    nid, s = step_to_pair(step)
                    if s is not None and isinstance(s, (int, float)):
                        scores.append(s)
                m1, m2, m3 = st.columns(3)
                with m1:
                    st.metric("⏱️ Time", f"{t1 - t0:.2f}s")
                with m2:
                    st.metric("🛤️ Visited nodes", len(path))
                with m3:
                    st.metric("📊 Average score", f"{np.mean(scores):.3f}" if scores else "N/A")

            # Graph display
            st.markdown("### 🎯 Navigation graph")
            show(G)

            # Context & LLM answer
            if path:
                node_ids = [step_to_pair(s)[0] for s in path]
                if node_ids:
                    last_node = node_ids[-1]
                    node_text = G.nodes.get(last_node, {}).get("text", "Text not available")
                    context = f"{last_node}: {node_text}"

                    with st.expander("📄 Extracted context", expanded=True):
                        st.markdown(f"**Final node:** {context}")
                        st.markdown("**Full path:**")
                        for nid in node_ids:
                            text = G.nodes.get(nid, {}).get("text", "Text not available")
                            st.markdown(f"- **{nid}** : {text}")

                    with st.spinner("🤖 Generating answer..."):
                        prompt = build_prompt(question, context)
                        response = ask_llm(prompt, client)

                    st.markdown("### 💬 LLM Answer")
                    st.success(response)

                    with st.expander("🗺️ Path details"):
                        for s in path:
                            nid, score = step_to_pair(s)
                            if nid in G.nodes:
                                text = G.nodes[nid].get("text", "Text not available")
                                if score is not None:
                                    st.markdown(f"- **{nid}** (similarity: `{score:.3f}`) : {text}")
                                else:
                                    st.markdown(f"- **{nid}** (root) : {text}")
                            else:
                                st.markdown(f"- **{nid}** (node not found) : Score {score}")
                else:
                    st.warning("⚠️ No node found in the path.")
            else:
                st.warning("⚠️ No path found in the graph.")

        except Exception as e:
            st.error(f"❌ Error during navigation: {e}")
            st.exception(e)
