import re
import logging
from typing import Dict, List, Optional, Union, Tuple
import json


# Extracts and processes text from radiologist transcript data
class ReportTextExtractor:
    
    def __init__(self, logger: Optional[logging.Logger] = None):
        self.logger = logger or logging.getLogger(__name__)
        
        self.medical_abbreviations = {
            'AP': 'anteroposterior',
            'PA': 'posteroanterior',
            'LAT': 'lateral',
            'CXR': 'chest X-ray',
            'NPO': 'nothing by mouth',
            'SOB': 'shortness of breath',
            'DOE': 'dyspnea on exertion',
            'CHF': 'congestive heart failure',
            'COPD': 'chronic obstructive pulmonary disease',
            'PE': 'pulmonary embolism',
            'PNA': 'pneumonia',
            'PTX': 'pneumothorax',
            'R/O': 'rule out',
            'WNL': 'within normal limits'
        }
    
    # Extract the full report text from transcript data
    def extract_full_text(self, transcript_data: Dict) -> str:
        if "full_text" in transcript_data and transcript_data["full_text"]:
            text = transcript_data["full_text"].strip()
            if text:
                self.logger.debug(f"Extracted full_text: {len(text)} characters")
                return text
        
        if "time_stamped_text" in transcript_data:
            reconstructed = self.reconstruct_from_timestamps(transcript_data["time_stamped_text"])
            if reconstructed:
                self.logger.debug(f"Reconstructed from timestamps: {len(reconstructed)} characters")
                return reconstructed
        
        text_fields = ["text", "report", "content", "transcript", "description"]
        for field in text_fields:
            if field in transcript_data and transcript_data[field]:
                text = str(transcript_data[field]).strip()
                if text:
                    self.logger.debug(f"Extracted from {field}: {len(text)} characters")
                    return text
        
        self.logger.warning("No text found in transcript data")
        return ""
    
    # Reconstruct full text from time-stamped segments
    def reconstruct_from_timestamps(self, time_stamped_data: List[Dict]) -> str:
        if not time_stamped_data:
            return ""
        
        text_segments = []
        
        for segment in time_stamped_data:
            if isinstance(segment, dict):
                text_fields = ["text", "content", "transcript", "words", "phrase"]
                for field in text_fields:
                    if field in segment and segment[field]:
                        text_segments.append(str(segment[field]).strip())
                        break
            elif isinstance(segment, str):
                text_segments.append(segment.strip())
        
        reconstructed = " ".join(text_segments)
        return reconstructed.strip()
    
    # Extract text along with metadata about the extraction
    def extract_text_with_metadata(self, transcript_data: Dict) -> Dict[str, Union[str, int, bool]]:
        full_text = self.extract_full_text(transcript_data)
        
        metadata = {
            "text": full_text,
            "text_length": len(full_text),
            "is_empty": len(full_text) == 0,
            "has_full_text_field": "full_text" in transcript_data,
            "has_timestamps": "time_stamped_text" in transcript_data,
            "timestamp_count": len(transcript_data.get("time_stamped_text", [])),
            "extraction_method": self._determine_extraction_method(transcript_data, full_text),
            "word_count": len(full_text.split()) if full_text else 0,
            "sentence_count": len([s for s in full_text.split('.') if s.strip()]) if full_text else 0
        }
        
        metadata["potential_issues"] = self._detect_text_issues(full_text, transcript_data)
        
        return metadata
    
    # Determine which method was used to extract the text
    def _determine_extraction_method(self, transcript_data: Dict, extracted_text: str) -> str:
        if not extracted_text:
            return "none"
        
        if "full_text" in transcript_data and transcript_data.get("full_text", "").strip():
            return "full_text"
        
        if "time_stamped_text" in transcript_data:
            return "timestamp_reconstruction"
        
        text_fields = ["text", "report", "content", "transcript", "description"]
        for field in text_fields:
            if field in transcript_data and transcript_data.get(field, ""):
                return f"fallback_{field}"
        
        return "unknown"
    
    # Detect potential issues with extracted text
    def _detect_text_issues(self, text: str, transcript_data: Dict) -> List[str]:
        issues = []
        
        if not text:
            issues.append("empty_text")
            return issues
        
        if len(text) < 10:
            issues.append("very_short_text")
        
        if len(text) > 5000:
            issues.append("very_long_text")
        
        if re.search(r'[^\x00-\x7F]+', text):
            issues.append("non_ascii_characters")
        
        words = text.lower().split()
        if len(words) > 10:
            word_freq = {}
            for word in words:
                word_freq[word] = word_freq.get(word, 0) + 1
            
            max_freq = max(word_freq.values())
            if max_freq > len(words) * 0.3:
                issues.append("excessive_repetition")
        
        if text and not text.rstrip().endswith(('.', '!', '?')):
            issues.append("incomplete_sentence")
        
        if re.search(r'\d{1,2}:\d{2}', text):
            issues.append("timestamp_artifacts")
        
        if any(char in text for char in ['{', '}', '[', ']', '":']):
            issues.append("json_artifacts")
        
        return issues
    
    # Extract text from a batch of transcripts
    def extract_batch(self, transcript_batch: Dict[str, Dict]) -> Dict[str, Dict]:
        extracted_batch = {}
        
        self.logger.info(f"Extracting text from {len(transcript_batch)} transcripts...")
        
        for image_id, transcript_data in transcript_batch.items():
            try:
                raw_data = transcript_data.get("raw_data", transcript_data)
                extraction_result = self.extract_text_with_metadata(raw_data)
                
                extraction_result.update({
                    "image_id": image_id,
                    "source_type": transcript_data.get("source_type", "unknown"),
                    "file_path": transcript_data.get("file_path", ""),
                    "original_text_length": transcript_data.get("text_length", 0)
                })
                
                extracted_batch[image_id] = extraction_result
                
            except Exception as e:
                self.logger.error(f"Error extracting text for {image_id}: {e}")
                extracted_batch[image_id] = {
                    "text": "",
                    "text_length": 0,
                    "is_empty": True,
                    "extraction_method": "error",
                    "potential_issues": ["extraction_error"],
                    "error": str(e)
                }
        
        self._log_extraction_statistics(extracted_batch)
        
        return extracted_batch
    
    # Log statistics about the text extraction results
    def _log_extraction_statistics(self, extracted_batch: Dict[str, Dict]) -> None:
        if not extracted_batch:
            return
        
        total = len(extracted_batch)
        successful = sum(1 for result in extracted_batch.values() if not result["is_empty"])
        empty = total - successful
        
        text_lengths = [result["text_length"] for result in extracted_batch.values() 
                       if not result["is_empty"]]
        
        if text_lengths:
            avg_length = sum(text_lengths) / len(text_lengths)
            min_length = min(text_lengths)
            max_length = max(text_lengths)
        else:
            avg_length = min_length = max_length = 0
        
        methods = {}
        for result in extracted_batch.values():
            method = result.get("extraction_method", "unknown")
            methods[method] = methods.get(method, 0) + 1
        
        all_issues = []
        for result in extracted_batch.values():
            all_issues.extend(result.get("potential_issues", []))
        
        issue_counts = {}
        for issue in all_issues:
            issue_counts[issue] = issue_counts.get(issue, 0) + 1
        
        self.logger.info(f"Text extraction statistics:")
        self.logger.info(f"  Total: {total}, Successful: {successful}, Empty: {empty}")
        self.logger.info(f"  Text lengths - Avg: {avg_length:.1f}, Min: {min_length}, Max: {max_length}")
        self.logger.info(f"  Extraction methods: {methods}")
        
        if issue_counts:
            self.logger.info(f"  Issues found: {issue_counts}")
    
    # Validate extracted text for quality and completeness
    def validate_extracted_text(self, text: str, min_length: int = 5, 
                               max_length: int = 10000) -> Tuple[bool, List[str]]:
        issues = []
        
        if len(text) < min_length:
            issues.append(f"text_too_short (< {min_length} chars)")
        
        if len(text) > max_length:
            issues.append(f"text_too_long (> {max_length} chars)")
        
        medical_terms = [
            "chest", "lung", "heart", "normal", "abnormal", "radiograph",
            "x-ray", "impression", "findings", "patient", "examination"
        ]
        
        text_lower = text.lower()
        has_medical_content = any(term in text_lower for term in medical_terms)
        
        if not has_medical_content and len(text) > 20:
            issues.append("no_medical_content_detected")
        
        if text and not re.search(r'[.!?]', text):
            issues.append("no_sentence_endings")
        
        is_valid = len(issues) == 0
        return is_valid, issues


