import numpy as np
from typing import Dict, List, Set, Optional
from collections import defaultdict, Counter
from dataclasses import dataclass
import logging
from datetime import datetime

from pure_medical_qlearning import PureMedicalQLearning

logger = logging.getLogger(__name__)

@dataclass
class Agent:
    agent_id: str
    focus_concepts: Set[str]
    discovery_patterns: List[str]
    effectiveness: float
    context: Dict
    connections: Dict[str, float]
    success_history: List[float]
    current_questions: List[str] = None
    responses_received: Dict[str, str] = None
    analysis_updated: bool = False
    
    def __post_init__(self):
        if self.current_questions is None:
            self.current_questions = []
        if self.responses_received is None:
            self.responses_received = {}

class AgentCreator:
    
    def __init__(self, qlearning_system, mimic_cases: List[Dict]):
        self.qlearning_system = qlearning_system
        self.mimic_cases = mimic_cases
        
        self.concept_clusters = {}
        self.missing_patterns = defaultdict(Counter)
        self.doctor_gaps = defaultdict(int)
        self.agents = {}
        
        self.cluster_threshold = 0.15
        self.agent_threshold = 0.3
        
        logger.info("COMPLETE FIXED Dynamic Agent Creator - pure data learning")
    
    def learn_missing_patterns(self) -> Dict[str, List[str]]:
        
        print("Learning missing patterns from MIMIC...")
        
        for case in self.mimic_cases:
            present = set()
            potential = set()
            
            if case.get('symptoms'):
                present.update(case['symptoms'].split(', '))
            if case.get('diagnoses'):
                present.update([str(d).strip().lower() for d in case['diagnoses'][:3] 
                               if str(d).strip().lower() != 'nan'])
            
            condition = case.get('admission_info', {}).get('diagnosis', 'unknown')
            potential = self._find_related_concepts(condition)
            
            missing = potential - present
            
            for missing_concept in missing:
                self.missing_patterns[condition][missing_concept] += 1
        
        learned = {}
        for condition, missing_counter in self.missing_patterns.items():
            frequent = [concept for concept, count in missing_counter.items() if count >= 3]
            if frequent:
                learned[condition] = frequent
        
        print(f"Learned patterns for {len(learned)} conditions")
        return learned
    
    def _find_related_concepts(self, condition: str) -> Set[str]:
        
        related = set()
        condition_lower = str(condition).lower()
        
        for case in self.mimic_cases:
            case_diag = str(case.get('admission_info', {}).get('diagnosis', '')).lower()
            
            if condition_lower in case_diag or case_diag in condition_lower:
                if case.get('symptoms'):
                    related.update(case['symptoms'].split(', '))
                if case.get('diagnoses'):
                    related.update([str(d).strip().lower() for d in case['diagnoses'] 
                                   if str(d).strip().lower() != 'nan'])
        
        return related
    
    def discover_clusters(self) -> Dict[str, Set[str]]:
        
        print("Discovering clusters from Q-learning...")
        
        if not hasattr(self.qlearning_system, 'Q') or self.qlearning_system.Q is None:
            print("Q-matrix not available")
            return {}
        
        q_matrix = np.array(self.qlearning_system.Q)
        concepts = self.qlearning_system.medical_concepts
        
        clusters = {}
        cluster_id = 0
        processed = set()
        
        for i, concept1 in enumerate(concepts):
            if concept1 in processed:
                continue
            
            cluster_concepts = {concept1}
            
            for j, concept2 in enumerate(concepts):
                if i != j and concept2 not in processed:
                    connection = max(q_matrix[i, j], q_matrix[j, i])
                    
                    if connection > self.cluster_threshold:
                        cluster_concepts.add(concept2)
            
            if len(cluster_concepts) > 1:
                clusters[f"cluster_{cluster_id}"] = cluster_concepts
                processed.update(cluster_concepts)
                cluster_id += 1
        
        self.concept_clusters = clusters
        print(f"Found {len(clusters)} concept clusters")
        return clusters
    
    def analyze_doctor_gaps(self, doctor_notes: str, case: Dict) -> Set[str]:
        
        if not doctor_notes:
            return set()
        
        notes_lower = doctor_notes.lower()
        mentioned = set()
        
        for concept in self.qlearning_system.medical_concepts:
            if concept in notes_lower:
                mentioned.add(concept)
        
        case_diag = case.get('admission_info', {}).get('diagnosis', 'unknown')
        suggested = self._find_related_concepts(case_diag)
        
        gaps = suggested - mentioned
        
        for gap in gaps:
            self.doctor_gaps[gap] += 1
        
        return gaps
    
    def create_agents(self, case: Dict, doctor_notes: str = "") -> List[Agent]:
        
        print(f"Creating dynamic agents...")
        
        created = []
        
        patient_concepts = self._extract_concepts(case)
        q_missing_raw = self.qlearning_system.discover_missing_concepts(patient_concepts)
        
        q_missing_unique = []
        seen_concepts = set()
        patient_concepts_set = set([c.lower() for c in patient_concepts])
        
        for item in q_missing_raw:
            concept = item['concept'].lower()
            if concept not in seen_concepts and concept not in patient_concepts_set:
                q_missing_unique.append(item)
                seen_concepts.add(concept)
        
        print(f"Found {len(q_missing_unique)} unique missing concepts (filtered from {len(q_missing_raw)})")
        
        gaps = self.analyze_doctor_gaps(doctor_notes, case)
        
        missing_groups = self._group_by_clusters(
            [item['concept'] for item in q_missing_unique] + list(gaps)
        )
        
        for group_name, group_concepts in missing_groups.items():
            if len(group_concepts) >= 1:
                
                agent = self._create_agent(group_name, group_concepts, case, q_missing_unique)
                
                if agent.effectiveness > self.agent_threshold:
                    created.append(agent)
                    self.agents[agent.agent_id] = agent
        
        print(f"Created {len(created)} agents")
        return created
    
    def _extract_concepts(self, case: Dict) -> List[str]:
        
        concepts = []
        
        if case.get('symptoms'):
            concepts.extend(case['symptoms'].split(', '))
        if case.get('known_info'):
            concepts.extend(case['known_info'])
        
        return [str(c).strip().lower() for c in concepts if str(c).strip() and len(str(c).strip()) > 2]
    
    def _group_by_clusters(self, missing_concepts: List[str]) -> Dict[str, List[str]]:
        
        groups = {}
        ungrouped = list(missing_concepts)
        
        for cluster_name, cluster_concepts in self.concept_clusters.items():
            group = []
            
            for concept in missing_concepts:
                if concept in cluster_concepts:
                    group.append(concept)
                    if concept in ungrouped:
                        ungrouped.remove(concept)
            
            if group:
                groups[cluster_name] = group
        
        for i, concept in enumerate(ungrouped):
            groups[f"single_{i}"] = [concept]
        
        return groups
    
    def _create_agent(self, group_name: str, concepts: List[str], 
                     case: Dict, q_missing: List[Dict]) -> Agent:
        
        effectiveness_scores = []
        connections = {}
        
        for concept in concepts:
            concept_data = next((item for item in q_missing if item['concept'] == concept), None)
            
            if concept_data:
                q_norm = min(concept_data['q_value'] / 100.0, 1.0)
                freq_norm = min(concept_data['mimic_frequency'] / 50.0, 1.0)
                effectiveness = (q_norm + freq_norm) / 2.0
                
                effectiveness_scores.append(effectiveness)
                connections[concept] = concept_data['q_value']
            else:
                effectiveness_scores.append(0.4)
                connections[concept] = 10.0
        
        avg_effectiveness = np.mean(effectiveness_scores) if effectiveness_scores else 0.4
        
        patterns = self._generate_qlearning_patterns(concepts, case)
        
        agent = Agent(
            agent_id=f"agent_{group_name}_{datetime.now().strftime('%H%M%S')}",
            focus_concepts=set(concepts),
            discovery_patterns=patterns,
            effectiveness=avg_effectiveness,
            context={
                'patient_id': case.get('patient_id', 'unknown'),
                'diagnosis': case.get('admission_info', {}).get('diagnosis', 'unknown'),
                'creation_time': datetime.now(),
                'concept_count': len(concepts)
            },
            connections=connections,
            success_history=[]
        )
        
        return agent
    
    def _generate_qlearning_patterns(self, concepts: List[str], case: Dict) -> List[str]:
        
        patterns = []
        
        for concept in concepts:
            if concept in self.qlearning_system.concept_to_index:
                concept_idx = self.qlearning_system.concept_to_index[concept]
                q_matrix = np.array(self.qlearning_system.Q)
                
                q_row = q_matrix[concept_idx, :]
                strong_connections = []
                
                for target_idx, q_value in enumerate(q_row):
                    if q_value > 10:
                        target_concept = self.qlearning_system.index_to_concept[target_idx]
                        strong_connections.append((target_concept, q_value))
                
                if strong_connections:
                    top_connection = max(strong_connections, key=lambda x: x[1])
                    pattern = f"Explore {concept} in relation to {top_connection[0]}"
                    patterns.append(pattern)
                else:
                    pattern = f"Investigate {concept}"
                    patterns.append(pattern)
            else:
                pattern = f"Assess {concept}"
                patterns.append(pattern)
        
        return patterns
    
    def get_recommendations(self, agent: Agent, current_case: Dict) -> List[str]:
        
        recommendations = []
        
        for pattern in agent.discovery_patterns:
            diagnosis = current_case.get('admission_info', {}).get('diagnosis', 'condition')
            
            contextualized = pattern.replace('condition', diagnosis)
            recommendations.append(contextualized)
        
        return recommendations[:3]
    
    def generate_questions_for_agent(self, agent: Agent, case: Dict) -> List[str]:
        
        questions = []
        diagnosis = case.get('admission_info', {}).get('diagnosis', 'condition')
        
        for concept in list(agent.focus_concepts)[:3]:
            if concept in self.qlearning_system.concept_to_index:
                concept_idx = self.qlearning_system.concept_to_index[concept]
                q_matrix = np.array(self.qlearning_system.Q)
                q_row = q_matrix[concept_idx, :]
                
                max_idx = np.argmax(q_row)
                if q_row[max_idx] > 5:
                    connected_concept = self.qlearning_system.index_to_concept[max_idx]
                    
                    question = self._create_qlearning_question(concept, connected_concept, diagnosis)
                    questions.append(question)
                else:
                    question = f"Can you provide more information about {concept} in this case?"
                    questions.append(question)
            else:
                question = f"What additional details about {concept} should we consider?"
                questions.append(question)
        
        agent.current_questions = questions
        return questions
    
    def _create_qlearning_question(self, source_concept: str, target_concept: str, diagnosis: str) -> str:
        
        if source_concept == target_concept:
            return f"Any additional information about {source_concept}?"
        
        question_templates = [
            f"Given the {source_concept}, have you evaluated for {target_concept}?",
            f"In relation to {source_concept}, what about {target_concept}?",
            f"Considering {source_concept}, should we assess {target_concept}?",
            f"Any signs of {target_concept} related to the {source_concept}?"
        ]
        
        return question_templates[0]
    
    def process_agent_response(self, agent: Agent, question: str, response: str) -> Dict:
        
        agent.responses_received[question] = response
        
        response_analysis = self._analyze_response(response, agent.focus_concepts)
        
        if response_analysis['new_info_found']:
            agent.effectiveness += 0.1
            agent.success_history.append(1.0)
        else:
            agent.success_history.append(0.5)
        
        agent.analysis_updated = True
        
        return {
            'agent_id': agent.agent_id,
            'question': question,
            'response': response,
            'analysis': response_analysis,
            'updated_effectiveness': agent.effectiveness
        }
    
    def _analyze_response(self, response: str, focus_concepts: Set[str]) -> Dict:
        
        response_lower = response.lower()
        
        analysis = {
            'new_info_found': False,
            'concepts_mentioned': [],
            'response_type': 'unknown',
            'follow_up_needed': False
        }
        
        info_indicators = ['yes', 'no', 'positive', 'negative', 'history', 'family', 'previous']
        if any(indicator in response_lower for indicator in info_indicators):
            analysis['new_info_found'] = True
        
        for concept in focus_concepts:
            if concept in response_lower:
                analysis['concepts_mentioned'].append(concept)
        
        if any(word in response_lower for word in ['yes', 'positive', 'confirmed']):
            analysis['response_type'] = 'positive'
        elif any(word in response_lower for word in ['no', 'negative', 'denied']):
            analysis['response_type'] = 'negative'
        elif 'unknown' in response_lower or 'unsure' in response_lower:
            analysis['response_type'] = 'uncertain'
            analysis['follow_up_needed'] = True
        
        return analysis
    
    def get_agent_summary(self, agent: Agent) -> str:
        
        if not agent.responses_received:
            return f"Agent {agent.agent_id}: No responses received yet"
        
        summary_parts = [f"Agent {agent.agent_id} Summary:"]
        summary_parts.append(f"Focus: {', '.join(list(agent.focus_concepts)[:3])}")
        summary_parts.append(f"Questions asked: {len(agent.responses_received)}")
        summary_parts.append(f"Effectiveness: {agent.effectiveness:.2f}")
        
        positive_responses = 0
        negative_responses = 0
        
        for question, response in agent.responses_received.items():
            response_lower = response.lower()
            if any(word in response_lower for word in ['yes', 'positive', 'confirmed']):
                positive_responses += 1
            elif any(word in response_lower for word in ['no', 'negative', 'denied']):
                negative_responses += 1
        
        summary_parts.append(f"Positive findings: {positive_responses}")
        summary_parts.append(f"Negative findings: {negative_responses}")
        
        return "\n".join(summary_parts)

