import sys
import os
sys.path.append('src')
from typing import List, Dict, TypedDict
from typing_extensions import TypedDict
import json
from datetime import datetime

try:
    from langgraph.graph import END, StateGraph, START
    from langchain_core.prompts import ChatPromptTemplate
    from langchain_openai import ChatOpenAI
    from langchain_anthropic import ChatAnthropic
    from langchain_core.output_parsers import StrOutputParser
    from pydantic import BaseModel, Field
    LANGGRAPH_AVAILABLE = True
    print("LangGraph available")
except ImportError:
    print("LangGraph not available. Install with: pip install langgraph langchain-openai langchain-anthropic")
    LANGGRAPH_AVAILABLE = False

from rag_llm_integration import ClinicalRAGSystem, LLMClinicalResponder

class MedicalResponse(BaseModel):
    clinical_reasoning: str = Field(description="Clinical reasoning and differential diagnosis")
    questions_to_ask: List[str] = Field(description="Specific questions to gather missing information")
    temporal_analysis: str = Field(description="Analysis of temporal relationships and causality")
    missing_information: List[str] = Field(description="Critical missing information identified")
    confidence_level: str = Field(description="Confidence level: high, medium, low")
    follow_up_needed: bool = Field(description="Whether follow-up is needed")
    recommended_actions: List[str] = Field(description="Recommended clinical actions")

class MedicalRAGState(TypedDict):
    patient_case: Dict
    messages: List
    rag_context: str
    retrieved_docs: List[str]
    current_reasoning: str
    questions_generated: List[str]
    responses_collected: Dict[str, str]
    missing_info_identified: List[str]
    confidence_score: float
    iterations: int
    final_assessment: str
    error: str

