# hybrid_neural_system.py 
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from typing import Dict, List, Tuple, Optional
import logging
from collections import defaultdict
from dataclasses import dataclass


from core_system import AgenticAISystem, SemanticNode, KnowledgeState, ExpertKnowledge


from agentic_layer import DynamicAgentCreationLayer, AgenticLoss, AgentOutputProcessor, PatientDocumentProcessor

logger = logging.getLogger(__name__)

@dataclass
class TrainingExample:
    
    patient_id: str
    input_features: List[str]
    target_diagnosis: str
    missing_info: List[str]
    expert_reasoning: Dict[str, List[str]]
    temporal_sequence: List[str]

class MedicalConceptEmbedding(nn.Module):
   
    
    def __init__(self, vocab_size: int = 4005, embedding_dim: int = 256, hidden_dim: int = 512):
        super(MedicalConceptEmbedding, self).__init__()
        
       
        
        self.embedding_dim = embedding_dim
        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim
        
       
        print(f"  Layer 1: Embedding Layer ({vocab_size:,} → {embedding_dim})")
        self.concept_embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        
        
        print(f"  Layer 2: Multi-Layer Encoder ({embedding_dim} → {hidden_dim} → {embedding_dim})")
        self.concept_encoder = nn.Sequential(
           
            nn.Linear(embedding_dim, hidden_dim),
            nn.ReLU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(0.3),
          
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(0.3),
           
            nn.Linear(hidden_dim, embedding_dim)
        )
        
        
        print(f" Layer 3: Multi-Head Attention (16 heads - Paper specification)")
        self.missing_info_attention = nn.MultiheadAttention(
            embed_dim=embedding_dim, 
            num_heads=16,  
            batch_first=True
        )
        
      
        print(f" Layer 4: Dynamic Agentic Layer (Variable |M| agents)")
        self.agentic_layer = DynamicAgentCreationLayer(
            attention_dim=embedding_dim,
            max_agents=50,
            agent_embedding_dim=128,
            concept_vocab_size=vocab_size
        )
        
     
        print(f"   ⏰ Layer 5: Bidirectional LSTM (Paper: BiLSTM(256, 256, layers=2))")
        self.temporal_encoder = nn.LSTM(
            input_size=embedding_dim,     # 256 input
            hidden_size=embedding_dim // 2,  # 128 per direction → 256 total
            num_layers=2,                
            batch_first=True,
            bidirectional=True,         
            dropout=0.2
        )
        
        
        print(f"  Output Layers: D (Diagnosis), M (Missing), T (Temporal), A (Agentic)")
        
        
        self.diagnosis_classifier = nn.Sequential(
            nn.Linear(embedding_dim, hidden_dim),      # 256 → 512
            nn.ReLU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(0.4),
            nn.Linear(hidden_dim, hidden_dim // 2),    # 512 → 256
            nn.ReLU(),
            nn.LayerNorm(hidden_dim // 2),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim // 2, vocab_size)     # 256 → vocab_size
        )
        
        
        self.missing_info_predictor = nn.Sequential(
            nn.Linear(embedding_dim, hidden_dim),      # 256 → 512
            nn.ReLU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, hidden_dim // 2),    # 512 → 256
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, vocab_size),    # 256 → vocab_size
            nn.Sigmoid()  
        )
        
        
        self.temporal_classifier = nn.Sequential(
            nn.Linear(embedding_dim, hidden_dim // 2), # 256 → 256
            nn.ReLU(),
            nn.LayerNorm(hidden_dim // 2),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim // 2, vocab_size)     # 256 → vocab_size
        )
        
      
        total_params = self._count_parameters()
       
        
    def _count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)
    
    def forward(self, concept_ids: torch.Tensor, mask: torch.Tensor = None):
        
        
        batch_size, seq_len = concept_ids.shape
        

        embeddings = self.concept_embedding(concept_ids)  # [batch, seq_len, 256]
        
       
        encoded_concepts = self.concept_encoder(embeddings)  # [batch, seq_len, 256]
        
       
        if mask is not None:
            attended_features, attention_weights = self.missing_info_attention(
                encoded_concepts, encoded_concepts, encoded_concepts, 
                key_padding_mask=~mask
            )
        else:
            attended_features, attention_weights = self.missing_info_attention(
                encoded_concepts, encoded_concepts, encoded_concepts
            )
        
        pooled_attention = attended_features.mean(dim=1)  # [batch, 256]
        
        
        agentic_outputs = self.agentic_layer(
            attention_features=attended_features,
            missing_info_mask=mask
        )
        
       
        temporal_features, (hidden, cell) = self.temporal_encoder(attended_features)
        pooled_temporal = temporal_features.mean(dim=1)  # [batch, 256]
        
        
        
       
        diagnosis_logits = self.diagnosis_classifier(pooled_attention)
        
       
        missing_info_logits = self.missing_info_predictor(pooled_attention)
        
       
        temporal_logits = self.temporal_classifier(pooled_temporal)
        
        
        agentic_question_priorities = agentic_outputs['question_priorities']
        agentic_missing_predictions = agentic_outputs['missing_info_predictions']
        
        return {
           
            'diagnosis_logits': diagnosis_logits,           # D output
            'missing_info_logits': missing_info_logits,     # M output  
            'temporal_logits': temporal_logits,             # T output
            
            
            'agentic_question_priorities': agentic_question_priorities,
            'agentic_missing_predictions': agentic_missing_predictions,
            'agentic_full_output': agentic_outputs,
            
          
            'concept_embeddings': encoded_concepts,
            'attention_weights': attention_weights,
            'temporal_features': temporal_features,
            'pooled_features': pooled_attention,
            
        
            'num_active_agents': agentic_outputs['num_active_agents'],
            'agent_representations': agentic_outputs['agent_representations'],
            'agent_effectiveness_scores': agentic_outputs['agent_effectiveness']
        }
    
    def get_concept_similarity(self, concept1_id: int, concept2_id: int) -> float:
        
        with torch.no_grad():
            emb1 = self.concept_embedding(torch.tensor([concept1_id]))
            emb2 = self.concept_embedding(torch.tensor([concept2_id]))
            similarity = torch.cosine_similarity(emb1, emb2, dim=1)
            return similarity.item()
    
    def get_agentic_analysis(self, patient_data: Dict) -> Dict:
        
        
        if self.neural_model is None:
            return {'available': False}
        
        input_concepts = self._extract_concepts_from_data(patient_data)
        concept_ids = self._prepare_concept_tensor(input_concepts)
        
        self.eval()
        with torch.no_grad():
            outputs = self.forward(concept_ids['tensor'], concept_ids['mask'])
            agentic_analysis = self._process_agentic_outputs(outputs, patient_data)
        
        return agentic_analysis
    
    def _prepare_concept_tensor(self, concepts: List[str]) -> Dict:
        
        
        concept_ids = []
        for concept in concepts:
            if hasattr(self, 'concept_vocab') and concept in self.concept_vocab:
                concept_ids.append(self.concept_vocab[concept])
            else:
                concept_ids.append(1)  # <UNK> token
        
       
        max_len = 50
        if len(concept_ids) < max_len:
            concept_ids.extend([0] * (max_len - len(concept_ids)))
        else:
            concept_ids = concept_ids[:max_len]
        
        concept_tensor = torch.tensor([concept_ids], dtype=torch.long)
        mask = torch.tensor([[1 if cid != 0 else 0 for cid in concept_ids]], dtype=torch.bool)
        
        return {'tensor': concept_tensor, 'mask': mask}
    
    def _process_agentic_outputs(self, outputs: Dict, patient_data: Dict) -> Dict:
       
        
        agentic_outputs = outputs['agentic_full_output']
        
        return {
            'available': True,
            'num_active_agents': agentic_outputs['num_active_agents'].item(),
            'agent_effectiveness_mean': torch.mean(agentic_outputs['agent_effectiveness']).item(),
            'top_question_priorities': torch.topk(outputs['agentic_question_priorities'], k=5).indices.squeeze().tolist(),
            'missing_info_confidence': torch.mean(outputs['agentic_missing_predictions']).item(),
            'agent_representations': agentic_outputs['agent_representations'],
            'neural_confidence': torch.softmax(outputs['diagnosis_logits'], dim=1).max().item()
        }

class NeuralTrainer:
   
    
    def __init__(self, model: MedicalConceptEmbedding, learning_rate: float = 0.001):
        self.model = model
        self.optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-4)
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=10, factor=0.5)
        
       
        self.diagnosis_criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
        self.missing_info_criterion = nn.BCELoss()
        self.temporal_criterion = nn.CrossEntropyLoss()
        
       
        self.agentic_criterion = AgenticLoss()
        
        self.training_history = []
        
       
        self.best_loss = float('inf')
        self.patience_counter = 0
        self.patience = 15
        
        
        
    def train_epoch(self, training_examples: List[TrainingExample], 
                   concept_vocab: Dict[str, int]) -> Dict[str, float]:
        
        
        self.model.train()
        total_diagnosis_loss = 0.0
        total_missing_loss = 0.0
        total_temporal_loss = 0.0
        total_agentic_loss = 0.0
        total_examples = len(training_examples)
        
        for example in training_examples:
          
            concept_ids = self._prepare_concept_ids(example.input_features, concept_vocab)
            mask = self._create_mask(concept_ids)
            
         
            concept_tensor = torch.tensor([concept_ids], dtype=torch.long)
            mask_tensor = mask.unsqueeze(0)
            
            
            targets = self._prepare_targets(example, concept_vocab)
            
          
            outputs = self.model(concept_tensor, mask_tensor)
            
        
            losses = self._calculate_all_losses(outputs, targets)
            
            total_loss = (
                losses['diagnosis_loss'] + 
                0.5 * losses['missing_loss'] + 
                0.3 * losses['temporal_loss'] +
                0.4 * losses['agentic_loss']  
            )
            
     
            self.optimizer.zero_grad()
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            self.optimizer.step()
            
            total_diagnosis_loss += losses['diagnosis_loss'].item()
            total_missing_loss += losses['missing_loss'].item()
            total_temporal_loss += losses['temporal_loss'].item()
            total_agentic_loss += losses['agentic_loss'].item()
        
        
        epoch_metrics = {
            'diagnosis_loss': total_diagnosis_loss / total_examples,
            'missing_info_loss': total_missing_loss / total_examples,
            'temporal_loss': total_temporal_loss / total_examples,
            'agentic_loss': total_agentic_loss / total_examples,  
            'total_loss': (total_diagnosis_loss + 0.5 * total_missing_loss + 
                          0.3 * total_temporal_loss + 0.4 * total_agentic_loss) / total_examples
        }
        
        self.training_history.append(epoch_metrics)
        return epoch_metrics
    
    def _prepare_concept_ids(self, input_features: List[str], concept_vocab: Dict[str, int]) -> List[int]:
        
        concept_ids = []
        for concept in input_features:
            if concept in concept_vocab:
                concept_ids.append(concept_vocab[concept])
            else:
                concept_ids.append(concept_vocab.get('<UNK>', 1))
        
        
        max_len = 50
        if len(concept_ids) < max_len:
            concept_ids.extend([0] * (max_len - len(concept_ids)))
        else:
            concept_ids = concept_ids[:max_len]
        
        return concept_ids
    
    def _create_mask(self, concept_ids: List[int]) -> torch.Tensor:
        
        return torch.tensor([1 if cid != 0 else 0 for cid in concept_ids], dtype=torch.bool)
    
    def _prepare_targets(self, example: TrainingExample, concept_vocab: Dict[str, int]) -> Dict:
        
        
        
        target_diagnosis_id = concept_vocab.get(example.target_diagnosis, concept_vocab.get('<UNK>', 1))
        target_diagnosis = torch.tensor([target_diagnosis_id], dtype=torch.long)
        
        
        missing_target = torch.zeros(len(concept_vocab))
        for missing_concept in example.missing_info:
            if missing_concept in concept_vocab:
                missing_target[concept_vocab[missing_concept]] = 1.0
        missing_target = missing_target.unsqueeze(0)
        
        
        temporal_target = target_diagnosis
        
       
        effectiveness_target = torch.tensor([0.7])  
        question_target = torch.randint(0, len(concept_vocab), (1,))
        
        return {
            'diagnosis': target_diagnosis,
            'missing_info': missing_target,
            'temporal': temporal_target,
            'effectiveness': effectiveness_target,
            'question': question_target
        }
    
    def _calculate_all_losses(self, outputs: Dict, targets: Dict) -> Dict:
        """Calculate all loss components"""
        
        
        diagnosis_loss = self.diagnosis_criterion(outputs['diagnosis_logits'], targets['diagnosis'])
        missing_loss = self.missing_info_criterion(outputs['missing_info_logits'], targets['missing_info'])
        temporal_loss = self.temporal_criterion(outputs['temporal_logits'], targets['temporal'])
        
   
        agentic_loss_outputs = self.agentic_criterion(
            outputs['agentic_full_output'],
            targets['effectiveness'],
            targets['missing_info'],
            targets['question']
        )
        agentic_loss = agentic_loss_outputs['total_loss']
        
        return {
            'diagnosis_loss': diagnosis_loss,
            'missing_loss': missing_loss,
            'temporal_loss': temporal_loss,
            'agentic_loss': agentic_loss
        }

