import streamlit as st
import networkx as nx
import numpy as np
import faiss
import pyvis.network as net
import tempfile
from sentence_transformers import SentenceTransformer
from openai import OpenAI
import streamlit.components.v1 as components
import requests
import json
from huggingface_hub import configure_http_backend
import os
import copy
# module to check execution time
import time
from utils_streamlit.utils_st import draw_graph, draw_graph_with_similarity
from faq_src.utils.utils_graph.kg_base import compute_sim_all_nodes
from faq_src.utils.utils_graph.distance import compute_all_graphs_mds, compute_all_graphs_pca
from faq_src.utils.utils_graph.clustering import cluster_space
from faq_src.utils.utils_fgw.utils_fgw import compute_similarity_weighted_question_structure_distances, compute_embeddings, compute_pairwise_gw, compute_pairwise_fgw
from faq_src.utils.utils_graph.distance import vectorize_graph_by_similarities
from utils_streamlit.utils_st import display_graph_distance

from sklearn.metrics.pairwise import pairwise_distances
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import plotly.colors

from dotenv import load_dotenv, find_dotenv

load_dotenv(find_dotenv())

model_path = os.getenv("MODEL_PATH")
if not model_path:
    model_path = "sentence-transformers/all-MiniLM-L6-v2"


# --------------------- SSL Bypass -----------------------------
def backend_factory() -> requests.Session:
    session = requests.Session()
    session.verify = False
    return session

configure_http_backend(backend_factory=backend_factory)

# --------------------------------------------------------------

# --------------------- Streamlit App --------------------------
st.set_page_config(page_title="FAQtorisation", layout="wide")
st.title("FAQtorisation")

# 1. Graph and embeddings
uploaded_file = st.file_uploader("📂 Upload a JSON graph", type=["json"])

# --- Load questions ---
uploaded_questions = st.file_uploader("❓ Upload a questions file (JSON or TXT)", type=["json", "txt"])
questions = []
if uploaded_questions:
    if uploaded_questions.name.endswith(".json"):
        questions = json.load(uploaded_questions)
    else:
        questions = uploaded_questions.read().decode("utf-8").splitlines()
    st.success(f"{len(questions)} questions loaded.")
else:
    st.stop()

# Use session cache to store the graph and the model (plus embeddings)
if 'G' not in st.session_state:
    st.session_state['G'] = None
    st.session_state['model'] = None
    st.session_state['last_uploaded_filename'] = None

if uploaded_file is not None:
    # Check if the file has changed
    if (st.session_state['last_uploaded_filename'] != uploaded_file.name):
        try:
            json_data = json.load(uploaded_file)
            G = nx.DiGraph()

            # Add nodes
            for node in json_data["nodes"]:
                G.add_node(node["index"], **node)

            # Add edges
            for source, target in json_data["edges"]:
                G.add_edge(source, target, prob=1.0)

            # Load the model only once
            if st.session_state['model'] is None:
                model = SentenceTransformer(model_path)
                st.session_state['model'] = model
            else:
                model = st.session_state['model']

            st.session_state['G'] = G
            st.session_state['last_uploaded_filename'] = uploaded_file.name

            st.success(f"Graph imported with {len(G.nodes)} nodes and {len(G.edges)} edges.")
        except Exception as e:
            st.error(f"Error while loading file: {e}")
            st.stop()
    else:
        G = st.session_state['G']
        model = st.session_state['model']
        st.success(f"Graph already loaded with {len(G.nodes)} nodes and {len(G.edges)} edges.")
else:
    st.info("💡 Upload a JSON file to get started.")
    st.stop()

# Recompute embeddings at each iteration
embeddings = {n: st.session_state['model'].encode(G.nodes[n]["text"], normalize_embeddings=True) for n in G.nodes}
if 'embeddings' not in st.session_state:
    st.session_state['embeddings'] = embeddings

with st.expander("👁️ View graph"):
    html_path = draw_graph(G)
    components.html(open(html_path, 'r', encoding='utf-8').read(), height=500, scrolling=True)

if 'last_question' not in st.session_state:
    st.session_state['last_question'] = None

