import pandas as pd
import numpy as np
from collections import defaultdict, deque
from dataclasses import dataclass
from typing import Dict, List, Set, Optional, Tuple, Union
from enum import Enum
import json
import logging
from datetime import datetime

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class KnowledgeState(Enum):
   KNOWN = "known"
   UNKNOWN = "unknown"
   INFERRED = "inferred"
   CONFLICTED = "conflicted"

@dataclass
class SemanticNode:
   concept: str
   domain: str
   knowledge_state: KnowledgeState
   confidence: float
   temporal_position: int
   causal_links: Dict[str, float]
   expert_sources: List[str]
   missing_info_details: Dict[str, str]
   
   def __post_init__(self):
       if self.causal_links is None:
           self.causal_links = {}
       if self.expert_sources is None:
           self.expert_sources = []
       if self.missing_info_details is None:
           self.missing_info_details = {}

@dataclass
class ExpertKnowledge:
   expert_id: str
   domain_expertise: Dict[str, float]
   known_concepts: Set[str]
   unknown_concepts: Set[str]
   reasoning_patterns: Dict[str, List[str]]
   temporal_understanding: Dict[str, int]
   
   def __post_init__(self):
       if self.known_concepts is None:
           self.known_concepts = set()
       if self.unknown_concepts is None:
           self.unknown_concepts = set()
       if self.reasoning_patterns is None:
           self.reasoning_patterns = {}
       if self.temporal_understanding is None:
           self.temporal_understanding = {}

class MissingInformationEntity:
   def __init__(self, concept: str, missing_type: str, impact_level: float):
       self.concept = concept
       self.missing_type = missing_type
       self.impact_level = impact_level
       self.related_concepts = []
       self.potential_questions = []
       self.discovery_priority = 0.0
       
   def add_related_concept(self, concept: str, relationship: str):
       self.related_concepts.append((concept, relationship))
       
   def generate_question(self, context: Dict) -> str:
       if self.missing_type == 'expert_gap':
           return f"What is the relationship between {self.concept} and the current symptoms?"
       elif self.missing_type == 'temporal_gap':
           return f"What typically happens before/after {self.concept}?"
       else:
           return f"What additional information about {self.concept} would be helpful?"

class TemporalCausalGraph:
   def __init__(self):
       self.nodes = {}
       self.causal_edges = defaultdict(dict)
       self.temporal_order = defaultdict(list)
       self.missing_entities = {}
       
   def add_semantic_node(self, node: SemanticNode):
       self.nodes[node.concept] = node
       self.temporal_order[node.temporal_position].append(node.concept)
       
       if node.knowledge_state == KnowledgeState.UNKNOWN:
           self.missing_entities[node.concept] = MissingInformationEntity(
               node.concept, 
               'expert_gap', 
               1.0 - node.confidence
           )
   
   def add_causal_relationship(self, from_concept: str, to_concept: str, 
                             strength: float, expert_source: str):
       self.causal_edges[from_concept][to_concept] = strength
       
       if from_concept in self.nodes:
           self.nodes[from_concept].causal_links[to_concept] = strength
           if expert_source not in self.nodes[from_concept].expert_sources:
               self.nodes[from_concept].expert_sources.append(expert_source)
   
   def get_temporal_path(self, start_concept: str, end_concept: str) -> List[str]:
       if start_concept not in self.nodes or end_concept not in self.nodes:
           return []
           
       start_time = self.nodes[start_concept].temporal_position
       end_time = self.nodes[end_concept].temporal_position
       
       queue = deque([(start_concept, [start_concept])])
       visited = set()
       
       while queue:
           current, path = queue.popleft()
           
           if current == end_concept:
               return path
               
           if current in visited:
               continue
           visited.add(current)
           
           for next_concept, strength in self.causal_edges[current].items():
               if (next_concept not in visited and 
                   next_concept in self.nodes and
                   self.nodes[next_concept].temporal_position >= 
                   self.nodes[current].temporal_position):
                   queue.append((next_concept, path + [next_concept]))
       
       return []
   
   def identify_missing_critical_info(self) -> List[MissingInformationEntity]:
       critical_missing = []
       
       for concept, missing_entity in self.missing_entities.items():
           impact_score = 0.0
           
           dependent_concepts = len(self.causal_edges.get(concept, {}))
           
           temporal_centrality = sum(1 for path_concepts in self.temporal_order.values() 
                                   if concept in path_concepts)
           
           medical_priority_boost = 0.0
           if any(term in concept.lower() for term in ['history', 'family', 'allergy', 'medication', 'symptom']):
               medical_priority_boost = 0.3
           
           missing_entity.discovery_priority = (dependent_concepts * 0.4 + 
                                              temporal_centrality * 0.3 +
                                              medical_priority_boost + 0.2)
           
           if missing_entity.discovery_priority > 0.1:
               critical_missing.append(missing_entity)
       
       return sorted(critical_missing, key=lambda x: x.discovery_priority, reverse=True)