class HybridNeuralSymbolicSystem(AgenticAISystem):
    
    
    def __init__(self, embedding_dim: int = 256, hidden_dim: int = 512):
        super().__init__()
        
       
        self.concept_vocab = {'<PAD>': 0, '<UNK>': 1}
        self.neural_model = None
        self.neural_trainer = None
        
    
        self.agent_output_processor = None
        self.document_processor = PatientDocumentProcessor()
        
        
        self.training_examples = []
        
        
        self.neural_symbolic_bridge = {}
        
        logger.info("Initialized COMPLETE Hybrid Neural-Symbolic-Agentic System")
    
    def process_patient_documents(self, document_paths: List[str]) -> Dict:
  
        
        print(f"Processing {len(document_paths)} patient documents...")
        unified_case = self.document_processor.process_multiple_documents(document_paths)
        
        print(f"Created unified patient case from documents:")
        print(f"   Patient ID: {unified_case['patient_id']}")
        print(f"   Symptoms: {unified_case['symptoms']}")
        print(f"   Diagnoses: {len(unified_case['diagnoses'])}")
        print(f"   Procedures: {len(unified_case['procedures'])}")
        print(f"   Medications: {len(unified_case['medications'])}")
        
        return unified_case
    
    def build_concept_vocabulary(self, patient_cases: List[Dict]):
        
        
        logger.info("Building concept vocabulary...")
        
        concept_counts = defaultdict(int)
        
        for case in patient_cases:
            all_concepts = self._extract_comprehensive_concepts(case)
            
            for concept in all_concepts:
                if concept and len(concept) > 2:
                    concept_counts[concept] += 1
        
        
        vocab_id = 2  
        target_vocab_size = 4005
        
        
        sorted_concepts = sorted(concept_counts.items(), key=lambda x: x[1], reverse=True)
        
        for concept, count in sorted_concepts[:target_vocab_size-2]:
            if count >= 1:  
                self.concept_vocab[concept] = vocab_id
                vocab_id += 1
        
        actual_vocab_size = len(self.concept_vocab)
        logger.info(f"Built vocabulary: {actual_vocab_size:,} concepts (Target: {target_vocab_size:,})")
        
        
        self.neural_model = MedicalConceptEmbedding(
            vocab_size=actual_vocab_size,
            embedding_dim=256,     
            hidden_dim=512         
        )
        
        self.neural_trainer = NeuralTrainer(self.neural_model, learning_rate=0.001)
        self.agent_output_processor = AgentOutputProcessor(self.concept_vocab)
        
        return actual_vocab_size
    
    def _extract_comprehensive_concepts(self, case: Dict) -> List[str]:
        
        
        all_concepts = []
        
     
        if case.get('symptoms'):
            symptoms = case['symptoms'].replace(',', ' ').replace(';', ' ').split()
            all_concepts.extend([s.strip().lower() for s in symptoms if s.strip()])
        
       
        for diag in case.get('diagnoses', []):
            if diag and str(diag).strip() and str(diag).strip().lower() != 'nan':
                clean_diag = str(diag).strip().lower()
                all_concepts.append(clean_diag)
                if ' ' in clean_diag:
                    all_concepts.extend(clean_diag.split())
        
      
        for proc in case.get('procedures', []):
            if proc and str(proc).strip() and str(proc).strip().lower() != 'nan':
                clean_proc = str(proc).strip().lower()
                all_concepts.append(clean_proc)
                if ' ' in clean_proc:
                    all_concepts.extend(clean_proc.split())
        
      
        for med in case.get('medications', []):
            if med and str(med).strip() and str(med).strip().lower() != 'nan':
                all_concepts.append(str(med).strip().lower())
        
    
        for note in case.get('clinical_notes', []):
            if note and len(str(note)) > 10:
                note_words = str(note).lower().split()
                medical_words = [w for w in note_words if len(w) > 4 and w.isalpha()]
                all_concepts.extend(medical_words[:5])
        
     
        cleaned = []
        for concept in set(all_concepts):
            concept = str(concept).strip().lower()
            if len(concept) > 2 and concept.isalpha():
                cleaned.append(concept)
        
        return cleaned
    
    def neural_enhanced_reasoning(self, patient_data: Dict, expert_inputs: List[Dict]) -> Dict:
       
        
        print(" COMPLETE NEURAL-SYMBOLIC-AGENTIC REASONING:")
        
        
    
        print(" NEURAL + AGENTIC PROCESSING...")
        neural_agentic_insights = self._get_neural_agentic_insights(patient_data)
        
        if neural_agentic_insights.get('available'):
            print(f"   Neural diagnosis confidence: {neural_agentic_insights.get('predicted_diagnosis_confidence', 0):.3f}")
            print(f"   Active agents created: {neural_agentic_insights.get('num_active_agents', 0)}")
            print(f"   Agent effectiveness mean: {neural_agentic_insights.get('agent_effectiveness_mean', 0):.3f}")
            print(f"   Agentic missing predictions: {len(neural_agentic_insights.get('agentic_missing_concepts', []))}")
        
      
        print("2️⃣ SYMBOLIC REASONING...")
        symbolic_result = self._enhanced_symbolic_reasoning(patient_data, expert_inputs, neural_agentic_insights)
        
      
        print("3️⃣ NEURAL-SYMBOLIC-AGENTIC FUSION...")
        complete_result = self._combine_neural_symbolic_agentic(neural_agentic_insights, symbolic_result)
        
        print("="*60)
        return complete_result
    
    def _get_neural_agentic_insights(self, patient_data: Dict) -> Dict:
      
        
        if self.neural_model is None:
            return {'available': False}
        
   
        input_concepts = self._extract_concepts_from_data(patient_data)
        concept_prep = self.neural_model._prepare_concept_tensor(input_concepts)
        
      
        self.neural_model.eval()
        with torch.no_grad():
            outputs = self.neural_model(concept_prep['tensor'], concept_prep['mask'])
       
        agentic_outputs = outputs['agentic_full_output']
        agentic_missing_concepts = self._extract_agentic_missing_concepts(agentic_outputs)
        
        return {
            'available': True,
            'predicted_diagnosis_confidence': torch.softmax(outputs['diagnosis_logits'], dim=1).max().item(),
            'missing_info_predictions': outputs['missing_info_logits'].squeeze().tolist(),
            'attention_weights': outputs['attention_weights'].squeeze().tolist(),
            
          
            'num_active_agents': agentic_outputs['num_active_agents'].item(),
            'agent_effectiveness_mean': torch.mean(agentic_outputs['agent_effectiveness']).item(),
            'agentic_question_priorities': agentic_outputs['question_priorities'].squeeze().tolist(),
            'agentic_missing_concepts': agentic_missing_concepts,
            'agent_representations': agentic_outputs['agent_representations']
        }
    
    def _extract_agentic_missing_concepts(self, agentic_outputs: Dict) -> List[str]:
 
        
        missing_probs = agentic_outputs['missing_info_predictions'].squeeze()
        vocab_list = list(self.concept_vocab.keys())
        
        agentic_missing = []
        for i, prob in enumerate(missing_probs):
            if prob > 0.5 and i < len(vocab_list):
                agentic_missing.append(vocab_list[i])
        
        return agentic_missing
    
    def _enhanced_symbolic_reasoning(self, patient_data: Dict, expert_inputs: List[Dict], 
                                   neural_agentic_insights: Dict) -> Dict:
        
        self._create_semantic_nodes_from_data(patient_data)
        self._incorporate_expert_knowledge(expert_inputs)
        
        critical_missing = self.causal_graph.identify_missing_critical_info()
        context = self._create_adaptive_context(patient_data)
        suggested_question = self.question_generator.generate_targeted_question(context)
        reasoning_output = self._generate_reasoning_output(patient_data, critical_missing)
        
     
        result = {
            'reasoning': reasoning_output,
            'missing_critical_info': [entity.concept for entity in critical_missing],
            'suggested_question': suggested_question,
            'expert_consensus': self._get_overall_expert_consensus(),
            'temporal_reasoning_chain': self._get_temporal_chain(patient_data)
        }
        
        if neural_agentic_insights['available']:
            result['neural_confidence'] = neural_agentic_insights['predicted_diagnosis_confidence']
            result['agentic_insights'] = {
                'num_active_agents': neural_agentic_insights['num_active_agents'],
                'agent_effectiveness_mean': neural_agentic_insights['agent_effectiveness_mean'],
                'agentic_missing_concepts': neural_agentic_insights['agentic_missing_concepts']
            }
        
        return result
    
    def _combine_neural_symbolic_agentic(self, neural_agentic_insights: Dict, symbolic_result: Dict) -> Dict:
      
        combined_result = symbolic_result.copy()
        
        if neural_agentic_insights['available']:
     
            neural_confidence = neural_agentic_insights['predicted_diagnosis_confidence']
            num_active_agents = neural_agentic_insights['num_active_agents']
            agent_effectiveness = neural_agentic_insights['agent_effectiveness_mean']
            agentic_missing = neural_agentic_insights['agentic_missing_concepts']
            
            
            enhanced_reasoning = f"{symbolic_result['reasoning']}\n\n" \
                               f"Neural Component Analysis:\n" \
                               f"- Diagnosis confidence: {neural_confidence:.3f}\n" \
                               f"- Neural missing predictions: {len(neural_agentic_insights.get('missing_info_predictions', []))}\n\n" \
                               f"AGENTIC LAYER Analysis (AAAI Paper Implementation):\n" \
                               f"- Dynamic agents created: {num_active_agents}\n" \
                               f"- Agent effectiveness (eff_i): {agent_effectiveness:.3f}\n" \
                               f"- Agentic missing concepts: {len(agentic_missing)}\n" \
                               f"- Question generation: Neural-learned patterns\n" \
                               f"- Agent representation: a_i = (F_i, P_i, eff_i, L_i, H_i, N_i)\n\n" \
                               f"TRIPLE FUSION Results:\n" \
                               f"- Symbolic + Neural + Agentic integration complete\n" \
                               f"- Enhanced missing information detection\n" \
                               f"- Dynamic agent-based question generation"
            
            combined_result['reasoning'] = enhanced_reasoning
            combined_result['neural_agentic_insights'] = neural_agentic_insights
            
          
            symbolic_missing = set(symbolic_result['missing_critical_info'])
            neural_missing = set([str(i) for i, score in enumerate(neural_agentic_insights.get('missing_info_predictions', [])) if score > 0.5])
            agentic_missing_set = set(agentic_missing)
            
            combined_missing = list(symbolic_missing.union(neural_missing).union(agentic_missing_set))
            combined_result['missing_critical_info'] = combined_missing
            
            combined_result['agentic_questions'] = self._extract_agentic_questions(
                neural_agentic_insights, patient_data={'admission_info': {'diagnosis': 'analysis'}}
            )
            combined_result['agentic_layer_active'] = True
        
        return combined_result
    
    def _extract_agentic_questions(self, neural_agentic_insights: Dict, patient_data: Dict) -> List[str]:
       
        
        if not self.agent_output_processor:
            return []
        
        mock_agentic_outputs = {
            'agent_representations': neural_agentic_insights.get('agent_representations', [])
        }
        
        questions = self.agent_output_processor.extract_agent_questions(
            mock_agentic_outputs, patient_data
        )
        
        return [q['question'] for q in questions[:5]]


