import json
import numpy as np
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass
from collections import defaultdict

@dataclass
class ClinicalDocument:
    doc_id: str
    content: str
    metadata: Dict
    embeddings: Optional[np.ndarray] = None
    document_type: str = "clinical_note"

class ClinicalRAGSystem:
    
    def __init__(self):
        self.document_store = []
        self.embeddings_index = None
        self.concept_to_docs = defaultdict(list)
        
        print("Clinical RAG System initialized")
    
    def add_mimic_cases_to_rag(self, mimic_cases: List[Dict]):
        
        print(f"Adding {len(mimic_cases)} MIMIC cases to RAG...")
        
        for i, case in enumerate(mimic_cases):
            doc_content = self._create_document_from_case(case)
            
            doc = ClinicalDocument(
                doc_id=f"mimic_case_{i}",
                content=doc_content,
                metadata={
                    'patient_id': case.get('patient_id', f'unknown_{i}'),
                    'diagnosis': case.get('admission_info', {}).get('diagnosis', 'unknown'),
                    'symptoms': case.get('symptoms', ''),
                    'procedures': case.get('procedures', []),
                    'medications': case.get('medications', [])
                },
                document_type="mimic_case"
            )
            
            self.document_store.append(doc)
            
            concepts = self._extract_concepts_from_case(case)
            for concept in concepts:
                self.concept_to_docs[concept].append(len(self.document_store) - 1)
        
        print(f"RAG knowledge base: {len(self.document_store)} documents indexed")
    
    def _create_document_from_case(self, case: Dict) -> str:
        
        doc_parts = []
        
        if case.get('symptoms'):
            doc_parts.append(f"Patient presents with: {case['symptoms']}")
        
        if case.get('clinical_notes'):
            for note in case['clinical_notes']:
                doc_parts.append(f"Clinical note: {note}")
        
        if case.get('diagnoses'):
            diagnoses_text = ", ".join([str(d) for d in case['diagnoses']])
            doc_parts.append(f"Diagnoses: {diagnoses_text}")
        
        if case.get('procedures'):
            procedures_text = ", ".join([str(p) for p in case['procedures']])
            doc_parts.append(f"Procedures: {procedures_text}")
        
        if case.get('medications'):
            meds_text = ", ".join([str(m) for m in case['medications'][:5]])
            doc_parts.append(f"Medications: {meds_text}")
        
        if case.get('temporal_sequence'):
            sequence_text = " → ".join(case['temporal_sequence'])
            doc_parts.append(f"Clinical sequence: {sequence_text}")
        
        return " | ".join(doc_parts)
    
    def _extract_concepts_from_case(self, case: Dict) -> List[str]:
        
        concepts = set()
        
        for field in ['symptoms', 'diagnoses', 'procedures', 'medications']:
            if case.get(field):
                if isinstance(case[field], str):
                    concepts.update(case[field].lower().replace(',', ' ').split())
                elif isinstance(case[field], list):
                    for item in case[field]:
                        if item and str(item) != 'nan':
                            concepts.update(str(item).lower().replace(',', ' ').split())
        
        cleaned = [c.strip() for c in concepts if len(c.strip()) > 2 and c.strip().isalpha()]
        return list(set(cleaned))
    
    def retrieve_relevant_docs(self, query_concepts: List[str], top_k: int = 3) -> List[ClinicalDocument]:
        
        doc_scores = defaultdict(float)
        
        for concept in query_concepts:
            concept_lower = concept.lower()
            
            if concept_lower in self.concept_to_docs:
                for doc_idx in self.concept_to_docs[concept_lower]:
                    doc_scores[doc_idx] += 1.0
        
        for i, doc in enumerate(self.document_store):
            content_lower = doc.content.lower()
            for concept in query_concepts:
                if concept.lower() in content_lower:
                    doc_scores[i] += 0.5
        
        top_docs = sorted(doc_scores.items(), key=lambda x: x[1], reverse=True)[:top_k]
        
        return [self.document_store[doc_idx] for doc_idx, score in top_docs if score > 0]

