import streamlit as st
import spacy
import networkx as nx
import matplotlib.pyplot as plt
import json
from kt_gen.knowledge_graph.extract_graph import structure_text



# Initial configuration
st.set_page_config(page_title="document Graph (Textual Inclusion)", layout="centered")
st.title("document Graph by Textual Inclusion")

# Load spaCy
@st.cache_resource
def load_spacy_model():
    return spacy.load("en_core_news_sm")

nlp = load_spacy_model()



# Visualization with highlight of a selected node
def plot_graph(graph, selected_node=None, level="document"):
    pos = nx.spring_layout(graph, seed=42)
    node_colors = []
    node_sizes = []
    node_to_draw = []

    for node, data in graph.nodes(data=True):
        if data.get('type') == 'merged':
            node_to_draw.append([node, data])
        
        if level == "document" and data.get('type') == 'document':
            node_to_draw.append([node, data])
        elif level == "section" and data.get('type') in ['document', 'section']:
            node_to_draw.append([node, data])
        elif level == "paragraph" and data.get('type') in ['document', 'section', 'paragraph']:
            node_to_draw.append([node, data])
        elif level == "sentence" and data.get('type') in ['document', 'section', 'paragraph', 'sentence']:
            node_to_draw.append([node, data])
        else:
            continue


    newGraph = graph.subgraph([node for node, data in node_to_draw])
    pos = nx.spring_layout(newGraph, seed=42)

    for node, data in node_to_draw:                
        # Color by type
        if data.get('type') == 'document':
            color = 'lightblue'
        elif data.get('type') == 'section':
            color = 'lightgreen'
        elif data.get('type') == 'paragraph':
            color = 'lightyellow'
        elif data.get('type') == 'sentence':
            color = 'lightcoral'
        elif data.get('type') == 'merged':
            color = 'lightpink'
        else:
            color = 'gray'

        if node == selected_node:
            color = 'gold'
            size = 1000
        else:
            size = 800

        node_colors.append(color)
        node_sizes.append(size)
    

    fig, ax = plt.subplots(figsize=(12, 8))
    nx.draw(newGraph, pos, with_labels=True, node_color=node_colors, node_size=node_sizes, font_size=8, ax=ax)
    plt.title("document Graph (selected node in yellow)")
    return fig


# Function to merge nodes (and subgraphs) in the graph
def merge_nodes(graph, node1, node2):
    # Create a new merged node
    merged_node = f"{node1}_{node2} "

    graph.add_node(merged_node, type="merged", text=f"Merged: {node1} and {node2}")

    # Add edges from the merged node to the neighbors of node1 and node2
    for neighbor in list(graph.neighbors(node1)) + list(graph.neighbors(node2)):
        graph.add_edge(merged_node, neighbor)

    # Redirect edges that were pointing to node1 to the new merged node
    for pred in list(graph.predecessors(node1)):
        graph.add_edge(pred, merged_node)
        graph.remove_edge(pred, node1)

    # Remove the old nodes
    graph.remove_node(node1)
    graph.remove_node(node2)

    return graph

# User interface
with st.expander("Raw text to analyze", expanded=True):
    sample = """1. Introduction

Artificial intelligence is transforming many sectors.
It relies on machine learning.

2. Context

Algorithms are evolving rapidly.
They require large amounts of data."""
    user_text = st.text_area("Raw text:", value=sample, height=300)
if "graph" not in st.session_state:
    st.session_state.graph = None

if st.button("Generate graph"):
    if user_text.strip():
        with st.spinner("Analyzing and generating the graph..."):
            G = structure_text(user_text, nlp)
            st.session_state.graph = G
            st.session_state.node = list(G.nodes)[0]  # default value

