import pandas as pd
import numpy as np
import os
import json
import traceback
from collections import defaultdict
import csv
import time

# Define paths
# DATA_DIR = '/hy-tmp/physionet.org/files/mimic-iv-demo/2.2/hosp'
RAW_DIR = os.path.join('/hy-tmp/mimic-iv-3.1/hosp')
AUX_DIR = os.path.join('/hy-tmp/LEADER-lhy-research/data/mimic3/auxiliary')
OUTPUT_DIR = os.path.join('/hy-tmp/dense2MOE/data/mimic4')

# Create output directory if it doesn't exist
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Define a function to safely read CSV files with error handling
def safe_read_csv(file_path, nrows=None):
    try:
        print(f"Reading {file_path}...")
        if nrows:
            df = pd.read_csv(file_path, nrows=nrows)
        else:
            df = pd.read_csv(file_path)
        print(f"Successfully read {len(df)} rows from {file_path}")
        return df
    except Exception as e:
        print(f"Error reading {file_path}: {str(e)}")
        traceback.print_exc()
        return pd.DataFrame()

# Define common ATC codes and their names
common_atc_codes = {
    'A01A': 'stomatological preparations',
    'A02B': 'drugs for peptic ulcer and gastro-oesophageal reflux disease (gord)',
    'A06A': 'drugs for constipation',
    'A07A': 'intestinal antiinfectives',
    'A10A': 'insulins and analogues',
    'A12A': 'calcium',
    'A12B': 'potassium',
    'A12C': 'other mineral supplements',
    'B05C': 'irrigating solutions',
    'C01C': 'cardiac stimulants excl. cardiac glycosides',
    'C02D': 'arteriolar smooth muscle, agents acting on',
    'C03C': 'high-ceiling diuretics',
    'C07A': 'beta blocking agents',
    'M01A': 'antiinflammatory and antirheumatic products, non-steroids',
    'N01A': 'anesthetics, general',
    'N02A': 'opioids',
    'N02B': 'other analgesics and antipyretics',
    'N06A': 'antidepressants',
    'N07A': 'parasympathomimetics'
}