class LLMClinicalResponder:
    
    def __init__(self, rag_system: ClinicalRAGSystem):
        self.rag_system = rag_system
        self.response_templates = self._create_clinical_templates()
        
        self.device = "cuda" if self._check_cuda() else "cpu"
        self.tokenizer = None
        self.model = None
        self.pipeline = None
        self.llm_available = False
        
        self._initialize_free_huggingface_llm()
        
        print("LLM Clinical Responder initialized with RAG + FREE LLM")
    
    def _check_cuda(self):
        try:
            import torch
            return torch.cuda.is_available()
        except ImportError:
            return False
    
    def _initialize_free_huggingface_llm(self):
        
        try:
            import torch
            from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
            
            print(f"Initializing FREE Medical LLM on {self.device}...")
            
            medical_models = [
                "microsoft/BioGPT-Large",
                "microsoft/BioGPT",
                "dmis-lab/biobert-base-cased-v1.1",
                "emilyalsentzer/Bio_ClinicalBERT",
                "microsoft/DialoGPT-medium",
                "gpt2"
            ]
            
            for model_name in medical_models:
                try:
                    print(f"   Trying medical model: {model_name}")
                    
                    self.tokenizer = AutoTokenizer.from_pretrained(model_name)
                    if self.tokenizer.pad_token is None:
                        self.tokenizer.pad_token = self.tokenizer.eos_token
                    
                    self.model = AutoModelForCausalLM.from_pretrained(
                        model_name,
                        torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
                    )
                    
                    self.pipeline = pipeline(
                        "text-generation",
                        model=self.model,
                        tokenizer=self.tokenizer,
                        device=0 if self.device == "cuda" else -1,
                        max_length=300,
                        do_sample=True,
                        temperature=0.6,
                        top_p=0.9
                    )
                    
                    self.llm_available = True
                    print(f"Medical LLM loaded: {model_name}")
                    break
                    
                except Exception as e:
                    print(f"   Failed: {e}")
                    continue
            
            if not self.llm_available:
                print("No medical LLM available, using template responses only")
                
        except ImportError:
            print("transformers/torch not installed, using template responses only")
            self.llm_available = False
    
    def _create_clinical_templates(self) -> Dict[str, str]:
        
        return {
            'temporal_sequence': """
            Clinical Timeline Analysis: {relevant_context}
            
            For the temporal sequence involving {concept}:
            - Typical progression: {temporal_progression}
            - Missing temporal links: {missing_links}
            - Causal relationships: {causal_analysis}
            - Next expected steps: {next_steps}
            """,
            
            'missing_temporal_info': """
            Temporal Gap Analysis: {relevant_context}
            
            Missing information in timeline for {current_diagnosis}:
            - Key missing temporal data: {missing_temporal}
            - Impact on clinical reasoning: {reasoning_impact}
            - Recommended temporal assessment: {temporal_assessment}
            """,
            
            'causal_relationship': """
            Causal Chain Analysis: {relevant_context}
            
            For {symptom} → {condition} relationship:
            - Causal pathway: {causal_pathway}
            - Temporal dependencies: {dependencies}
            - Missing causal links: {missing_causality}
            - Clinical implications: {implications}
            """,
            
            'sequence_validation': """
            Clinical Sequence Validation: {relevant_context}
            
            Evaluating sequence: {clinical_sequence}
            - Temporal consistency: {consistency_check}
            - Expected vs actual timing: {timing_analysis}
            - Sequence anomalies: {anomalies}
            - Clinical recommendations: {recommendations}
            """,
            
            'symptom_temporal': """
            Symptom Timeline Analysis: {relevant_context}
            
            For symptom {symptom} in temporal context:
            - Onset pattern: {onset_pattern}
            - Progression characteristics: {progression}
            - Temporal associations: {associations}
            - Missing temporal details: {missing_details}
            """,
            
            'diagnostic_sequence': """
            Diagnostic Sequence Analysis: {relevant_context}
            
            For {condition} diagnostic timeline:
            - Standard diagnostic sequence: {standard_sequence}
            - Current sequence gaps: {sequence_gaps}
            - Temporal diagnostic markers: {temporal_markers}
            - Missing diagnostic steps: {missing_steps}
            """
        }
    
    def generate_clinical_response(self, question: str, patient_case: Dict, 
                                 agent_focus: List[str]) -> str:
        
        response_type = self._classify_question_type(question)
        
        query_concepts = agent_focus + self._extract_question_concepts(question)
        relevant_docs = self.rag_system.retrieve_relevant_docs(query_concepts, top_k=3)
        
        clinical_context = self._create_clinical_context(relevant_docs, patient_case)
        
        if response_type in self.response_templates:
            response = self._fill_clinical_template(
                response_type, question, patient_case, clinical_context, agent_focus
            )
        else:
            response = self._generate_general_clinical_response(
                question, patient_case, clinical_context
            )
        
        return response
    
    def _classify_question_type(self, question: str) -> str:
        
        question_lower = question.lower()
        
        if any(word in question_lower for word in ['sequence', 'order', 'timeline', 'progression']):
            return 'temporal_sequence'
        elif any(word in question_lower for word in ['before', 'after', 'next', 'previous', 'then']):
            return 'temporal_sequence'
        elif any(word in question_lower for word in ['missing', 'gap', 'incomplete', 'unknown']):
            return 'missing_temporal_info'
        elif any(word in question_lower for word in ['cause', 'caused', 'leads to', 'results in']):
            return 'causal_relationship'
        elif any(word in question_lower for word in ['symptom', 'pain', 'onset', 'started']):
            return 'symptom_temporal'
        elif any(word in question_lower for word in ['diagnosis', 'test', 'workup', 'evaluate']):
            return 'diagnostic_sequence'
        elif any(word in question_lower for word in ['validate', 'correct', 'appropriate', 'timing']):
            return 'sequence_validation'
        else:
            return 'temporal_sequence'
    
    def _extract_question_concepts(self, question: str) -> List[str]:
        
        words = question.lower().replace('?', '').replace(',', '').split()
        
        medical_concepts = []
        for word in words:
            if (len(word) > 3 and 
                word not in ['what', 'have', 'been', 'does', 'should', 'could', 'would']):
                medical_concepts.append(word)
        
        return medical_concepts
    
    def _create_clinical_context(self, relevant_docs: List[ClinicalDocument], 
                               patient_case: Dict) -> str:
        
        if not relevant_docs:
            return "Limited clinical context available from current knowledge base."
        
        context_parts = []
        
        for doc in relevant_docs[:2]:
            context_parts.append(f"Similar case: {doc.content}")
        
        current_diagnosis = patient_case.get('admission_info', {}).get('diagnosis', 'unknown')
        symptoms = patient_case.get('symptoms', 'not specified')
        
        context_parts.append(f"Current patient: diagnosed with {current_diagnosis}, presenting with {symptoms}")
        
        return " | ".join(context_parts)
    
    def _fill_clinical_template(self, template_type: str, question: str, 
                              patient_case: Dict, clinical_context: str, 
                              agent_focus: List[str]) -> str:
        
        template = self.response_templates[template_type]
        
        current_diagnosis = patient_case.get('admission_info', {}).get('diagnosis', 'condition under evaluation')
        primary_focus = agent_focus[0] if agent_focus else 'clinical assessment'
        temporal_sequence = patient_case.get('temporal_sequence', [])
        symptoms = patient_case.get('symptoms', 'not specified')
        
        if template_type == 'temporal_sequence':
            return template.format(
                relevant_context=clinical_context[:200],
                concept=primary_focus,
                temporal_progression=f"{primary_focus} typically follows: {' → '.join(temporal_sequence[:3])}",
                missing_links="gaps in temporal understanding need investigation",
                causal_analysis=f"causal relationship between {primary_focus} and {current_diagnosis}",
                next_steps="systematic temporal assessment recommended"
            )
        
        elif template_type == 'missing_temporal_info':
            return template.format(
                relevant_context=clinical_context[:200],
                current_diagnosis=current_diagnosis,
                missing_temporal=f"timing information for {primary_focus}",
                reasoning_impact="affects clinical decision-making timeline",
                temporal_assessment=f"detailed temporal history of {primary_focus}"
            )
        
        elif template_type == 'causal_relationship':
            return template.format(
                relevant_context=clinical_context[:200],
                symptom=symptoms.split(',')[0] if ',' in symptoms else symptoms,
                condition=current_diagnosis,
                causal_pathway=f"pathway from {symptoms} to {current_diagnosis}",
                dependencies="temporal dependencies in causal chain",
                missing_causality="intermediate causal steps need clarification",
                implications="impacts treatment timing and prognosis"
            )
        
        elif template_type == 'sequence_validation':
            sequence_str = ' → '.join(temporal_sequence) if temporal_sequence else 'sequence under evaluation'
            return template.format(
                relevant_context=clinical_context[:200],
                clinical_sequence=sequence_str,
                consistency_check="evaluating temporal consistency",
                timing_analysis="comparing expected vs actual timing",
                anomalies="potential temporal anomalies identified",
                recommendations="temporal sequence optimization suggested"
            )
        
        elif template_type == 'symptom_temporal':
            return template.format(
                relevant_context=clinical_context[:200],
                symptom=primary_focus,
                onset_pattern=f"temporal onset pattern of {primary_focus}",
                progression="symptom progression over time",
                associations=f"temporal associations with {current_diagnosis}",
                missing_details="specific timing details need clarification"
            )
        
        elif template_type == 'diagnostic_sequence':
            return template.format(
                relevant_context=clinical_context[:200],
                condition=current_diagnosis,
                standard_sequence="standard diagnostic temporal sequence",
                sequence_gaps="gaps in current diagnostic timeline",
                temporal_markers=f"temporal diagnostic markers for {current_diagnosis}",
                missing_steps="missing temporal diagnostic steps"
            )
        
        return template.format(relevant_context=clinical_context[:200])
    
    def _generate_general_clinical_response(self, question: str, patient_case: Dict, 
                                          clinical_context: str) -> str:
        
        current_diagnosis = patient_case.get('admission_info', {}).get('diagnosis', 'unknown')
        
        return f"""
        Clinical Assessment for {current_diagnosis}:
        
        Based on similar cases in knowledge base: {clinical_context[:150]}
        
        Clinical recommendation: Further evaluation of the specific aspect mentioned 
        in your question would be appropriate given the current clinical context.
        
        Consider systematic assessment and documentation of findings.
        """

