# mimic_processor.py 
import pandas as pd
import numpy as np
import os
import gc
import psutil
from typing import Dict, List, Tuple, Optional, Set
import logging
from pathlib import Path
import re
from tqdm import tqdm
import pickle
from collections import Counter, defaultdict
import string
from datetime import datetime

logger = logging.getLogger(__name__)

class MIMICProcessor:
    
    
    def __init__(self, mimic_root_path: str):
        self.mimic_root = Path(mimic_root_path)
        self.data_cache = {}
        self.processed_patients = {}
        
      
        self.max_memory_gb = 8.0
        self.chunk_size = 5000
        
        
        self.learned_patterns = {}
        self.discovered_categories = {}
        self.data_relationships = defaultdict(list)
        self.frequency_maps = {}
        
        
        self.cache_dir = self.mimic_root / "processing_cache"
        self.cache_dir.mkdir(parents=True, exist_ok=True)

        print(f" MIMIC Processor initialized")
        print(f"    Root: {self.mimic_root}")
        print(f"    Cache: {self.cache_dir}")
    
    def discover_mimic_files(self) -> Dict[str, Path]:
        
        
        print(" Discovering MIMIC file structure...")
        discovered_files = {}
        
       
        expected_tables = [
            'ADMISSIONS', 'PATIENTS', 'DIAGNOSES_ICD', 'PROCEDURES_ICD',
            'PRESCRIPTIONS', 'CHARTEVENTS', 'LABEVENTS', 'NOTEEVENTS',
            'ICUSTAYS', 'TRANSFERS', 'SERVICES', 'CALLOUT'
        ]
        
        for table_name in expected_tables:
            print(f"    Looking for {table_name}...")
            
            
            possible_paths = [
                
                self.mimic_root / f"{table_name}.csv",
                
                self.mimic_root / f"{table_name}.csv" / f"{table_name}.csv",
                
                self.mimic_root / f"{table_name}.csv",
            ]
            
            found = False
            for path in possible_paths:
                if path.exists() and path.is_file():
                    try:
                        
                        test_df = pd.read_csv(path, nrows=5)
                        if len(test_df) > 0:
                            discovered_files[table_name] = {
                                'type': 'single_csv',
                                'path': path,
                                'size_mb': path.stat().st_size / 1024 / 1024
                            }
                            print(f"    {table_name}: Found at {path}")
                            found = True
                            break
                    except Exception as e:
                        continue
            
            
            if not found and table_name in ['DIAGNOSES_ICD', 'PROCEDURES_ICD']:
                folder_path = self.mimic_root / f"{table_name}.csv"
                if folder_path.exists() and folder_path.is_dir():
                    csv_files = list(folder_path.glob("*.csv"))
                    if not csv_files:
                        csv_files = list(folder_path.glob("*")) 
                    
                    if csv_files:
                        discovered_files[table_name] = {
                            'type': 'individual_files',
                            'path': folder_path,
                            'file_count': len(csv_files),
                            'sample_files': [f.name for f in csv_files[:5]]
                        }
                        print(f"   {table_name}: Folder with {len(csv_files)} files")
                        found = True
            
            if not found:
                print(f"    {table_name}: Not found")
        
        print(f" Discovery complete: Found {len(discovered_files)} tables")
        return discovered_files
    
    def load_mimic_table_smart(self, table_name: str, table_info: Dict, 
                              sample_size: Optional[int] = None) -> pd.DataFrame:
        
        
        if table_info['type'] == 'single_csv':
            return self._load_single_csv_file(table_name, table_info['path'], sample_size)
        elif table_info['type'] == 'individual_files':
            return self._load_individual_files_folder(table_name, table_info['path'], 
                                                    table_info['file_count'], sample_size)
        else:
            print(f"    Unknown table type: {table_info['type']}")
            return pd.DataFrame()
    
    def _load_single_csv_file(self, table_name: str, file_path: Path, 
                             sample_size: Optional[int]) -> pd.DataFrame:
        
        
        file_size_mb = file_path.stat().st_size / 1024 / 1024
        print(f"    Loading {table_name}: {file_size_mb:.1f}MB")
        
        try:
            if file_size_mb < 500:  
                df = pd.read_csv(file_path, low_memory=False)
                print(f"    Loaded: {len(df)} rows, {len(df.columns)} columns")
                
                if sample_size and len(df) > sample_size:
                    df = df.sample(n=sample_size, random_state=42)
                    print(f"  Sampled to: {len(df)} rows")
                
                return df
            
            else:  
                print(f"  Large file - reading in chunks...")
                chunks = []
                rows_collected = 0
                target_rows = sample_size or 10000
                
                chunk_iter = pd.read_csv(file_path, chunksize=self.chunk_size, low_memory=False)
                
                for chunk in chunk_iter:
                    if len(chunk) > 0:
                        chunks.append(chunk)
                        rows_collected += len(chunk)
                        
                        if rows_collected >= target_rows:
                            break
                    
                    if len(chunks) % 20 == 0:
                        self._clear_memory_if_needed()
                
                if chunks:
                    result = pd.concat(chunks, ignore_index=True)
                    print(f"    Loaded: {len(result)} rows from chunks")
                    return result
                else:
                    print(f"    No data loaded from chunks")
                    return pd.DataFrame()
        
        except Exception as e:
            print(f"    Error loading {table_name}: {e}")
            return pd.DataFrame()
    
    def _load_individual_files_folder(self, table_name: str, folder_path: Path, 
                                     file_count: int, sample_size: Optional[int]) -> pd.DataFrame:
        
        
        print(f"    Loading {table_name}: {file_count} individual files")
        
        try:
            all_files = list(folder_path.iterdir())
            csv_files = [f for f in all_files if f.is_file()]
            
            if not csv_files:
                print(f"    No files found in {folder_path}")
                return pd.DataFrame()
            
           
            files_to_process = csv_files
            if sample_size and len(csv_files) > sample_size:
                import random
                random.seed(42)
                files_to_process = random.sample(csv_files, min(sample_size, len(csv_files)))
            
            
            return self._combine_individual_csv_files(files_to_process, table_name)
            
        except Exception as e:
            print(f"    Error loading {table_name}: {e}")
            return pd.DataFrame()
    
    def _combine_individual_csv_files(self, csv_files: List[Path], table_name: str) -> pd.DataFrame:
        
        
        print(f"  Combining {len(csv_files)} files...")
        
        
        strategy1_result = self._try_normal_csv_reading(csv_files, table_name)
        if not strategy1_result.empty:
            return strategy1_result
        
        
        strategy2_result = self._try_filename_as_data(csv_files, table_name)
        if not strategy2_result.empty:
            return strategy2_result
        
       
        strategy3_result = self._try_file_content_reading(csv_files, table_name)
        if not strategy3_result.empty:
            return strategy3_result
        
        print(f"    All strategies failed for {table_name}")
        return pd.DataFrame()
    
    def _try_normal_csv_reading(self, csv_files: List[Path], table_name: str) -> pd.DataFrame:
        
        
        print(f"    Strategy 1: Reading as normal CSV files...")
        
        successful_dfs = []
        
        for file_path in csv_files[:100]:  
            try:
                df = pd.read_csv(file_path, low_memory=False)
                if len(df) > 0:
                    successful_dfs.append(df)
                    
                if len(successful_dfs) >= 50:  
                    break
                    
            except Exception:
                continue
        
        if successful_dfs:
            result = pd.concat(successful_dfs, ignore_index=True)
            print(f"   Strategy 1 success: {len(result)} rows from {len(successful_dfs)} files")
            return result
        
        print(f"    Strategy 1 failed")
        return pd.DataFrame()
    
    def _try_filename_as_data(self, csv_files: List[Path], table_name: str) -> pd.DataFrame:
       
        
        print(f"    Strategy 2: Parsing filenames as data...")
        
        
        if table_name == 'DIAGNOSES_ICD':
            headers = ['ROW_ID', 'SUBJECT_ID', 'HADM_ID', 'SEQ_NUM', 'ICD9_CODE']
        elif table_name == 'PROCEDURES_ICD':
            headers = ['ROW_ID', 'SUBJECT_ID', 'HADM_ID', 'SEQ_NUM', 'ICD9_CODE']
        else:
            headers = ['COL_0', 'COL_1', 'COL_2', 'COL_3', 'COL_4']
        
        parsed_rows = []
        
        for file_path in csv_files[:1000]:  
            filename = file_path.name
            
            
            if ',' in filename or '_' in filename:
                
                name_without_ext = filename.replace('.csv', '').replace('.txt', '')
                
               
                if ',' in name_without_ext:
                    parts = name_without_ext.split(',')
                else:
                    
                    parts = name_without_ext.split('_')
                
               
                clean_parts = []
                for part in parts:
                    clean_part = part.strip()
                    if clean_part and clean_part.replace('.', '').replace('-', '').isalnum():
                        clean_parts.append(clean_part)
                
                if len(clean_parts) >= 3: 
                    while len(clean_parts) < len(headers):
                        clean_parts.append('')
                    clean_parts = clean_parts[:len(headers)]
                    
                    row_dict = dict(zip(headers, clean_parts))
                    parsed_rows.append(row_dict)
        
        if parsed_rows:
            result = pd.DataFrame(parsed_rows)
            print(f"    Strategy 2 success: {len(result)} rows from filenames")
            return result
        
        print(f"    Strategy 2 failed")
        return pd.DataFrame()
    
    def _try_file_content_reading(self, csv_files: List[Path], table_name: str) -> pd.DataFrame:
       
        
        print(f"   Strategy 3: Reading file contents...")
        
        content_rows = []
        
        for file_path in csv_files[:500]:  
            try:
                
                with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
                    content = f.read().strip()
                
                if content and len(content) > 0:
                    
                    if ',' in content:
                        parts = content.split(',')
                        clean_parts = [part.strip() for part in parts if part.strip()]
                        
                        if len(clean_parts) >= 3:
                            content_rows.append(clean_parts[:5])  
                            
            except Exception:
                continue
        
        if content_rows:
            
            max_cols = max(len(row) for row in content_rows)
            headers = [f'COL_{i}' for i in range(max_cols)]
            
            
            padded_rows = []
            for row in content_rows:
                padded_row = row + [''] * (max_cols - len(row))
                padded_rows.append(padded_row)
            
            result = pd.DataFrame(padded_rows, columns=headers)
            print(f"   Strategy 3 success: {len(result)} rows from file contents")
            return result
        
        print(f"   Strategy 3 failed")
        return pd.DataFrame()
    
    def _clear_memory_if_needed(self):
        
        current_memory = self._get_memory_usage()
        if current_memory > self.max_memory_gb:
            print(f"    Memory cleanup: {current_memory:.1f}GB -> ", end="")
            self.data_cache.clear()
            gc.collect()
            new_memory = self._get_memory_usage()
            print(f"{new_memory:.1f}GB")
    
    def _get_memory_usage(self) -> float:
        
        process = psutil.Process(os.getpid())
        return process.memory_info().rss / 1024 / 1024 / 1024
    
    def load_core_tables(self, sample_size: int = 40000) -> Dict[str, pd.DataFrame]:
        
        return self.load_mimic_tables(sample_size)
    
    def load_mimic_tables(self, sample_size: int = 40000) -> Dict[str, pd.DataFrame]:
        
        
        print(f" LOADING MIMIC-III DATA")
        print(f" Target sample size: {sample_size}")
        print("="*60)
        
        
        mimic_files = self.discover_mimic_files()
        
        if not mimic_files:
            raise ValueError(" No MIMIC tables found! Check directory structure.")
        
        tables = {}
        
        
        priority_order = [
            'PATIENTS', 'ADMISSIONS', 'DIAGNOSES_ICD', 'PROCEDURES_ICD',
            'PRESCRIPTIONS', 'ICUSTAYS', 'TRANSFERS',
            'CHARTEVENTS', 'LABEVENTS', 'NOTEEVENTS'
        ]
        
        for table_name in priority_order:
            if table_name in mimic_files:
                print(f"\n Processing {table_name}...")
                
                try:
                    df = self.load_mimic_table_smart(
                        table_name, 
                        mimic_files[table_name],
                        sample_size
                    )
                    
                    if not df.empty:
                        tables[table_name] = df
                        print(f"    {table_name}: {len(df):,} rows, {len(df.columns)} columns")
                        
                        
                        if len(df) > 0:
                            print(f"    Sample columns: {list(df.columns)[:5]}")
                    else:
                        print(f"    {table_name}: No data loaded")
                        
                except Exception as e:
                    print(f"    {table_name} failed: {e}")
                    continue
                
                self._clear_memory_if_needed()
        
        
        for table_name, table_info in mimic_files.items():
            if table_name not in tables:
                print(f"\n Processing additional table: {table_name}...")
                try:
                    df = self.load_mimic_table_smart(table_name, table_info, sample_size)
                    if not df.empty:
                        tables[table_name] = df
                        print(f"    {table_name}: {len(df):,} rows")
                except Exception as e:
                    print(f"    {table_name} failed: {e}")
        
        print(f"\n MIMIC LOADING COMPLETE!")
        print(f" Successfully loaded {len(tables)} tables:")
        total_rows = 0
        for table_name, df in tables.items():
            rows = len(df)
            total_rows += rows
            print(f"    {table_name}: {rows:,} rows")
        
        print(f"Total data: {total_rows:,} rows across all tables")
        return tables
    
    def extract_patient_cases(self, tables: Dict[str, pd.DataFrame], 
                             num_patients: int = 40000) -> List[Dict]:
        
        
        print(f" EXTRACTING {num_patients} PATIENT CASES FROM MIMIC DATA")
        print("="*60)
        
       
        required_tables = ['PATIENTS', 'ADMISSIONS']
        available_tables = list(tables.keys())
        
        print(f" Available tables: {available_tables}")
        
        if not any(table in available_tables for table in required_tables):
            print(f" No required tables found. Creating cases from available data...")
            return self._extract_cases_from_available_tables(tables, num_patients)
        
        
        if 'PATIENTS' in tables and 'ADMISSIONS' in tables:
            return self._extract_standard_cases(tables, num_patients)
        else:
            return self._extract_cases_from_available_tables(tables, num_patients)
    
    def _extract_standard_cases(self, tables: Dict[str, pd.DataFrame], 
                               num_patients: int) -> List[Dict]:
       
        
        patients_df = tables['PATIENTS']
        admissions_df = tables['ADMISSIONS']
        
        print(f" Standard extraction:")
        print(f"    Patients: {len(patients_df):,}")
        print(f"    Admissions: {len(admissions_df):,}")
        
       
        if len(patients_df) > num_patients:
            sampled_patients = patients_df.sample(n=num_patients, random_state=42)
        else:
            sampled_patients = patients_df
        
        patient_cases = []
        
        for _, patient_row in tqdm(sampled_patients.iterrows(), 
                                  desc="Processing patients", 
                                  total=len(sampled_patients)):
            
            
            subject_id = patient_row.get('SUBJECT_ID', patient_row.iloc[0])
            
            
            patient_admissions = admissions_df[admissions_df['SUBJECT_ID'] == subject_id] if 'SUBJECT_ID' in admissions_df.columns else pd.DataFrame()
            
            if len(patient_admissions) == 0:
               
                case = self._create_basic_case_from_patient(patient_row, None, tables)
            else:
                admission = patient_admissions.iloc[0]
                case = self._create_case_from_patient_admission(patient_row, admission, tables)
            
            if case:
                patient_cases.append(case)
        
        print(f" Extracted {len(patient_cases)} standard cases")
        return patient_cases
    
    def _extract_cases_from_available_tables(self, tables: Dict[str, pd.DataFrame], 
                                           num_patients: int) -> List[Dict]:
        
        
        print(f" Flexible extraction from available tables...")
        
        patient_cases = []
        
       
        largest_table_name = max(tables.keys(), key=lambda k: len(tables[k]))
        base_table = tables[largest_table_name]
        
        print(f"    Using {largest_table_name} as base table ({len(base_table)} rows)")
        
        
        if len(base_table) > num_patients:
            sampled_rows = base_table.sample(n=num_patients, random_state=42)
        else:
            sampled_rows = base_table
        
        for i, (_, row) in enumerate(tqdm(sampled_rows.iterrows(), 
                                         desc="Creating cases", 
                                         total=len(sampled_rows))):
            
            case = self._create_case_from_any_table_row(row, tables, i)
            if case:
                patient_cases.append(case)
        
        print(f" Extracted {len(patient_cases)} flexible cases")
        return patient_cases
    
    def _create_case_from_patient_admission(self, patient_row: pd.Series, 
                                          admission_row: pd.Series, 
                                          tables: Dict[str, pd.DataFrame]) -> Optional[Dict]:
        
        
        subject_id = patient_row.get('SUBJECT_ID', str(patient_row.iloc[0]))
        hadm_id = admission_row.get('HADM_ID', str(admission_row.iloc[0]))
        
        case = {
            'patient_id': f"MIMIC_{subject_id}",
            'subject_id': subject_id,
            'hadm_id': hadm_id,
            'demographics': {
                'gender': str(patient_row.get('GENDER', 'Unknown')),
                'dob': str(patient_row.get('DOB', 'Unknown')),
                'dod': str(patient_row.get('DOD', '')) if pd.notna(patient_row.get('DOD')) else None
            },
            'admission_info': {
                'admittime': str(admission_row.get('ADMITTIME', 'Unknown')),
                'dischtime': str(admission_row.get('DISCHTIME', 'Unknown')),
                'admission_type': str(admission_row.get('ADMISSION_TYPE', 'Unknown')),
                'diagnosis': str(admission_row.get('DIAGNOSIS', 'Unknown'))
            },
            'diagnoses': [],
            'procedures': [],
            'medications': [],
            'clinical_notes': [],
            'lab_values': [],
            'chart_events': []
        }
        
       
        case = self._enrich_case_with_table_data(case, tables, subject_id, hadm_id)
        
        return case
    
    def _create_basic_case_from_patient(self, patient_row: pd.Series, 
                                      admission_row: Optional[pd.Series], 
                                      tables: Dict[str, pd.DataFrame]) -> Optional[Dict]:
       
        
        subject_id = str(patient_row.iloc[0]) if len(patient_row) > 0 else f"UNKNOWN_{np.random.randint(10000, 99999)}"
        
        case = {
            'patient_id': f"MIMIC_{subject_id}",
            'subject_id': subject_id,
            'hadm_id': f"ADM_{subject_id}",
            'demographics': {
                'gender': str(patient_row.get('GENDER', patient_row.iloc[1] if len(patient_row) > 1 else 'Unknown')),
                'dob': str(patient_row.get('DOB', patient_row.iloc[2] if len(patient_row) > 2 else 'Unknown')),
                'dod': None
            },
            'admission_info': {
                'admittime': 'Unknown',
                'dischtime': 'Unknown',
                'admission_type': 'Unknown',
                'diagnosis': 'Unknown'
            },
            'diagnoses': [],
            'procedures': [],
            'medications': [],
            'clinical_notes': [],
            'lab_values': [],
            'chart_events': []
        }
        
        
        case = self._enrich_case_with_table_data(case, tables, subject_id, None)
        
        return case
    
    def _create_case_from_any_table_row(self, row: pd.Series, 
                                       tables: Dict[str, pd.DataFrame], 
                                       index: int) -> Optional[Dict]:
        
        subject_id = None
        hadm_id = None
        
        for col in row.index:
            if 'SUBJECT' in str(col).upper():
                subject_id = str(row[col])
            elif 'HADM' in str(col).upper():
                hadm_id = str(row[col])
        
        if not subject_id:
            subject_id = f"FLEX_{index}"
        if not hadm_id:
            hadm_id = f"ADM_{index}"
        
        case = {
            'patient_id': f"MIMIC_{subject_id}",
            'subject_id': subject_id,
            'hadm_id': hadm_id,
            'demographics': {
                'gender': 'Unknown',
                'dob': 'Unknown',
                'dod': None
            },
            'admission_info': {
                'admittime': 'Unknown',
                'dischtime': 'Unknown',
                'admission_type': 'Unknown',
                'diagnosis': 'Unknown'
            },
            'diagnoses': [],
            'procedures': [],
            'medications': [],
            'clinical_notes': [],
            'lab_values': [],
            'chart_events': []
        }
        
  
        case = self._extract_data_from_row(case, row)
        
  
        case = self._enrich_case_with_table_data(case, tables, subject_id, hadm_id)
        
        return case
    
    def _extract_data_from_row(self, case: Dict, row: pd.Series) -> Dict:
      
        
        row_data = []
        for col, value in row.items():
            if pd.notna(value) and str(value).strip() and str(value) != 'nan':
                clean_value = str(value).strip()
                if len(clean_value) > 1:
                    row_data.append(clean_value)
        
       
        for item in row_data[:10]: 
            item_lower = item.lower()
            
            
            if any(term in item_lower for term in ['icd', 'diagnosis', 'disease']):
                case['diagnoses'].append(item)
            elif any(term in item_lower for term in ['procedure', 'surgery', 'operation']):
                case['procedures'].append(item)
            elif any(term in item_lower for term in ['drug', 'medication', 'medicine']):
                case['medications'].append(item)
            else:
              
                case['clinical_notes'].append(item)
        
       
        if case['diagnoses']:
            case['symptoms'] = ', '.join(case['diagnoses'][:3])
        else:
            case['symptoms'] = 'clinical presentation from medical data'
        
        
        case['known_info'] = case['diagnoses'][:2] + case['procedures'][:2]
        
        
        case['temporal_sequence'] = ['presentation'] + case['procedures'][:3] + ['diagnosis']
        
        return case
    
    def _enrich_case_with_table_data(self, case: Dict, tables: Dict[str, pd.DataFrame], 
                                    subject_id: str, hadm_id: Optional[str]) -> Dict:
      
        for table_name, table_df in tables.items():
            try:
                
                matching_rows = pd.DataFrame()
                
               
                if 'SUBJECT_ID' in table_df.columns and subject_id:
                    matching_rows = table_df[table_df['SUBJECT_ID'].astype(str) == str(subject_id)]
                elif 'HADM_ID' in table_df.columns and hadm_id:
                    matching_rows = table_df[table_df['HADM_ID'].astype(str) == str(hadm_id)]
                
                if len(matching_rows) > 0:
                    
                    if 'DIAGNOSES' in table_name.upper():
                        if 'ICD9_CODE' in matching_rows.columns:
                            new_diagnoses = matching_rows['ICD9_CODE'].dropna().astype(str).tolist()
                            case['diagnoses'].extend(new_diagnoses[:5])
                    
                    elif 'PROCEDURES' in table_name.upper():
                        if 'ICD9_CODE' in matching_rows.columns:
                            new_procedures = matching_rows['ICD9_CODE'].dropna().astype(str).tolist()
                            case['procedures'].extend(new_procedures[:5])
                    
                    elif 'PRESCRIPTIONS' in table_name.upper():
                        if 'DRUG' in matching_rows.columns:
                            new_medications = matching_rows['DRUG'].dropna().astype(str).tolist()
                            case['medications'].extend(new_medications[:5])
                    
                    elif 'NOTEEVENTS' in table_name.upper():
                        if 'TEXT' in matching_rows.columns:
                            note_texts = matching_rows['TEXT'].dropna().astype(str).tolist()
                            for text in note_texts[:3]:
                                if len(text) > 10:
                                   
                                    sentences = text.split('.')[:2]
                                    case['clinical_notes'].extend([s.strip() for s in sentences if s.strip()])
                    
                    elif 'LABEVENTS' in table_name.upper():
                        if 'ITEMID' in matching_rows.columns and 'VALUE' in matching_rows.columns:
                            for _, lab_row in matching_rows.head(5).iterrows():
                                item_id = lab_row.get('ITEMID', 'Unknown')
                                value = lab_row.get('VALUE', 'Unknown')
                                case['lab_values'].append(f"Lab_{item_id}_{value}")
            
            except Exception as e:
               
                continue
        
       
        self._finalize_case_attributes(case)
        
        return case
    
    def _finalize_case_attributes(self, case: Dict):
        
        case['diagnoses'] = list(set([str(d) for d in case['diagnoses'] if str(d) != 'nan' and str(d).strip()]))[:10]
        case['procedures'] = list(set([str(p) for p in case['procedures'] if str(p) != 'nan' and str(p).strip()]))[:10]
        case['medications'] = list(set([str(m) for m in case['medications'] if str(m) != 'nan' and str(m).strip()]))[:10]
        case['clinical_notes'] = [str(n) for n in case['clinical_notes'] if str(n) != 'nan' and len(str(n)) > 5][:10]
        
    
        if not case.get('symptoms') or case['symptoms'] == 'clinical presentation from medical data':
            symptom_sources = []
            if case['diagnoses']:
                symptom_sources.extend(case['diagnoses'][:2])
            if case['clinical_notes']:
                
                for note in case['clinical_notes'][:2]:
                    note_lower = str(note).lower()
                    if any(word in note_lower for word in ['pain', 'fever', 'cough', 'nausea', 'breath']):
                        symptom_sources.append('documented symptoms')
                        break
            
            case['symptoms'] = ', '.join(symptom_sources) if symptom_sources else 'clinical presentation'
        
        
        known_info = []
        known_info.extend(case['diagnoses'][:3])
        known_info.extend(case['procedures'][:2])
        case['known_info'] = [str(k) for k in known_info if str(k).strip()]
        
      
        temporal_sequence = ['initial_presentation']
        if case['procedures']:
            temporal_sequence.extend(case['procedures'][:3])
        if case['diagnoses']:
            temporal_sequence.append('diagnosis_established')
        if case['medications']:
            temporal_sequence.append('treatment_initiated')
        
        case['temporal_sequence'] = temporal_sequence
    
    def create_expert_knowledge_from_mimic(self, patient_cases: List[Dict]) -> List[Dict]:
        
        
        print(f" Creating expert knowledge from {len(patient_cases)} MIMIC cases...")
        
       
        all_diagnoses = []
        all_procedures = []
        all_medications = []
        
        for case in patient_cases:
            all_diagnoses.extend(case.get('diagnoses', []))
            all_procedures.extend(case.get('procedures', []))
            all_medications.extend(case.get('medications', []))
        
       
        diagnosis_counts = Counter(all_diagnoses)
        procedure_counts = Counter(all_procedures)
        medication_counts = Counter(all_medications)
        
        
        experts = []
        
        
        common_diagnoses = [d for d, c in diagnosis_counts.most_common(20)]
        common_procedures = [p for p, c in procedure_counts.most_common(15)]
        
        expert1 = {
            'expert_id': 'mimic_emergency_physician',
            'domain_expertise': {
                'emergency_medicine': 0.9,
                'general_medicine': 0.8,
                'critical_care': 0.7
            },
            'known_concepts': set(common_diagnoses[:10] + common_procedures[:10]),
            'unknown_concepts': set(['long_term_outcomes', 'rehabilitation', 'family_counseling']),
            'reasoning_patterns': {},
            'temporal_understanding': {}
        }
        experts.append(expert1)
        
        
        specialist_procedures = [p for p, c in procedure_counts.most_common(30)[10:20]]
        specialist_diagnoses = [d for d, c in diagnosis_counts.most_common(30)[10:20]]
        
        expert2 = {
            'expert_id': 'mimic_specialist_physician',
            'domain_expertise': {
                'cardiology': 0.8,
                'pulmonology': 0.7,
                'general_medicine': 0.6
            },
            'known_concepts': set(specialist_procedures + specialist_diagnoses),
            'unknown_concepts': set(['emergency_protocols', 'trauma_management']),
            'reasoning_patterns': {},
            'temporal_understanding': {}
        }
        experts.append(expert2)
        
        
        common_medications = [m for m, c in medication_counts.most_common(15)]
        
        expert3 = {
            'expert_id': 'mimic_internist',
            'domain_expertise': {
                'internal_medicine': 0.9,
                'medication_management': 0.8,
                'chronic_disease': 0.7
            },
            'known_concepts': set(common_medications + common_diagnoses[10:20]),
            'unknown_concepts': set(['surgical_procedures', 'emergency_interventions']),
            'reasoning_patterns': {},
            'temporal_understanding': {}
        }
        experts.append(expert3)
        
        print(f"Created {len(experts)} expert profiles from MIMIC data patterns")
        print(f"   Based on {len(diagnosis_counts)} unique diagnoses")
        print(f"   Based on {len(procedure_counts)} unique procedures")
        print(f"   Based on {len(medication_counts)} unique medications")
        
        return experts


