import re
import string
import logging
from typing import Dict, List, Optional, Tuple
import unicodedata


# Cleans and normalizes radiologist report text for evaluation
class ReportTextCleaner:
    
    def __init__(self, logger: Optional[logging.Logger] = None):
        self.logger = logger or logging.getLogger(__name__)
        
        # Common patterns to remove or normalize
        self.timestamp_patterns = [
            r'\d{1,2}:\d{2}:\d{2}',  # HH:MM:SS
            r'\d{1,2}:\d{2}',        # HH:MM
            r'\[\d{1,2}:\d{2}:\d{2}\]',  # [HH:MM:SS]
            r'\(\d{1,2}:\d{2}:\d{2}\)',  # (HH:MM:SS)
            r'<\d{1,2}:\d{2}:\d{2}>',    # <HH:MM:SS>
        ]
        
        # Patterns for speaker annotations
        self.speaker_patterns = [
            r'\b(doctor|dr|radiologist|physician|md)[\s]*:',
            r'\b(resident|fellow|attending)[\s]*:',
            r'\b(speaker|voice|person)\s*\d*[\s]*:',
            r'\b(male|female)\s*(voice|speaker)[\s]*:',
        ]
        
        # Medical abbreviation normalization
        self.medical_normalizations = {
            # Common variations
            'x-ray': 'radiograph',
            'xray': 'radiograph',
            'cxr': 'chest radiograph',
            'ap': 'anteroposterior',
            'pa': 'posteroanterior',
            'lat': 'lateral',
            'wnl': 'within normal limits',
            'nml': 'normal',
            'nl': 'normal',
            'abn': 'abnormal',
            'neg': 'negative',
            'pos': 'positive',
            'bil': 'bilateral',
            'bilat': 'bilateral',
            'r/o': 'rule out',
            'r o': 'rule out',
            'w/': 'with',
            'w/o': 'without',
            'h/o': 'history of',
            'pt': 'patient',
            'pts': 'patients',
            'yr': 'year',
            'yrs': 'years',
            'mo': 'month',
            'mos': 'months',
            'wk': 'week',
            'wks': 'weeks',
            'yo': 'year old',
            'y/o': 'year old',
            'y.o.': 'year old',
        }
        
        # Punctuation patterns
        self.punctuation_patterns = {
            r'\.{2,}': '.',  # Multiple periods
            r',{2,}': ',',   # Multiple commas
            r'\s+([,.;:!?])': r'\1',  # Space before punctuation
            r'([,.;:!?])\s*([,.;:!?])': r'\1 \2',  # Multiple punctuation
        }
    
    # Remove timestamp patterns from text
    def remove_timestamps(self, text: str) -> str:
        cleaned_text = text
        
        for pattern in self.timestamp_patterns:
            cleaned_text = re.sub(pattern, '', cleaned_text, flags=re.IGNORECASE)
        
        return cleaned_text
    
    # Remove speaker annotations and labels from text
    def remove_speaker_annotations(self, text: str) -> str:
        cleaned_text = text
        
        for pattern in self.speaker_patterns:
            cleaned_text = re.sub(pattern, '', cleaned_text, flags=re.IGNORECASE)
        
        return cleaned_text
    
    # Normalize whitespace in text
    def normalize_whitespace(self, text: str) -> str:
        text = re.sub(r'\s+', ' ', text)
        text = text.strip()
        text = re.sub(r'\r\n', '\n', text)
        text = re.sub(r'\r', '\n', text)
        text = re.sub(r'\n+', ' ', text)
        
        return text
    
    # Normalize punctuation in text
    def normalize_punctuation(self, text: str) -> str:
        cleaned_text = text
        
        for pattern, replacement in self.punctuation_patterns.items():
            cleaned_text = re.sub(pattern, replacement, cleaned_text)
        
        if cleaned_text and not cleaned_text.rstrip().endswith(('.', '!', '?')):
            cleaned_text = cleaned_text.rstrip() + '.'
        
        return cleaned_text
    
    # Normalize medical abbreviations and terms
    def normalize_medical_terms(self, text: str) -> str:
        cleaned_text = text.lower()
        
        for abbrev, expansion in self.medical_normalizations.items():
            pattern = r'\b' + re.escape(abbrev) + r'\b'
            cleaned_text = re.sub(pattern, expansion, cleaned_text, flags=re.IGNORECASE)
        
        return cleaned_text
    
    # Remove or normalize special characters
    def remove_special_characters(self, text: str, keep_medical: bool = True) -> str:
        text = unicodedata.normalize('NFKD', text)
        
        if keep_medical:
            medical_chars = r'[%+\-/°μ×÷±≤≥<>]'
            text = re.sub(r'[^\w\s.,;:!?()[\]{}"\'' + medical_chars + r']', ' ', text)
        else:
            text = re.sub(r'[^\w\s.,;:!?()[\]{}"\']', ' ', text)
        
        return text
    
    # Perform comprehensive text cleaning
    def clean_text_comprehensive(self, text: str, 
                               remove_timestamps: bool = True,
                               remove_speakers: bool = True,
                               normalize_medical: bool = True,
                               normalize_punctuation: bool = True,
                               preserve_case: bool = False) -> str:
        if not text:
            return ""
        
        cleaned_text = text
        
        if remove_timestamps:
            cleaned_text = self.remove_timestamps(cleaned_text)
        
        if remove_speakers:
            cleaned_text = self.remove_speaker_annotations(cleaned_text)
        
        cleaned_text = self.remove_special_characters(cleaned_text, keep_medical=True)
        cleaned_text = self.normalize_whitespace(cleaned_text)
        
        if normalize_medical:
            cleaned_text = self.normalize_medical_terms(cleaned_text)
        elif not preserve_case:
            cleaned_text = cleaned_text.lower()
        
        if normalize_punctuation:
            cleaned_text = self.normalize_punctuation(cleaned_text)
        
        cleaned_text = self.normalize_whitespace(cleaned_text)
        
        return cleaned_text
    
    # Clean a batch of texts
    def clean_batch(self, text_batch: Dict[str, str], **kwargs) -> Dict[str, Dict]:
        cleaned_batch = {}
        
        self.logger.info(f"Cleaning {len(text_batch)} texts...")
        
        for text_id, original_text in text_batch.items():
            try:
                cleaned_text = self.clean_text_comprehensive(original_text, **kwargs)
                stats = self._calculate_cleaning_stats(original_text, cleaned_text)
                
                cleaned_batch[text_id] = {
                    "original_text": original_text,
                    "cleaned_text": cleaned_text,
                    "original_length": len(original_text),
                    "cleaned_length": len(cleaned_text),
                    "length_reduction": stats["length_reduction"],
                    "word_count_change": stats["word_count_change"],
                    "cleaning_applied": stats["cleaning_applied"]
                }
                
            except Exception as e:
                self.logger.error(f"Error cleaning text {text_id}: {e}")
                cleaned_batch[text_id] = {
                    "original_text": original_text,
                    "cleaned_text": original_text,
                    "original_length": len(original_text),
                    "cleaned_length": len(original_text),
                    "length_reduction": 0,
                    "word_count_change": 0,
                    "cleaning_applied": [],
                    "error": str(e)
                }
        
        self._log_batch_cleaning_stats(cleaned_batch)
        
        return cleaned_batch
    
    # Calculate statistics about the cleaning process
    def _calculate_cleaning_stats(self, original: str, cleaned: str) -> Dict:
        original_words = len(original.split()) if original else 0
        cleaned_words = len(cleaned.split()) if cleaned else 0
        
        length_reduction = len(original) - len(cleaned)
        word_count_change = original_words - cleaned_words
        
        cleaning_applied = []
        
        if any(re.search(pattern, original) for pattern in self.timestamp_patterns):
            cleaning_applied.append("timestamps_removed")
        
        if any(re.search(pattern, original, re.IGNORECASE) for pattern in self.speaker_patterns):
            cleaning_applied.append("speakers_removed")
        
        if original != original.lower() and cleaned == cleaned.lower():
            cleaning_applied.append("case_normalized")
        
        if re.search(r'\s{2,}', original) and not re.search(r'\s{2,}', cleaned):
            cleaning_applied.append("whitespace_normalized")
        
        if len(re.findall(r'[.!?]', original)) != len(re.findall(r'[.!?]', cleaned)):
            cleaning_applied.append("punctuation_normalized")
        
        return {
            "length_reduction": length_reduction,
            "word_count_change": word_count_change,
            "cleaning_applied": cleaning_applied
        }
    
    # Log statistics about batch cleaning results
    def _log_batch_cleaning_stats(self, cleaned_batch: Dict[str, Dict]) -> None:
        if not cleaned_batch:
            return
        
        total_texts = len(cleaned_batch)
        successful_cleanings = sum(1 for result in cleaned_batch.values() if "error" not in result)
        
        total_length_reduction = sum(result["length_reduction"] for result in cleaned_batch.values())
        total_word_reduction = sum(result["word_count_change"] for result in cleaned_batch.values())
        
        avg_length_reduction = total_length_reduction / total_texts
        avg_word_reduction = total_word_reduction / total_texts
        
        all_operations = []
        for result in cleaned_batch.values():
            all_operations.extend(result.get("cleaning_applied", []))
        
        operation_counts = {}
        for op in all_operations:
            operation_counts[op] = operation_counts.get(op, 0) + 1
        
        self.logger.info(f"Text cleaning statistics:")
        self.logger.info(f"  Total texts: {total_texts}, Successful: {successful_cleanings}")
        self.logger.info(f"  Average length reduction: {avg_length_reduction:.1f} characters")
        self.logger.info(f"  Average word reduction: {avg_word_reduction:.1f} words")
        
        if operation_counts:
            self.logger.info(f"  Cleaning operations applied: {operation_counts}")
    
    # Validate that text cleaning was successful and appropriate
    def validate_cleaned_text(self, original: str, cleaned: str) -> Tuple[bool, List[str]]:
        issues = []
        
        if len(cleaned) < len(original) * 0.3:
            issues.append("excessive_text_reduction")
        
        if original and not cleaned:
            issues.append("text_completely_removed")
        
        original_medical_words = self._count_medical_words(original)
        cleaned_medical_words = self._count_medical_words(cleaned)
        
        if original_medical_words > 0 and cleaned_medical_words < original_medical_words * 0.5:
            issues.append("medical_content_lost")
        
        if cleaned and not re.search(r'[.!?]$', cleaned.strip()):
            issues.append("missing_sentence_ending")
        
        if re.search(r'\s{2,}', cleaned):
            issues.append("excessive_whitespace_remaining")
        
        is_valid = len(issues) == 0
        return is_valid, issues
    
    # Count medical words in text
    def _count_medical_words(self, text: str) -> int:
        medical_keywords = [
            'chest', 'lung', 'heart', 'radiograph', 'x-ray', 'patient',
            'normal', 'abnormal', 'examination', 'findings', 'impression',
            'bilateral', 'cardiopulmonary', 'thoracic', 'pulmonary',
            'cardiac', 'mediastinal', 'pleural', 'pneumonia', 'pneumothorax'
        ]
        
        text_lower = text.lower()
        count = 0
        
        for keyword in medical_keywords:
            count += len(re.findall(r'\b' + re.escape(keyword) + r'\b', text_lower))
        
        return count


