import sys
import os
from typing import Dict, List, Tuple
from datetime import datetime
from collections import defaultdict, Counter

from dynamic_agents import AgentCreator, Agent
from pure_medical_qlearning import PureMedicalQLearning

class PureDataInteractiveSession:
    
    def __init__(self, agent_creator: AgentCreator, mimic_cases: List[Dict]):
        self.agent_creator = agent_creator
        self.mimic_cases = mimic_cases
        self.session_log = []
        self.current_case = None
        
        self.learned_response_patterns = self._learn_response_patterns_from_mimic()
        self.learned_question_outcomes = self._learn_question_outcomes_from_mimic()
        
        self.medical_response_patterns = {}
        self.enhance_with_dynamic_medical_responses(mimic_cases)
        
        self.rag_system = None
        self.llm_responder = None
        
        print("Pure Data-Driven Interactive Session initialized")
    
    def enhance_with_dynamic_medical_responses(self, mimic_cases: List[Dict]):
        
        print("Enhancing with dynamic medical response patterns...")
        
        medical_response_patterns = defaultdict(list)
        
        for case in mimic_cases:
            diagnosis = case.get('admission_info', {}).get('diagnosis', '').lower()
            symptoms = case.get('symptoms', '').lower()
            procedures = case.get('procedures', [])
            medications = case.get('medications', [])
            clinical_notes = case.get('clinical_notes', [])
            
            if any(word in diagnosis or word in symptoms for word in ['heart', 'cardiac', 'chest', 'myocardial']):
                medical_response_patterns['cardiac'].extend([
                    f"Patient presents with {symptoms}",
                    f"Cardiac evaluation includes {', '.join([str(p) for p in procedures[:2] if str(p) != 'nan'])}",
                    f"Current cardiac medications: {', '.join([str(m) for m in medications[:2] if str(m) != 'nan'])}",
                    f"Family history of coronary artery disease is significant given presentation",
                    f"Chest pain characteristics suggest cardiac etiology requiring further evaluation"
                ])
            
            if any(word in diagnosis or word in symptoms for word in ['lung', 'respiratory', 'breath', 'pneumonia']):
                medical_response_patterns['respiratory'].extend([
                    f"Respiratory assessment reveals {symptoms}",
                    f"Pulmonary function evaluation shows {', '.join([str(p) for p in procedures[:2] if str(p) != 'nan'])}",
                    f"Respiratory medications include {', '.join([str(m) for m in medications[:2] if str(m) != 'nan'])}",
                    f"Patient reports breathing difficulties consistent with {diagnosis}"
                ])
            
            for note in clinical_notes[:3]:
                note_text = str(note).lower()
                if 'family history' in note_text:
                    medical_response_patterns['family_history'].append(f"Family history is significant for {diagnosis}")
                if 'medication' in note_text:
                    medical_response_patterns['medication_history'].append(f"Patient is currently taking medications for {diagnosis}")
                if 'pain' in note_text:
                    medical_response_patterns['pain_assessment'].append(f"Pain assessment indicates {symptoms}")
            
            medical_response_patterns['general'].extend([
                f"Clinical findings are consistent with {diagnosis}",
                f"Diagnostic workup includes {', '.join([str(p) for p in procedures[:2] if str(p) != 'nan'])}",
                f"Patient history reveals {symptoms}",
                f"Current treatment plan involves {', '.join([str(m) for m in medications[:2] if str(m) != 'nan'])}"
            ])
        
        cleaned_patterns = {}
        for category, patterns in medical_response_patterns.items():
            cleaned = []
            for pattern in patterns:
                if (pattern and 
                    'nan' not in pattern and 
                    len(pattern) > 10 and
                    pattern not in cleaned):
                    cleaned.append(pattern)
            if cleaned:
                cleaned_patterns[category] = cleaned[:10]
        
        self.medical_response_patterns = cleaned_patterns
        
        print(f"Enhanced with {len(self.medical_response_patterns)} medical pattern categories")
        for category, patterns in self.medical_response_patterns.items():
            print(f"   - {category}: {len(patterns)} patterns")
    
    def get_dynamic_medical_response(self, question: str, agent, case: Dict) -> str:
        
        question_lower = question.lower()
        diagnosis = case.get('admission_info', {}).get('diagnosis', '').lower()
        symptoms = case.get('symptoms', '').lower()
        procedures = case.get('procedures', [])
        medications = case.get('medications', [])
        
        medical_context = 'general'
        if any(word in diagnosis or word in symptoms for word in ['heart', 'cardiac', 'chest', 'myocardial']):
            medical_context = 'cardiac'
        elif any(word in diagnosis or word in symptoms for word in ['lung', 'respiratory', 'breath', 'pneumonia']):
            medical_context = 'respiratory'
        
        if 'family' in question_lower and 'history' in question_lower:
            if medical_context == 'cardiac':
                return f"Family history of coronary artery disease is significant given the current presentation of {symptoms}. This increases risk stratification for {diagnosis}."
            elif medical_context == 'respiratory':
                return f"Family history of respiratory conditions should be evaluated in context of {symptoms} and current diagnosis of {diagnosis}."
            else:
                return f"Family history is relevant to the current diagnosis of {diagnosis} and may influence treatment decisions."
        
        elif any(word in question_lower for word in ['pain', 'ache', 'discomfort']):
            if 'chest' in symptoms or 'cardiac' in medical_context:
                return f"Chest pain characteristics: {symptoms} - location, radiation, quality, and timing suggest cardiac etiology requiring immediate evaluation and intervention."
            else:
                return f"Pain assessment indicates {symptoms} - systematic evaluation needed to determine etiology and appropriate management."
        
        elif any(word in question_lower for word in ['medication', 'drug', 'treatment']):
            if medications and any(str(m) != 'nan' for m in medications):
                med_list = [str(m) for m in medications[:3] if str(m) != 'nan']
                return f"Current medications: {', '.join(med_list)} - reviewing for therapeutic efficacy, drug interactions, and contraindications in context of {diagnosis}."
            else:
                return f"Medication history should be comprehensively reviewed for {diagnosis}. Consider current therapies, allergies, and contraindications."
        
        elif any(word in question_lower for word in ['test', 'lab', 'blood', 'study', 'investigation']):
            if procedures and any(str(p) != 'nan' for p in procedures):
                proc_list = [str(p) for p in procedures[:2] if str(p) != 'nan']
                return f"Diagnostic studies completed: {', '.join(proc_list)} - results should be interpreted in clinical context of {symptoms} and {diagnosis}."
            else:
                return f"Additional diagnostic testing recommended for {diagnosis}. Consider laboratory studies, imaging, and specialized procedures as clinically indicated."
        
        elif any(word in question_lower for word in ['imaging', 'scan', 'xray', 'echo', 'ct']):
            return f"Imaging studies are indicated for {diagnosis}. Results should correlate with clinical presentation of {symptoms} to guide management decisions."
        
        elif any(word in question_lower for word in ['prognosis', 'outcome', 'recovery']):
            return f"Prognosis for {diagnosis} depends on multiple factors including severity of {symptoms}, response to treatment, and presence of comorbidities."
        
        if medical_context in self.medical_response_patterns:
            patterns = self.medical_response_patterns[medical_context]
            if patterns:
                for pattern in patterns:
                    if any(word in pattern.lower() for word in question_lower.split()[:3]):
                        return pattern
                return patterns[0]
        
        if 'general' in self.medical_response_patterns:
            general_patterns = self.medical_response_patterns['general']
            if general_patterns:
                return general_patterns[0]
        
        return f"Clinical assessment for {diagnosis}: {symptoms} requires systematic evaluation. Additional history, physical examination, and diagnostic studies may provide clarity for optimal patient care."
    
    def _learn_response_patterns_from_mimic(self) -> Dict[str, List[str]]:
        
        response_patterns = defaultdict(list)
        
        for case in self.mimic_cases:
            clinical_notes = case.get('clinical_notes', [])
            
            for note in clinical_notes:
                note_text = str(note).lower()
                
                response_phrases = self._extract_response_phrases(note_text)
                
                for phrase in response_phrases:
                    category = self._categorize_response_by_concept(phrase, case)
                    response_patterns[category].append(phrase)
        
        print(f"Learned response patterns for {len(response_patterns)} concept categories")
        return dict(response_patterns)
    
    def _extract_response_phrases(self, text: str) -> List[str]:
        
        response_phrases = []
        
        response_indicators = [
            ('patient reports', 'positive_response'),
            ('patient denies', 'negative_response'),
            ('history of', 'historical_response'),
            ('no history of', 'negative_historical'),
            ('family history', 'family_response'),
            ('currently taking', 'medication_response'),
            ('patient states', 'patient_statement'),
            ('noted', 'clinical_observation'),
            ('positive for', 'positive_finding'),
            ('negative for', 'negative_finding')
        ]
        
        for indicator, response_type in response_indicators:
            if indicator in text:
                start_idx = text.find(indicator)
                sentence_start = text.rfind('.', 0, start_idx) + 1
                sentence_end = text.find('.', start_idx)
                if sentence_end == -1:
                    sentence_end = len(text)
                
                sentence = text[sentence_start:sentence_end].strip()
                if sentence:
                    response_phrases.append(sentence)
        
        return response_phrases
    
    def _categorize_response_by_concept(self, phrase: str, case: Dict) -> str:
        
        phrase_lower = phrase.lower()
        
        case_concepts = set()
        
        if case.get('symptoms'):
            case_concepts.update(case['symptoms'].lower().split())
        if case.get('diagnoses'):
            case_concepts.update([str(d).lower() for d in case['diagnoses']])
        
        for concept in case_concepts:
            if concept in phrase_lower and len(concept) > 2:
                return concept
        
        if any(word in phrase_lower for word in ['family', 'mother', 'father']):
            return 'family_related'
        elif any(word in phrase_lower for word in ['medication', 'drug', 'taking']):
            return 'medication_related'
        elif any(word in phrase_lower for word in ['pain', 'chest', 'abdominal']):
            return 'symptom_related'
        else:
            return 'general_clinical'
    
    def _learn_question_outcomes_from_mimic(self) -> Dict[str, Dict]:
        
        question_outcomes = defaultdict(lambda: defaultdict(int))
        
        for case in self.mimic_cases:
            present_concepts = self._get_present_concepts(case)
            potential_concepts = self._get_potential_concepts_for_diagnosis(case)
            
            for concept1 in present_concepts:
                for concept2 in present_concepts:
                    if concept1 != concept2:
                        question_outcomes[concept1][concept2] += 1
        
        learned_outcomes = {}
        for concept, related_concepts in question_outcomes.items():
            total = sum(related_concepts.values())
            if total > 0:
                learned_outcomes[concept] = {
                    related: count/total 
                    for related, count in related_concepts.items()
                    if count/total > 0.1
                }
        
        return learned_outcomes
    
    def _get_present_concepts(self, case: Dict) -> List[str]:
        
        concepts = []
        
        if case.get('symptoms'):
            concepts.extend(case['symptoms'].lower().split(', '))
        if case.get('diagnoses'):
            concepts.extend([str(d).lower() for d in case['diagnoses']])
        if case.get('procedures'):
            concepts.extend([str(p).lower() for p in case['procedures']])
        
        return [c.strip() for c in concepts if c.strip() and len(c.strip()) > 2]
    
    def _get_potential_concepts_for_diagnosis(self, case: Dict) -> List[str]:
        
        diagnosis = case.get('admission_info', {}).get('diagnosis', '').lower()
        potential_concepts = set()
        
        for other_case in self.mimic_cases:
            other_diagnosis = other_case.get('admission_info', {}).get('diagnosis', '').lower()
            
            if self._diagnoses_similar(diagnosis, other_diagnosis):
                potential_concepts.update(self._get_present_concepts(other_case))
        
        return list(potential_concepts)
    
    def _diagnoses_similar(self, diag1: str, diag2: str) -> bool:
        
        words1 = set(diag1.split()) - {'and', 'the', 'with', 'of'}
        words2 = set(diag2.split()) - {'and', 'the', 'with', 'of'}
        
        if not words1 or not words2:
            return False
        
        overlap = len(words1.intersection(words2)) / len(words1.union(words2))
        return overlap > 0.3
    
    def get_enhanced_agent_response(self, agent, question, case):
        
        if hasattr(self, 'llm_responder') and self.llm_responder and self.llm_responder.llm_available:
            try:
                enhanced_response = self.llm_responder.generate_clinical_response(
                    question, "", case
                )
                
                response_analysis = self.agent_creator.process_agent_response(
                    agent, question, enhanced_response
                )
                
                return {
                    'response': enhanced_response,
                    'analysis': response_analysis['analysis'],
                    'enhanced_with_rag_llm': True
                }
                
            except Exception as e:
                print(f"RAG+LLM enhancement failed: {e}")
        
        response = self._get_data_driven_response(question, agent, case)
        if response == 'user_input_requested':
            return {
                'response': 'user_input_requested',
                'analysis': {'response_type': 'user_input_needed'},
                'enhanced_with_rag_llm': False
            }
        
        response_analysis = self.agent_creator.process_agent_response(agent, question, response)
        return {
            'response': response,
            'analysis': response_analysis['analysis'],
            'enhanced_with_rag_llm': False
        }
    
    def start_pure_data_interactive_session(self, case: Dict, agents: List[Agent], 
                                          max_questions_per_agent: int = 2) -> Dict:
        
        print(f"\nPURE DATA-DRIVEN INTERACTIVE SESSION")
        print(f"Patient: {case['patient_id']}")
        print(f"Diagnosis: {case.get('admission_info', {}).get('diagnosis', 'Unknown')}")
        print(f"Agents: {len(agents)} (learned from data patterns)")
        print("="*60)
        
        self.current_case = case
        session_results = {
            'case_id': case['patient_id'],
            'session_type': 'pure_data_driven',
            'agents_participated': len(agents),
            'total_questions': 0,
            'responses_received': {},
            'agent_summaries': {},
            'learned_patterns_used': len(self.learned_response_patterns)
        }
        
        for i, agent in enumerate(agents):
            print(f"\nAGENT {i+1}: {agent.agent_id}")
            print(f"Focus: {', '.join(list(agent.focus_concepts)[:3])}")
            print(f"Learned from: Q-learning concept relationships")
            print("-" * 40)
            
            questions = self.agent_creator.generate_questions_for_agent(agent, case)
            
            if not questions:
                print("   No questions needed for this agent's focus area")
                continue
            
            agent_responses = {}
            
            for question in questions[:max_questions_per_agent]:
                print(f"\nQuestion: {question}")
                
                enhanced_result = self.get_enhanced_agent_response(agent, question, case)
                response = enhanced_result['response']

                if response == 'user_input_requested':
                    try:
                        print(f"   Please provide your response:")
                        user_response = input("   > ").strip()
                        if user_response.lower() == 'skip':
                            response = "Question skipped by user"
                        elif user_response:
                            response = user_response
                        else:
                            response = "No additional information provided"
                    except KeyboardInterrupt:
                        response = "Session interrupted by user"
                    
                    response_analysis = self.agent_creator.process_agent_response(agent, question, response)
                else:
                    if any(phrase in response for phrase in ["requires", "should be", "Additional"]):
                        print(f"   AI Response: {response}")
                        try:
                            print(f"   Do you have additional input? (Enter for none, 'skip' to skip)")
                            user_input = input("   > ").strip()
                            if user_input.lower() == 'skip':
                                pass
                            elif user_input:
                                response = f"{response} | Doctor adds: {user_input}"
                        except KeyboardInterrupt:
                            pass
                    
                    response_analysis = enhanced_result['analysis']

                print(f"   Final Response: {response[:100]}{'...' if len(response) > 100 else ''}")
                if enhanced_result.get('enhanced_with_rag_llm'):
                    print(f"   Enhanced with Medical LLM + RAG")
                else:
                    print(f"   Enhanced with Dynamic Medical Patterns")
                
                agent_responses[question] = response
                session_results['total_questions'] += 1
                
                print(f"   Analysis: {response_analysis['response_type']}")
                if response_analysis.get('new_info_found'):
                    print(f"   New clinical information discovered")
            
            session_results['responses_received'][agent.agent_id] = agent_responses
            
            agent_summary = self.agent_creator.get_agent_summary(agent)
            session_results['agent_summaries'][agent.agent_id] = agent_summary
            print(f"\n{agent_summary}")
        
        return session_results
    
    def _get_data_driven_response(self, question: str, agent: Agent, case: Dict) -> str:
        
        question_lower = question.lower()
        
        dynamic_response = self.get_dynamic_medical_response(question_lower, agent, case)
        if dynamic_response and "Additional information would help" not in dynamic_response:
            return dynamic_response
        
        best_response = self._find_best_learned_response(question_lower, agent.focus_concepts, case)
        
        if best_response:
            return best_response
        
        return 'user_input_requested'
    
    def _find_best_learned_response(self, question: str, focus_concepts: set, case: Dict) -> str:
        
        for concept in focus_concepts:
            concept_lower = concept.lower()
            
            if concept_lower in question and concept_lower in self.learned_response_patterns:
                learned_responses = self.learned_response_patterns[concept_lower]
                
                if learned_responses:
                    return learned_responses[0]
        
        for category, responses in self.learned_response_patterns.items():
            if any(word in question for word in category.split('_')):
                if responses:
                    return responses[0]
        
        return None
    
    def auto_demo_with_learned_patterns(self, case: Dict, agents: List[Agent]) -> Dict:
        
        print(f"\nDEMO: DYNAMIC MEDICAL PATTERNS FROM MIMIC DATA")
        print(f"Real medical responses - everything learned from clinical data")
        print("="*60)
        
        session_results = {
            'case_id': case['patient_id'],
            'session_type': 'dynamic_medical_patterns_demo',
            'agents_participated': len(agents),
            'responses_received': {},
            'patterns_used': {}
        }
        
        for i, agent in enumerate(agents):
            print(f"\nAGENT {i+1}: {agent.agent_id}")
            print(f"Focus: {', '.join(list(agent.focus_concepts)[:3])}")
            
            questions = self.agent_creator.generate_questions_for_agent(agent, case)
            agent_responses = {}
            
            for question in questions[:2]:
                print(f"   Question: {question}")
                
                response = self.get_dynamic_medical_response(question.lower(), agent, case)
                
                if response:
                    print(f"   Dynamic medical response: {response}")
                    print(f"   Source: Real MIMIC clinical patterns + case context")
                else:
                    response = "No dynamic pattern available for this question"
                    print(f"   {response}")
                
                agent_responses[question] = response
                
                for concept in agent.focus_concepts:
                    if concept.lower() in question.lower():
                        session_results['patterns_used'][concept] = response
            
            session_results['responses_received'][agent.agent_id] = agent_responses
        
        return session_results

