"""
Generation of questions based on nodes from a Document graph
"""

import networkx as nx
from typing import List, Dict, Optional, Tuple
from openai import OpenAI
from pydantic import BaseModel, Field
import json


class GeneratedQuestion(BaseModel):
    """Model for a generated question."""
    node_id: str
    node_type: str  # 'section' or 'paragraph'
    node_text: str
    question: str
    difficulty_level: Optional[str] = "medium"  # easy, medium, hard


class QuestionGenerationResult(BaseModel):
    """Result of the question generation for a graph."""
    section_questions: List[GeneratedQuestion] = Field(default_factory=list)
    paragraph_questions: List[GeneratedQuestion] = Field(default_factory=list)
    total_questions: int = 0


def build_question_generation_prompt_for_section(section_text: str, section_id: str) -> str:
    """
    Build a prompt to generate a question from a section.
    
    Args:
        section_text: The text of the section
        section_id: The identifier of the section
        
    Returns:
        The formatted prompt for the LLM
    """
    return f"""You are an expert in educational question creation.

Based on the following section content, generate ONE relevant question that checks the overall understanding of this section.

Section content "{section_id}":
{section_text}

Instructions:
- Create a question focusing on the main concepts of this section
- The question must be clear and precise
- Avoid overly general or overly specific questions
- The question must be answerable based on the section content
- Return ONLY the question, with no extra explanation

Question:"""


def build_question_generation_prompt_for_paragraph(paragraph_text: str, paragraph_id: str, section_context: str = "") -> str:
    """
    Build a prompt to generate a question from a paragraph.
    
    Args:
        paragraph_text: The text of the paragraph
        paragraph_id: The identifier of the paragraph
        section_context: The context of the parent section (optional)
        
    Returns:
        The formatted prompt for the LLM
    """
    context_section = f"\n\nParent section context:\n{section_context}" if section_context else ""
    
    return f"""You are an expert in educational question creation.

Based on the following paragraph content, generate ONE specific question that checks the understanding of this particular paragraph.

Paragraph content "{paragraph_id}":
{paragraph_text}{context_section}

Instructions:
- Create a question that specifically focuses on this paragraph
- The question must be precise and target the details of the paragraph
- Avoid overly general questions
- The question must be directly answerable with the paragraph content
- Return ONLY the question, with no extra explanation

Question:"""


def ask_llm_for_question_generation(prompt: str, client: OpenAI) -> str:
    """
    Query the LLM to generate a question.
    
    Args:
        prompt: The prompt to send to the LLM
        client: The configured OpenAI client
        
    Returns:
        The question generated by the LLM
    """
    response = client.chat.completions.create(
        model="llama3",
        messages=[
            {"role": "system", "content": "You are an expert in educational question creation. You must create clear, relevant, and useful questions to assess comprehension."},
            {"role": "user", "content": prompt}
        ],
        temperature=0.3,  # Slightly creative but coherent
        max_tokens=200,   # Limit to avoid overly long questions
    )
    return response.choices[0].message.content.strip()


def extract_sections_and_paragraphs(G: nx.Graph) -> Tuple[List[Dict], List[Dict]]:
    """
    Extract 'section' and 'paragraph' nodes from the graph.
    
    Args:
        G: The NetworkX graph containing the nodes
        
    Returns:
        Tuple containing (sections_list, paragraphs_list)
    """
    sections = []
    paragraphs = []
    
    for node_id, node_data in G.nodes(data=True):
        if node_data.get('type') == 'section':
            sections.append({
                'id': node_id,
                'text': node_data.get('text', ''),
                'type': 'section'
            })
        elif node_data.get('type') == 'paragraph':
            paragraphs.append({
                'id': node_id,
                'text': node_data.get('text', ''),
                'type': 'paragraph'
            })
    
    return sections, paragraphs