# Test the Report Text Cleaner functionality
def test_text_cleaner():
    import logging
    
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger("test")
    
    test_samples = [
        "Doctor: 09:15:30 The chest X-ray shows normal lung fields. Radiologist: 09:16:45 No acute abnormalities seen.",
        "Normal   chest  radiograph  .  .  .  The  lungs  are  clear    bilaterally,,,  heart  size  WNL.",
        "CXR shows bil clear lungs. Pt is 45 y/o w/ h/o COPD. R/O PNA. Heart size nl.",
        "Chest X-ray: NORMAL!!! No acute findings... Patient's condition: STABLE @@@ Temperature: 98.6°F",
        "CHEST RADIOGRAPH   demonstrates   NORMAL   cardiac   silhouette   AND   clear   LUNG   fields   bilaterally   .",
        "Normal chest radiograph with clear lung fields bilaterally.",
        "AP and lateral chest radiographs demonstrate normal cardiac silhouette. The lungs are clear bilaterally. No pleural effusions or pneumothorax. Bony structures are intact.",
        "Speaker 1: 12:34:56 Um, the, uh, chest... Speaker 2: 12:35:10 shows normal... Speaker 1: 12:35:20 findings.",
        "",
        "Radiológica normal. Pulmões limpos. Coração—tamanho normal."
    ]
    
    try:
        cleaner = ReportTextCleaner(logger)
        
        print("Testing Report Text Cleaner...")
        
        for i, sample in enumerate(test_samples, 1):
            print(f"\n--- Test Sample {i} ---")
            print(f"Original: '{sample[:60]}{'...' if len(sample) > 60 else ''}'")
            
            cleaned = cleaner.clean_text_comprehensive(sample)
            print(f"Cleaned:  '{cleaned[:60]}{'...' if len(cleaned) > 60 else ''}'")
            
            length_change = len(sample) - len(cleaned)
            print(f"Length change: {length_change} characters")
            
            is_valid, issues = cleaner.validate_cleaned_text(sample, cleaned)
            print(f"Valid: {is_valid}")
            if issues:
                print(f"Issues: {issues}")
        
        print(f"\n--- Batch Cleaning Test ---")
        batch_dict = {f"sample_{i}": sample for i, sample in enumerate(test_samples, 1)}
        batch_results = cleaner.clean_batch(batch_dict)
        
        print(f"Batch processed: {len(batch_results)} items")
        
        print("\nAll text cleaning tests completed!")
        return True
        
    except Exception as e:
        print(f"Test failed: {e}")
        return False


# Run tests when script is executed directly
if __name__ == "__main__":
    success = test_text_cleaner()
    
    if success:
        print("\nText Cleaner tests passed!")
    else:
        print("\nSome tests failed!") 