#!/usr/bin/env python3
"""
RAG Final Answer Generation - Compare answer quality using different claim sets
Analyzes how including uncertain but relevant claims affects answer accuracy
"""

import os
import json
import time
from typing import Dict, List, Any, Tuple
from dotenv import load_dotenv
from openai import OpenAI
from tqdm import tqdm

load_dotenv()

def convert_numpy_types(obj):
    """Convert numpy types to Python native types for JSON serialization."""
    import numpy as np
    if isinstance(obj, np.bool_):
        return bool(obj)
    elif isinstance(obj, np.integer):
        return int(obj)
    elif isinstance(obj, np.floating):
        return float(obj)
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, dict):
        return {key: convert_numpy_types(value) for key, value in obj.items()}
    elif isinstance(obj, list):
        return [convert_numpy_types(item) for item in obj]
    else:
        return obj

class RAGFinalAnswerGenerator:
    """Generate and evaluate final answers using different claim combinations."""
    
    def __init__(self):
        """Initialize the answer generator."""
        self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
        
        # Load data
        self.rag_results = self._load_rag_results()
        self.final_answers = self._load_final_answers()
        
        print(f"Initialized RAGFinalAnswerGenerator")
        print(f"Loaded {len(self.rag_results)} questions from RAG results")
        print(f"Loaded {len(self.final_answers)} final answers")
    
    def _load_rag_results(self) -> List[Dict]:
        """Load RAG simulation results."""
        try:
            with open("rag_simulation_results.json", 'r', encoding='utf-8') as f:
                data = json.load(f)
                return data.get('results', [])
        except FileNotFoundError:
            print("❌ rag_simulation_results.json not found!")
            return []
    
    def _load_final_answers(self) -> List[Dict]:
        """Load final answers data."""
        try:
            with open("final_questions.json", 'r', encoding='utf-8') as f:
                return json.load(f)
        except FileNotFoundError:
            print("❌ final_questions.json not found!")
            return []
    
    def _get_reference_answer(self, question: str) -> str:
        """Get reference answer for a question."""
        for item in self.final_answers:
            if item.get('open_question', '') == question:
                return item.get('reference_answer', '')
        return ""
    
    def _get_original_claim(self, claim_data: Dict) -> str:
        """Get the original claim text."""
        if claim_data.get('was_updated', False):
            return claim_data.get('original_claim', '')
        else:
            return claim_data.get('claim', '')
    
    def _get_final_claim(self, claim_data: Dict) -> str:
        """Get the final claim text."""
        return claim_data.get('claim', '')
    
    def _generate_answer_from_claims(self, question: str, claims: List[str]) -> str:
        """Generate an answer to the question based on provided claims."""
        claims_text = "\n".join([f"- {claim}" for claim in claims])
        
        prompt = f"""You are an expert answering a complex question based on provided factual claims.

QUESTION: {question}

AVAILABLE CLAIMS:
{claims_text}

TASK: Generate a comprehensive and accurate answer to the question using only the information provided in the claims above.

REQUIREMENTS:
- Use only the factual information from the provided claims
- Synthesize the claims into a coherent, well-structured answer
- If claims conflict, prioritize the most specific and detailed information
- If the claims don't fully address the question, acknowledge the limitations
- Do not add information beyond what's provided in the claims

Generate a clear, comprehensive answer:"""

        max_retries = 3
        for attempt in range(max_retries):
            try:
                response = self.client.chat.completions.create(
                    model="gpt-4o",
                    messages=[{"role": "user", "content": prompt}],
                    max_tokens=800,
                    temperature=0.3
                )
                
                return response.choices[0].message.content.strip()
                
            except Exception as e:
                print(f"Error generating answer (attempt {attempt + 1}): {e}")
                if attempt < max_retries - 1:
                    time.sleep(2)
        
        return "Error generating answer"
    
    def _evaluate_answer_accuracy(self, question: str, generated_answer: str, reference_answer: str) -> Dict[str, Any]:
        """Evaluate if the generated answer is semantically correct compared to reference."""
        prompt = f"""You are evaluating the semantic correctness of a generated answer against a reference answer.

QUESTION: {question}

GENERATED ANSWER: {generated_answer}

REFERENCE ANSWER: {reference_answer}

TASK: Determine if the generated answer is semantically correct compared to the reference answer.

EVALUATION CRITERIA:
- YES if the generated answer conveys the same key information and conclusions as the reference
- YES if the generated answer reaches similar conclusions even if using different words or structure  
- NO if the generated answer contradicts the reference answer
- NO if the generated answer misses critical information that changes the conclusion
- NO if the generated answer contains significant factual errors

Consider semantic equivalence, not exact word matching.

Respond with a single word: YES or NO"""

        max_retries = 3
        for attempt in range(max_retries):
            try:
                response = self.client.chat.completions.create(
                    model="gpt-4o",
                    messages=[{"role": "user", "content": prompt}],
                    max_tokens=10,
                    temperature=0.1
                )
                
                result = response.choices[0].message.content.strip().upper()
                is_correct = result == "YES"
                
                return {
                    'is_semantically_correct': is_correct,
                    'raw_response': result
                }
                
            except Exception as e:
                print(f"Error evaluating answer accuracy (attempt {attempt + 1}): {e}")
                if attempt < max_retries - 1:
                    time.sleep(2)
        
        # Fallback
        return {'is_semantically_correct': False, 'raw_response': 'ERROR'}
    
    def _process_single_question(self, question_result: Dict) -> Dict[str, Any]:
        """Process answer generation for a single question."""
        question = question_result['question']
        claims = question_result.get('claims', [])
        reference_answer = self._get_reference_answer(question)
        
        if not reference_answer:
            print(f"⚠️  No reference answer found for question")
            return {
                'question': question,
                'reference_answer': '',
                'confident_only_result': {'generated_answer': '', 'is_correct': False},
                'confident_plus_included_result': {'generated_answer': '', 'is_correct': False}
            }
        
        # Extract claim sets
        confident_claims = [c for c in claims if c.get('is_confident', False)]
        uncertain_included_claims = [c for c in claims if not c.get('is_confident', True) and c.get('is_included', False)]
        
        print(f"  Found {len(confident_claims)} confident claims, {len(uncertain_included_claims)} uncertain but included claims")
        
        # Set 1: Only confident claims (original)
        confident_original_claims = [self._get_original_claim(c) for c in confident_claims if self._get_original_claim(c)]
        
        print(f"  Generating answer from {len(confident_original_claims)} confident claims...")
        confident_only_answer = self._generate_answer_from_claims(question, confident_original_claims)
        time.sleep(0.5)
        
        print(f"  Evaluating confident-only answer...")
        confident_only_eval = self._evaluate_answer_accuracy(question, confident_only_answer, reference_answer)
        time.sleep(0.5)
        
        # Set 2: Confident claims (original) + uncertain included claims (final)
        uncertain_final_claims = [self._get_final_claim(c) for c in uncertain_included_claims if self._get_final_claim(c)]
        combined_claims = confident_original_claims + uncertain_final_claims
        
        print(f"  Generating answer from {len(combined_claims)} total claims ({len(confident_original_claims)} confident + {len(uncertain_final_claims)} uncertain included)...")
        combined_answer = self._generate_answer_from_claims(question, combined_claims)
        time.sleep(0.5)
        
        print(f"  Evaluating combined answer...")
        combined_eval = self._evaluate_answer_accuracy(question, combined_answer, reference_answer)
        time.sleep(0.5)
        
        return {
            'question': question,
            'reference_answer': reference_answer,
            'confident_claims_count': len(confident_original_claims),
            'uncertain_included_claims_count': len(uncertain_final_claims),
            'confident_only_result': {
                'generated_answer': confident_only_answer,
                'is_semantically_correct': confident_only_eval['is_semantically_correct'],
                'claims_used': confident_original_claims
            },
            'confident_plus_included_result': {
                'generated_answer': combined_answer,
                'is_semantically_correct': combined_eval['is_semantically_correct'],
                'confident_claims_used': confident_original_claims,
                'uncertain_claims_used': uncertain_final_claims
            }
        }
    
    def run_answer_generation(self) -> None:
        """Run the complete answer generation analysis."""
        if not self.rag_results:
            print("❌ No RAG results available!")
            return
        
        print(f"\n📝 Starting RAG Final Answer Generation")
        print(f"Processing {len(self.rag_results)} questions...")
        
        all_results = []
        confident_only_correct = 0
        combined_correct = 0
        
        for i, question_result in enumerate(tqdm(self.rag_results, desc="Generating answers")):
            try:
                print(f"\n--- Processing Question {i+1}: {question_result['question'][:50]}... ---")
                
                result = self._process_single_question(question_result)
                all_results.append(result)
                
                # Update statistics
                if result['confident_only_result']['is_semantically_correct']:
                    confident_only_correct += 1
                if result['confident_plus_included_result']['is_semantically_correct']:
                    combined_correct += 1
                
                conf_status = "✅" if result['confident_only_result']['is_semantically_correct'] else "❌"
                comb_status = "✅" if result['confident_plus_included_result']['is_semantically_correct'] else "❌"
                print(f"Question {i+1}: Confident-only {conf_status}, Combined {comb_status}")
                
            except Exception as e:
                print(f"Error processing question {i+1}: {e}")
        
        # Save results
        self._save_confident_only_results(all_results)
        self._save_combined_results(all_results)
        self._save_summary(all_results, confident_only_correct, combined_correct)
        
        print(f"\n🎉 Answer Generation Complete!")
        print(f"Results saved to rag_final_answer_*.json files")
    
    def _save_confident_only_results(self, results: List[Dict]) -> None:
        """Save confident-only claims answer results."""
        output_file = "rag_final_answer_confident_only.json"
        
        data = {
            'methodology': 'Generate answers using only confident claims (original versions)',
            'claim_selection': 'is_confident=true, using original_claim if was_updated=true else claim',
            'results': [{
                'question': r['question'],
                'reference_answer': r['reference_answer'],
                'claims_used': r['confident_only_result']['claims_used'],
                'generated_answer': r['confident_only_result']['generated_answer'],
                'is_semantically_correct': r['confident_only_result']['is_semantically_correct'],
                'claims_count': r['confident_claims_count']
            } for r in results]
        }
        
        clean_data = convert_numpy_types(data)
        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump(clean_data, f, ensure_ascii=False, indent=4)
        
        print(f"Confident-only results saved to {output_file}")
    
    def _save_combined_results(self, results: List[Dict]) -> None:
        """Save combined claims answer results."""
        output_file = "rag_final_answer_combined.json"
        
        data = {
            'methodology': 'Generate answers using confident claims (original) + uncertain but included claims (final)',
            'claim_selection': 'confident (original) + uncertain with is_included=true (final versions)',
            'results': [{
                'question': r['question'],
                'reference_answer': r['reference_answer'],
                'confident_claims_used': r['confident_plus_included_result']['confident_claims_used'],
                'uncertain_claims_used': r['confident_plus_included_result']['uncertain_claims_used'],
                'generated_answer': r['confident_plus_included_result']['generated_answer'],
                'is_semantically_correct': r['confident_plus_included_result']['is_semantically_correct'],
                'confident_claims_count': r['confident_claims_count'],
                'uncertain_included_claims_count': r['uncertain_included_claims_count']
            } for r in results]
        }
        
        clean_data = convert_numpy_types(data)
        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump(clean_data, f, ensure_ascii=False, indent=4)
        
        print(f"Combined results saved to {output_file}")
    
    def _save_summary(self, results: List[Dict], confident_only_correct: int, combined_correct: int) -> None:
        """Save summary comparison."""
        output_file = "rag_final_answer_summary.json"
        
        total_questions = len(results)
        
        # Calculate rates
        confident_only_accuracy = confident_only_correct / max(total_questions, 1)
        combined_accuracy = combined_correct / max(total_questions, 1)
        accuracy_improvement = combined_accuracy - confident_only_accuracy
        relative_improvement = accuracy_improvement / max(confident_only_accuracy, 0.001)
        
        # Per-question analysis
        improvements = []
        degradations = []
        unchanged = []
        
        for r in results:
            conf_correct = r['confident_only_result']['is_semantically_correct']
            comb_correct = r['confident_plus_included_result']['is_semantically_correct']
            
            if not conf_correct and comb_correct:
                improvements.append(r['question'])
            elif conf_correct and not comb_correct:
                degradations.append(r['question'])
            else:
                unchanged.append(r['question'])
        
        # Calculate average claims used
        avg_confident_claims = sum(r['confident_claims_count'] for r in results) / len(results) if results else 0
        avg_uncertain_claims = sum(r['uncertain_included_claims_count'] for r in results) / len(results) if results else 0
        
        summary = {
            'overall_statistics': {
                'total_questions': total_questions,
                'confident_only_correct': confident_only_correct,
                'combined_correct': combined_correct,
                'confident_only_accuracy': confident_only_accuracy,
                'combined_accuracy': combined_accuracy,
                'accuracy_improvement': accuracy_improvement,
                'relative_improvement': relative_improvement,
                'improvement_count': len(improvements),
                'degradation_count': len(degradations),
                'unchanged_count': len(unchanged)
            },
            'claim_usage_statistics': {
                'avg_confident_claims_per_question': avg_confident_claims,
                'avg_uncertain_included_claims_per_question': avg_uncertain_claims,
                'avg_total_claims_in_combined': avg_confident_claims + avg_uncertain_claims
            },
            'detailed_analysis': {
                'questions_improved': improvements[:5],  # Show first 5 examples
                'questions_degraded': degradations[:5],  # Show first 5 examples
                'improvement_rate': len(improvements) / total_questions if total_questions > 0 else 0,
                'degradation_rate': len(degradations) / total_questions if total_questions > 0 else 0
            },
            'interpretation': {
                'methodology': 'Compare answer quality using confident claims only vs confident + uncertain included claims',
                'evaluation_criteria': 'Semantic correctness against reference answers',
                'improvement_meaning': f"Including uncertain but relevant claims {'improved' if accuracy_improvement > 0 else 'did not improve'} accuracy by {accuracy_improvement:.2%}"
            }
        }
        
        clean_summary = convert_numpy_types(summary)
        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump(clean_summary, f, ensure_ascii=False, indent=4)
        
        print(f"Summary saved to {output_file}")
        
        # Print key statistics
        print(f"\n📊 Final Answer Generation Summary:")
        print(f"  Total questions: {total_questions}")
        print(f"  Confident-only accuracy: {confident_only_correct}/{total_questions} ({confident_only_accuracy:.2%})")
        print(f"  Combined accuracy: {combined_correct}/{total_questions} ({combined_accuracy:.2%})")
        print(f"  Accuracy improvement: {accuracy_improvement:.2%}")
        print(f"  Questions improved: {len(improvements)}")
        print(f"  Questions degraded: {len(degradations)}")
        print(f"  Avg confident claims per question: {avg_confident_claims:.1f}")
        print(f"  Avg uncertain included claims per question: {avg_uncertain_claims:.1f}")

def main():
    """Main function to run answer generation."""
    print("📝 RAG Final Answer Generation Analysis")
    print("=" * 50)
    
    try:
        generator = RAGFinalAnswerGenerator()
        generator.run_answer_generation()
        
    except KeyboardInterrupt:
        print("\n\n⚠️  Answer generation interrupted by user")
    except Exception as e:
        print(f"\n❌ Answer generation failed: {e}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    main() 