import re
from collections import defaultdict, Counter
from typing import Dict, List, Tuple

class EnhancedMIMICLearning:
    
    def __init__(self):
        self.response_patterns = defaultdict(list)
        self.clinical_relationships = defaultdict(Counter)
        self.symptom_to_findings = defaultdict(list)
        self.diagnosis_to_treatments = defaultdict(list)
        
    def extract_rich_patterns_from_cases(self, mimic_cases: List[Dict]) -> Dict[str, List[str]]:
        
        print(f"Extracting rich patterns from {len(mimic_cases)} cases...")
        
        for case in mimic_cases:
            symptoms = self._extract_symptoms(case)
            diagnoses = self._extract_diagnoses(case)
            treatments = self._extract_treatments(case)
            
            for symptom in symptoms:
                for diagnosis in diagnoses:
                    self.clinical_relationships[symptom][diagnosis] += 1
                    self.symptom_to_findings[symptom].append(f"Could indicate {diagnosis}")
            
            for diagnosis in diagnoses:
                for treatment in treatments:
                    self.diagnosis_to_treatments[diagnosis].append(f"Typically treated with {treatment}")
            
            self._learn_response_patterns_from_case_structure(case)
        
        learned_patterns = self._generate_response_patterns()
        
        print(f"Learned response patterns for {len(learned_patterns)} concept categories")
        return learned_patterns
    
    def _extract_symptoms(self, case: Dict) -> List[str]:
        symptoms = []
        
        if case.get('symptoms'):
            symptoms.extend([s.strip().lower() for s in case['symptoms'].split(', ') if s.strip()])
        
        if case.get('clinical_notes'):
            for note in case['clinical_notes']:
                note_text = str(note).lower()
                if 'pain' in note_text:
                    symptoms.append('pain')
                if 'chest pain' in note_text:
                    symptoms.append('chest pain')
                if 'shortness of breath' in note_text:
                    symptoms.append('shortness of breath')
                if 'nausea' in note_text:
                    symptoms.append('nausea')
                if 'fever' in note_text:
                    symptoms.append('fever')
        
        return list(set(symptoms))
    
    def _extract_diagnoses(self, case: Dict) -> List[str]:
        diagnoses = []
        
        if case.get('diagnoses'):
            diagnoses.extend([str(d).lower() for d in case['diagnoses'] if str(d) != 'nan'])
        
        if case.get('admission_info', {}).get('diagnosis'):
            diagnoses.append(case['admission_info']['diagnosis'].lower())
        
        return list(set(diagnoses))
    
    def _extract_treatments(self, case: Dict) -> List[str]:
        treatments = []
        
        if case.get('procedures'):
            treatments.extend([str(p).lower() for p in case['procedures'] if str(p) != 'nan'])
        
        if case.get('medications'):
            treatments.extend([str(m).lower() for m in case['medications'][:5] if str(m) != 'nan'])
        
        return list(set(treatments))
    
    def _learn_response_patterns_from_case_structure(self, case: Dict):
        
        if 'family' in str(case.get('clinical_notes', '')).lower():
            self.response_patterns['family_history'].append("Family history is significant")
            self.response_patterns['family_history'].append("Positive family history noted")
        
        medications = case.get('medications', [])
        if medications:
            for med in medications[:3]:
                self.response_patterns['medication_history'].append(f"Patient is taking {med}")
                self.response_patterns['medication_history'].append(f"Currently prescribed {med}")
        
        symptoms = case.get('symptoms', '')
        if symptoms:
            for symptom in symptoms.split(', ')[:3]:
                self.response_patterns['symptom_assessment'].append(f"Patient reports {symptom}")
                self.response_patterns['symptom_assessment'].append(f"{symptom} is present")
        
        procedures = case.get('procedures', [])
        if procedures:
            for proc in procedures[:3]:
                self.response_patterns['test_results'].append(f"{proc} was performed")
                self.response_patterns['test_results'].append(f"{proc} shows abnormal findings")
    
    def _generate_response_patterns(self) -> Dict[str, List[str]]:
        
        for symptom, diagnoses in self.clinical_relationships.items():
            most_common = diagnoses.most_common(2)
            for diagnosis, count in most_common:
                if count > 0:
                    self.response_patterns[symptom].append(f"Could be related to {diagnosis}")
                    self.response_patterns[symptom].append(f"Consistent with {diagnosis}")
        
        for symptom, findings in self.symptom_to_findings.items():
            self.response_patterns[symptom].extend(findings[:2])
        
        for diagnosis, treatments in self.diagnosis_to_treatments.items():
            self.response_patterns[diagnosis].extend(treatments[:2])
        
        return dict(self.response_patterns)
    
    def get_contextual_response(self, question: str, patient_case: Dict) -> str:
        
        question_lower = question.lower()
        
        best_response = None
        
        if 'family' in question_lower and 'family_history' in self.response_patterns:
            responses = self.response_patterns['family_history']
            if responses:
                best_response = responses[0]
        elif 'pain' in question_lower and 'pain' in self.response_patterns:
            responses = self.response_patterns['pain']
            if responses:
                best_response = responses[0]
        elif 'medication' in question_lower and 'medication_history' in self.response_patterns:
            responses = self.response_patterns['medication_history']
            if responses:
                best_response = responses[0]
        
        return best_response
    
    def _extract_all_concepts(self, case: Dict) -> List[str]:
        concepts = []
        
        if case.get('symptoms'):
            concepts.extend(case['symptoms'].split(', '))
        if case.get('diagnoses'):
            concepts.extend([str(d) for d in case['diagnoses'] if str(d) != 'nan'])
        if case.get('medications'):
            concepts.extend([str(m) for m in case['medications'][:3] if str(m) != 'nan'])
            
        return [c.strip().lower() for c in concepts if c.strip()]