def save_preprocessed_mimic_data(processor, tables, patient_cases, experts, save_path="preprocessed_mimic_data.pkl"):
    
    
    print(f"\n Saving preprocessed MIMIC data to {save_path}...")
    
    data_package = {
        'patient_cases': patient_cases,
        'experts': experts,
        'tables_info': {name: len(df) for name, df in tables.items()},
        'processor_config': {
            'mimic_root': str(processor.mimic_root),
            'total_patients': len(patient_cases),
            'total_experts': len(experts),
            'sample_size': len(patient_cases)
        },
        'preprocessing_stats': {
            'total_diagnoses': sum(len(case.get('diagnoses', [])) for case in patient_cases),
            'total_procedures': sum(len(case.get('procedures', [])) for case in patient_cases),
            'total_medications': sum(len(case.get('medications', [])) for case in patient_cases),
            'total_clinical_notes': sum(len(case.get('clinical_notes', [])) for case in patient_cases)
        },
        'timestamp': datetime.now().isoformat(),
        'version': '1.0'
    }
    
    try:
        with open(save_path, 'wb') as f:
            pickle.dump(data_package, f, protocol=pickle.HIGHEST_PROTOCOL)
        
        file_size_mb = os.path.getsize(save_path) / 1024 / 1024
        
        print(f" Successfully saved preprocessed MIMIC data!")
        print(f"    File: {save_path}")
        print(f"    Size: {file_size_mb:.1f} MB")
        print(f"    Patients: {len(patient_cases)}")
        print(f"    Experts: {len(experts)}")
        print(f"    Timestamp: {data_package['timestamp']}")
        
        return save_path
        
    except Exception as e:
        print(f" Failed to save preprocessed data: {e}")
        return None