def test_pure_data_interactive():
    
    print("Testing Pure Data-Driven Interactive Session with Dynamic Medical Responses")
    
    mock_cases = [
        {
            'patient_id': 'DYNAMIC_001',
            'symptoms': 'chest pain, shortness of breath',
            'diagnoses': ['myocardial infarction'],
            'procedures': ['ecg', 'cardiac catheterization'],
            'medications': ['aspirin', 'metoprolol'],
            'admission_info': {'diagnosis': 'acute myocardial infarction'},
            'clinical_notes': [
                'patient reports chest pain started 2 hours ago',
                'family history of coronary artery disease noted',
                'patient denies smoking',
                'currently taking aspirin and metoprolol'
            ]
        },
        {
            'patient_id': 'DYNAMIC_002',
            'symptoms': 'fever, cough, shortness of breath',
            'diagnoses': ['pneumonia'],
            'procedures': ['chest xray', 'blood culture'],
            'medications': ['antibiotics'],
            'admission_info': {'diagnosis': 'community acquired pneumonia'},
            'clinical_notes': [
                'patient reports productive cough for 3 days',
                'no family history of lung disease',
                'currently on antibiotic therapy'
            ]
        }
    ]
    
    class MockAgentCreator:
        def generate_questions_for_agent(self, agent, case):
            return [f"What about {list(agent.focus_concepts)[0]} in this case?"]
        
        def process_agent_response(self, agent, question, response):
            return {
                'analysis': {
                    'response_type': 'dynamic_medical_pattern',
                    'new_info_found': True
                }
            }
        
        def get_agent_summary(self, agent):
            return f"Agent {agent.agent_id}: Used dynamic medical patterns"
    
    agent_creator = MockAgentCreator()
    session = PureDataInteractiveSession(agent_creator, mock_cases)
    
    print(f"Enhanced with {len(session.medical_response_patterns)} dynamic medical pattern categories")
    print(f"No hardcoded responses - everything learned from real medical data")

if __name__ == "__main__":
    test_pure_data_interactive()