def enhance_interactive_session_learning(session_instance, mimic_cases):
    
    enhanced_learner = EnhancedMIMICLearning()
    
    enhanced_patterns = enhanced_learner.extract_rich_patterns_from_cases(mimic_cases)
    session_instance.learned_response_patterns = enhanced_patterns
    session_instance.enhanced_learner = enhanced_learner
    
    original_method = session_instance._find_best_learned_response
    
    def enhanced_find_response(question, focus_concepts, case):
        response = enhanced_learner.get_contextual_response(question, case)
        if response:
            return response
        
        try:
            question_lower = question.lower()
            for concept in focus_concepts:
                concept_lower = concept.lower()
                if concept_lower in session_instance.learned_response_patterns:
                    responses = session_instance.learned_response_patterns[concept_lower]
                    if responses:
                        return responses[0]
            
            for category, responses in session_instance.learned_response_patterns.items():
                if any(word in question_lower for word in category.split('_')):
                    if responses:
                        return responses[0]
            
            return None
            
        except Exception as e:
            return None
    
    session_instance._find_best_learned_response = enhanced_find_response
    
    return session_instance

def enhance_interactive_session_with_rag_llm(session_instance, mimic_cases, agent_creator):
    
    try:
        from rag_llm_integration import ClinicalRAGSystem, LLMClinicalResponder
        
        print("Integrating RAG + LLM into interactive session...")
        
        rag_system = ClinicalRAGSystem()
        rag_system.add_mimic_cases_to_rag(mimic_cases)
        
        llm_responder = LLMClinicalResponder(rag_system)
        
        original_find_response = session_instance._find_best_learned_response
        
        def rag_llm_enhanced_response(question, focus_concepts, case):
            learned_response = original_find_response(question, focus_concepts, case)
            
            if learned_response:
                try:
                    enhanced_response = llm_responder.generate_clinical_response(
                        question, case, list(focus_concepts)
                    )
                    
                    combined_response = f"{learned_response}\n\nRAG+LLM Enhancement: {enhanced_response[:200]}..."
                    return combined_response
                    
                except Exception as e:
                    print(f"RAG+LLM enhancement failed: {e}")
                    return learned_response
            else:
                try:
                    rag_llm_response = llm_responder.generate_clinical_response(
                        question, case, list(focus_concepts)
                    )
                    return f"RAG+LLM Response: {rag_llm_response}"
                except Exception as e:
                    print(f"RAG+LLM failed: {e}")
                    return "Unable to generate enhanced response"
        
        session_instance._find_best_learned_response = rag_llm_enhanced_response
        session_instance.rag_system = rag_system
        session_instance.llm_responder = llm_responder
        
        print(f"RAG+LLM integrated! Using {len(mimic_cases)} cases for knowledge base")
        print(f"Medical LLM available: {llm_responder.llm_available}")
        
        return session_instance
        
    except ImportError:
        print("RAG+LLM components not available - skipping enhancement")
        return session_instance

def test_enhanced_learning():
    
    test_cases = [
        {
            'symptoms': 'chest pain, shortness of breath',
            'diagnoses': ['myocardial infarction'],
            'procedures': ['ecg', 'cardiac catheterization'],
            'medications': ['aspirin', 'metoprolol'],
            'clinical_notes': ['patient reports chest pain started 2 hours ago', 'family history of heart disease']
        },
        {
            'symptoms': 'fever, cough',
            'diagnoses': ['pneumonia'],
            'procedures': ['chest xray'],
            'medications': ['antibiotics'],
            'clinical_notes': ['patient experiencing productive cough', 'no family history of lung disease']
        }
    ]
    
    learner = EnhancedMIMICLearning()
    patterns = learner.extract_rich_patterns_from_cases(test_cases)
    
    print("Testing Enhanced MIMIC Learning")
    print(f"Extracted {len(patterns)} pattern categories")
    
    for category, responses in patterns.items():
        if responses:
            print(f"  {category}: {responses[0]}")
    
    test_question = "What about family history?"
    test_case = test_cases[0]
    response = learner.get_contextual_response(test_question, test_case)
    print(f"Test response: {response}")

if __name__ == "__main__":
    test_enhanced_learning()