import re
from collections import Counter, defaultdict
from typing import Dict, List, Optional, Tuple, Set, Union
import logging

import sys
from pathlib import Path
sys.path.append(str(Path(__file__).parent))

from base_metric import BaseMetric


# Medical domain-specific scorer for radiology reports
class MedicalScorer(BaseMetric):
    
    def __init__(self, use_negation_detection: bool = True, 
                 use_severity_matching: bool = True,
                 normalize_abbreviations: bool = True,
                 logger: Optional[logging.Logger] = None):
        super().__init__("medical", logger)
        self.description = "Medical domain-specific scorer for radiology reports"
        self.metric_type = "clinical"
        
        self.use_negation_detection = use_negation_detection
        self.use_severity_matching = use_severity_matching
        self.normalize_abbreviations = normalize_abbreviations
        
        self._initialize_medical_vocabularies()
        
        self.negation_patterns = [
            r'\bno\b', r'\bnot\b', r'\bnegative\b', r'\babsent\b', r'\bwithout\b',
            r'\bdenies\b', r'\bdenied\b', r'\bunremarkable\b', r'\bnormal\b',
            r'\bno evidence of\b', r'\bno signs of\b', r'\bno indication of\b'
        ]
        
        self.severity_levels = {
            'mild': 1, 'minimal': 1, 'slight': 1, 'small': 1,
            'moderate': 2, 'medium': 2,
            'severe': 3, 'marked': 3, 'significant': 3, 'large': 3,
            'massive': 4, 'extensive': 4, 'complete': 4
        }
        
        self.logger.debug("Initialized Medical scorer with clinical vocabularies")
    
    # Initialize medical vocabularies and mappings
    def _initialize_medical_vocabularies(self) -> None:
        
        self.anatomy_terms = {
            'lung', 'lungs', 'pulmonary', 'respiratory', 'bronchi', 'bronchus',
            'alveoli', 'pleura', 'pleural', 'diaphragm', 'mediastinum',
            'heart', 'cardiac', 'cardio', 'aorta', 'aortic', 'ventricle', 'atrium',
            'ribs', 'sternum', 'clavicle', 'spine', 'vertebra', 'thorax', 'thoracic',
            
            'apex', 'base', 'hilum', 'hila', 'bilateral', 'unilateral',
            'upper_lobe', 'middle_lobe', 'lower_lobe', 'lingula',
            'right_lung', 'left_lung', 'costophrenic', 'cardiophrenic'
        }
        
        self.pathology_terms = {
            'pneumonia', 'infection', 'infiltrate', 'consolidation', 'opacity',
            'effusion', 'fluid', 'edema', 'congestion', 'atelectasis',
            'pneumothorax', 'collapse', 'mass', 'nodule', 'lesion',
            'cardiomegaly', 'enlargement', 'hypertrophy', 'dilation',
            
            'copd', 'emphysema', 'fibrosis', 'scarring', 'granuloma',
            'tuberculosis', 'cancer', 'malignancy', 'metastasis',
            'fracture', 'break', 'displacement', 'deformity'
        }
        
        self.normal_terms = {
            'normal', 'unremarkable', 'clear', 'patent', 'intact',
            'within_normal_limits', 'wnl', 'negative', 'no_abnormalities',
            'stable', 'unchanged', 'appropriate', 'adequate'
        }
        
        self.abbreviations = {
            'wnl': 'within normal limits',
            'copd': 'chronic obstructive pulmonary disease',
            'chf': 'congestive heart failure',
            'pe': 'pulmonary embolism',
            'ptx': 'pneumothorax',
            'cxr': 'chest x-ray',
            'ct': 'computed tomography',
            'mri': 'magnetic resonance imaging',
            'rul': 'right upper lobe',
            'rml': 'right middle lobe',
            'rll': 'right lower lobe',
            'lul': 'left upper lobe',
            'lll': 'left lower lobe',
            'bil': 'bilateral',
            'r/o': 'rule out',
            'h/o': 'history of',
            's/p': 'status post'
        }
        
        self.concept_categories = {
            'anatomy': self.anatomy_terms,
            'pathology': self.pathology_terms,
            'normal': self.normal_terms
        }
    
    # Calculate medical-specific scores between reference and candidate texts
    def calculate(self, reference: str, candidate: str, **kwargs) -> Dict[str, float]:
        is_valid, issues = self.validate_inputs(reference, candidate)
        if not is_valid:
            raise ValueError(f"Invalid inputs for Medical scoring: {issues}")
        
        ref_processed = self._preprocess_medical_text(reference)
        cand_processed = self._preprocess_medical_text(candidate)
        
        ref_concepts = self._extract_medical_concepts(ref_processed)
        cand_concepts = self._extract_medical_concepts(cand_processed)
        
        concept_scores = self._calculate_concept_scores(ref_concepts, cand_concepts)
        
        anatomy_score = self._calculate_anatomy_score(ref_concepts, cand_concepts)
        
        pathology_score = self._calculate_pathology_score(ref_concepts, cand_concepts)
        
        negation_score = self._calculate_negation_score(reference, candidate)
        
        severity_score = self._calculate_severity_score(reference, candidate)
        
        terminology_score = self._calculate_terminology_score(ref_processed, cand_processed)
        
        medical_score = (
            0.25 * concept_scores['concept_f1'] +
            0.20 * anatomy_score +
            0.25 * pathology_score +
            0.15 * negation_score +
            0.10 * severity_score +
            0.05 * terminology_score
        )
        
        scores = {
            "medical_score": medical_score,
            "medical_concept_precision": concept_scores['concept_precision'],
            "medical_concept_recall": concept_scores['concept_recall'],
            "medical_concept_f1": concept_scores['concept_f1'],
            "medical_anatomy_score": anatomy_score,
            "medical_pathology_score": pathology_score,
            "medical_negation_score": negation_score,
            "medical_severity_score": severity_score,
            "medical_terminology_score": terminology_score,
            "reference_concepts": float(len(ref_concepts['all'])),
            "candidate_concepts": float(len(cand_concepts['all'])),
            "concept_overlap": float(len(ref_concepts['all'] & cand_concepts['all']))
        }
        
        return scores
    
    # Preprocess medical text with abbreviation normalization
    def _preprocess_medical_text(self, text: str) -> str:
        text = text.lower()
        
        if self.normalize_abbreviations:
            for abbrev, expansion in self.abbreviations.items():
                pattern = r'\b' + re.escape(abbrev) + r'\b'
                text = re.sub(pattern, expansion, text)
        
        text = re.sub(r'([a-z])\.([a-z])', r'\1. \2', text)
        
        text = re.sub(r'\s+', ' ', text).strip()
        
        return text
    
    # Extract medical concepts from text
    def _extract_medical_concepts(self, text: str) -> Dict[str, Set[str]]:
        concepts = {
            'anatomy': set(),
            'pathology': set(),
            'normal': set(),
            'all': set()
        }
        
        tokens = re.findall(r'\b\w+\b', text)
        
        for category, terms in self.concept_categories.items():
            for token in tokens:
                if token in terms:
                    concepts[category].add(token)
                    concepts['all'].add(token)
        
        for category, terms in self.concept_categories.items():
            for term in terms:
                if '_' in term:
                    term_pattern = term.replace('_', r'\s+')
                    if re.search(r'\b' + term_pattern + r'\b', text):
                        concepts[category].add(term)
                        concepts['all'].add(term)
        
        return concepts
    
    # Calculate concept-based precision, recall, and F1 scores
    def _calculate_concept_scores(self, ref_concepts: Dict[str, Set[str]], 
                                 cand_concepts: Dict[str, Set[str]]) -> Dict[str, float]:
        ref_all = ref_concepts['all']
        cand_all = cand_concepts['all']
        
        if not ref_all and not cand_all:
            return {'concept_precision': 1.0, 'concept_recall': 1.0, 'concept_f1': 1.0}
        
        if not ref_all:
            return {'concept_precision': 0.0, 'concept_recall': 0.0, 'concept_f1': 0.0}
        
        if not cand_all:
            return {'concept_precision': 0.0, 'concept_recall': 0.0, 'concept_f1': 0.0}
        
        overlap = ref_all & cand_all
        
        precision = len(overlap) / len(cand_all)
        recall = len(overlap) / len(ref_all)
        
        if precision + recall > 0:
            f1 = 2 * precision * recall / (precision + recall)
        else:
            f1 = 0.0
        
        return {
            'concept_precision': precision,
            'concept_recall': recall,
            'concept_f1': f1
        }
    
    # Calculate anatomy-specific matching score
    def _calculate_anatomy_score(self, ref_concepts: Dict[str, Set[str]], 
                                cand_concepts: Dict[str, Set[str]]) -> float:
        ref_anatomy = ref_concepts['anatomy']
        cand_anatomy = cand_concepts['anatomy']
        
        if not ref_anatomy and not cand_anatomy:
            return 1.0
        
        if not ref_anatomy or not cand_anatomy:
            return 0.0
        
        overlap = ref_anatomy & cand_anatomy
        union = ref_anatomy | cand_anatomy
        
        return len(overlap) / len(union) if union else 0.0
    
    # Calculate pathology-specific matching score
    def _calculate_pathology_score(self, ref_concepts: Dict[str, Set[str]], 
                                  cand_concepts: Dict[str, Set[str]]) -> float:
        ref_pathology = ref_concepts['pathology']
        cand_pathology = cand_concepts['pathology']
        
        ref_normal = ref_concepts['normal']
        cand_normal = cand_concepts['normal']
        
        if ref_normal and cand_normal and not ref_pathology and not cand_pathology:
            return 1.0
        
        if (ref_normal and cand_pathology) or (ref_pathology and cand_normal):
            return 0.0
        
        if ref_pathology and cand_pathology:
            overlap = ref_pathology & cand_pathology
            union = ref_pathology | cand_pathology
            return len(overlap) / len(union) if union else 0.0
        
        return 0.5
    
    # Calculate negation consistency score
    def _calculate_negation_score(self, reference: str, candidate: str) -> float:
        if not self.use_negation_detection:
            return 1.0
        
        ref_negations = self._detect_negations(reference)
        cand_negations = self._detect_negations(candidate)
        
        if ref_negations == cand_negations:
            return 1.0
        elif abs(ref_negations - cand_negations) <= 1:
            return 0.5
        else:
            return 0.0
    
    def _detect_negations(self, text: str) -> int:
        text = text.lower()
        negation_count = 0
        
        for pattern in self.negation_patterns:
            matches = re.findall(pattern, text)
            negation_count += len(matches)
        
        return negation_count
    
    def _calculate_severity_score(self, reference: str, candidate: str) -> float:
        if not self.use_severity_matching:
            return 1.0
        
        ref_severity = self._extract_severity(reference)
        cand_severity = self._extract_severity(candidate)
        
        if ref_severity == 0 and cand_severity == 0:
            return 1.0
        
        if ref_severity == 0 or cand_severity == 0:
            return 0.5
        
        severity_diff = abs(ref_severity - cand_severity)
        
        if severity_diff == 0:
            return 1.0
        elif severity_diff == 1:
            return 0.7
        elif severity_diff == 2:
            return 0.4
        else:
            return 0.1
    
    def _extract_severity(self, text: str) -> int:
        text = text.lower()
        for severity_term, level in self.severity_levels.items():
            if re.search(r'\b' + severity_term + r'\b', text):
                return level
        
        return 0
    
    def _calculate_terminology_score(self, ref_text: str, cand_text: str) -> float:
        ref_terms = set()
        cand_terms = set()
        
        ref_tokens = re.findall(r'\b\w+\b', ref_text)
        cand_tokens = re.findall(r'\b\w+\b', cand_text)
        
        all_medical_terms = set()
        for terms in self.concept_categories.values():
            all_medical_terms.update(terms)
        
        for token in ref_tokens:
            if token in all_medical_terms:
                ref_terms.add(token)
        
        for token in cand_tokens:
            if token in all_medical_terms:
                cand_terms.add(token)
        
        if not ref_terms and not cand_terms:
            return 1.0
        
        overlap = ref_terms & cand_terms
        union = ref_terms | cand_terms
        
        return len(overlap) / len(union) if union else 0.0
    
    def get_name(self) -> str:
        return "Medical"
    
    def get_description(self) -> str:
        return (f"Medical domain-specific scorer for radiology reports "
                f"(negation={self.use_negation_detection}, "
                f"severity={self.use_severity_matching}, "
                f"abbreviations={self.normalize_abbreviations})")