class MultiExpertReasoning:
   def __init__(self):
       self.experts = {}
       self.expert_agreements = defaultdict(dict)
       self.expert_disagreements = defaultdict(list)
       
   def add_expert(self, expert: ExpertKnowledge):
       self.experts[expert.expert_id] = expert
       
   def learn_expert_pattern(self, expert_id: str, concept: str, 
                          reasoning_chain: List[str], outcome: str):
       if expert_id not in self.experts:
           return
           
       expert = self.experts[expert_id]
       
       if concept not in expert.reasoning_patterns:
           expert.reasoning_patterns[concept] = []
       expert.reasoning_patterns[concept].append(reasoning_chain)
       
       expert.known_concepts.add(concept)
       if concept in expert.unknown_concepts:
           expert.unknown_concepts.remove(concept)
   
   def get_expert_consensus(self, concept: str) -> Dict[str, Union[str, float, List]]:
       consensus = {
           'agreement_level': 0.0,
           'majority_opinion': None,
           'disagreements': [],
           'missing_expertise': []
       }
       
       expert_opinions = []
       experts_with_knowledge = []
       experts_missing_knowledge = []
       
       for expert_id, expert in self.experts.items():
           if concept in expert.known_concepts:
               experts_with_knowledge.append(expert_id)
               if concept in expert.reasoning_patterns:
                   opinion = expert.reasoning_patterns[concept][-1]
                   if isinstance(opinion, list):
                       opinion = ' -> '.join(opinion)
                   expert_opinions.append(opinion)
           elif concept in expert.unknown_concepts:
               experts_missing_knowledge.append(expert_id)
       
       if len(expert_opinions) > 1:
           unique_opinions = set(expert_opinions)
           if len(unique_opinions) == 1:
               consensus['agreement_level'] = 1.0
               consensus['majority_opinion'] = expert_opinions[0]
           else:
               consensus['agreement_level'] = 0.5
               opinion_counts = {}
               for opinion in expert_opinions:
                   opinion_counts[opinion] = opinion_counts.get(opinion, 0) + 1
               consensus['majority_opinion'] = max(opinion_counts, key=opinion_counts.get)
       elif len(expert_opinions) == 1:
           consensus['agreement_level'] = 0.8
           consensus['majority_opinion'] = expert_opinions[0]
       
       consensus['disagreements'] = list(set(expert_opinions))
       consensus['missing_expertise'] = experts_missing_knowledge
       
       return consensus

class AdaptiveQuestionGeneration:
   def __init__(self, causal_graph: TemporalCausalGraph, 
                expert_reasoning: MultiExpertReasoning):
       self.causal_graph = causal_graph
       self.expert_reasoning = expert_reasoning
       self.question_history = []
       self.effectiveness_scores = {}
       
   def generate_targeted_question(self, context: Dict[str, str]) -> Optional[str]:
       critical_missing = self.causal_graph.identify_missing_critical_info()
       
       if not critical_missing:
           return None
           
       target_missing = critical_missing[0]
       
       question = self._create_contextual_question(target_missing, context)
       
       self.question_history.append({
           'question': question,
           'target_concept': target_missing.concept,
           'context': context,
           'timestamp': datetime.now()
       })
       
       return question
   
   def _create_contextual_question(self, missing_entity: MissingInformationEntity, 
                                  context: Dict[str, str]) -> str:
       concept = missing_entity.concept
       
       related_concepts = [rel[0] for rel in missing_entity.related_concepts]
       
       expert_opinions = []
       for related_concept in related_concepts:
           consensus = self.expert_reasoning.get_expert_consensus(related_concept)
           if consensus['majority_opinion']:
               expert_opinions.append(consensus['majority_opinion'])
       
       if 'patient_symptoms' in context:
           question = f"Given the symptoms {context['patient_symptoms']}, what is the relationship to {concept}?"
       elif 'current_diagnosis' in context:
           question = f"How does {concept} relate to the current diagnosis of {context['current_diagnosis']}?"
       else:
           question = f"What additional information about {concept} would help understand this case?"
           
       return question
   
   def evaluate_question_effectiveness(self, question: str, answer: str, 
                                     outcome_improvement: float):
       self.effectiveness_scores[question] = outcome_improvement
       
       if outcome_improvement > 0.7:
           logger.info(f"Effective question pattern: {question}")
       elif outcome_improvement < 0.3:
           logger.info(f"Ineffective question pattern: {question}")