def test_mimic_processor_with_save():
    
    
    
    MIMIC_PATH = r"#input path"
    
    print(" TESTING MIMIC PROCESSOR WITH SAVE FUNCTIONALITY")
    print("="*60)
    
    try:
        processor = MIMICProcessor(MIMIC_PATH)
        
        print(" Testing file discovery...")
        files = processor.discover_mimic_files()
        print(f"   Found {len(files)} table structures")
        
        if not files:
            print(" No files found! Check your MIMIC path.")
            return None
        
        print("\n Testing table loading...")
        tables = processor.load_mimic_tables(sample_size=40000)
        
        if not tables:
            print(" No tables loaded!")
            return None
        
        print(f"    Loaded {len(tables)} tables")
        
        print("\n Testing patient case extraction...")
        patient_cases = processor.extract_patient_cases(tables, num_patients=40000)
        
        if not patient_cases:
            print(" No patient cases extracted!")
            return None
        
        print(f"    Extracted {len(patient_cases)} patient cases")
        
        print("\n Testing expert knowledge creation...")
        experts = processor.create_expert_knowledge_from_mimic(patient_cases)
        print(f"    Created {len(experts)} expert profiles")
        
       
        print("\n Saving preprocessed data for fast training...")
        save_path = save_preprocessed_mimic_data(processor, tables, patient_cases, experts)
        
        if save_path:
            print(f"\n  Preprocessed data saved to: {save_path}")
            print(f" Now you can run fast training without reprocessing!")
            print(f" Next step: Run working_main_system.py for instant training!")
        
        return processor, tables, patient_cases, experts, save_path
        
    except Exception as e:
        print(f" Error: {e}")
        import traceback
        traceback.print_exc()
        return None


