import pandas as pd
import json
import re
import os
from pathlib import Path
from collections import Counter, defaultdict
from typing import Set, List, Dict, Tuple
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize, sent_tokenize
import string

try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    nltk.download('punkt')

try:
    nltk.data.find('corpora/stopwords')
except LookupError:
    nltk.download('stopwords')

class MedicalKeywordExtractor:
    # Initializes the medical keyword extractor
    def __init__(self):
        self.stop_words = set(stopwords.words('english'))
        
        self.medical_stop_words = {
            'patient', 'patients', 'study', 'studies', 'image', 'images', 'x-ray', 'xray',
            'chest', 'radiograph', 'examination', 'exam', 'finding', 'findings', 
            'show', 'shows', 'shown', 'demonstrate', 'demonstrates', 'appear', 'appears',
            'seen', 'visualized', 'noted', 'observed', 'present', 'absent', 'normal',
            'abnormal', 'within', 'limits', 'stable', 'unchanged', 'compared', 'prior',
            'previous', 'again', 'still', 'continue', 'continues', 'recommend', 'suggests'
        }
        
        self.all_stop_words = self.stop_words.union(self.medical_stop_words)
        
        self.keyword_categories = {
            'anatomical': set(),
            'pathological': set(), 
            'descriptive': set(),
            'devices': set(),
            'procedures': set()
        }
        
        self.medical_patterns = {
            'anatomical': [
                r'\b(?:left|right|bilateral)\s+(?:lung|lobe|hemithorax|pleura|diaphragm)\b',
                r'\b(?:upper|middle|lower)\s+lobe\b',
                r'\bcardiac\s+silhouette\b',
                r'\b(?:mediastinum|hilum|hila|trachea|bronchi|bronchus)\b',
                r'\bcostophrenic\s+angle\b',
                r'\bpulmonary\s+vasculature\b',
                r'\b(?:aorta|heart|pericardium)\b'
            ],
            'pathological': [
                r'\b(?:pneumonia|consolidation|infiltrat\w+|opacity|opacities)\b',
                r'\b(?:atelectasis|pneumothorax|effusion)\b',
                r'\b(?:cardiomegaly|edema|mass|nodule|lesion)\b',
                r'\b(?:emphysema|fibrosis|scarring|calcification)\b',
                r'\b(?:pneumomediastinum|pneumopericardium)\b'
            ],
            'descriptive': [
                r'\b(?:mild|moderate|severe|extensive|minimal|small|large)\b',
                r'\b(?:acute|chronic|subacute|new|old|stable)\b',
                r'\b(?:bilateral|unilateral|focal|diffuse|multifocal)\b',
                r'\b(?:basilar|apical|peripheral|central|hilar)\b',
                r'\b(?:clear|unremarkable|improved|worsened|progressive)\b'
            ]
        }
        
        self.extracted_keywords = {
            'reports': Counter(),
            'transcripts': Counter(),
            'conditions': Counter(),
            'combined': Counter()
        }
        
        self.processing_stats = {
            'reports_processed': 0,
            'transcripts_processed': 0,
            'total_samples': 0,
            'errors': []
        }
    
    # Cleans and normalizes text for keyword extraction
    def clean_text(self, text: str) -> str:
        if not text:
            return ""
        
        text = text.lower()
        text = re.sub(r'\s+', ' ', text.strip())
        text = re.sub(r'\bdr\.?\s+\w+\b', '', text)
        text = re.sub(r'\b\d{1,2}[-/]\d{1,2}[-/]\d{2,4}\b', '', text)
        text = re.sub(r'\b\d{1,2}:\d{2}\b', '', text)
        text = re.sub(r'\bpatient\s+id\s*:?\s*\w+\b', '', text)
        
        return text
    
    # Extracts medical phrases using pattern matching and n-grams
    def extract_medical_phrases(self, text: str) -> Set[str]:
        phrases = set()
        
        if not text:
            return phrases
        
        for category, patterns in self.medical_patterns.items():
            for pattern in patterns:
                matches = re.findall(pattern, text, re.IGNORECASE)
                for match in matches:
                    if isinstance(match, tuple):
                        match = ' '.join(match)
                    if len(match.strip()) > 2:
                        phrases.add(match.strip().lower())
        
        sentences = sent_tokenize(text)
        for sentence in sentences:
            words = word_tokenize(sentence.lower())
            words = [w for w in words if w.isalpha() and len(w) > 2]
            
            for n in range(2, 5):
                for i in range(len(words) - n + 1):
                    phrase = ' '.join(words[i:i+n])
                    
                    if self.is_likely_medical_term(phrase):
                        phrases.add(phrase)
        
        return phrases
    
    # Determines if a phrase is likely a medical term
    def is_likely_medical_term(self, phrase: str) -> bool:
        words = phrase.split()
        
        stop_word_count = sum(1 for word in words if word in self.all_stop_words)
        if stop_word_count > len(words) // 2:
            return False
        
        if any(char.isdigit() or char in string.punctuation for char in phrase):
            return False
        
        medical_indicators = [
            'pulmonary', 'cardiac', 'thoracic', 'pleural', 'bronchial', 'vascular',
            'lobe', 'lung', 'heart', 'chest', 'rib', 'diaphragm', 'mediastinal',
            'hilar', 'basilar', 'apical', 'lateral', 'anterior', 'posterior',
            'consolidation', 'opacity', 'effusion', 'pneumonia', 'atelectasis',
            'cardiomegaly', 'edema', 'mass', 'nodule', 'lesion', 'infiltrate',
            'emphysema', 'fibrosis', 'calcification', 'pneumothorax'
        ]
        
        phrase_lower = phrase.lower()
        return any(indicator in phrase_lower for indicator in medical_indicators)
    
    # Extracts keywords from a MIMIC-CXR report file
    def extract_from_report(self, report_path: str) -> Set[str]:
        keywords = set()
        
        try:
            if not os.path.exists(report_path):
                self.processing_stats['errors'].append(f"Report not found: {report_path}")
                return keywords
            
            with open(report_path, 'r', encoding='utf-8') as f:
                content = f.read()
            
            cleaned_text = self.clean_text(content)
            
            keywords = self.extract_medical_phrases(cleaned_text)
            
            words = word_tokenize(cleaned_text)
            for word in words:
                if (len(word) > 3 and 
                    word.isalpha() and 
                    word not in self.all_stop_words and
                    self.is_likely_medical_term(word)):
                    keywords.add(word)
            
            self.processing_stats['reports_processed'] += 1
            
        except Exception as e:
            self.processing_stats['errors'].append(f"Error processing report {report_path}: {str(e)}")
        
        return keywords
    
    # Extracts keywords from a REFLACX transcript JSON file
    def extract_from_transcript(self, transcript_path: str) -> Set[str]:
        keywords = set()
        
        try:
            if not os.path.exists(transcript_path):
                self.processing_stats['errors'].append(f"Transcript not found: {transcript_path}")
                return keywords
            
            with open(transcript_path, 'r', encoding='utf-8') as f:
                transcript_data = json.load(f)
            
            text_fields = []
            if 'full_text' in transcript_data:
                text_fields.append(transcript_data['full_text'])
            
            for key, value in transcript_data.items():
                if isinstance(value, str) and len(value) > 50:
                    text_fields.append(value)
            
            for text in text_fields:
                cleaned_text = self.clean_text(text)
                keywords.update(self.extract_medical_phrases(cleaned_text))
                
                words = word_tokenize(cleaned_text)
                for word in words:
                    if (len(word) > 3 and 
                        word.isalpha() and 
                        word not in self.all_stop_words and
                        self.is_likely_medical_term(word)):
                        keywords.add(word)
            
            self.processing_stats['transcripts_processed'] += 1
            
        except Exception as e:
            self.processing_stats['errors'].append(f"Error processing transcript {transcript_path}: {str(e)}")
        
        return keywords
    
    # Processes the entire training dataset to extract keywords
    def process_dataset(self, train_csv_path: str, sample_size: int = None, base_path: str = "../../") -> Dict:
        print("Starting medical keyword extraction from training dataset...")
        
        df = pd.read_csv(train_csv_path)
        print(f"Total training samples: {len(df)}")
        
        if sample_size:
            df = df.sample(n=min(sample_size, len(df)), random_state=42)
            print(f"Processing sample of {len(df)} samples")
        
        self.processing_stats['total_samples'] = len(df)
        
        for idx, row in df.iterrows():
            if idx % 100 == 0:
                print(f"  Progress: {idx}/{len(df)} samples processed")
            
            if pd.notna(row['report_path']):
                full_report_path = base_path + row['report_path']
                report_keywords = self.extract_from_report(full_report_path)
                self.extracted_keywords['reports'].update(report_keywords)
                self.extracted_keywords['combined'].update(report_keywords)
            
            if pd.notna(row['transcript_path']):
                full_transcript_path = base_path + row['transcript_path']
                transcript_keywords = self.extract_from_transcript(full_transcript_path)
                self.extracted_keywords['transcripts'].update(transcript_keywords)
                self.extracted_keywords['combined'].update(transcript_keywords)
            
            if pd.notna(row['condition']):
                conditions = row['condition'].split('|')
                for condition in conditions:
                    condition = condition.strip().lower()
                    if condition and condition != 'no finding':
                        self.extracted_keywords['conditions'][condition] += 1
                        self.extracted_keywords['combined'][condition] += 1
        
        print(f"Processing complete!")
        return self.get_results()
    
    # Categorizes extracted keywords into medical categories
    def categorize_keywords(self) -> Dict:
        categorized = {
            'anatomical': [],
            'pathological': [],
            'descriptive': [],
            'devices': [],
            'procedures': [],
            'uncategorized': []
        }
        
        top_keywords = self.extracted_keywords['combined'].most_common(500)
        
        for keyword, count in top_keywords:
            categorized_flag = False
            
            for pattern in self.medical_patterns['anatomical']:
                if re.search(pattern, keyword, re.IGNORECASE):
                    categorized['anatomical'].append((keyword, count))
                    categorized_flag = True
                    break
            
            if categorized_flag:
                continue
            
            for pattern in self.medical_patterns['pathological']:
                if re.search(pattern, keyword, re.IGNORECASE):
                    categorized['pathological'].append((keyword, count))
                    categorized_flag = True
                    break
            
            if categorized_flag:
                continue
            
            for pattern in self.medical_patterns['descriptive']:
                if re.search(pattern, keyword, re.IGNORECASE):
                    categorized['descriptive'].append((keyword, count))
                    categorized_flag = True
                    break
            
            if categorized_flag:
                continue
            
            device_indicators = ['tube', 'line', 'catheter', 'device', 'implant', 'pacemaker', 'stent']
            if any(indicator in keyword.lower() for indicator in device_indicators):
                categorized['devices'].append((keyword, count))
                continue
            
            categorized['uncategorized'].append((keyword, count))
        
        return categorized
    
    # Gets extraction results and statistics
    def get_results(self) -> Dict:
        categorized = self.categorize_keywords()
        
        return {
            'keyword_counts': {
                'reports': len(self.extracted_keywords['reports']),
                'transcripts': len(self.extracted_keywords['transcripts']),
                'conditions': len(self.extracted_keywords['conditions']),
                'combined_unique': len(self.extracted_keywords['combined'])
            },
            'top_keywords': {
                'reports': self.extracted_keywords['reports'].most_common(50),
                'transcripts': self.extracted_keywords['transcripts'].most_common(50),
                'combined': self.extracted_keywords['combined'].most_common(100)
            },
            'categorized_keywords': categorized,
            'processing_stats': self.processing_stats
        }
    
    # Saves extraction results to files
    def save_results(self, results: Dict, output_dir: str = "keyword_extraction_results"):
        os.makedirs(output_dir, exist_ok=True)
        
        with open(f"{output_dir}/complete_results.json", 'w') as f:
            json_results = {}
            for key, value in results.items():
                if key == 'top_keywords':
                    json_results[key] = {k: list(v) for k, v in value.items()}
                elif key == 'categorized_keywords':
                    json_results[key] = {k: list(v) for k, v in value.items()}
                else:
                    json_results[key] = value
            
            json.dump(json_results, f, indent=2, ensure_ascii=False)
        
        for category, keywords in results['categorized_keywords'].items():
            with open(f"{output_dir}/{category}_keywords.txt", 'w') as f:
                f.write(f"# {category.title()} Medical Keywords\n")
                f.write(f"# Total: {len(keywords)} keywords\n\n")
                for keyword, count in keywords:
                    f.write(f"{keyword}\t{count}\n")
        
        with open(f"{output_dir}/all_medical_keywords.txt", 'w') as f:
            f.write("# All Medical Keywords Extracted from Training Dataset\n")
            f.write(f"# Total unique keywords: {results['keyword_counts']['combined_unique']}\n\n")
            for keyword, count in results['top_keywords']['combined']:
                f.write(f"{keyword}\n")
        
        print(f"Results saved to: {output_dir}/")

