# rare_handler.py 
import numpy as np
from typing import Dict, List, Tuple, Set
from collections import defaultdict, Counter
import logging

logger = logging.getLogger(__name__)

class RareHandler:
    
    
    def __init__(self, mimic_cases: List[Dict]):
        self.mimic_cases = mimic_cases
        self.rare_cache = {}
        self.frequency_threshold = 5  
        self.importance_weights = defaultdict(float)
        
        logger.info(" Rare Medical Relationship Handler initialized")
    
    def find_rare_critical_relationships(self) -> Dict[str, Dict[str, float]]:
        
        
        print(" Finding rare but critical relationships...")
        
        
        cooccurrence = defaultdict(Counter)
        total_cases = len(self.mimic_cases)
        
        for case in self.mimic_cases:
            concepts = self._extract_concepts(case)
            
            
            for i, concept1 in enumerate(concepts):
                for j, concept2 in enumerate(concepts):
                    if i != j:
                        cooccurrence[concept1][concept2] += 1
        
        
        rare_relationships = {}
        
        for concept1, related in cooccurrence.items():
            for concept2, count in related.items():
                frequency = count / total_cases
                
                if 1 <= count < self.frequency_threshold:
                    importance = self._calculate_importance(concept1, concept2, count)
                    
                    if importance > 0.3:  
                        if concept1 not in rare_relationships:
                            rare_relationships[concept1] = {}
                        rare_relationships[concept1][concept2] = importance
        
        print(f" Found {len(rare_relationships)} concepts with rare critical relationships")
        return rare_relationships
    
    def _extract_concepts(self, case: Dict) -> List[str]:
       
        
        concepts = []
        
        if case.get('symptoms'):
            concepts.extend([s.strip().lower() for s in case['symptoms'].split(', ') if s.strip()])
        
        if case.get('diagnoses'):
            concepts.extend([str(d).strip().lower() for d in case['diagnoses'] 
                            if str(d).strip().lower() != 'nan'])
        
        if case.get('procedures'):
            concepts.extend([str(p).strip().lower() for p in case['procedures'] 
                            if str(p).strip().lower() != 'nan'])
        
        if case.get('medications'):
            concepts.extend([str(m).strip().lower() for m in case['medications'][:5] 
                            if str(m).strip().lower() != 'nan'])
        
      
        cleaned = []
        for concept in concepts:
            if concept and len(concept) > 2 and concept.isalpha():
                cleaned.append(concept)
        
        return list(set(cleaned))
    
    def _calculate_importance(self, concept1: str, concept2: str, count: int) -> float:
        
        
        # Rarity score (rare but existing = important)
        rarity_score = 1.0 / max(count, 1)
        
        
        pattern_type = self._infer_pattern(concept1, concept2)
        pattern_scores = {
            'symptom_to_condition': 0.8,
            'condition_to_complication': 0.9,
            'medication_to_effect': 0.7,
            'test_to_finding': 0.6,
            'general': 0.5
        }
        medical_importance = pattern_scores.get(pattern_type, 0.5)
        
        
        temporal_importance = self._calculate_temporal_importance(concept1, concept2)
        
        
        total = (rarity_score * 0.3 + medical_importance * 0.5 + temporal_importance * 0.2)
        
        return min(total, 1.0)
    
    def _infer_pattern(self, concept1: str, concept2: str) -> str:
        
        concept1_contexts = self._get_contexts(concept1)
        concept2_contexts = self._get_contexts(concept2)
        
        # Pattern detection
        if self._is_symptom(concept1) and self._is_diagnosis(concept2):
            return 'symptom_to_condition'
        elif self._is_diagnosis(concept1) and self._is_complication(concept2):
            return 'condition_to_complication'
        elif self._is_medication(concept1) and self._is_symptom(concept2):
            return 'medication_to_effect'
        elif self._is_test(concept1) and self._is_finding(concept2):
            return 'test_to_finding'
        else:
            return 'general'
    
    def _get_contexts(self, concept: str) -> List[str]:
        
        
        contexts = []
        
        for case in self.mimic_cases:
            if concept in str(case.get('symptoms', '')).lower():
                contexts.append('symptom')
            if concept in [str(d).lower() for d in case.get('diagnoses', [])]:
                contexts.append('diagnosis')
            if concept in [str(p).lower() for p in case.get('procedures', [])]:
                contexts.append('procedure')
            if concept in [str(m).lower() for m in case.get('medications', [])]:
                contexts.append('medication')
        
        return contexts
    
    def _is_symptom(self, concept: str) -> bool:
        
        contexts = self._get_contexts(concept)
        return contexts.count('symptom') > len(contexts) * 0.5
    
    def _is_diagnosis(self, concept: str) -> bool:
        
        contexts = self._get_contexts(concept)
        return contexts.count('diagnosis') > len(contexts) * 0.5
    
    def _is_complication(self, concept: str) -> bool:
        
        return self._calculate_temporal_position(concept) > 0.6
    
    def _is_medication(self, concept: str) -> bool:
        
        contexts = self._get_contexts(concept)
        return contexts.count('medication') > len(contexts) * 0.5
    
    def _is_test(self, concept: str) -> bool:
        
        contexts = self._get_contexts(concept)
        return contexts.count('procedure') > len(contexts) * 0.3
    
    def _is_finding(self, concept: str) -> bool:
        
        return self._is_diagnosis(concept) and len(concept) > 5
    
    def _calculate_temporal_importance(self, concept1: str, concept2: str) -> float:
        
        
        
        temporal_patterns = []
        
        for case in self.mimic_cases:
            if self._concept_in_case(concept1, case) and self._concept_in_case(concept2, case):
                sequence = case.get('temporal_sequence', [])
                
                if concept1 in sequence and concept2 in sequence:
                    pos1 = sequence.index(concept1)
                    pos2 = sequence.index(concept2)
                    temporal_patterns.append(pos2 - pos1)
        
        if temporal_patterns:
            avg_diff = np.mean(temporal_patterns)
            consistency = 1.0 - (np.std(temporal_patterns) / (abs(avg_diff) + 1))
            return max(0, consistency)
        
        return 0.3
    
    def _concept_in_case(self, concept: str, case: Dict) -> bool:
        
        
        case_text = ""
        for field in ['symptoms', 'diagnoses', 'procedures', 'medications']:
            if case.get(field):
                if isinstance(case[field], str):
                    case_text += case[field].lower() + " "
                elif isinstance(case[field], list):
                    case_text += " ".join([str(item).lower() for item in case[field]]) + " "
        
        return concept.lower() in case_text
    
    def _calculate_temporal_position(self, concept: str) -> float:
        """Calculate average temporal position (0=early, 1=late)"""
        
        positions = []
        
        for case in self.mimic_cases:
            sequence = case.get('temporal_sequence', [])
            if concept in sequence and len(sequence) > 1:
                pos = sequence.index(concept) / (len(sequence) - 1)
                positions.append(pos)
        
        return np.mean(positions) if positions else 0.5

