#!/usr/bin/env python3

import os
import json
import requests
from pathlib import Path
import time
from datetime import datetime
from typing import Dict, List, Any

class SectionBasedExtractor:
    def __init__(self, output_dir: str = None):
        self.api_key = os.getenv('OPENAI_API_KEY')
        self.api_url = os.getenv('OPENAI_BASE_URL')
        
        if not self.api_key:
            raise ValueError("Please set the environment variable OPENAI_API_KEY")
        
        script_dir = Path(__file__).parent
        self.output_dir = Path(output_dir) 
        self.output_dir.mkdir(exist_ok=True)
        
        prompt_file = script_dir / 'prompt' / 'section_based_extraction_prompt.txt'
        with open(prompt_file, 'r', encoding='utf-8') as f:
            self.section_prompt = f.read()
    
    def load_paper_chunks(self, paper_file: str) -> List[Dict]:
        try:
            with open(paper_file, 'r', encoding='utf-8') as f:
                data = json.load(f)
            return data if isinstance(data, list) else [data]
        except Exception as e:
            print(f"Cannot load paper file {paper_file}: {e}")
            return []
    
    def merge_paper_chunks(self, chunks: List[Dict]) -> str:
        full_text = ""
        for i, chunk in enumerate(chunks):
            section_id = chunk.get('id', f'chunk_{i}')
            if 'text' in chunk:
                full_text += f"\n=== SECTION: {section_id} (CHUNK {i}) ===\n"
                full_text += chunk['text'] + "\n"
            elif 'content' in chunk:
                full_text += f"\n=== SECTION: {section_id} (CHUNK {i}) ===\n"
                full_text += chunk['content'] + "\n"
        return full_text.strip()
    
    def extract_paper_metadata(self, chunks: List[Dict]) -> Dict[str, str]:
        metadata = {
            'title': 'Unknown Title',
            'authors': 'Unknown Authors',
            'chunks': len(chunks),
            'timestamp': datetime.now().isoformat()
        }
        
        if chunks:
            first_chunk = chunks[0]
            text = first_chunk.get('text', first_chunk.get('content', ''))
            
            lines = text.split('\n')[:10]
            for line in lines:
                line = line.strip()
                if len(line) > 10 and len(line) < 200:
                    if any(word in line.lower() for word in ['learning', 'model', 'network', 'algorithm']):
                        metadata['title'] = line
                        break
        
        return metadata
    
    def call_openai_api(self, messages: List[Dict], max_tokens: int = 3000) -> str:
        try:
            payload = json.dumps({
                "model": "gpt-4o",
                "messages": messages,
                "max_tokens": max_tokens,
                "temperature": 0.1
            })
            
            headers = {
                'Accept': 'application/json',
                'Authorization': f'Bearer {self.api_key}',
                'Content-Type': 'application/json'
            }
            
            response = requests.post(self.api_url, headers=headers, data=payload)
            
            if response.status_code == 200:
                try:
                    response_data = response.json()
                except Exception as e:
                    print(f"Failed to parse JSON response: {e}")
                    print(f"Raw response: {response.text[:500]}")
                    return f"API Error: Invalid JSON response - {str(e)}"
                
                if 'choices' in response_data and isinstance(response_data['choices'], list) and len(response_data['choices']) > 0:
                    choice0 = response_data['choices'][0]
                    if isinstance(choice0, dict):
                        if 'message' in choice0 and isinstance(choice0['message'], dict) and 'content' in choice0['message']:
                            return choice0['message']['content']
                        if 'text' in choice0 and isinstance(choice0['text'], str):
                            return choice0['text']
                    print(f"API response structure unexpected: {response_data}")
                    return f"API Error: Unexpected response structure in choices[0]"
                else:
                    print(f"API response missing or invalid choices: {response_data}")
                    print(f"Raw response (first 500 chars): {response.text[:500]}")
                    return f"API Error: Missing or invalid choices array"
            else:
                print(f"API call failed with status {response.status_code}: {response.text}")
                return f"API Error: HTTP {response.status_code} - {response.text}"
                
        except Exception as e:
            print(f"API call failed with exception: {e}")
            print(f"Exception type: {type(e).__name__}")
            import traceback
            print(f"Traceback: {traceback.format_exc()}")
            return f"API Error: {str(e)}"
    
    def extract_single_section(self, section_text: str, section_name: str, chunk_id: int, metadata: Dict, improvement_suggestions: str = "") -> str:
        improvement_guidance = ""
        if improvement_suggestions:
            improvement_guidance = f"""

IMPROVEMENT GUIDANCE (from evaluator):
{improvement_suggestions}

Please incorporate these suggestions in your extraction to improve the quality of the knowledge graph."""
        
        system_prompt = f"""You are an expert knowledge extractor. Extract ALL entities, relationships, and facts from this academic paper section.

Section: {section_name} (Chunk {chunk_id})
Paper: {metadata['title']}

{self.section_prompt}

FOCUS FOR THIS SECTION:
- Extract EVERY entity mentioned (models, datasets, metrics, methods, authors, organizations)
- Extract EVERY relationship and comparison
- Extract ALL numerical values and scores
- Create comprehensive triples for multi-hop reasoning{improvement_guidance}

TARGET: 20-50 triples from this section alone."""

        user_prompt = f"""Extract MAXIMUM TRIPLES from this section of the academic paper.

SECTION: {section_name} (CHUNK {chunk_id})

SECTION CONTENT:
{section_text}

EXTRACTION REQUIREMENTS:
- Extract EVERY entity mentioned
- Extract EVERY relationship, comparison, evaluation
- Extract ALL numerical values with context
- Use :sourceChunk "{chunk_id}" and :sourceSection "{section_name}" for ALL triples{improvement_guidance}

OUTPUT FORMAT:
```turtle
# Entities from {section_name}
:EntityName rdf:type :EntityType ;
    :sourceChunk "{chunk_id}" ;
    :sourceSection "{section_name}" ;
    :contextText "original text snippet" .

# Relationships from {section_name}  
:Entity1 :relationshipType :Entity2 ;
    :sourceChunk "{chunk_id}" ;
    :sourceSection "{section_name}" ;
    :contextText "original text snippet" .
```

Be exhaustive - extract EVERY piece of factual information from this section."""

        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ]
        
        print(f"   Extracting section: {section_name} (Chunk {chunk_id})")
        return self.call_openai_api(messages, max_tokens=3000)

    def extract_section_knowledge(self, paper_text: str, metadata: Dict, improvement_suggestions: str = "") -> str:
        print("Starting section-by-section extraction strategy...")
        
        if improvement_suggestions:
            print(f"Applying improvement suggestions: {improvement_suggestions}")
        
        sections = self.parse_sections_from_text(paper_text)
        
        if not sections:
            print("Cannot parse sections, falling back to full paper extraction")
            return self.extract_full_paper(paper_text, metadata, improvement_suggestions)

        print(f"Found {len(sections)} sections")
        
        all_knowledge_graphs = []
        
        for section_info in sections:
            section_name = section_info['name']
            chunk_id = section_info['chunk_id']
            content = section_info['content']
            
            print(f"   Processing: {section_name} (Length: {len(content)} characters)")
            
            section_kg = self.extract_single_section(content, section_name, chunk_id, metadata, improvement_suggestions)
            
            if section_kg and "API Error" not in section_kg:
                all_knowledge_graphs.append(f"# ===== SECTION: {section_name.upper()} (CHUNK {chunk_id}) =====\n{section_kg}")
                print(f"   Extraction completed: {section_name}")
            else:
                print(f"   Extraction failed: {section_name}")
            
            import time
            time.sleep(1)
        
        combined_kg = "\n\n".join(all_knowledge_graphs)
        
        print(f"Section-by-section extraction completed, processed {len(sections)} sections")
        return combined_kg
    
    def parse_sections_from_text(self, paper_text: str) -> List[Dict]:
        sections = []
        lines = paper_text.split('\n')
        current_section = None
        current_content = []
        
        for line in lines:
            if line.startswith('=== SECTION:') and 'CHUNK' in line:
                if current_section:
                    sections.append({
                        'name': current_section['name'],
                        'chunk_id': current_section['chunk_id'],
                        'content': '\n'.join(current_content).strip()
                    })
                    current_content = []
                
                parts = line.split('(CHUNK')
                if len(parts) == 2:
                    section_name = parts[0].replace('=== SECTION:', '').strip()
                    chunk_id = parts[1].split(')')[0].strip()
                    current_section = {'name': section_name, 'chunk_id': chunk_id}
            else:
                if current_section:
                    current_content.append(line)
        
        if current_section:
            sections.append({
                'name': current_section['name'],
                'chunk_id': current_section['chunk_id'],
                'content': '\n'.join(current_content).strip()
            })
        
        return sections
    
    def extract_full_paper(self, paper_text: str, metadata: Dict, improvement_suggestions: str = "") -> str:
        print("Falling back to full paper extraction as a fallback...")
        
        improvement_guidance = ""
        if improvement_suggestions:
            improvement_guidance = f"""

IMPROVEMENT GUIDANCE (from evaluator):
{improvement_suggestions}

Please incorporate these suggestions in your extraction to improve the quality of the knowledge graph."""
        
        system_prompt = f"""Extract knowledge from this academic paper for multi-hop reasoning.

Paper: {metadata['title']}
{self.section_prompt}{improvement_guidance}"""

        user_prompt = f"""Extract ALL entities and relationships from this paper:

{paper_text[:40000]}

Target: 100+ triples covering all sections.{improvement_guidance}"""

        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ]
        
        return self.call_openai_api(messages, max_tokens=4000)
    
    def save_extraction_results(self, paper_name: str, metadata: Dict, 
                              knowledge_graph: str, raw_text: str) -> Path:
        base_name = paper_name.replace('.json', '')
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        
        turtle_file = self.output_dir / f"{base_name}_section_extraction_{timestamp}.ttl"
        with open(turtle_file, 'w', encoding='utf-8') as f:
            f.write(knowledge_graph)
        
        print(f"Results saved: {turtle_file}")
        return turtle_file
    
    def save_extraction_results_with_evaluation(self, paper_name: str, metadata: Dict, 
                                              knowledge_graph: str, raw_text: str,
                                              evaluation_passed: bool, improvement_suggestions: str) -> Path:
        base_name = paper_name.replace('.json', '')
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        
        evaluation_header = f"""# ===== EVALUATION RESULTS =====
# Evaluation Passed: {evaluation_passed}
# Can Proceed to QA Generation: {evaluation_passed}
# Timestamp: {timestamp}
"""
        
        if improvement_suggestions:
            evaluation_header += f"# Improvement Suggestions: {improvement_suggestions}\n"
        
        evaluation_header += "# ===== KNOWLEDGE GRAPH =====\n\n"
        
        turtle_file = self.output_dir / f"{base_name}_section_extraction_{timestamp}.ttl"
        with open(turtle_file, 'w', encoding='utf-8') as f:
            f.write(evaluation_header)
            f.write(knowledge_graph)
        
        evaluation_file = self.output_dir / f"{base_name}_evaluation_{timestamp}.json"
        evaluation_data = {
            "paper_name": paper_name,
            "evaluation_passed": evaluation_passed,
            "can_proceed_to_qa": evaluation_passed,
            "improvement_suggestions": improvement_suggestions,
            "timestamp": timestamp,
            "metadata": metadata
        }
        
        with open(evaluation_file, 'w', encoding='utf-8') as f:
            json.dump(evaluation_data, f, ensure_ascii=False, indent=2)
        
        print(f"Results saved: {turtle_file}")
        print(f"Evaluation information saved: {evaluation_file}")
        return turtle_file
    

    
    def process_paper(self, paper_file: str, evaluation_result: Dict[str, Any] = None) -> Dict[str, Any]:
        print(f"Processing paper: {paper_file}")
        
        chunks = self.load_paper_chunks(paper_file)
        if not chunks:
            return {"error": "Cannot load paper file", "paper_file": paper_file}
        
        metadata = self.extract_paper_metadata(chunks)
        
        full_text = self.merge_paper_chunks(chunks)
        metadata['text_length'] = len(full_text)
        
        print(f"Paper information:")
        print(f"   Title: {metadata['title']}")
        print(f"   Chunks: {metadata['chunks']}")
        print(f"   Text length: {metadata['text_length']} characters")
        
        improvement_suggestions = ""
        if evaluation_result and not evaluation_result.get("passed", True):
            improvement_suggestions = evaluation_result.get("suggestions", "")
        
        knowledge_graph = self.extract_section_knowledge(full_text, metadata, improvement_suggestions)
        
        evaluation_passed = True
        improvement_suggestions = ""
        
        if evaluation_result:
            evaluation_passed = evaluation_result.get("passed", False)
            improvement_suggestions = evaluation_result.get("suggestions", "")
            
            print(f"Evaluation results:")
            print(f"   Passed: {'Passed' if evaluation_passed else 'Failed'}")
            if improvement_suggestions:
                print(f"   Improvement suggestions: {improvement_suggestions}")
        
        paper_name = Path(paper_file).name
        saved_file = self.save_extraction_results_with_evaluation(
            paper_name, metadata, knowledge_graph, full_text, 
            evaluation_passed, improvement_suggestions
        )
        
        print(f"{paper_file} processing completed!")
        
        return {
            "success": True,
            "paper_file": paper_file,
            "metadata": metadata,
            "knowledge_graph": knowledge_graph,
            "saved_file": str(saved_file),
            "evaluation_passed": evaluation_passed,
            "improvement_suggestions": improvement_suggestions,
            "can_proceed_to_qa": evaluation_passed,
            "statistics": {
                "total_triples": knowledge_graph.count(" ;") + knowledge_graph.count(" .") - knowledge_graph.count("# "),
                "models_extracted": knowledge_graph.count("rdf:type :Model"),
                "datasets_extracted": knowledge_graph.count("rdf:type :Dataset"),
                "metrics_extracted": knowledge_graph.count("rdf:type :Metric"),
                "methods_extracted": knowledge_graph.count("rdf:type :Method"),
                "sections_covered": len(set([line.split('"')[1] for line in knowledge_graph.split('\n') if ':sourceSection' in line])),
                "chunks_covered": len(set([line.split('"')[1] for line in knowledge_graph.split('\n') if ':sourceChunk' in line]))
            }
        }
    
    def process_single_paper(self, paper_file: str, evaluation_result: Dict[str, Any] = None) -> Dict[str, Any]:
        if not Path(paper_file).exists():
            return {"error": f"File not found: {paper_file}", "paper_file": paper_file}
        
        return self.process_paper(paper_file, evaluation_result)
    
    def process_paper_from_text(self, paper_text: str, paper_name: str = "unknown", 
                               evaluation_result: Dict[str, Any] = None) -> Dict[str, Any]:
        print(f"Processing text content: {paper_name}")
        
        metadata = {
            'title': paper_name,
            'authors': 'Unknown',
            'chunks': 1,
            'timestamp': datetime.now().isoformat(),
            'text_length': len(paper_text)
        }
        
        improvement_suggestions = ""
        if evaluation_result and not evaluation_result.get("passed", True):
            improvement_suggestions = evaluation_result.get("suggestions", "")
        
        knowledge_graph = self.extract_section_knowledge(paper_text, metadata, improvement_suggestions)
        
        evaluation_passed = True
        improvement_suggestions = ""
        
        if evaluation_result:
            evaluation_passed = evaluation_result.get("passed", True)
            improvement_suggestions = evaluation_result.get("suggestions", "")
        
        saved_file = self.save_extraction_results_with_evaluation(
            f"{paper_name}.txt", metadata, knowledge_graph, paper_text,
            evaluation_passed, improvement_suggestions
        )
        
        return {
            "success": True,
            "paper_name": paper_name,
            "metadata": metadata,
            "knowledge_graph": knowledge_graph,
            "saved_file": str(saved_file),
            "evaluation_passed": evaluation_passed,
            "improvement_suggestions": improvement_suggestions,
            "can_proceed_to_qa": evaluation_passed,
            "statistics": {
                "total_triples": knowledge_graph.count(" ;") + knowledge_graph.count(" .") - knowledge_graph.count("# "),
                "models_extracted": knowledge_graph.count("rdf:type :Model"),
                "datasets_extracted": knowledge_graph.count("rdf:type :Dataset"),
                "metrics_extracted": knowledge_graph.count("rdf:type :Metric"),
                "methods_extracted": knowledge_graph.count("rdf:type :Method")
            }
        }