class AgenticAISystem:
   def __init__(self):
       self.causal_graph = TemporalCausalGraph()
       self.expert_reasoning = MultiExpertReasoning()
       self.question_generator = AdaptiveQuestionGeneration(
           self.causal_graph, self.expert_reasoning
       )
       self.learning_history = []
       
   def process_medical_case(self, patient_data: Dict, expert_inputs: List[Dict]) -> Dict:
       self._create_semantic_nodes_from_data(patient_data)
       
       self._incorporate_expert_knowledge(expert_inputs)
       
       critical_missing = self.causal_graph.identify_missing_critical_info()
       
       context = self._create_adaptive_context(patient_data)
       
       suggested_question = self.question_generator.generate_targeted_question(context)
       
       reasoning_output = self._generate_reasoning_output(patient_data, critical_missing)
       
       return {
           'reasoning': reasoning_output,
           'missing_critical_info': [entity.concept for entity in critical_missing],
           'suggested_question': suggested_question,
           'expert_consensus': self._get_overall_expert_consensus(),
           'temporal_reasoning_chain': self._get_temporal_chain(patient_data)
       }
   
   def _create_adaptive_context(self, patient_data: Dict) -> Dict[str, str]:
       context = {}
       
       if any(field in patient_data for field in ['symptoms', 'diagnosis', 'procedures']):
           context['patient_symptoms'] = patient_data.get('symptoms', '')
           context['current_diagnosis'] = patient_data.get('diagnosis', '')
       elif any(field in patient_data for field in ['presenting_issues', 'assessment_results']):
           context['student_issues'] = patient_data.get('presenting_issues', '')
           context['current_assessment'] = patient_data.get('assessment_results', '')
       elif any(field in patient_data for field in ['current_status', 'interventions_tried']):
           context['project_issues'] = patient_data.get('presenting_issues', patient_data.get('issues', ''))
           context['current_status'] = patient_data.get('current_status', '')
       
       return context
   
   def _create_semantic_nodes_from_data(self, patient_data: Dict):
       concepts = self._extract_concepts_from_data(patient_data)
       
       for i, concept in enumerate(concepts):
           if concept in self.causal_graph.nodes:
               continue
               
           if concept in patient_data.get('known_info', []):
               knowledge_state = KnowledgeState.KNOWN
               confidence = 0.9
           else:
               knowledge_state = KnowledgeState.UNKNOWN
               confidence = 0.1
               
           node = SemanticNode(
               concept=concept,
               domain='medical',
               knowledge_state=knowledge_state,
               confidence=confidence,
               temporal_position=i,
               causal_links={},
               expert_sources=[],
               missing_info_details={}
           )
           
           self.causal_graph.add_semantic_node(node)
   
   def _extract_concepts_from_data(self, patient_data: Dict) -> List[str]:
       concepts = []
       
       concept_fields = ['symptoms', 'presenting_issues', 'issues', 'diagnosis', 'assessment_results', 
                        'current_status', 'procedures', 'interventions_tried', 'known_info', 'known_information']
       
       for field in concept_fields:
           if field in patient_data:
               value = patient_data[field]
               if isinstance(value, str):
                   cleaned = value.replace(',', ' ').replace(';', ' ').replace('/', ' ')
                   concepts.extend(cleaned.split())
               elif isinstance(value, list):
                   concepts.extend([str(item) for item in value])
       
       cleaned_concepts = []
       for concept in concepts:
           concept = concept.strip().lower()
           if len(concept) > 2 and concept not in ['and', 'the', 'with', 'for', 'to', 'of', 'in', 'on']:
               cleaned_concepts.append(concept)
       
       return list(set(cleaned_concepts))
   
   def _incorporate_expert_knowledge(self, expert_inputs: List[Dict]):
       for expert_input in expert_inputs:
           expert_id = expert_input['expert_id']
           
           if expert_id not in self.expert_reasoning.experts:
               expert = ExpertKnowledge(
                   expert_id=expert_id,
                   domain_expertise=expert_input.get('domain_expertise', {}),
                   known_concepts=set(expert_input.get('known_concepts', [])),
                   unknown_concepts=set(expert_input.get('unknown_concepts', [])),
                   reasoning_patterns={},
                   temporal_understanding={}
               )
               self.expert_reasoning.add_expert(expert)
           
           if 'reasoning_chain' in expert_input:
               for concept, chain in expert_input['reasoning_chain'].items():
                   self.expert_reasoning.learn_expert_pattern(
                       expert_id, concept, chain, expert_input.get('outcome', '')
                   )
   
   def _generate_reasoning_output(self, patient_data: Dict, 
                                critical_missing: List[MissingInformationEntity]) -> str:
       reasoning_parts = []
       
       reasoning_parts.append("Current Understanding:")
       known_concepts = [node.concept for node in self.causal_graph.nodes.values() 
                        if node.knowledge_state == KnowledgeState.KNOWN]
       reasoning_parts.append(f"- Known information: {', '.join(known_concepts)}")
       
       reasoning_parts.append("\nCritical Missing Information:")
       for missing in critical_missing[:3]:
           reasoning_parts.append(f"- {missing.concept} (Priority: {missing.discovery_priority:.2f})")
       
       reasoning_parts.append("\nExpert Perspective:")
       consensus = self._get_overall_expert_consensus()
       if consensus['agreements']:
           reasoning_parts.append(f"- Experts agree on: {', '.join(consensus['agreements'])}")
       if consensus['disagreements']:
           reasoning_parts.append(f"- Experts disagree on: {', '.join(consensus['disagreements'])}")
       
       reasoning_parts.append("\nTemporal Analysis:")
       temporal_chain = self._get_temporal_chain(patient_data)
       reasoning_parts.append(f"- Likely progression: {' → '.join(temporal_chain)}")
       
       return '\n'.join(reasoning_parts)
   
   def _get_overall_expert_consensus(self) -> Dict:
       agreements = []
       disagreements = []
       
       for concept in self.causal_graph.nodes.keys():
           consensus = self.expert_reasoning.get_expert_consensus(concept)
           if consensus['agreement_level'] > 0.8:
               agreements.append(concept)
           elif len(consensus['disagreements']) > 1:
               disagreements.append(concept)
       
       return {
           'agreements': agreements,
           'disagreements': disagreements
       }
   
   def _get_temporal_chain(self, patient_data: Dict) -> List[str]:
       concepts = list(self.causal_graph.nodes.keys())
       if len(concepts) < 2:
           return concepts
           
       sorted_concepts = sorted(concepts, 
                              key=lambda x: self.causal_graph.nodes[x].temporal_position)
       
       return sorted_concepts