class TemporalHandler:
    
    
    def __init__(self, mimic_cases: List[Dict]):
        self.mimic_cases = mimic_cases
        self.temporal_patterns = defaultdict(list)
        self.causal_relationships = defaultdict(dict)
        
        logger.info(" Dynamic Temporal Handler initialized")
    
    def learn_temporal_patterns(self) -> Dict[str, List[float]]:
        
        
        print(" Learning temporal patterns...")
        
        for case in self.mimic_cases:
            sequence = case.get('temporal_sequence', [])
            
            if len(sequence) > 1:
                for i, concept in enumerate(sequence):
                    position = i / (len(sequence) - 1)
                    self.temporal_patterns[concept].append(position)
        
       
        avg_positions = {}
        for concept, positions in self.temporal_patterns.items():
            avg_positions[concept] = np.mean(positions)
        
        print(f" Learned temporal patterns for {len(avg_positions)} concepts")
        return avg_positions
    
    def detect_sequence_anomalies(self, case_sequence: List[str]) -> List[Dict]:
        
        
        anomalies = []
        avg_positions = self.learn_temporal_patterns()
        
        for i, concept in enumerate(case_sequence):
            if concept in avg_positions:
                expected_pos = avg_positions[concept]
                actual_pos = i / (len(case_sequence) - 1) if len(case_sequence) > 1 else 0.5
                
                deviation = abs(expected_pos - actual_pos)
                
                if deviation > 0.3:  
                    anomalies.append({
                        'concept': concept,
                        'expected_position': expected_pos,
                        'actual_position': actual_pos,
                        'deviation': deviation,
                        'severity': 'high' if deviation > 0.5 else 'medium'
                    })
        
        return anomalies
    
    def suggest_sequence_corrections(self, anomalies: List[Dict]) -> List[str]:
        
        
        suggestions = []
        
        for anomaly in anomalies:
            concept = anomaly['concept']
            deviation = anomaly['deviation']
            
            if anomaly['actual_position'] > anomaly['expected_position']:
                suggestions.append(f"Consider if {concept} should have occurred earlier in timeline")
            else:
                suggestions.append(f"Verify if {concept} timing is appropriate for this case")
            
            if deviation > 0.5:
                suggestions.append(f"Significant timing anomaly detected for {concept} - review case")
        
        return suggestions

def test_rare_handler():
    
    
    print("🧪 Testing Rare Relationship Handler")
    
    
    mock_cases = [
        {
            'symptoms': 'chest pain, shortness of breath',
            'diagnoses': ['myocardial infarction'],
            'procedures': ['ECG'],
            'medications': ['aspirin'],
            'temporal_sequence': ['chest pain', 'ECG', 'myocardial infarction', 'aspirin']
        },
        {
            'symptoms': 'fever, cough',
            'diagnoses': ['pneumonia'],
            'procedures': ['chest xray'],
            'medications': ['antibiotics'],
            'temporal_sequence': ['fever', 'chest xray', 'pneumonia', 'antibiotics']
        }
    ]
    
    
    rare_handler = RareHandler(mock_cases)
    rare_relationships = rare_handler.find_rare_critical_relationships()
    
    print(f" Found rare relationships: {len(rare_relationships)}")
    for concept, relationships in rare_relationships.items():
        print(f"  {concept}: {relationships}")
    
    
    temporal_handler = TemporalHandler(mock_cases)
    temporal_patterns = temporal_handler.learn_temporal_patterns()
    
    print(f" Learned temporal patterns: {len(temporal_patterns)}")
    for concept, position in temporal_patterns.items():
        print(f"  {concept}: avg position {position:.2f}")
    
   
    test_sequence = ['aspirin', 'chest pain', 'ECG', 'myocardial infarction']  # Wrong order
    anomalies = temporal_handler.detect_sequence_anomalies(test_sequence)
    suggestions = temporal_handler.suggest_sequence_corrections(anomalies)
    
    print(f" Detected {len(anomalies)} temporal anomalies")
    print(f" Generated {len(suggestions)} suggestions")

if __name__ == "__main__":
    test_rare_handler()