def main():
    print("Section-based knowledge graph extraction test program")
    print("=" * 50)
    
    try:
        extractor = SectionBasedExtractor()
    except Exception as e:
        print(f"Initialization failed: {e}")
        return
    
    paper_file = "data/papers_0_merged.json"
    
    evaluation_result = {
        "passed": True,  
        "suggestions": "Please increase the number of entity relationships to improve the completeness of the graph"
    }
    
    result = extractor.process_single_paper(paper_file, evaluation_result)
    
    if result.get("success"):
        print(f"Test successful!")
        print(f"   Triples: {result['statistics']['total_triples']}")
        print(f"   Models: {result['statistics']['models_extracted']}")
        print(f"   Datasets: {result['statistics']['datasets_extracted']}")
        print(f"   Evaluation passed: {result['evaluation_passed']}")
        print(f"   Can proceed to QA generation: {result['can_proceed_to_qa']}")
        if result['improvement_suggestions']:
            print(f"   Improvement suggestions: {result['improvement_suggestions']}")
        print(f"   Saved file: {result['saved_file']}")
    else:
        print(f"Test failed: {result.get('error')}")
    
    print(f"\nProgram executed successfully! Results saved in: {extractor.output_dir}")


def example_workflow_usage():
    print("\nWorkflow integration example:")
    print("-" * 30)
    
    extractor = SectionBasedExtractor(output_dir="workflow_results")
    
    paper_file = "data/paper.json"
    result = extractor.process_single_paper(paper_file)
    
    paper_text = """
    This is a sample paper about machine learning.
    The paper discusses various algorithms and their performance.
    """
    result2 = extractor.process_paper_from_text(paper_text, "sample_paper")
    
    print("Workflow integration examples completed!")

if __name__ == "__main__":
    main()
