import spacy
import networkx as nx
import matplotlib.pyplot as plt
import json
from kt_gen.knowledge_graph.extract_graph import structure_text
import pandas as pd
# Exemple d'affichage du graphe d'une question :
import matplotlib.pyplot as plt
import os 
import json
from sentence_transformers import SentenceTransformer
from kt_gen.utils.llm.utils_llm import build_prompt, ask_llm_strict, build_EM_prompt_v3
from kt_gen.knowledge_graph.utils.pydantic_models import GraphWrapper, Node, TraversalStep, EmbeddingStore, EmbeddingMatrix
import time
from kt_gen.knowledge_graph.kg_fgw import recursive_traversal_with_score_fusion_fgw_genetic_optimized
from openai import OpenAI

# load spacy model
nlp = spacy.load("en_core_web_sm")

def generate_graphs_from_parsed_file(parsed_file_path):
    """
    Read a TSV file with columns 'id', 'question', 'context',
    and generate a NetworkX graph for each entry.

    Returns a dictionary {question_id: graph}
    """
    df = pd.read_csv(parsed_file_path, sep="\t")

    graphs_by_id = {}
    questions_by_id = {}
    
    
    for _, row in df.iterrows():
        qid = row["id"]
        context_text = str(row["context"]) if pd.notnull(row["context"]) else ""
        print(f"Processing question ID: {qid}")
        print(f"Context: {context_text[:100]}...")
        # Generate graph
        graph = structure_text(context_text, nlp)
        graphs_by_id[qid] = graph
        questions_by_id[qid] = row["question"]
        if len(graphs_by_id) > 10:
            return graphs_by_id, questions_by_id  # Limit to 10 graphs for the example

    return graphs_by_id, questions_by_id


def export_graph_to_json(graph,level="sentence"):

    nodes = []
    edges = []
    for node, data in graph.nodes(data=True):
        if level == "document" and data.get('type') == 'document':
            nodes.append({"index": node, "text": data.get("text", ""), "type": data.get("type", "")})
        elif level == "section" and data.get('type') in ['document', 'section']:
            nodes.append({"index": node, "text": data.get("text", ""), "type": data.get("type", "")})
        elif level == "paragraph" and data.get('type') in ['document', 'section', 'paragraph']:
            nodes.append({"index": node, "text": data.get("text", ""), "type": data.get("type", "")})
        elif level == "sentence" and data.get('type') in ['document', 'section', 'paragraph', 'sentence']:
            nodes.append({"index": node, "text": data.get("text", ""), "type": data.get("type", "")})
    # Edges as tuples of (node1, node2)

    for source, target in graph.edges():
        # Check if both nodes are in the list of nodes
        if source in [n["index"] for n in nodes] and target in [n["index"] for n in nodes]:
            edges.append((source, target))
    return {"nodes": nodes, "edges": edges}





model_path = os.getenv("MODEL_PATH")
if not model_path:
    model_path = "utils/model/all-MiniLM-L6-v2"  # HF repo name

model = SentenceTransformer(model_path, trust_remote_code=False)
client = OpenAI(base_url="http://localhost:11434/v1", api_key="ollama")

graphs, questions = generate_graphs_from_parsed_file("data/hotpotqa_fullwiki_parsed.tsv")
for qid, G in graphs.items():
    # Export the graphs to jsons
    graph_json = export_graph_to_json(G, level="sentence")
    with open(f"data/output/{qid}_graph.json", "w") as f:
        json.dump(graph_json, f, indent=2)
    time_start = time.time()

    embeddings = {n: model.encode(G.nodes[n]["text"], normalize_embeddings=True) for n in G.nodes}
    embedding_store = EmbeddingStore(vectors=embeddings)
    question = questions[qid]
    path_with_scores = recursive_traversal_with_score_fusion_fgw_genetic_optimized("document", question, embedding_store, G, model, sim_threshold=0.3, fgw_threshold=1.3, alpha=0.5,structure_distance='similarity_weighted'),
    node_ids = [nid for nid in path_with_scores]
    last_node = node_ids[-1][-2]

    context = last_node.node_id + ": " + G.nodes[last_node.node_id]["text"]
    print(f"time : {time.time()-time_start} sec")
    print(f"Question: {question}")
    print(f"Context: {context}")
    prompt = build_EM_prompt_v3(question, context)
    answer = ask_llm_strict(prompt, client)
    # Save in a json file named answers with qid: "EM response"
    with open(f"data/output/answer.json", "a") as f:
        json.dump({qid: answer}, f, indent=2)