def test_medical_scorer():
    import logging
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger("test")
    
    try:
        print("Testing Medical Scorer...")
        
        medical = MedicalScorer(
            use_negation_detection=True,
            use_severity_matching=True,
            normalize_abbreviations=True,
            logger=logger
        )
        
        print(f"Metric name: {medical.get_name()}")
        print(f"Metric description: {medical.get_description()}")
        
        test_cases = [
            {
                "reference": "Normal chest radiograph with clear lung fields.",
                "candidate": "Normal chest radiograph with clear lung fields.",
                "description": "Perfect medical match"
            },
            
            {
                "reference": "Bilateral lung infiltrates consistent with pneumonia.",
                "candidate": "Bilateral pulmonary infiltrates suggesting pneumonia.",
                "description": "Anatomy and pathology match"
            },
            
            {
                "reference": "Normal chest radiograph with clear lung fields.",
                "candidate": "Bilateral pneumonia with consolidation.",
                "description": "Normal vs pathological mismatch"
            },
            
            {
                "reference": "Mild cardiomegaly with no acute abnormalities.",
                "candidate": "Moderate heart enlargement without acute findings.",
                "description": "Severity level difference"
            },
            
            {
                "reference": "No evidence of pneumonia or effusion.",
                "candidate": "No pneumonia or pleural fluid identified.",
                "description": "Negation consistency"
            },
            
            {
                "reference": "CXR shows wnl findings in rul and lul.",
                "candidate": "Chest x-ray shows within normal limits findings in right upper lobe and left upper lobe.",
                "description": "Medical abbreviation normalization"
            }
        ]
        
        print("\n--- Medical Score Tests ---")
        for i, test_case in enumerate(test_cases, 1):
            scores = medical.calculate(test_case["reference"], test_case["candidate"])
            
            print(f"\nTest {i}: {test_case['description']}")
            print(f"Reference: '{test_case['reference']}'")
            print(f"Candidate: '{test_case['candidate']}'")
            print(f"Medical Score: {scores['medical_score']:.4f}")
            print(f"Concept F1: {scores['medical_concept_f1']:.4f}")
            print(f"Anatomy Score: {scores['medical_anatomy_score']:.4f}")
            print(f"Pathology Score: {scores['medical_pathology_score']:.4f}")
            print(f"Negation Score: {scores['medical_negation_score']:.4f}")
            print(f"Severity Score: {scores['medical_severity_score']:.4f}")
        
        print("\nAll Medical scorer tests completed!")
        return True
        
    except Exception as e:
        print(f"Test failed: {e}")
        import traceback
        traceback.print_exc()
        return False


if __name__ == "__main__":
    success = test_medical_scorer()
    
    if success:
        print("\nMedical Scorer tests passed!")
    else:
        print("\nSome tests failed!") 