def integrate_rag_llm_with_system(agent_creator, mimic_cases):
    
    print("Integrating RAG + LLM with existing system...")
    
    rag_system = ClinicalRAGSystem()
    rag_system.add_mimic_cases_to_rag(mimic_cases)
    
    llm_responder = LLMClinicalResponder(rag_system)
    
    original_generate_questions = agent_creator.generate_questions_for_agent
    
    def enhanced_generate_questions(agent, case):
        original_questions = original_generate_questions(agent, case)
        
        enhanced_questions = []
        for question in original_questions:
            enhanced_question = f"{question} (Consider clinical context and similar cases)"
            enhanced_questions.append(enhanced_question)
        
        return enhanced_questions
    
    agent_creator.generate_questions_for_agent = enhanced_generate_questions
    agent_creator.rag_system = rag_system
    agent_creator.llm_responder = llm_responder
    
    print("RAG + LLM integration complete!")
    return agent_creator

def test_rag_llm_integration():
    
    print("Testing RAG + LLM Integration")
    
    test_cases = [
        {
            'patient_id': 'RAG_001',
            'symptoms': 'chest pain, shortness of breath',
            'diagnoses': ['myocardial infarction'],
            'clinical_notes': ['patient reports family history of heart disease'],
            'admission_info': {'diagnosis': 'acute myocardial infarction'}
        }
    ]
    
    rag = ClinicalRAGSystem()
    rag.add_mimic_cases_to_rag(test_cases)
    
    docs = rag.retrieve_relevant_docs(['chest', 'pain'], top_k=2)
    print(f"RAG retrieved {len(docs)} relevant documents")
    
    llm = LLMClinicalResponder(rag)
    response = llm.generate_clinical_response(
        "What about family history?", 
        test_cases[0], 
        ['family', 'history']
    )
    print(f"LLM generated response: {response[:100]}...")

if __name__ == "__main__":
    test_rag_llm_integration()