def main():
    """Main execution function"""
    print("🏥 MEDICAL KEYWORD EXTRACTION PIPELINE")
    print("=" * 60)
    
    # Initialize extractor
    extractor = MedicalKeywordExtractor()
    
    # Process dataset
    train_csv = "../../dataset_splits/train.csv"
    
    # Ask user for sample size
    sample_size = input("Enter sample size (or press Enter for full dataset): ").strip()
    if sample_size:
        try:
            sample_size = int(sample_size)
        except ValueError:
            print("Invalid input, using full dataset")
            sample_size = None
    else:
        sample_size = None
    
    # Extract keywords
    results = extractor.process_dataset(train_csv, sample_size=sample_size)
    
    print("\nKEYWORD EXTRACTION RESULTS")
    print("=" * 40)
    print(f"Reports processed: {results['processing_stats']['reports_processed']}")
    print(f"Transcripts processed: {results['processing_stats']['transcripts_processed']}")
    print(f"Total unique keywords: {results['keyword_counts']['combined_unique']}")
    print(f"Errors encountered: {len(results['processing_stats']['errors'])}")
    
    print(f"\nTop 20 Most Common Keywords:")
    for keyword, count in results['top_keywords']['combined'][:20]:
        print(f"  {keyword:<30} {count:>5}")
    
    print(f"\nKeywords by Category:")
    for category, keywords in results['categorized_keywords'].items():
        print(f"  {category.title():<15} {len(keywords):>5} keywords")
    
    extractor.save_results(results)
    
    print(f"\nKeyword extraction complete!")
    print(f"Check 'keyword_extraction_results/' for detailed output")

if __name__ == "__main__":
    main() 