def get_section_context_for_paragraph(paragraph_id: str, G: nx.Graph) -> str:
    """
    Find the section context for a given paragraph.
    
    Args:
        paragraph_id: The identifier of the paragraph
        G: The NetworkX graph
        
    Returns:
        The text of the parent section or an empty string
    """
    # Look for predecessors of the paragraph to find its section
    predecessors = list(G.predecessors(paragraph_id))
    
    for pred in predecessors:
        if G.nodes[pred].get('type') == 'section':
            return G.nodes[pred].get('text', '')
    
    # If no direct section, try searching in the hierarchy
    if '_para_' in paragraph_id:
        section_id = paragraph_id.split('_para_')[0]
        if section_id in G.nodes and G.nodes[section_id].get('type') == 'section':
            return G.nodes[section_id].get('text', '')
    
    return ""


def generate_questions_from_graph(G: nx.Graph, client: OpenAI, 
                                 include_sections: bool = True, 
                                 include_paragraphs: bool = True) -> QuestionGenerationResult:
    """
    Generate questions from section and paragraph nodes in a document graph.
    
    Args:
        G: The NetworkX graph containing the nodes
        client: The configured OpenAI client
        include_sections: If True, generate questions for sections
        include_paragraphs: If True, generate questions for paragraphs
        
    Returns:
        QuestionGenerationResult containing all generated questions
    """
    result = QuestionGenerationResult()
    
    # Extract sections and paragraphs
    sections, paragraphs = extract_sections_and_paragraphs(G)
    
    print(f"Found {len(sections)} sections and {len(paragraphs)} paragraphs")
    
    # Generate questions for sections
    if include_sections:
        print("Generating questions for sections...")
        for section in sections:
            try:
                prompt = build_question_generation_prompt_for_section(
                    section['text'], section['id']
                )
                question = ask_llm_for_question_generation(prompt, client)
                
                generated_q = GeneratedQuestion(
                    node_id=section['id'],
                    node_type='section',
                    node_text=section['text'],
                    question=question
                )
                result.section_questions.append(generated_q)
                print(f"Generated question for section {section['id']}: {question[:80]}...")
                
            except Exception as e:
                print(f"Error generating question for section {section['id']}: {e}")
    
    # Generate questions for paragraphs
    if include_paragraphs:
        print("Generating questions for paragraphs...")
        for paragraph in paragraphs:
            try:
                # Get the section context
                section_context = get_section_context_for_paragraph(paragraph['id'], G)
                
                prompt = build_question_generation_prompt_for_paragraph(
                    paragraph['text'], paragraph['id'], section_context
                )
                question = ask_llm_for_question_generation(prompt, client)
                
                generated_q = GeneratedQuestion(
                    node_id=paragraph['id'],
                    node_type='paragraph',
                    node_text=paragraph['text'],
                    question=question
                )
                result.paragraph_questions.append(generated_q)
                print(f"Generated question for paragraph {paragraph['id']}: {question[:80]}...")
                
            except Exception as e:
                print(f"Error generating question for paragraph {paragraph['id']}: {e}")
    
    result.total_questions = len(result.section_questions) + len(result.paragraph_questions)
    print(f"Generation completed: {result.total_questions} questions created")
    
    return result


def save_questions_to_json(result: QuestionGenerationResult, output_path: str):
    """
    Save generated questions to a JSON file.
    
    Args:
        result: The question generation result
        output_path: The output file path
    """
    data = {
        "section_questions": [q.model_dump() for q in result.section_questions],
        "paragraph_questions": [q.model_dump() for q in result.paragraph_questions],
        "total_questions": result.total_questions,
        "metadata": {
            "sections_count": len(result.section_questions),
            "paragraphs_count": len(result.paragraph_questions)
        }
    }
    
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(data, f, ensure_ascii=False, indent=2)
    
    print(f"Questions saved to {output_path}")


def load_questions_from_json(json_path: str) -> QuestionGenerationResult:
    """
    Load questions from a JSON file.
    
    Args:
        json_path: The path to the JSON file
        
    Returns:
        QuestionGenerationResult loaded from the file
    """
    with open(json_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    result = QuestionGenerationResult()
    
    for q_data in data.get("section_questions", []):
        result.section_questions.append(GeneratedQuestion(**q_data))
    
    for q_data in data.get("paragraph_questions", []):
        result.paragraph_questions.append(GeneratedQuestion(**q_data))
    
    result.total_questions = data.get("total_questions", 0)
    
    return result