class LangGraphRAGSystem:
    
    def __init__(self, mimic_cases: List[Dict]):
        if not LANGGRAPH_AVAILABLE:
            raise ImportError("LangGraph required. Install with: pip install langgraph")
        
        self.mimic_cases = mimic_cases
        
        print("Initializing RAG system with MIMIC data...")
        self.rag_system = ClinicalRAGSystem()
        self.rag_system.add_mimic_cases_to_rag(mimic_cases)
        
        self.llm = self._initialize_medical_llm()
        
        self.rag_chain = self._create_rag_chain()
        self.question_chain = self._create_question_chain()
        self.reasoning_chain = self._create_reasoning_chain()
        self.assessment_chain = self._create_assessment_chain()
        
        self.workflow = self._build_langgraph_workflow()
        
        print("LangGraph RAG System initialized successfully!")
    
    def _initialize_medical_llm(self):
        
        print("Initializing medical LLM...")
        
        try:
            llm = ChatAnthropic(
                model="claude-3-sonnet-20240229",
                temperature=0.1,
                max_tokens=3000
            )
            print("Using Claude 3 Sonnet for medical reasoning")
            return llm
        except Exception as e:
            print(f"Claude not available: {e}")
            
        try:
            llm = ChatOpenAI(
                model="gpt-4",
                temperature=0.1,
                max_tokens=3000
            )
            print("Using GPT-4 for medical reasoning")
            return llm
        except Exception as e:
            print(f"GPT-4 not available: {e}")
            
        try:
            llm = ChatOpenAI(
                model="gpt-3.5-turbo",
                temperature=0.1,
                max_tokens=3000
            )
            print("Using GPT-3.5 Turbo for medical reasoning")
            return llm
        except Exception as e:
            print(f"No LLM available: {e}")
            raise e
    
    def _create_rag_chain(self):
        
        rag_prompt = ChatPromptTemplate.from_messages([
            ("system", """You are an expert medical AI assistant with access to a comprehensive medical knowledge base.

Retrieved Medical Knowledge:
{retrieved_context}

Similar Patient Cases:
{similar_cases}

Your task is to provide medical analysis using the retrieved information. Always:
1. Base your reasoning on the retrieved medical knowledge
2. Reference similar cases when relevant
3. Acknowledge limitations in the available information
4. Provide evidence-based medical reasoning
5. Use appropriate medical terminology"""),
            
            ("human", """Patient Case:
{patient_case}

Question: {query}

Provide medical analysis based on the retrieved knowledge and similar cases.""")
        ])
        
        return rag_prompt | self.llm | StrOutputParser()
    
    def _create_question_chain(self):
        
        question_prompt = ChatPromptTemplate.from_messages([
            ("system", """You are a clinical expert generating targeted medical questions.

Medical Knowledge Base:
{medical_knowledge}

Similar Cases:
{similar_cases}

Guidelines:
1. Generate specific, actionable medical questions
2. Focus on missing information that impacts diagnosis/treatment
3. Consider temporal relationships and causality
4. Ask about relevant family history, medications, lab values
5. Prioritize questions based on clinical significance
6. Use proper medical terminology

Generate 2-3 high-priority questions."""),
            
            ("human", """Patient Case:
Diagnosis: {diagnosis}
Symptoms: {symptoms}
Known Information: {known_info}

Based on the medical knowledge and similar cases, what specific questions should be asked?""")
        ])
        
        return question_prompt | self.llm.with_structured_output(MedicalResponse)
    
    def _create_reasoning_chain(self):
        
        reasoning_prompt = ChatPromptTemplate.from_messages([
            ("system", """You are performing comprehensive clinical reasoning using retrieved medical knowledge.

Medical Knowledge Base:
{medical_knowledge}

Similar Cases:
{similar_cases}

Patient Responses:
{patient_responses}

Perform systematic clinical reasoning:
1. Analyze symptoms and timeline using retrieved knowledge
2. Consider differential diagnoses based on similar cases
3. Evaluate risk factors and contributing factors
4. Assess temporal relationships and causality
5. Identify knowledge gaps and missing information
6. Provide confidence assessment based on available evidence

Be thorough, evidence-based, and acknowledge uncertainties."""),
            
            ("human", """Patient Case:
{patient_case}

Questions Asked: {questions}
Responses Received: {responses}

Provide comprehensive clinical reasoning using the medical knowledge base.""")
        ])
        
        return reasoning_prompt | self.llm.with_structured_output(MedicalResponse)
    
    def _create_assessment_chain(self):
        
        assessment_prompt = ChatPromptTemplate.from_messages([
            ("system", """You are providing a final clinical assessment based on comprehensive analysis.

Medical Knowledge Base:
{medical_knowledge}

Analysis Summary:
{analysis_summary}

Provide a final assessment including:
1. Clinical summary and key findings
2. Most likely diagnosis with supporting evidence
3. Differential diagnoses to consider
4. Critical missing information
5. Recommended next steps and actions
6. Confidence level and reasoning

Be definitive where evidence supports conclusions, acknowledge uncertainty where appropriate."""),
            
            ("human", """Complete Analysis:

Patient Case: {patient_case}
Questions & Responses: {qa_summary}
Clinical Reasoning: {reasoning}

Provide final clinical assessment and recommendations.""")
        ])
        
        return assessment_prompt | self.llm | StrOutputParser()
    
    def _build_langgraph_workflow(self):
        
        workflow = StateGraph(MedicalRAGState)
        
        workflow.add_node("retrieve_medical_knowledge", self.retrieve_medical_knowledge)
        workflow.add_node("generate_targeted_questions", self.generate_targeted_questions)
        workflow.add_node("collect_clinical_responses", self.collect_clinical_responses)
        workflow.add_node("perform_clinical_reasoning", self.perform_clinical_reasoning)
        workflow.add_node("assess_information_completeness", self.assess_information_completeness)
        workflow.add_node("generate_final_assessment", self.generate_final_assessment)
        
        workflow.add_edge(START, "retrieve_medical_knowledge")
        workflow.add_edge("retrieve_medical_knowledge", "generate_targeted_questions")
        workflow.add_edge("generate_targeted_questions", "collect_clinical_responses")
        workflow.add_edge("collect_clinical_responses", "perform_clinical_reasoning")
        workflow.add_edge("perform_clinical_reasoning", "assess_information_completeness")
        
        workflow.add_conditional_edges(
            "assess_information_completeness",
            self.decide_next_action,
            {
                "continue_questioning": "generate_targeted_questions",
                "finalize_assessment": "generate_final_assessment",
                "max_iterations": "generate_final_assessment"
            }
        )
        
        workflow.add_edge("generate_final_assessment", END)
        
        return workflow.compile()
    
    def retrieve_medical_knowledge(self, state: MedicalRAGState):
        
        print("Retrieving medical knowledge from MIMIC database...")
        
        patient_case = state["patient_case"]
        
        query_concepts = []
        if patient_case.get('symptoms'):
            query_concepts.extend(patient_case['symptoms'].split(', '))
        if patient_case.get('admission_info', {}).get('diagnosis'):
            query_concepts.append(patient_case['admission_info']['diagnosis'])
        if patient_case.get('procedures'):
            query_concepts.extend(patient_case['procedures'][:3])
        
        relevant_docs = self.rag_system.retrieve_relevant_docs(query_concepts, top_k=5)
        
        rag_context = self._format_rag_context(patient_case, relevant_docs)
        retrieved_docs = [doc.content for doc in relevant_docs]
        
        print(f"   Retrieved {len(relevant_docs)} relevant medical documents")
        
        return {
            **state,
            "rag_context": rag_context,
            "retrieved_docs": retrieved_docs
        }
    
    def generate_targeted_questions(self, state: MedicalRAGState):
        
        print("Generating targeted medical questions...")
        
        patient_case = state["patient_case"]
        rag_context = state["rag_context"]
        retrieved_docs = state["retrieved_docs"]
        
        try:
            response = self.question_chain.invoke({
                "medical_knowledge": rag_context,
                "similar_cases": "\n".join(retrieved_docs[:3]),
                "diagnosis": patient_case.get('admission_info', {}).get('diagnosis', 'Unknown'),
                "symptoms": patient_case.get('symptoms', 'Clinical presentation'),
                "known_info": str(patient_case.get('known_info', []))
            })
            
            questions = response.questions_to_ask
            missing_info = response.missing_information
            
            print(f"   Generated {len(questions)} targeted questions:")
            for i, q in enumerate(questions, 1):
                print(f"   {i}. {q}")
            
            return {
                **state,
                "questions_generated": questions,
                "missing_info_identified": missing_info,
                "current_reasoning": response.clinical_reasoning
            }
            
        except Exception as e:
            print(f"   Question generation error: {e}")
            fallback_questions = [
                "What is the timeline and progression of symptoms?",
                "Any relevant family history or past medical history?",
                "Current medications and their effectiveness?"
            ]
            
            return {
                **state,
                "questions_generated": fallback_questions,
                "missing_info_identified": ["symptom timeline", "medical history", "medications"],
                "current_reasoning": "Basic clinical assessment needed"
            }
    
    def collect_clinical_responses(self, state: MedicalRAGState):
        
        print("Collecting clinical responses...")
        
        questions = state["questions_generated"]
        responses = state.get("responses_collected", {})
        
        for question in questions:
            if question not in responses:
                print(f"\n{question}")
                
                try:
                    response = input("Clinical Response: ").strip()
                    
                    if response.lower() in ['skip', 'exit', 'quit']:
                        response = "Information not available"
                    elif not response:
                        response = "No additional information provided"
                    
                    responses[question] = response
                    print(f"   Response recorded")
                    
                except KeyboardInterrupt:
                    print("\nInput interrupted")
                    responses[question] = "Response collection interrupted"
                    break
        
        return {
            **state,
            "responses_collected": responses
        }
    
    def perform_clinical_reasoning(self, state: MedicalRAGState):
        
        print("Performing clinical reasoning with medical knowledge...")
        
        patient_case = state["patient_case"]
        rag_context = state["rag_context"]
        retrieved_docs = state["retrieved_docs"]
        questions = state["questions_generated"]
        responses = state["responses_collected"]
        
        try:
            reasoning_response = self.reasoning_chain.invoke({
                "medical_knowledge": rag_context,
                "similar_cases": "\n".join(retrieved_docs[:3]),
                "patient_responses": str(responses),
                "patient_case": str(patient_case),
                "questions": questions,
                "responses": responses
            })
            
            return {
                **state,
                "current_reasoning": reasoning_response.clinical_reasoning,
                "confidence_score": self._convert_confidence_to_score(reasoning_response.confidence_level),
                "missing_info_identified": reasoning_response.missing_information
            }
            
        except Exception as e:
            print(f"   Reasoning error: {e}")
            
            return {
                **state,
                "current_reasoning": "Clinical assessment based on available information",
                "confidence_score": 0.5,
                "missing_info_identified": ["additional clinical details"]
            }
    
    def assess_information_completeness(self, state: MedicalRAGState):
        
        print("Assessing information completeness...")
        
        confidence_score = state.get("confidence_score", 0.0)
        iterations = state.get("iterations", 0)
        missing_info = state.get("missing_info_identified", [])
        
        iterations += 1
        
        print(f"   Confidence Score: {confidence_score:.2f}")
        print(f"   Missing Information Items: {len(missing_info)}")
        print(f"   Iterations Completed: {iterations}")
        
        return {
            **state,
            "iterations": iterations
        }
    
    def decide_next_action(self, state: MedicalRAGState):
        
        confidence_score = state.get("confidence_score", 0.0)
        iterations = state.get("iterations", 0)
        missing_info = state.get("missing_info_identified", [])
        
        max_iterations = 3
        confidence_threshold = 0.75
        
        if iterations >= max_iterations:
            print("---DECISION: MAX ITERATIONS REACHED---")
            return "max_iterations"
        elif confidence_score >= confidence_threshold and len(missing_info) <= 2:
            print("---DECISION: SUFFICIENT INFORMATION FOR ASSESSMENT---")
            return "finalize_assessment"
        else:
            print("---DECISION: NEED MORE INFORMATION---")
            return "continue_questioning"
    
    def generate_final_assessment(self, state: MedicalRAGState):
        
        print("Generating final clinical assessment...")
        
        patient_case = state["patient_case"]
        rag_context = state["rag_context"]
        questions = state["questions_generated"]
        responses = state["responses_collected"]
        reasoning = state["current_reasoning"]
        
        try:
            final_assessment = self.assessment_chain.invoke({
                "medical_knowledge": rag_context,
                "analysis_summary": f"Confidence: {state.get('confidence_score', 0):.2f}, Iterations: {state.get('iterations', 0)}",
                "patient_case": str(patient_case),
                "qa_summary": str(dict(zip(questions, [responses.get(q, 'No response') for q in questions]))),
                "reasoning": reasoning
            })
            
            return {
                **state,
                "final_assessment": final_assessment
            }
            
        except Exception as e:
            print(f"   Assessment generation error: {e}")
            
            fallback_assessment = f"""
            Clinical Assessment Summary:
            
            Patient: {patient_case.get('patient_id', 'Unknown')}
            Diagnosis: {patient_case.get('admission_info', {}).get('diagnosis', 'Unknown')}
            
            Clinical Reasoning: {reasoning}
            
            Confidence Level: {state.get('confidence_score', 0):.2f}
            Information Gathered: {len(responses)} responses to clinical questions
            
            Assessment: Based on available information and clinical reasoning.
            """
            
            return {
                **state,
                "final_assessment": fallback_assessment
            }
    
    def _format_rag_context(self, patient_case: Dict, relevant_docs: List) -> str:
        
        context_parts = []
        
        diagnosis = patient_case.get('admission_info', {}).get('diagnosis', 'Unknown')
        symptoms = patient_case.get('symptoms', 'Clinical presentation')
        context_parts.append(f"Current Patient: {diagnosis} presenting with {symptoms}")
        
        if relevant_docs:
            context_parts.append("Retrieved Medical Knowledge:")
            for i, doc in enumerate(relevant_docs[:3], 1):
                context_parts.append(f"{i}. {doc.content[:300]}...")
        
        return "\n".join(context_parts)
    
    def _convert_confidence_to_score(self, confidence_level: str) -> float:
        
        confidence_mapping = {
            "high": 0.9,
            "medium": 0.6,
            "low": 0.3
        }
        
        return confidence_mapping.get(confidence_level.lower(), 0.5)
    
    def analyze_patient_with_rag(self, patient_case: Dict) -> Dict:
        
        print(f"Starting LangGraph RAG Analysis...")
        print(f"   Patient: {patient_case.get('patient_id', 'Unknown')}")
        print(f"   Diagnosis: {patient_case.get('admission_info', {}).get('diagnosis', 'Unknown')}")
        print(f"   Symptoms: {patient_case.get('symptoms', 'Unknown')}")
        
        initial_state = {
            "patient_case": patient_case,
            "messages": [],
            "rag_context": "",
            "retrieved_docs": [],
            "current_reasoning": "",
            "questions_generated": [],
            "responses_collected": {},
            "missing_info_identified": [],
            "confidence_score": 0.0,
            "iterations": 0,
            "final_assessment": "",
            "error": ""
        }
        
        try:
            result = self.workflow.invoke(initial_state)
            
            print("\nLANGGRAPH RAG ANALYSIS COMPLETE")
            print("="*50)
            print(f"Final Assessment:")
            print(f"{result['final_assessment']}")
            print(f"\nAnalysis Summary:")
            print(f"   Confidence Score: {result['confidence_score']:.2f}")
            print(f"   Iterations: {result['iterations']}")
            print(f"   Questions Asked: {len(result['questions_generated'])}")
            print(f"   Responses Collected: {len(result['responses_collected'])}")
            print(f"   RAG Documents Retrieved: {len(result['retrieved_docs'])}")
            
            return result
            
        except Exception as e:
            print(f"LangGraph RAG workflow error: {e}")
            return {
                "error": str(e),
                "final_assessment": "RAG analysis failed",
                "confidence_score": 0.0
            }

