"""
Streamlit page for generating questions from document graphs.
"""

import streamlit as st
import json
import networkx as nx
from openai import OpenAI
import os
import sys
from typing import Optional
import tempfile


from knowledge_graph.kg_gen_questions import (
    generate_questions_from_graph,
    save_questions_to_json,
    load_questions_from_json,
    QuestionGenerationResult,
    extract_sections_and_paragraphs
)


def load_graph_from_json(graph_data: dict) -> nx.Graph:
    """
    Load a graph from JSON data.
    """
    G = nx.DiGraph()
    
    # Add nodes
    for node in graph_data.get('nodes', []):
        G.add_node(
            node['index'], 
            text=node.get('text', ''),
            type=node.get('type', 'unknown')
        )
    
    # Add edges
    for edge in graph_data.get('edges', []):
        if isinstance(edge, list) and len(edge) == 2:
            G.add_edge(edge[0], edge[1])
        elif isinstance(edge, dict):
            G.add_edge(edge['source'], edge['target'])
    
    return G


def analyze_graph_structure(G: nx.Graph) -> dict:
    """
    Analyze the graph structure.
    """
    node_types = {}
    for node_id, node_data in G.nodes(data=True):
        node_type = node_data.get('type', 'unknown')
        node_types[node_type] = node_types.get(node_type, 0) + 1
    
    return {
        'total_nodes': G.number_of_nodes(),
        'total_edges': G.number_of_edges(),
        'node_types': node_types
    }


def setup_openai_client(api_base: str, api_key: str, model: str) -> OpenAI:
    """
    Configure the OpenAI client.
    """
    return OpenAI(
        base_url=api_base,
        api_key=api_key
    )