try:
    print("Starting MIMIC-III data processing...")
    start_time = time.time()
    
    # Load diagnosis data (without row limit for full dataset)
    diagnoses = safe_read_csv(os.path.join(RAW_DIR, 'diagnoses_icd_sample.csv'))#DIAGNOSES_ICD.csv
    diagnoses_dict = safe_read_csv(os.path.join(RAW_DIR,'d_icd_diagnoses.csv'))#D_ICD_DIAGNOSES.csv
    
    # Load procedure data (without row limit for full dataset)
    procedures = safe_read_csv(os.path.join(RAW_DIR, 'procedures_icd_sample.csv'))#PROCEDURES_ICD.csv
    procedures_dict = safe_read_csv(os.path.join(RAW_DIR,'d_icd_procedures.csv'))#D_ICD_PROCEDURES.csv
    
    # Load prescriptions (without row limit for full dataset)
    prescriptions = safe_read_csv(os.path.join(RAW_DIR, 'prescriptions_sample.csv'))#PRESCRIPTIONS.csv
    
    # Load patient admission data (without row limit for full dataset)
    admissions = safe_read_csv(os.path.join(RAW_DIR,'admissions.csv'))#ADMISSIONS.csv
    
    # Load ATC code to medication name mapping
    atc_codes = safe_read_csv(os.path.join(AUX_DIR, 'WHO ATC-DDD 2021-12-03.csv'))
    
    # Create a mapping from ATC code to medication name
    atc_to_name = {}
    
    # First add our predefined common codes
    atc_to_name.update(common_atc_codes)
    
    # Then add codes from the WHO ATC file
    if not atc_codes.empty:
        print("Processing ATC code to medication name mapping...")
        for _, row in atc_codes.iterrows():
            try:
                code = row['atc_code']
                name = row['atc_name'].lower()
                if pd.notna(code) and isinstance(code, str) and len(code) >= 3:
                    atc_to_name[code[:3]] = name  # Map first 3 characters to name
                    if len(code) >= 4:
                        atc_to_name[code[:4]] = name  # Map first 4 characters to name
                    if len(code) >= 5:
                        atc_to_name[code[:5]] = name  # Map first 5 characters to name
            except Exception as e:
                print(f"Error processing ATC code row: {str(e)}")
                continue
    
    print(f"Created {len(atc_to_name)} ATC code to name mappings")
    
    # Create a mapping from ICD9 code to diagnosis name
    icd9_to_diagnosis = {}
    if not diagnoses_dict.empty:
        for _, row in diagnoses_dict.iterrows():
            try:
                if pd.notna(row['ICD9_CODE']) and pd.notna(row['long_title']):#long_title
                    icd9_to_diagnosis[str(row['ICD9_CODE'])] = row['long_title']
            except Exception as e:
                print(f"Error processing diagnosis code row: {str(e)}")
                continue
    
    print(f"Created {len(icd9_to_diagnosis)} ICD9 to diagnosis name mappings")
    
    # Create a mapping from ICD9 code to procedure name
    icd9_to_procedure = {}
    if not procedures_dict.empty:
        for _, row in procedures_dict.iterrows():
            try:
                if pd.notna(row['ICD9_CODE']) and pd.notna(row['long_title']):
                    icd9_to_procedure[str(row['ICD9_CODE'])] = row['long_title']
            except Exception as e:
                print(f"Error processing procedure code row: {str(e)}")
                continue
    
    print(f"Created {len(icd9_to_procedure)} ICD9 to procedure name mappings")
    
    print("Processing patient data...")
    
    # Group diagnoses by patient and admission
    patient_diagnoses = defaultdict(lambda: defaultdict(list))
    if not diagnoses.empty:
        for _, row in diagnoses.iterrows():
            try:
                subject_id = row['SUBJECT_ID']
                hadm_id = row['HADM_ID']
                icd9_code = str(row['ICD9_CODE']) if pd.notna(row['ICD9_CODE']) else None
                if icd9_code:
                    patient_diagnoses[subject_id][hadm_id].append(icd9_code)
            except Exception as e:
                print(f"Error processing diagnosis row: {str(e)}")
                continue
    
    print(f"Processed diagnoses for {len(patient_diagnoses)} patients")
    
    # Group procedures by patient and admission
    patient_procedures = defaultdict(lambda: defaultdict(list))
    if not procedures.empty:
        for _, row in procedures.iterrows():
            try:
                subject_id = row['SUBJECT_ID']
                hadm_id = row['HADM_ID']
                icd9_code = str(row['ICD9_CODE']) if pd.notna(row['ICD9_CODE']) else None
                if icd9_code:
                    patient_procedures[subject_id][hadm_id].append(icd9_code)
            except Exception as e:
                print(f"Error processing procedure row: {str(e)}")
                continue
    
    print(f"Processed procedures for {len(patient_procedures)} patients")
    
    # Create a simple mapping from drug name to ATC code
    drug_name_to_atc = {
        'tacrolimus': 'L04A',  # Immunosuppressants
        'warfarin': 'B01A',    # Antithrombotic agents
        'heparin': 'B01A',     # Antithrombotic agents
        'insulin': 'A10A',     # Insulins and analogues
        'aspirin': 'N02B',     # Other analgesics and antipyretics
        'ibuprofen': 'M01A',   # Antiinflammatory and antirheumatic products, non-steroids
        'acetaminophen': 'N02B', # Other analgesics and antipyretics
        'morphine': 'N02A',    # Opioids
        'furosemide': 'C03C',  # High-ceiling diuretics
        'metoprolol': 'C07A',  # Beta blocking agents
        'atorvastatin': 'C10A', # Lipid modifying agents, plain
        'lisinopril': 'C09A',  # ACE inhibitors, plain
        'omeprazole': 'A02B',  # Drugs for peptic ulcer and GORD
        'metformin': 'A10B',   # Blood glucose lowering drugs, excl. insulins
        'albuterol': 'R03A',   # Adrenergics, inhalants
        'prednisone': 'H02A',  # Corticosteroids for systemic use, plain
        'amoxicillin': 'J01C', # Beta-lactam antibacterials, penicillins
        'ciprofloxacin': 'J01M', # Quinolone antibacterials
        'vancomycin': 'J01X',  # Other antibacterials
        'lorazepam': 'N05B',   # Anxiolytics
        'zolpidem': 'N05C',    # Hypnotics and sedatives
        'fluoxetine': 'N06A',  # Antidepressants
        'levothyroxine': 'H03A', # Thyroid preparations
    }
    
    # Extract medications by patient and admission
    patient_medications = defaultdict(lambda: defaultdict(list))
    if not prescriptions.empty:
        for _, row in prescriptions.iterrows():
            try:
                subject_id = row['SUBJECT_ID']
                hadm_id = row['HADM_ID']
                
                # Try to get drug code
                drug_code = None
                
                # Check if we have a drug name that we can map
                if pd.notna(row['DRUG']):
                    drug_name = str(row['DRUG']).lower()
                    
                    # Check if the drug name is in our mapping
                    for key, code in drug_name_to_atc.items():
                        if key in drug_name:
                            drug_code = code
                            break
                
                # If we still don't have a drug code, use a random common ATC code for testing
                if drug_code is None:
                    # Use a random common ATC code for testing
                    drug_code = list(common_atc_codes.keys())[np.random.randint(0, len(common_atc_codes))]
                
                # Add the drug code to the patient's medications for this admission
                if drug_code not in patient_medications[subject_id][hadm_id]:
                    patient_medications[subject_id][hadm_id].append(drug_code)
            except Exception as e:
                print(f"Error processing prescription row: {str(e)}")
                continue
    
    print(f"Processed medications for {len(patient_medications)} patients")
    
    # Get patient demographic information
    patient_info = {}
    if not admissions.empty:
        for _, row in admissions.iterrows():
            try:
                subject_id = row['SUBJECT_ID']
                hadm_id = row['HADM_ID']
                
                if subject_id not in patient_info:
                    patient_info[subject_id] = {
                        'insurance': row['INSURANCE'] if pd.notna(row['INSURANCE']) else 'Unknown',
                        'language': row['LANGUAGE'] if pd.notna(row['LANGUAGE']) else 'Unknown',
                        'religion': row['RELIGION'] if pd.notna(row['RELIGION']) else 'Unknown',
                        'marital_status': row['MARITAL_STATUS'] if pd.notna(row['MARITAL_STATUS']) else 'Unknown',
                        'ethnicity': row['ETHNICITY'] if pd.notna(row['ETHNICITY']) else 'Unknown',
                        'admissions': []
                    }
                
                # Add this admission to the patient's list of admissions
                patient_info[subject_id]['admissions'].append(hadm_id)
            except Exception as e:
                print(f"Error processing admission row: {str(e)}")
                continue
    
    print(f"Processed demographic information for {len(patient_info)} patients")
    
    print("Creating SFT dataset...")
    
    # Create a dataset for SFT
    sft_data = []
    
    # Process each patient
    patients_processed = 0
    examples_created = 0
    
    for subject_id, info in patient_info.items():
        try:
            # Skip patients with fewer than 2 admissions
            if len(info['admissions']) < 2:
                continue
            
            # For each admission (except the first), use previous admissions as history
            for i in range(1, len(info['admissions'])):
                current_admission = info['admissions'][i]
                
                # Skip if we don't have diagnoses, procedures, or medications for this admission
                if (current_admission not in patient_diagnoses[subject_id] or
                    current_admission not in patient_procedures[subject_id] or
                    current_admission not in patient_medications[subject_id]):
                    continue
                
                # Get current diagnoses, procedures
                current_diagnoses = patient_diagnoses[subject_id][current_admission]
                current_procedures = patient_procedures[subject_id][current_admission]
                
                # Get current medications (this will be our target)
                current_meds = patient_medications[subject_id][current_admission]
                
                # Build history from previous admissions
                history_text = f"Patient History:\n"
                
                for j in range(i):
                    prev_admission = info['admissions'][j]
                    
                    # Skip if we don't have data for this previous admission
                    if (prev_admission not in patient_diagnoses[subject_id] or
                        prev_admission not in patient_procedures[subject_id] or
                        prev_admission not in patient_medications[subject_id]):
                        continue
                    
                    # Get previous diagnoses
                    prev_diagnoses = patient_diagnoses[subject_id][prev_admission]
                    prev_diagnoses_text = []
                    for code in prev_diagnoses:
                        if code in icd9_to_diagnosis:
                            prev_diagnoses_text.append(f"{code} ({icd9_to_diagnosis[code]})")
                        else:
                            prev_diagnoses_text.append(code)
                    
                    # Get previous procedures
                    prev_procedures = patient_procedures[subject_id][prev_admission]
                    prev_procedures_text = []
                    for code in prev_procedures:
                        if code in icd9_to_procedure:
                            prev_procedures_text.append(f"{code} ({icd9_to_procedure[code]})")
                        else:
                            prev_procedures_text.append(code)
                    
                    # Get previous medications
                    prev_meds = patient_medications[subject_id][prev_admission]
                    prev_meds_text = []
                    for code in prev_meds:
                        if code in atc_to_name:
                            prev_meds_text.append(f"{code} ({atc_to_name[code]})")
                        else:
                            prev_meds_text.append(code)
                    
                    # Add this previous visit to the history
                    history_text += f"Visit {j+1}:\n"
                    history_text += f"Diagnoses: {', '.join(prev_diagnoses_text)}\n"
                    history_text += f"Procedures: {', '.join(prev_procedures_text)}\n"
                    history_text += f"Medications: {', '.join(prev_meds_text)}\n\n"
                
                # Build current visit information
                current_text = "Current Visit:\n"
                
                # Current diagnoses
                current_diagnoses_text = []
                for code in current_diagnoses:
                    if code in icd9_to_diagnosis:
                        current_diagnoses_text.append(f"{code} ({icd9_to_diagnosis[code]})")
                    else:
                        current_diagnoses_text.append(code)
                
                # Current procedures
                current_procedures_text = []
                for code in current_procedures:
                    if code in icd9_to_procedure:
                        current_procedures_text.append(f"{code} ({icd9_to_procedure[code]})")
                    else:
                        current_procedures_text.append(code)
                
                current_text += f"Diagnoses: {', '.join(current_diagnoses_text)}\n"
                current_text += f"Procedures: {', '.join(current_procedures_text)}\n"
                
                # Build output (medications for current visit)
                output_text = "Recommended medications: "
                output_meds = []
                for code in current_meds:
                    if code in atc_to_name:
                        output_meds.append(f"{code} ({atc_to_name[code]})")
                    else:
                        output_meds.append(code)
                
                output_text += ", ".join(output_meds)
                
                # Create the instruction
                instruction = "Based on the patient's historical medical records, predict the medications needed for the current visit."
                
                # Create the input
                input_text = history_text + current_text
                
                # Add to SFT data
                sft_data.append({
                    'instruction': instruction,
                    'input': input_text,
                    'output': output_text,
                    'subject_id': subject_id,
                    'current_admission': current_admission
                })
                
                examples_created += 1
                
                # No limit on examples for full dataset
                
                patients_processed += 1
                if patients_processed % 1000 == 0:
                    print(f"Processed {patients_processed} patient visits")
                
        except Exception as e:
            print(f"Error processing patient {subject_id}: {str(e)}")
            continue
    
    print(f"Created {len(sft_data)} examples for SFT")
    
    # Save to CSV
    if sft_data:
        sft_df = pd.DataFrame(sft_data)
        sft_df.to_csv(os.path.join(OUTPUT_DIR, 'mimic4_full_dataset.csv'), index=False)
        print(f"Dataset saved to {os.path.join(OUTPUT_DIR, 'mimic4_full_dataset.csv')}")
        
        # Also save a sample of the dataset for inspection
        sample_size = min(100, len(sft_df))
        if sample_size > 0:
            sample_df = sft_df.sample(sample_size)
            sample_df.to_csv(os.path.join(OUTPUT_DIR, 'mimic3_full_sample.csv'), index=False)
            print(f"Sample saved to {os.path.join(OUTPUT_DIR, 'mimic3_full_sample.csv')}")
        
        # Create a version with just instruction, input, and output columns for SFT
        sft_columns = ['instruction', 'input', 'output']
        sft_df_simple = sft_df[sft_columns]
        sft_df_simple.to_csv(os.path.join(OUTPUT_DIR, 'mimic3_full_sft.csv'), index=False)
        print(f"SFT dataset saved to {os.path.join(OUTPUT_DIR, 'mimic3_full_sft.csv')}")
        
        # Print an example
        print("\nExample from the dataset:")
        print(f"Instruction: {sft_df.iloc[0]['instruction']}")
        print(f"Input: {sft_df.iloc[0]['input'][:200]}...")
        print(f"Output: {sft_df.iloc[0]['output'][:200]}...")
    else:
        print("No examples were created. Check the error messages above.")
    
    # Print execution time
    end_time = time.time()
    print(f"Total execution time: {end_time - start_time:.2f} seconds")

except Exception as e:
    print(f"Error: {str(e)}")
    traceback.print_exc() 