def test_langgraph_rag_system():
    
    print("Testing LangGraph RAG System")
    print("="*40)
    
    test_cases = [
        {
            'patient_id': 'RAG_TEST_001',
            'admission_info': {
                'diagnosis': 'acute myocardial infarction'
            },
            'symptoms': 'chest pain, shortness of breath, diaphoresis',
            'known_info': ['chest pain', 'elevated troponin'],
            'procedures': ['ECG', 'cardiac catheterization'],
            'medications': ['aspirin', 'metoprolol', 'atorvastatin']
        },
        {
            'patient_id': 'RAG_TEST_002',
            'admission_info': {
                'diagnosis': 'pneumonia'
            },
            'symptoms': 'cough, fever, shortness of breath',
            'known_info': ['productive cough', 'fever'],
            'procedures': ['chest X-ray', 'blood culture'],
            'medications': ['antibiotics', 'bronchodilators']
        }
    ]
    
    if not LANGGRAPH_AVAILABLE:
        print("LangGraph not available. Please install: pip install langgraph")
        return None
    
    try:
        rag_system = LangGraphRAGSystem(test_cases)
        
        test_patient = test_cases[0]
        result = rag_system.analyze_patient_with_rag(test_patient)
        
        print(f"\nLangGraph RAG System test completed")
        print(f"   Result type: {type(result)}")
        print(f"   Has final assessment: {'final_assessment' in result}")
        
        return rag_system
        
    except Exception as e:
        print(f"Test failed: {e}")
        return None

def integrate_with_main_system(main_system_results):
    
    print("Integrating LangGraph RAG with main system...")
    
    if not main_system_results:
        print("No main system results provided")
        return None
    
    patient_cases = main_system_results.get('patient_cases', [])
    
    if not patient_cases:
        print("No patient cases available")
        return None
    
    try:
        rag_system = LangGraphRAGSystem(patient_cases)
        
        test_patient = patient_cases[0]
        
        print(f"\nTesting LangGraph RAG with real MIMIC patient...")
        print(f"   Patient: {test_patient['patient_id']}")
        
        rag_result = rag_system.analyze_patient_with_rag(test_patient)
        
        main_system_results['langgraph_rag_system'] = rag_system
        main_system_results['langgraph_rag_result'] = rag_result
        main_system_results['langgraph_rag_integrated'] = True
        
        print("LangGraph RAG integration complete!")
        return main_system_results
        
    except Exception as e:
        print(f"Integration failed: {e}")
        main_system_results['langgraph_rag_integrated'] = False
        return main_system_results

if __name__ == "__main__":
    test_langgraph_rag_system()