def main():
    st.set_page_config(
        page_title="Question Generation",
        page_icon="❓",
        layout="wide"
    )
    
    st.title("🤖 Question Generation from document Graphs")
    st.markdown("---")
    
    # LLM API configuration
    with st.sidebar:
        st.header("⚙️ LLM Configuration")
        
        api_base = st.text_input(
            "API Base URL",
            value="http://localhost:11434/v1",
            help="URL for local Ollama or another OpenAI-compatible API"
        )
        
        api_key = st.text_input(
            "API Key",
            value="ollama",
            type="password",
            help="API key (for local Ollama, use 'ollama')"
        )
        
        model_name = st.text_input(
            "Model name",
            value="llama3",
            help="Name of the LLM model to use"
        )
        
        st.markdown("---")
        
        st.header("🎛️ Generation Options")
        
        include_sections = st.checkbox("Generate questions for sections", value=True)
        include_paragraphs = st.checkbox("Generate questions for paragraphs", value=True)
        
        max_sections = st.number_input(
            "Section limit (0 = no limit)",
            min_value=0,
            value=5,
            help="Maximum number of sections to process"
        )
        
        max_paragraphs = st.number_input(
            "Paragraph limit (0 = no limit)",
            min_value=0,
            value=10,
            help="Maximum number of paragraphs to process"
        )
    
    # Main interface
    col1, col2 = st.columns([1, 1])
    
    with col1:
        st.header("📁 Load Graph")
        
        # Option 1: Load from examples
        st.subheader("Option 1: Existing examples")
        examples_dir = ""
        
        if os.path.exists(examples_dir):
            graph_files = [f for f in os.listdir(examples_dir) if f.endswith('.json') and 'graph' in f]
            
            selected_example = st.selectbox(
                "Choose an example graph",
                [""] + graph_files,
                help="Select an example graph to get started"
            )
            
            if selected_example:
                example_path = os.path.join(examples_dir, selected_example)
                if st.button("Load selected example"):
                    try:
                        with open(example_path, 'r', encoding='utf-8') as f:
                            st.session_state.graph_data = json.load(f)
                        st.success(f"Graph loaded: {selected_example}")
                    except Exception as e:
                        st.error(f"Error while loading: {e}")
        
        st.markdown("---")
        
        # Option 2: File upload
        st.subheader("Option 2: File upload")
        uploaded_file = st.file_uploader(
            "Upload a graph JSON file",
            type=['json'],
            help="Upload a JSON file containing your document graph"
        )
        
        if uploaded_file is not None:
            try:
                graph_data = json.load(uploaded_file)
                st.session_state.graph_data = graph_data
                st.success("Graph uploaded successfully!")
            except Exception as e:
                st.error(f"Upload error: {e}")
    
    with col2:
        st.header("📊 Graph Analysis")
        
        if 'graph_data' in st.session_state:
            try:
                G = load_graph_from_json(st.session_state.graph_data)
                analysis = analyze_graph_structure(G)
                
                st.metric("Total nodes", analysis['total_nodes'])
                st.metric("Total edges", analysis['total_edges'])
                
                st.subheader("Node types:")
                for node_type, count in analysis['node_types'].items():
                    st.write(f"• **{node_type}**: {count}")
                
                # Extract sections and paragraphs
                sections, paragraphs = extract_sections_and_paragraphs(G)
                
                st.markdown("---")
                st.info(f"🎯 **{len(sections)}** sections and **{len(paragraphs)}** paragraphs detected for question generation")
                
                # Sections preview
                if sections:
                    with st.expander("Sections preview"):
                        for i, section in enumerate(sections[:3]):
                            st.write(f"**{section['id']}**: {section['text'][:100]}...")
                            if i >= 2 and len(sections) > 3:
                                st.write(f"... and {len(sections) - 3} more sections")
                                break
                
                # Paragraphs preview
                if paragraphs:
                    with st.expander("Paragraphs preview"):
                        for i, paragraph in enumerate(paragraphs[:3]):
                            st.write(f"**{paragraph['id']}**: {paragraph['text'][:100]}...")
                            if i >= 2 and len(paragraphs) > 3:
                                st.write(f"... and {len(paragraphs) - 3} more paragraphs")
                                break
                
            except Exception as e:
                st.error(f"Error during analysis: {e}")
        else:
            st.info("Load a graph to see the analysis")
    
    st.markdown("---")
    
    # Generation section
    st.header("🚀 Question Generation")
    
    if 'graph_data' in st.session_state:
        col1, col2, col3 = st.columns([1, 1, 1])
        
        with col2:
            if st.button("🎯 Generate Questions", type="primary", use_container_width=True):
                # LLM connection test
                try:
                    client = setup_openai_client(api_base, api_key, model_name)
                    
                    # Connection test
                    with st.spinner("Testing LLM connection..."):
                        test_response = client.chat.completions.create(
                            model=model_name,
                            messages=[{"role": "user", "content": "Connection test. Simply reply 'OK'."}],
                            max_tokens=10
                        )
                    
                    st.success("✅ LLM connection successful")
                    
                    # Load the graph
                    G = load_graph_from_json(st.session_state.graph_data)
                    
                    # Question generation
                    with st.spinner("Generating questions..."):
                        progress_bar = st.progress(0)
                        
                        result = generate_questions_from_graph(
                            G, client,
                            include_sections=include_sections,
                            include_paragraphs=include_paragraphs
                        )
                        
                        # Apply limits
                        if max_sections > 0 and len(result.section_questions) > max_sections:
                            result.section_questions = result.section_questions[:max_sections]
                        
                        if max_paragraphs > 0 and len(result.paragraph_questions) > max_paragraphs:
                            result.paragraph_questions = result.paragraph_questions[:max_paragraphs]
                        
                        result.total_questions = len(result.section_questions) + len(result.paragraph_questions)
                        progress_bar.progress(1.0)
                    
                    # Store results
                    st.session_state.generation_result = result
                    st.success(f"🎉 {result.total_questions} questions generated successfully!")
                    
                except Exception as e:
                    st.error(f"❌ Error during generation: {e}")
                    if "Connection" in str(e) or "connection" in str(e):
                        st.warning("💡 Make sure Ollama is running and the model is available")
    
    # Display results
    if 'generation_result' in st.session_state:
        st.markdown("---")
        st.header("📝 Generated Questions")
        
        result = st.session_state.generation_result
        
        # Stats
        col1, col2, col3 = st.columns(3)
        with col1:
            st.metric("Section questions", len(result.section_questions))
        with col2:
            st.metric("Paragraph questions", len(result.paragraph_questions))
        with col3:
            st.metric("Total", result.total_questions)
        
        # Tabs to display questions
        tab1, tab2, tab3 = st.tabs(["🏗️ Section Questions", "📄 Paragraph Questions", "💾 Export"])
        
        with tab1:
            if result.section_questions:
                for i, q in enumerate(result.section_questions, 1):
                    with st.expander(f"Section {i}: {q.node_id}"):
                        st.write("**Section text:**")
                        st.write(q.node_text)
                        st.write("**Generated question:**")
                        st.success(q.question)
            else:
                st.info("No section questions generated")
        
        with tab2:
            if result.paragraph_questions:
                for i, q in enumerate(result.paragraph_questions, 1):
                    with st.expander(f"Paragraph {i}: {q.node_id}"):
                        st.write("**Paragraph text:**")
                        st.write(q.node_text)
                        st.write("**Generated question:**")
                        st.success(q.question)
            else:
                st.info("No paragraph questions generated")
        
        with tab3:
            st.subheader("Export results")
            
            # Prepare data for export
            export_data = {
                "section_questions": [q.model_dump() for q in result.section_questions],
                "paragraph_questions": [q.model_dump() for q in result.paragraph_questions],
                "total_questions": result.total_questions,
                "metadata": {
                    "sections_count": len(result.section_questions),
                    "paragraphs_count": len(result.paragraph_questions)
                }
            }
            
            # Download button
            st.download_button(
                label="📥 Download questions (JSON)",
                data=json.dumps(export_data, ensure_ascii=False, indent=2),
                file_name="generated_questions.json",
                mime="application/json"
            )
            
            # JSON preview
            with st.expander("JSON preview"):
                st.json(export_data)
    
    # Footer
    st.markdown("---")
    st.markdown(
        "💡 **Tip**: Make sure your graph contains nodes of type 'section' and 'paragraph' "
        "for optimal question generation."
    )


if __name__ == "__main__":
    main()