if st.session_state.graph:
    G = st.session_state.graph
    all_nodes = list(G.nodes)

    selected = st.selectbox("Select a node to explore:", all_nodes, index=all_nodes.index(st.session_state.get("node", all_nodes[0])))
    st.session_state.node = selected

    # Add a selectbox to choose the branching level
    level = st.selectbox("Select branching level:", ["document", "section", "paragraph", "sentence"], index=0)
    st.session_state.level = level

    fig = plot_graph(G, selected_node=selected, level=level)
    st.pyplot(fig)

    node_data = G.nodes[selected]
    st.markdown("### Details of selected node")
    st.markdown(f"Type: `{node_data.get('type', 'unknown')}`")
    st.markdown("Content:")
    st.code(node_data.get("text", "(no text)"), language="markdown")


## Function to export graphs as JSON with fields: index, text, and a list of directed edges
def export_graph_to_json(graph,level="sentence"):

    nodes = []
    edges = []
    for node, data in graph.nodes(data=True):
        if level == "document" and data.get('type') == 'document':
            nodes.append({"index": node, "text": data.get("text", ""), "type": data.get("type", "")})
        elif level == "section" and data.get('type') in ['document', 'section']:
            nodes.append({"index": node, "text": data.get("text", ""), "type": data.get("type", "")})
        elif level == "paragraph" and data.get('type') in ['document', 'section', 'paragraph']:
            nodes.append({"index": node, "text": data.get("text", ""), "type": data.get("type", "")})
        elif level == "sentence" and data.get('type') in ['document', 'section', 'paragraph', 'sentence']:
            nodes.append({"index": node, "text": data.get("text", ""), "type": data.get("type", "")})
    # Edges as pairs (node1, node2)

    for source, target in graph.edges():
        # Check if both nodes are in the list of nodes
        if source in [n["index"] for n in nodes] and target in [n["index"] for n in nodes]:
            edges.append((source, target))
    return {"nodes": nodes, "edges": edges}


# Export graph to JSON format
if st.button("Export graph to JSON format"):
    if st.session_state.graph:
        graph_json = export_graph_to_json(st.session_state.graph, level=st.session_state.level)
        st.write("level:", st.session_state.level)
        st.json(graph_json)
    else:
        st.warning("No graph to export.")

if st.button("Download graph in JSON format"):
    if st.session_state.graph:
        graph_json = export_graph_to_json(st.session_state.graph, level=st.session_state.level)
        st.download_button("Download graph", data=json.dumps(graph_json, ensure_ascii= False, indent=2), file_name="graph.json", mime="application/json")
    else:
        st.warning("No graph to download.")



### TODO import into the FGW distance calculation interface

### TODO: Check the notebook with the ideas to discuss. 

### Interface for merging nodes:

if st.session_state.graph:
    st.subheader("Merge nodes")
    selected_node1 = st.selectbox("Select the first node to merge:", all_nodes, index=0)
    selected_node2 = st.selectbox("Select the second node to merge:", all_nodes, index=1)
    st.session_state.merged_graph = None

    if st.button("Merge selected nodes"):
        if selected_node1 != selected_node2:
            with st.spinner("Merging nodes..."):
                # Define a new graph:
                G_fus = merge_nodes(G, selected_node1, selected_node2)
                st.session_state.merged_graph = G_fus

                st.success(f"Nodes {selected_node1} and {selected_node2} merged.")
        else:
            st.warning("Please select two different nodes to merge.")

    # Update node list after merging
    if st.session_state.merged_graph is not None:
        all_nodes = list(st.session_state.merged_graph.nodes)
        selected = st.selectbox("Select a node to explore after merging:", all_nodes, index=all_nodes.index(st.session_state.get("node", all_nodes[0])))
        st.session_state.node = selected
        fig_fus = plot_graph(st.session_state.merged_graph, selected_node=selected, level=st.session_state.level)
        st.pyplot(fig_fus)
        node_data = st.session_state.merged_graph.nodes[selected]
        st.markdown("### Details of selected node after merging")
        st.markdown(f"Type: `{node_data.get('type', 'unknown')}`")
        st.markdown("Content:")
        st.code(node_data.get("text", "(no text)"), language="markdown")
        st.markdown("### Graph details after merging")
        st.markdown(f"Number of nodes: {len(st.session_state.merged_graph.nodes)}")
        st.markdown(f"Number of edges: {len(st.session_state.merged_graph.edges)}")
        st.markdown("### Graph after merging")