if __name__ == "__main__":
   system = AgenticAISystem()
   
   patient_data = {
       'symptoms': 'chest pain, shortness of breath, fatigue',
       'diagnosis': 'myocardial infarction',
       'procedures': ['ECG', 'blood_test'],
       'known_info': ['chest pain', 'ECG']
   }
   
   expert_inputs = [
       {
           'expert_id': 'cardiologist_1',
           'domain_expertise': {'cardiology': 0.9, 'general_medicine': 0.7},
           'known_concepts': ['chest pain', 'myocardial infarction', 'ECG'],
           'unknown_concepts': ['specific_enzyme_levels'],
           'reasoning_chain': {
               'chest pain': ['assess_location', 'check_duration', 'evaluate_triggers'],
               'myocardial infarction': ['confirm_with_enzymes', 'assess_damage', 'plan_treatment']
           }
       },
       {
           'expert_id': 'general_physician_1',
           'domain_expertise': {'general_medicine': 0.8, 'cardiology': 0.4},
           'known_concepts': ['chest pain', 'shortness of breath'],
           'unknown_concepts': ['myocardial infarction', 'specific_cardiac_procedures'],
           'reasoning_chain': {
               'chest pain': ['rule_out_cardiac', 'consider_other_causes']
           }
       }
   ]
   
   result = system.process_medical_case(patient_data, expert_inputs)
   
   print("=== AGENTIC AI REASONING OUTPUT ===")
   print(result['reasoning'])
   print(f"\nSuggested Question: {result['suggested_question']}")
   print(f"Missing Critical Info: {result['missing_critical_info']}")
   print(f"Temporal Chain: {result['temporal_reasoning_chain']}")