def test_fixed_agents():
    
    print("Testing COMPLETE FIXED Dynamic Agent Creation")
    
    class MockQ:
        def __init__(self):
            self.medical_concepts = ['chest', 'pain', 'heart', 'family', 'history']
            self.Q = np.random.rand(5, 5) * 50
            self.concept_to_index = {concept: i for i, concept in enumerate(self.medical_concepts)}
            self.index_to_concept = {i: concept for concept, i in self.concept_to_index.items()}
        
        def discover_missing_concepts(self, concepts):
            return [
                {'concept': 'family', 'q_value': 50.0, 'mimic_frequency': 10},
                {'concept': 'history', 'q_value': 30.0, 'mimic_frequency': 8},
                {'concept': 'heart', 'q_value': 25.0, 'mimic_frequency': 12}
            ]
    
    mock_cases = [
        {
            'patient_id': 'TEST_001',
            'symptoms': 'chest pain, shortness of breath',
            'diagnoses': ['myocardial infarction'],
            'admission_info': {'diagnosis': 'acute MI'}
        }
    ]
    
    mock_q = MockQ()
    creator = AgentCreator(mock_q, mock_cases)
    
    test_case = {
        'patient_id': 'TEST',
        'symptoms': 'chest pain',
        'known_info': ['chest', 'pain'],
        'admission_info': {'diagnosis': 'chest pain'}
    }
    
    agents = creator.create_agents(test_case, "Patient has chest pain")
    
    print(f"Created {len(agents)} COMPLETE FIXED agents")
    for agent in agents:
        print(f"  Agent: {agent.agent_id}")
        print(f"  Focus: {agent.focus_concepts}")
        print(f"  Effectiveness: {agent.effectiveness:.3f}")
        
        questions = creator.generate_questions_for_agent(agent, test_case)
        print(f"  Questions: {questions}")
        
        recs = creator.get_recommendations(agent, test_case)
        print(f"  Recommendations: {recs}")

if __name__ == "__main__":
    test_fixed_agents()