# Test the Report Text Extractor functionality
def test_text_extractor():
    import logging
    
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger("test")
    
    test_samples = [
        {
            "full_text": "Normal chest radiograph. The lungs are clear bilaterally. The heart size is normal. No acute abnormalities identified.",
            "time_stamped_text": []
        },
        {
            "full_text": "",
            "time_stamped_text": [
                {"text": "Normal chest"},
                {"text": "radiograph showing"},
                {"text": "clear lungs"}
            ]
        },
        {
            "report": "Chest X-ray demonstrates normal cardiac silhouette and clear lung fields.",
            "time_stamped_text": []
        },
        {
            "full_text": None,
            "time_stamped_text": "not_a_list"
        }
    ]
    
    try:
        extractor = ReportTextExtractor(logger)
        
        print("Testing Report Text Extractor...")
        
        for i, sample in enumerate(test_samples, 1):
            print(f"\n--- Test Sample {i} ---")
            
            text = extractor.extract_full_text(sample)
            print(f"Extracted text: '{text[:50]}{'...' if len(text) > 50 else ''}'")
            
            metadata = extractor.extract_text_with_metadata(sample)
            print(f"Text length: {metadata['text_length']}")
            print(f"Extraction method: {metadata['extraction_method']}")
            print(f"Issues: {metadata['potential_issues']}")
            
            is_valid, issues = extractor.validate_extracted_text(text)
            print(f"Valid: {is_valid}, Validation issues: {issues}")
        
        print("\nAll text extraction tests completed!")
        return True
        
    except Exception as e:
        print(f"Test failed: {e}")
        return False


# Run tests when script is executed directly
if __name__ == "__main__":
    success = test_text_extractor()
    
    if success:
        print("\nText Extractor tests passed!")
    else:
        print("\nSome tests failed!") 