import streamlit as st
import networkx as nx
import numpy as np
import faiss
import pyvis.network as net
import tempfile
from sentence_transformers import SentenceTransformer
from openai import OpenAI
import streamlit.components.v1 as components
import requests
import json
from huggingface_hub import configure_http_backend
import os
import copy
# module to check execution time
import time
from utils_streamlit.utils_st import draw_graph, draw_graph_with_similarity
from faq_src.utils.utils_graph.kg_base import compute_sim_all_nodes
from faq_src.utils.utils_graph.distance import compute_all_graphs_mds, compute_all_graphs_pca
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from faq_src.utils.utils_fgw.utils_fgw import compute_similarity_weighted_question_structure_distances
from dotenv import load_dotenv, find_dotenv


load_dotenv(find_dotenv())

model_path = os.getenv("MODEL_PATH")
if not model_path:
    model_path = "sentence-transformers/all-MiniLM-L6-v2"


# --------------------- SSL Bypass -----------------------------
def backend_factory() -> requests.Session:
    session = requests.Session()
    session.verify = False
    return session

configure_http_backend(backend_factory=backend_factory)

# --------------------------------------------------------------

# --------------------- Streamlit App --------------------------
st.set_page_config(page_title="FAQtorization", layout="wide")
st.title("FAQtorization")

# 1. Graph and embeddings
uploaded_file = st.file_uploader(" Upload a JSON graph", type=["json"])


# Use session cache to store the graph and the model (plus embeddings)
if 'G' not in st.session_state:
    st.session_state['G'] = None
    st.session_state['model'] = None
    st.session_state['last_uploaded_filename'] = None

if uploaded_file is not None:
    # Check if the file changed
    if (st.session_state['last_uploaded_filename'] != uploaded_file.name):
        try:
            json_data = json.load(uploaded_file)
            G = nx.DiGraph()

            # Add nodes
            for node in json_data["nodes"]:
                G.add_node(node["index"], **node)

            # Add edges
            for source, target in json_data["edges"]:
                G.add_edge(source, target, prob=1.0)

            # Load the model once
            if st.session_state['model'] is None:
                model = SentenceTransformer(model_path)
                st.session_state['model'] = model
            else:
                model = st.session_state['model']

            st.session_state['G'] = G
            st.session_state['last_uploaded_filename'] = uploaded_file.name

            st.success(f"Graph imported with {len(G.nodes)} nodes and {len(G.edges)} edges.")
        except Exception as e:
            st.error(f"Error while loading the file: {e}")
            st.stop()
    else:
        G = st.session_state['G']
        model = st.session_state['model']
        st.success(f"Graph already loaded with {len(G.nodes)} nodes and {len(G.edges)} edges.")
else:
    st.info("💡 Upload a JSON file to get started.")
    st.stop()

# Recompute embeddings every iteration
embeddings = {n: st.session_state['model'].encode(G.nodes[n]["text"], normalize_embeddings=True) for n in G.nodes}

with st.expander(" View graph"):
    html_path = draw_graph(G)
    components.html(open(html_path, 'r', encoding='utf-8').read(), height=500, scrolling=True)


# Define a list of graphs with similarity (we will add graphs progressively after asking questions)
if 'G_sim_list' not in st.session_state:
    st.session_state['G_sim_list'] = []

if 'last_question' not in st.session_state:
    st.session_state['last_question'] = None


question = st.text_input("Ask your question")

if st.button("Add graph with similarity"):
    if not question:
        st.warning("Please ask a question before computing similarity.")
        st.stop()

    # Check if the graph has already been computed for this question
    if st.session_state['last_question'] == question:
        st.warning("You already asked this question. Please ask a new one.")
        st.stop()

    # Compute similarity for all nodes in the graph
    else:
        start_time = time.time()
        G_sim = compute_sim_all_nodes(G, question)
        G_sim.graph['question'] = question
        st.session_state['last_question'] = question

        st.session_state['G_sim_list'].append(copy.deepcopy(G_sim))
        end_time = time.time()
        st.success(f"Similarity computed in {end_time - start_time:.2f} seconds.")
        with st.expander(" View graph with similarity"):
            html_path = draw_graph_with_similarity(G_sim)
            components.html(open(html_path, 'r', encoding='utf-8').read(), height=500, scrolling=True)
        st.write("Nodes with similarity:")
        for n in G_sim.nodes(data=True):
            st.write(f"{n[0]}: {n[1].get('similarity', 0):.4f} - {n[1]['text']}")

disp_all = st.button("Show all graphs with similarity")
if disp_all:
    if not st.session_state['G_sim_list']:
        st.warning("No graph with similarity has been computed.")
    else:
        for i, G_sim in enumerate(st.session_state['G_sim_list']):
            sim = [f"{G_sim.nodes[n]['similarity']:.4f}" for n in G_sim.nodes()]
            st.write(f"Graph {i+1} - Nodes with similarity:")
            for n in G_sim.nodes(data=True):
                st.write(f"{n[0]}: {n[1].get('similarity', 0):.4f} - {n[1]['text']}")


# --------------------- FAQtorization --------------------------


# Define a dictionary of configurations by method:
method_configs = {
    "MDS": {
        "function": lambda graphs, n_components: compute_all_graphs_mds(graphs, n_components)
    },
    "PCA": {
        "function": lambda graphs, n_components: compute_all_graphs_pca(graphs, n_components)
    }
}

n_components = st.selectbox("Choose the number of dimensions for MDS visualization", [2, 3], index=0)
method_proj = st.selectbox("Choose the projection method", ["MDS", "PCA"], index=0)


# Projection and visualization of graphs with similarity
if st.button("Visualize graphs with similarity"):
    if not st.session_state['G_sim_list']:
        st.warning("No graph with similarity has been computed.")
    else:
        try:
            proj_results = method_configs[method_proj]["function"](st.session_state['G_sim_list'], n_components)

            st.success("MDS computed successfully.")
            # Labels = question associated with each graph
            labels = [graph.graph.get('question', f"Graph {i+1}") for i, graph in enumerate(st.session_state['G_sim_list'])]
            if n_components == 2:
                fig = go.Figure(data=go.Scatter(
                    x=proj_results[:, 0],
                    y=proj_results[:, 1],
                    mode='markers+text',
                    text=labels,
                    textposition='top center',
                    marker=dict(size=10, color='royalblue')
                ))
                fig.update_layout(
                    title="2D MDS visualization of graphs (similarity)",
                    xaxis_title="Dimension 1",
                    yaxis_title="Dimension 2"
                )
            else:  # 3D
                fig = go.Figure(data=go.Scatter3d(
                    x=proj_results[:, 0],
                    y=proj_results[:, 1],
                    z=proj_results[:, 2],
                    mode='markers+text',
                    text=labels,
                    textposition='top center',
                    marker=dict(size=5, color='royalblue')
                ))
                fig.update_layout(
                    title=f"3D {method_proj} visualization of graphs (similarity)",
                    scene=dict(
                        xaxis_title="Dim 1",
                        yaxis_title="Dim 2",
                        zaxis_title="Dim 3"
                    )
                )

            st.plotly_chart(fig, use_container_width=True)

        except Exception as e:
            st.error(f"Error during interactive visualization: {e}")