def test_mimic_processor():
    
    MIMIC_PATH = r"C:\Users\Pranjali Talanki\Desktop\mcw_smartheath\mimic-iii-clinical-database-1.4"
    
    print(" TESTING FIXED MIMIC PROCESSOR")
    print("="*50)
    
    try:
        processor = MIMICProcessor(MIMIC_PATH)
        
        print(" Testing file discovery...")
        files = processor.discover_mimic_files()
        print(f"    Found {len(files)} table structures")
        
        if not files:
            print(" No files found! Check your MIMIC path.")
            return None
        
        print("\n Testing table loading...")
        tables = processor.load_mimic_tables(sample_size=5000)  
        
        if not tables:
            print(" No tables loaded!")
            return None
        
        print(f"   Loaded {len(tables)} tables")
        for table_name, df in tables.items():
            print(f"      - {table_name}: {len(df)} rows, {len(df.columns)} columns")
        
        print("\n Testing patient case extraction...")
        patient_cases = processor.extract_patient_cases(tables, num_patients=50)
        
        if not patient_cases:
            print(" No patient cases extracted!")
            return None
        
        print(f"   Extracted {len(patient_cases)} patient cases")
        
        
        if patient_cases:
            sample = patient_cases[0]
            print(f"\n📋 SAMPLE EXTRACTED CASE:")
            print(f"   Patient ID: {sample['patient_id']}")
            print(f"   Diagnosis: {sample['admission_info']['diagnosis']}")
            print(f"   Symptoms: {sample['symptoms']}")
            print(f"   Diagnoses: {len(sample['diagnoses'])} found")
            print(f"   Procedures: {len(sample['procedures'])} found")
            print(f"   Medications: {len(sample['medications'])} found")
            print(f"   Clinical notes: {len(sample['clinical_notes'])} found")
        
        print("\n Testing expert knowledge creation...")
        experts = processor.create_expert_knowledge_from_mimic(patient_cases)
        print(f"   Created {len(experts)} expert profiles")
        
        print("\n FIXED MIMIC PROCESSOR TEST COMPLETE!")
        print(f"Successfully processed REAL MIMIC data:")
        print(f"   Tables loaded: {len(tables)}")
        print(f"   Patient cases: {len(patient_cases)}")
        print(f"   Expert profiles: {len(experts)}")
        
        return processor, tables, patient_cases, experts
        
    except Exception as e:
        print(f" Error: {e}")
        import traceback
        traceback.print_exc()
        return None

if __name__ == "__main__":
    
    test_mimic_processor_with_save()