# Define a list of graphs with similarity (we will gradually add graphs after asking questions)
if 'G_sim_list' not in st.session_state:
    st.session_state['G_sim_list'] = []
    embedding_struct = compute_embeddings([G], model=model)
    dist_graph = compute_similarity_weighted_question_structure_distances([G], embedding_struct, question="graph theory", model=model)

    for i, question in enumerate(questions):
        if not question:
            st.warning(f"Question {i+1} is empty, it will be ignored.")
            continue

        # Check if the graph has already been computed for this question
        if st.session_state['last_question'] == question:
            st.warning(f"You already asked question {i+1}. Please ask a new one.")
            continue
        else:
            # Compute similarity for all nodes in the graph
            print(f"Computing similarity for question {i+1}: {question}")
            start_time = time.time()
            G_sim = compute_sim_all_nodes(G, question, embeddings=embedding_struct[0], model=model)
            G_sim.graph['question'] = question
            # Compute structural distances weighted by similarity
            dist_graph = compute_similarity_weighted_question_structure_distances([G_sim], embedding_struct, question, model)
            G_sim.graph['structure_distance'] = dist_graph[0]
            for j, node in enumerate(G_sim.nodes()):
                G_sim.nodes[node]["structure_distance"] = dist_graph[0][j]

            for src, dst in G_sim.edges():
                src_index = list(G_sim.nodes()).index(src)
                dst_index = list(G_sim.nodes()).index(dst)
                structure_dist = dist_graph[0][src_index, dst_index]
                G_sim.edges[src, dst]['structure_distance'] = structure_dist

            st.session_state['last_question'] = question
            st.session_state['G_sim_list'].append(copy.deepcopy(G_sim))
            end_time = time.time()
            st.success(f"Similarity for question {i+1} computed in {end_time - start_time:.2f} seconds.")

for i, graph in enumerate(st.session_state['G_sim_list']):
    structure_distances = np.array([graph.nodes[n]["structure_distance"] for n in graph.nodes()])
    st.dataframe(structure_distances, use_container_width=True)

disp_all = st.button("Show all graphs with similarity")
if disp_all:
    if not st.session_state['G_sim_list']:
        st.warning("No similarity graph has been computed.")
    else:
        for i, G_sim in enumerate(st.session_state['G_sim_list']):
            sim = [f"{G_sim.nodes[n]['similarity']:.4f}" for n in G_sim.nodes()]
            st.write(f"Graph {i+1} - Nodes with similarity:")
            for n in G_sim.nodes(data=True):
                st.write(f"{n[0]}: {n[1].get('similarity', 0):.4f} - {n[1]['text']}")

# --------------------- FAQtorisation --------------------------

# Define a configuration dictionary by method:

n_components = st.selectbox("Choose number of dimensions for MDS visualization", [2, 3], index=0)
method_proj = st.selectbox("Choose projection method", ["MDS", "PCA"], index=0)
dist_choice = st.selectbox("Choose distance metric for projection", ["fgw", "euclidean"], index=0)
alpha = st.slider("Alpha for graph fusion", min_value=0.0, max_value=1.0, value=0.5, step=0.1)

method_configs = {
    "MDS": {
        "function": lambda graphs, n_components: compute_all_graphs_mds(
            graphs,
            embeddings=st.session_state['embeddings'],
            model=st.session_state['model'],
            n_components=n_components,
            distance=dist_choice,
            alpha=alpha
        )
    },
    "PCA": {
        "function": lambda graphs, n_components: compute_all_graphs_pca(
            graphs,
            embeddings=st.session_state['embeddings'],
            model=st.session_state['model'],
            n_components=n_components,
            distance=dist_choice,
            alpha=alpha
        )
    }
}

n_clusters = st.number_input("Number of clusters for KMeans", min_value=1, max_value=10, value=3, step=1)
n_questions_per_cluster = st.number_input("Number of questions per cluster", min_value=1, max_value=10, value=5, step=1)

