# agentic_layer.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Dict, List, Tuple, Optional, Set
from dataclasses import dataclass
import logging

logger = logging.getLogger(__name__)

@dataclass
class AgentRepresentation:
    """Agent representation: a_i = (F_i, P_i, eff_i, L_i, H_i, N_i) from AAAI paper"""
    agent_id: int
    focus_concepts: Set[str]  # F_i - neural attention weighted concepts
    discovery_patterns: Dict  # P_i - Q-learning patterns
    effectiveness: float      # eff_i - enhanced effectiveness formula
    concept_connections: Dict # L_i - concept connections
    success_history: List[float] # H_i - success history
    neural_features: torch.Tensor # N_i - neural learned features

class DynamicAgentCreationLayer(nn.Module):
    
    
    def __init__(self, 
                 attention_dim: int = 256,
                 max_agents: int = 50,
                 agent_embedding_dim: int = 128,
                 concept_vocab_size: int = 4005):
        super(DynamicAgentCreationLayer, self).__init__()
        
        self.attention_dim = attention_dim
        self.max_agents = max_agents
        self.agent_embedding_dim = agent_embedding_dim
        self.concept_vocab_size = concept_vocab_size
        
       
        
        
        self.agent_detector = nn.Sequential(
            nn.Linear(attention_dim, 256),
            nn.ReLU(),
            nn.LayerNorm(256),
            nn.Dropout(0.3),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.LayerNorm(128),
            nn.Dropout(0.2),
            nn.Linear(128, max_agents),
            nn.Sigmoid()  # Agent activation probabilities
        )
        
        
        self.agent_specializer = nn.ModuleList([
            nn.Sequential(
                nn.Linear(attention_dim, 256),
                nn.ReLU(),
                nn.LayerNorm(256),
                nn.Dropout(0.3),
                nn.Linear(256, agent_embedding_dim),
                nn.Tanh()
            ) for _ in range(max_agents)
        ])
        
       
        self.effectiveness_calculator = nn.Sequential(
            nn.Linear(agent_embedding_dim + attention_dim, 256),
            nn.ReLU(),
            nn.LayerNorm(256),
            nn.Dropout(0.3),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )
        
       
        self.question_priority_network = nn.Sequential(
            nn.Linear(agent_embedding_dim, 256),
            nn.ReLU(),
            nn.LayerNorm(256),
            nn.Dropout(0.2),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, concept_vocab_size),
            nn.Softmax(dim=-1)
        )
        
        
        self.agent_interaction = nn.MultiheadAttention(
            embed_dim=agent_embedding_dim,
            num_heads=8,
            batch_first=True
        )
        
       
        self.missing_info_head = nn.Sequential(
            nn.Linear(agent_embedding_dim, 256),
            nn.ReLU(),
            nn.LayerNorm(256),
            nn.Linear(256, concept_vocab_size),
            nn.Sigmoid()
        )
        
       
        self.register_buffer('agent_activation_threshold', torch.tensor(0.3))
        self.register_buffer('effectiveness_history', torch.zeros(max_agents, 100))
        
        print(f" Dynamic Agentic Layer initialized with {self._count_parameters():,} parameters")
    
    def _count_parameters(self):
       
        return sum(p.numel() for p in self.parameters() if p.requires_grad)
    
    def forward(self, 
                attention_features: torch.Tensor,
                q_learning_context: Optional[torch.Tensor] = None,
                missing_info_mask: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
        
        
        batch_size, seq_len, _ = attention_features.shape
        pooled_attention = attention_features.mean(dim=1)  # [batch_size, attention_dim]
        
        
        agent_activation_probs = self.agent_detector(pooled_attention)
        active_agents_mask = (agent_activation_probs > self.agent_activation_threshold).float()
        num_active_agents = active_agents_mask.sum(dim=-1)
        
        
        agent_embeddings = []
        agent_effectiveness_scores = []
        
        for agent_idx in range(self.max_agents):
            agent_embedding = self.agent_specializer[agent_idx](pooled_attention)
            agent_embeddings.append(agent_embedding)
            
            combined_features = torch.cat([agent_embedding, pooled_attention], dim=-1)
            effectiveness = self.effectiveness_calculator(combined_features)
            agent_effectiveness_scores.append(effectiveness)
        
        agent_embeddings = torch.stack(agent_embeddings, dim=1)
        agent_effectiveness = torch.stack(agent_effectiveness_scores, dim=1)
        
        
        agent_embeddings_flat = agent_embeddings.view(-1, self.max_agents, self.agent_embedding_dim)
        interacted_agents, agent_attention_weights = self.agent_interaction(
            agent_embeddings_flat, agent_embeddings_flat, agent_embeddings_flat,
            key_padding_mask=~active_agents_mask.bool()
        )
        
        
        question_priorities = []
        for agent_idx in range(self.max_agents):
            agent_emb = interacted_agents[:, agent_idx, :]
            priority_scores = self.question_priority_network(agent_emb)
            question_priorities.append(priority_scores)
        
        question_priorities = torch.stack(question_priorities, dim=1)
        
        
        missing_info_predictions = []
        for agent_idx in range(self.max_agents):
            agent_emb = interacted_agents[:, agent_idx, :]
            missing_pred = self.missing_info_head(agent_emb)
            missing_info_predictions.append(missing_pred)
        
        missing_info_predictions = torch.stack(missing_info_predictions, dim=1)
        
     
        agent_representations = self._create_agent_representations(
            agent_embeddings, interacted_agents, agent_effectiveness, 
            question_priorities, active_agents_mask, batch_size
        )
        
      
        agent_weights = active_agents_mask * agent_effectiveness.squeeze(-1)
        agent_weights = F.softmax(agent_weights, dim=-1)
        
        aggregated_question_priorities = (question_priorities * agent_weights.unsqueeze(-1)).sum(dim=1)
        aggregated_missing_predictions = (missing_info_predictions * agent_weights.unsqueeze(-1)).sum(dim=1)
        
        return {
            'agent_activation_probs': agent_activation_probs,
            'active_agents_mask': active_agents_mask,
            'num_active_agents': num_active_agents,
            'agent_embeddings': agent_embeddings,
            'interacted_agents': interacted_agents,
            'agent_effectiveness': agent_effectiveness,
            'agent_attention_weights': agent_attention_weights,
            'question_priorities': aggregated_question_priorities,
            'missing_info_predictions': aggregated_missing_predictions,
            'agent_representations': agent_representations,
            'agent_weights': agent_weights
        }
    
    def _create_agent_representations(self, agent_embeddings, interacted_agents, 
                                    agent_effectiveness, question_priorities, 
                                    active_agents_mask, batch_size):
        """Create agent representations matching AAAI paper formulation"""
        
        agent_representations = []
        for batch_idx in range(batch_size):
            batch_agents = []
            for agent_idx in range(self.max_agents):
                if active_agents_mask[batch_idx, agent_idx] > 0.5:
                    
                    
                    agent_priorities = question_priorities[batch_idx, agent_idx]
                    top_concept_indices = torch.topk(agent_priorities, k=5).indices
                    focus_concepts = set(f"concept_{idx.item()}" for idx in top_concept_indices)
                    
                    
                    neural_features = interacted_agents[batch_idx, agent_idx]
                    
                    
                    effectiveness = agent_effectiveness[batch_idx, agent_idx, 0].item()
                    
                    agent_rep = AgentRepresentation(
                        agent_id=agent_idx,
                        focus_concepts=focus_concepts,
                        discovery_patterns={},  # P_i - filled by Q-learning
                        effectiveness=effectiveness,
                        concept_connections={},  # L_i - filled by symbolic
                        success_history=[],     # H_i - tracked during training
                        neural_features=neural_features
                    )
                    batch_agents.append(agent_rep)
            
            agent_representations.append(batch_agents)
        
        return agent_representations
    
    def get_agent_effectiveness_formula(self, q_values, mimic_frequencies, neural_confidence):
        q_normalized = torch.clamp(q_values, 0, 100) / 100.0
        freq_normalized = torch.clamp(mimic_frequencies, 0, 50) / 50.0
        neural_normalized = torch.clamp(neural_confidence, 0, 1)
        
     
        effectiveness = (q_normalized + freq_normalized + neural_normalized) / 3.0
        return effectiveness

class AgenticLoss(nn.Module):
    
    
    def __init__(self):
        super(AgenticLoss, self).__init__()
        self.effectiveness_loss = nn.MSELoss()
        self.missing_info_loss = nn.BCELoss()
        self.question_relevance_loss = nn.CrossEntropyLoss()
        
    def forward(self, agentic_outputs, target_effectiveness, target_missing_info, target_questions):
        
        predicted_effectiveness = agentic_outputs['agent_effectiveness'].mean(dim=1).squeeze(-1)
        eff_loss = self.effectiveness_loss(predicted_effectiveness, target_effectiveness)
        
        
        missing_loss = self.missing_info_loss(
            agentic_outputs['missing_info_predictions'],
            target_missing_info
        )
        
     
        question_loss = self.question_relevance_loss(
            agentic_outputs['question_priorities'],
            target_questions
        )
        

        num_active = agentic_outputs['num_active_agents'].float()
        sparsity_loss = torch.mean(num_active) * 0.01
        
      
        agent_embeddings = agentic_outputs['agent_embeddings']
        active_mask = agentic_outputs['active_agents_mask'].unsqueeze(-1)
        active_embeddings = agent_embeddings * active_mask
        similarities = torch.bmm(active_embeddings, active_embeddings.transpose(1, 2))
        diversity_loss = torch.mean(similarities) * 0.05
        
        total_loss = eff_loss + missing_loss + question_loss + sparsity_loss + diversity_loss
        
        return {
            'total_loss': total_loss,
            'effectiveness_loss': eff_loss,
            'missing_info_loss': missing_loss,
            'question_loss': question_loss,
            'sparsity_loss': sparsity_loss,
            'diversity_loss': diversity_loss
        }

class AgentOutputProcessor:

    
    def __init__(self, concept_vocab: Dict[str, int]):
        self.concept_vocab = concept_vocab
        self.vocab_to_concept = {v: k for k, v in concept_vocab.items()}
        
    def extract_agent_questions(self, agentic_outputs, patient_case, top_k_per_agent=3):
        
        
        questions = []
        agent_representations = agentic_outputs['agent_representations']
        
        for batch_agents in agent_representations:
            for agent_rep in batch_agents:
                agent_questions = self._generate_questions_for_agent(
                    agent_rep, patient_case, top_k_per_agent
                )
                questions.extend(agent_questions)
        
        return questions
    
    def _generate_questions_for_agent(self, agent_rep, patient_case, top_k):
       
        
        questions = []
        diagnosis = patient_case.get('admission_info', {}).get('diagnosis', 'condition')
        
        for i, concept in enumerate(list(agent_rep.focus_concepts)[:top_k]):
            if 'family' in concept.lower() or 'history' in concept.lower():
                question = f"Any family history or past medical history relevant to {diagnosis}?"
            elif 'pain' in concept.lower():
                question = f"Can you describe the pain characteristics in context of {diagnosis}?"
            elif 'medication' in concept.lower():
                question = f"Current medications related to {diagnosis} management?"
            elif 'test' in concept.lower() or 'lab' in concept.lower():
                question = f"Recent test results relevant to {diagnosis} assessment?"
            else:
                question = f"Additional information about {concept} in context of {diagnosis}?"
            
            question_info = {
                'agent_id': agent_rep.agent_id,
                'question': question,
                'focus_concept': concept,
                'effectiveness': agent_rep.effectiveness,
                'priority_rank': i + 1
            }
            questions.append(question_info)
        
        return questions
    
    def integrate_with_symbolic_system(self, agentic_outputs, symbolic_missing_info):
      
        
        neural_missing_probs = agentic_outputs['missing_info_predictions']
        neural_missing_concepts = []
        
        for batch_idx in range(neural_missing_probs.shape[0]):
            batch_probs = neural_missing_probs[batch_idx]
            top_missing_indices = torch.topk(batch_probs, k=10).indices
            
            for idx in top_missing_indices:
                concept_name = self.vocab_to_concept.get(idx.item(), f"concept_{idx.item()}")
                if batch_probs[idx] > 0.5:
                    neural_missing_concepts.append(concept_name)
        
        combined_missing = list(set(symbolic_missing_info + neural_missing_concepts))
        neural_confidence = float(torch.mean(agentic_outputs['agent_effectiveness']).item())
        
        return {
            'combined_missing_info': combined_missing,
            'neural_missing_concepts': neural_missing_concepts,
            'symbolic_missing_concepts': symbolic_missing_info,
            'neural_confidence': neural_confidence,
            'num_active_agents': agentic_outputs['num_active_agents'].mean().item(),
            'agentic_enhancement': True
        }


class PatientDocumentProcessor:
    
    
    def __init__(self):
        self.supported_formats = ['.txt', '.pdf', '.docx', '.csv']
        
    def process_multiple_documents(self, document_paths: List[str]) -> Dict:
        """Process multiple patient documents into unified patient case"""
        
        print(f"📄 Processing {len(document_paths)} patient documents...")
        
        unified_case = {
            'patient_id': f"DOC_PATIENT_{len(document_paths)}",
            'symptoms': '',
            'diagnoses': [],
            'procedures': [],
            'medications': [],
            'clinical_notes': [],
            'temporal_sequence': [],
            'known_info': [],
            'admission_info': {'diagnosis': 'multi-document analysis'},
            'document_sources': document_paths
        }
        
        for doc_path in document_paths:
            try:
                doc_content = self._extract_document_content(doc_path)
                self._parse_medical_content(doc_content, unified_case)
            except Exception as e:
                print(f"⚠️ Error processing {doc_path}: {e}")
                continue
        
        
        self._finalize_unified_case(unified_case)
        
        print(f"Unified patient case created from {len(document_paths)} documents")
        return unified_case
    
    def _extract_document_content(self, doc_path: str) -> str:
        
        
        if doc_path.endswith('.txt'):
            with open(doc_path, 'r', encoding='utf-8') as f:
                return f.read()
        elif doc_path.endswith('.csv'):
            import pandas as pd
            df = pd.read_csv(doc_path)
            return df.to_string()
        else:
            
            return f"Document: {doc_path}"
    
    def _parse_medical_content(self, content: str, unified_case: Dict):
        
        
        content_lower = content.lower()
        
       
        symptom_keywords = ['pain', 'ache', 'fever', 'nausea', 'fatigue', 'dizziness', 'shortness of breath']
        found_symptoms = [kw for kw in symptom_keywords if kw in content_lower]
        
        if found_symptoms:
            if unified_case['symptoms']:
                unified_case['symptoms'] += ', ' + ', '.join(found_symptoms)
            else:
                unified_case['symptoms'] = ', '.join(found_symptoms)
        
      
        diagnosis_keywords = ['diagnosis', 'condition', 'disease', 'syndrome', 'disorder']
        for keyword in diagnosis_keywords:
            if keyword in content_lower:
                unified_case['diagnoses'].append(f"documented_{keyword}")
        
        
        procedure_keywords = ['test', 'scan', 'x-ray', 'blood work', 'examination']
        for keyword in procedure_keywords:
            if keyword in content_lower:
                unified_case['procedures'].append(f"performed_{keyword}")
        
        
        medication_keywords = ['medication', 'drug', 'pill', 'treatment', 'prescription']
        for keyword in medication_keywords:
            if keyword in content_lower:
                unified_case['medications'].append(f"prescribed_{keyword}")
        
        
        unified_case['clinical_notes'].append(content[:200] + '...' if len(content) > 200 else content)
    
    def _finalize_unified_case(self, unified_case: Dict):
        """Finalize the unified patient case"""
        
        
        unified_case['diagnoses'] = list(set(unified_case['diagnoses']))
        unified_case['procedures'] = list(set(unified_case['procedures']))
        unified_case['medications'] = list(set(unified_case['medications']))
        
        
        unified_case['known_info'] = (
            unified_case['diagnoses'][:3] + 
            unified_case['procedures'][:3] + 
            unified_case['medications'][:2]
        )
        
        
        unified_case['temporal_sequence'] = (
            ['patient_presentation'] + 
            unified_case['procedures'][:3] + 
            ['diagnosis_established'] + 
            unified_case['medications'][:2]
        )
        
        
        if unified_case['diagnoses']:
            unified_case['admission_info']['diagnosis'] = unified_case['diagnoses'][0]

def test_complete_agentic_integration():
    
    
    print("🧪 TESTING COMPLETE AGENTIC LAYER INTEGRATION")
    print("="*60)
    
    
    agentic_layer = DynamicAgentCreationLayer(
        attention_dim=256,
        max_agents=50,
        agent_embedding_dim=128,
        concept_vocab_size=4005
    )
    
    
    batch_size = 4
    seq_len = 20
    attention_features = torch.randn(batch_size, seq_len, 256)
    
    
    agentic_outputs = agentic_layer(attention_features)
    
    print(f"🤖 Agentic Layer Results:")
    print(f"   Active agents per batch: {agentic_outputs['num_active_agents']}")
    print(f"   Parameters: {agentic_layer._count_parameters():,}")
    
    
    print(f"\n📄 Testing Document Processor:")
    doc_processor = PatientDocumentProcessor()
    
    
    mock_docs = ['patient_report.txt', 'lab_results.csv', 'clinical_notes.txt']
    
    
    mock_case = {
        'patient_id': 'MULTI_DOC_001',
        'symptoms': 'chest pain, shortness of breath',
        'diagnoses': ['acute myocardial infarction'],
        'procedures': ['ECG', 'cardiac_catheterization'],
        'medications': ['aspirin', 'metoprolol'],
        'admission_info': {'diagnosis': 'acute myocardial infarction'},
        'known_info': ['chest pain', 'ECG'],
        'temporal_sequence': ['chest pain', 'ECG', 'diagnosis', 'aspirin']
    }
    
    print(f" Mock unified case created for document processing")
    
    return agentic_layer, agentic_outputs, doc_processor

if __name__ == "__main__":
    test_complete_agentic_integration()