def test_complete_hybrid_with_agentic():
    hybrid_system = HybridNeuralSymbolicSystem()
    
    
    print(" Testing Document Processing...")
    mock_documents = ['patient_history.txt', 'lab_results.csv', 'clinical_notes.txt']
    
   
    mock_patient_case = {
        'patient_id': 'AGENTIC_TEST_001',
        'symptoms': 'chest pain, shortness of breath, fatigue',
        'diagnoses': ['acute myocardial infarction', 'coronary artery disease'],
        'procedures': ['ECG', 'cardiac catheterization', 'echocardiogram'],
        'medications': ['aspirin', 'metoprolol', 'atorvastatin'],
        'clinical_notes': [
            'Patient reports chest pain started 3 hours ago',
            'Family history of coronary artery disease',
            'Currently taking aspirin and metoprolol'
        ],
        'admission_info': {'diagnosis': 'acute myocardial infarction'},
        'known_info': ['chest pain', 'ECG', 'family history'],
        'temporal_sequence': ['chest pain', 'ECG', 'diagnosis', 'catheterization', 'medications'],
        'document_sources': mock_documents
    }
    
    print(f" Mock patient case created from {len(mock_documents)} documents")
    
    
    print("\n Testing Vocabulary Building...")
    mock_cases = [mock_patient_case]
    vocab_size = hybrid_system.build_concept_vocabulary(mock_cases)
    print(f" Vocabulary built: {vocab_size:,} concepts")
    
   
    print("\n Testing Neural Model with Agentic Layer...")
    if hybrid_system.neural_model:
        print(f"   Neural model parameters: {hybrid_system.neural_model._count_parameters():,}")
        print(f"   Agentic layer parameters: {hybrid_system.neural_model.agentic_layer._count_parameters():,}")
        
        
        test_input = torch.randint(0, 100, (2, 20))  
        test_mask = torch.ones(2, 20, dtype=torch.bool)
        
        with torch.no_grad():
            outputs = hybrid_system.neural_model(test_input, test_mask)
        
        print(f"   Forward pass successful:")
        print(f"     Diagnosis logits: {outputs['diagnosis_logits'].shape}")
        print(f"     Missing info logits: {outputs['missing_info_logits'].shape}")
        print(f"     Temporal logits: {outputs['temporal_logits'].shape}")
        print(f"     Agentic question priorities: {outputs['agentic_question_priorities'].shape}")
        print(f"     Active agents: {outputs['num_active_agents']}")
    

    print("\n Testing Complete Neural-Symbolic-Agentic Reasoning...")
    
    mock_experts = [{
        'expert_id': 'cardiologist_agentic_test',
        'domain_expertise': {'cardiology': 0.9, 'general_medicine': 0.7},
        'known_concepts': ['chest pain', 'myocardial infarction', 'ECG'],
        'unknown_concepts': ['family_history_details', 'medication_compliance'],
        'reasoning_chain': {
            'chest pain': ['assess_location', 'evaluate_severity', 'check_timing'],
            'myocardial infarction': ['confirm_diagnosis', 'assess_damage', 'plan_treatment']
        }
    }]
    
    complete_result = hybrid_system.neural_enhanced_reasoning(mock_patient_case, mock_experts)
    
    print(f" Complete reasoning result:")
    print(f"   Missing critical info: {len(complete_result['missing_critical_info'])}")
    print(f"   Neural confidence: {complete_result.get('neural_confidence', 'N/A')}")
    print(f"   Agentic layer active: {complete_result.get('agentic_layer_active', False)}")
    print(f"   Agentic questions: {len(complete_result.get('agentic_questions', []))}")
    
    if complete_result.get('agentic_questions'):
        print(f"   Sample agentic questions:")
        for i, q in enumerate(complete_result['agentic_questions'][:3], 1):
            print(f"     {i}. {q}")
    
    return hybrid_system, complete_result

if __name__ == "__main__":
    test_complete_hybrid_with_agentic()