if st.button("Visualize graphs with similarity"):
    if not st.session_state['G_sim_list']:
        st.warning("No similarity graph has been computed.")
    else:
        try:
            dist_pairwise_gw = compute_pairwise_gw(
                st.session_state['G_sim_list'],
                embeddings=st.session_state['embeddings'],
                model=st.session_state['model']
            )
            F1 = compute_embeddings([st.session_state['G_sim_list'][0]], model=st.session_state['model'])[0]
            EmbList = [F1 for _ in range(len(st.session_state['G_sim_list']))]
            dist_pairwise_fgw = compute_pairwise_fgw(
                st.session_state['G_sim_list'],
                embeddings=EmbList,
                model=st.session_state['model'],
                alpha=alpha
            )

            st.session_state['dist_pairwise'] = dist_pairwise_gw
            st.session_state['dist_pairwise_fgw'] = dist_pairwise_fgw
            st.write("Pairwise fused Gromov–Wasserstein distances:")
            st.dataframe(dist_pairwise_fgw, use_container_width=True)
            st.write("Pairwise Gromov–Wasserstein distances:")
            st.dataframe(dist_pairwise_gw, use_container_width=True)

            vectors = [vectorize_graph_by_similarities(G) for G in st.session_state['G_sim_list']]
            dist_pairwise_euclidean = pairwise_distances(vectors, metric="euclidean")
            st.session_state['dist_pairwise_euclidean'] = dist_pairwise_euclidean
            st.write("Pairwise Euclidean distances:")
            st.dataframe(dist_pairwise_euclidean, use_container_width=True)

            proj_results = method_configs[method_proj]["function"](st.session_state['G_sim_list'], n_components)
            st.success(f"{method_proj} computed successfully.")

            labels = [graph.graph.get('question', f"Graph {i+1}") for i, graph in enumerate(st.session_state['G_sim_list'])]

            # Display each graph with distances using helper
            for graph in st.session_state['G_sim_list']:
                display_graph_distance(graph)

            # Colors by 8-question blocks
            color_palette = plotly.colors.qualitative.Plotly
            n_colors = len(color_palette)
            colors = [color_palette[(i // n_questions_per_cluster) % n_colors] for i in range(len(labels))]

            # KMeans clustering on projected coordinates
            cluster_labels, _ = cluster_space(proj_results, n_clusters=n_clusters)
            n_clusters = len(set(cluster_labels))

            # Store important data in session state
            st.session_state['cluster_labels'] = cluster_labels
            st.session_state['proj_results'] = proj_results
            st.session_state['labels'] = labels
            st.session_state['n_components'] = n_components
            st.session_state['method_proj'] = method_proj
            st.session_state['dist_choice'] = dist_choice
            st.session_state['n_clusters'] = n_clusters
            st.session_state['alpha'] = alpha

            # Different shapes by cluster (2D only)
            symbols = ['circle', 'square', 'diamond', 'cross', 'x', 'triangle-up', 'triangle-down', 'triangle-left', 'triangle-right']
            cluster_shapes = [symbols[cl % len(symbols)] for cl in cluster_labels]

            if n_components == 2:
                fig = go.Figure(data=go.Scatter(
                    x=proj_results[:, 0],
                    y=proj_results[:, 1],
                    mode='markers+text',
                    text=labels,
                    textposition='top center',
                    marker=dict(
                        size=10,
                        color=colors,
                        symbol=cluster_shapes,
                        line=dict(width=1, color='black')
                    ),
                ))
                fig.update_layout(
                    title=f"2D {method_proj} visualization of graphs (clusters + similarity)",
                    xaxis_title="Dimension 1",
                    yaxis_title="Dimension 2"
                )
            else:  # 3D: plot a Scatter3d per cluster
                fig = go.Figure()
                for cl in sorted(set(cluster_labels)):
                    indices = [i for i, c in enumerate(cluster_labels) if c == cl]
                    fig.add_trace(go.Scatter3d(
                        x=proj_results[indices, 0],
                        y=proj_results[indices, 1],
                        z=proj_results[indices, 2],
                        mode='markers+text',
                        name=f"Cluster {cl}",
                        text=[labels[i] for i in indices],
                        textposition='top center',
                        marker=dict(
                            size=6,
                            color=[colors[i] for i in indices],
                            line=dict(width=1, color='black'),
                            symbol='circle'  # single shape in 3D
                        )
                    ))
                fig.update_layout(
                    title=f"3D {method_proj} visualization of graphs (clusters + similarity)",
                    scene=dict(
                        xaxis_title="Dim 1",
                        yaxis_title="Dim 2",
                        zaxis_title="Dim 3"
                    )
                )

            st.plotly_chart(fig, use_container_width=True)

            # Show clusters
            for cl in sorted(set(cluster_labels)):
                st.markdown(f"**Cluster {cl}:**")
                for idx, label in enumerate(labels):
                    if cluster_labels[idx] == cl:
                        st.write(f"- {label}")

        except Exception as e:
            st.error(f"Error during interactive visualization: {e}")

# Cluster extraction to JSON
st.markdown("---")
st.markdown("### Cluster extraction")

if st.button("Export clusters to JSON"):
    if 'cluster_labels' not in st.session_state:
        st.warning("Please visualize the graphs with similarity first.")
    else:
        # Retrieve data from session state
        cluster_labels = st.session_state['cluster_labels']
        proj_results = st.session_state['proj_results']
        labels = st.session_state['labels']
        n_components = st.session_state['n_components']
        method_proj = st.session_state['method_proj']
        dist_choice = st.session_state['dist_choice']
        n_clusters = st.session_state['n_clusters']
        alpha = st.session_state['alpha']

        clusters_data = {}

        # Build the clusters dictionary
        for cl in sorted(set(cluster_labels)):
            cluster_questions = []
            cluster_coords = []
            for idx, label in enumerate(labels):
                if cluster_labels[idx] == cl:
                    question_data = {
                        "question": label,
                        "coordinates": {
                            "dim_1": float(proj_results[idx, 0]),
                            "dim_2": float(proj_results[idx, 1])
                        }
                    }

                    # Add the 3rd dimension if present
                    if n_components == 3:
                        question_data["coordinates"]["dim_3"] = float(proj_results[idx, 2])

                    # Add metadata for the corresponding graph
                    graph_sim = st.session_state['G_sim_list'][idx]
                    question_data["metadata"] = {
                        "num_nodes": len(graph_sim.nodes()),
                        "num_edges": len(graph_sim.edges()),
                        "avg_similarity": float(np.mean([graph_sim.nodes[n].get('similarity', 0) for n in graph_sim.nodes()])),
                        "max_similarity": float(np.max([graph_sim.nodes[n].get('similarity', 0) for n in graph_sim.nodes()])),
                        "min_similarity": float(np.min([graph_sim.nodes[n].get('similarity', 0) for n in graph_sim.nodes()]))
                    }

                    cluster_questions.append(question_data)

            clusters_data[f"cluster_{cl}"] = {
                "cluster_id": int(cl),
                "num_questions": len(cluster_questions),
                "questions": cluster_questions
            }

        # Add global metadata
        final_json = {
            "extraction_info": {
                "method": method_proj,
                "distance_metric": dist_choice,
                "n_components": n_components,
                "n_clusters": n_clusters,
                "alpha": alpha if dist_choice == "fgw" else None,
                "total_questions": len(labels)
            },
            "clusters": clusters_data
        }

        # Store in session state for later use
        st.session_state['clusters_data'] = clusters_data
        st.session_state['final_json'] = final_json

        # Display JSON
        st.json(final_json)

        # Download button
        json_str = json.dumps(final_json, indent=2, ensure_ascii=False)
        st.download_button(
            label="💾 Download clusters JSON",
            data=json_str,
            file_name=f"clusters_{method_proj.lower()}_{dist_choice}_{n_clusters}clusters.json",
            mime="application/json"
        )
        print(f"Extracted clusters: {len(clusters_data)} clusters found.")

        st.success(f"✅ Clusters successfully extracted! {len(clusters_data)} clusters found.")

# Characteristic question generation with Ollama via OpenAI
st.markdown("---")
st.markdown("### 🤖 Characteristic question generation")

col1, col2 = st.columns(2)
with col1:
    ollama_host = st.text_input("Ollama host", value="http://localhost:11434", help="URL of the Ollama server")
    model_name = st.selectbox("Ollama model", ["llama3", "llama3:8b", "llama3:70b", "mistral", "codellama"], index=0)
with col2:
    max_questions = st.number_input("Max questions per cluster", min_value=1, max_value=5, value=2)
    temperature = st.slider("Temperature (creativity)", min_value=0.0, max_value=1.0, value=0.7, step=0.1)

def generate_characteristic_questions(cluster_questions, cluster_id, model="llama3", host="http://localhost:11434", max_q=2, temp=0.7):
    """
    Generate characteristic questions for a given cluster using Ollama via an OpenAI-compatible API.
    """
    try:
        # Prepare prompt
        questions_text = "\n".join([f"- {q['question']}" for q in cluster_questions])

        prompt = f"""You are an expert in data analysis and FAQs.
Analyze the following questions that belong to the same thematic cluster:

{questions_text}

Generate {max_q} question(s) that best summarize and characterize the main theme of this cluster.
These questions must:
1. Capture the common essence of all the questions in the cluster
2. Be phrased clearly and concisely
3. Represent the primary concerns of this group

Expected response format:
Question 1: [your question]
Question 2: [your question] (if requested)

Reply only with the generated questions, without any additional explanation."""

        # Call Ollama via OpenAI-compatible endpoint
        ollama_endpoint = f"{host}/v1"
        client = OpenAI(
            base_url=ollama_endpoint,
            api_key="ollama"  # Ollama does not require a real API key
        )

        response = client.chat.completions.create(
            model=model,
            messages=[{"role": "user", "content": prompt}],
            temperature=temp,
            max_tokens=500
        )

        generated_text = response.choices[0].message.content
        questions = []

        for line in generated_text.split('\n'):
            if line.strip() and ('Question' in line or line.startswith('-')):
                clean_question = line.split(':', 1)[-1].strip()
                if clean_question and clean_question not in questions:
                    questions.append(clean_question)

        return questions[:max_q] if questions else [f"Characteristic question for cluster {cluster_id}"]

    except Exception as e:
        st.error(f"Ollama error for cluster {cluster_id}: {str(e)}")
        return [f"Error during generation for cluster {cluster_id}"]

if st.button("🚀 Generate characteristic questions"):
    if 'clusters_data' not in st.session_state:
        st.warning("Please extract clusters to JSON first.")
    else:
        st.info(f"Generation in progress with model {model_name}...")

        clusters_data = st.session_state['clusters_data']
        characteristic_questions = {}

        progress_bar = st.progress(0)
        total_clusters = len(clusters_data)

        for i, (cluster_key, cluster_data) in enumerate(clusters_data.items()):
            cluster_id = cluster_data['cluster_id']
            cluster_qs = cluster_data['questions']

            st.write(f"Processing cluster {cluster_id}...")

            generated_questions = generate_characteristic_questions(
                cluster_qs,
                cluster_id,
                model=model_name,
                host=ollama_host,
                max_q=max_questions,
                temp=temperature
            )

            characteristic_questions[cluster_key] = {
                "cluster_id": cluster_id,
                "original_questions_count": len(cluster_qs),
                "characteristic_questions": generated_questions,
                "original_questions": [q['question'] for q in cluster_qs]
            }

            # Update progress bar
            progress_bar.progress((i + 1) / total_clusters)

        st.success("Characteristic questions generated successfully!")

        st.session_state['characteristic_questions'] = characteristic_questions

        for cluster_key, data in characteristic_questions.items():
            with st.expander(f"🎯 Cluster {data['cluster_id']} - Characteristic questions"):
                st.markdown("**Generated questions:**")
                for j, char_q in enumerate(data['characteristic_questions'], 1):
                    st.write(f"{j}. {char_q}")

                st.markdown("**Original cluster questions:**")
                for orig_q in data['original_questions']:
                    st.write(f"• {orig_q}")

# Export characteristic questions (outside the generation block)
if 'characteristic_questions' in st.session_state:
    st.markdown("---")
    st.markdown("### 💾 Export characteristic questions")

    export_data = {
        "generation_info": {
            "model": model_name,
            "api_provider": "ollama_via_openai",
            "host": ollama_host,
            "temperature": temperature,
            "max_questions_per_cluster": max_questions,
            "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
        },
        "characteristic_questions": st.session_state['characteristic_questions']
    }

    export_json = json.dumps(export_data, indent=2, ensure_ascii=False)
    st.download_button(
        label="💾 Download characteristic questions",
        data=export_json,
        file_name=f"characteristic_questions_{model_name.replace(':', '_')}_{time.strftime('%Y%m%d_%H%M%S')}.json",
        mime="application/json"
    )
