# CSV parsing module for EHR datasets
# Handles parsing and standardization of medical codes from different EHR systems

import os
from datetime import datetime
from collections import OrderedDict

import pandas
import pandas as pd
import numpy as np
import pickle

from collections import defaultdict, Counter
from .disparity import describe_patient_statistics


class EHRParser:
    """
    Base class for parsing EHR datasets
    Handles common functionality for processing admissions, diagnoses, procedures, and medications
    """
    # Column names for different data types
    pid_col = 'patient_id'
    adm_id_col = 'admission_id'
    adm_time_col = 'admission_time'
    disch_time_col = 'discharge_time'
    death_time_col = 'death_time'
    cid_col = 'concept_id'

    def __init__(self, path, procedure=True, medication=True):
        """
        Initialize parser with data path and feature options
        Args:
            path: Path to raw data directory
            procedure: Whether to include procedure codes
            medication: Whether to include medication codes
        """
        self.path = path
        self.procedure = procedure
        self.medication = medication

        self.skip_pid_check = False

        # Storage for parsed data
        self.patient_admission = None
        self.admission_codes = None
        self.admission_procedures = None
        self.admission_medications = None

        # Mapping of concept types to parsing functions
        self.parse_fn = {'d': self.set_diagnosis,
                         'p': self.set_procedures,
                         'm': self.set_medications}

    # Abstract methods to be implemented by specific dataset parsers
    def set_admission(self):
        raise NotImplementedError

    def set_diagnosis(self):
        raise NotImplementedError

    def set_procedures(self):
        raise NotImplementedError

    def set_medications(self):
        raise NotImplementedError

    @staticmethod
    def to_standard_icd9(code: str):
        raise NotImplementedError

    @staticmethod
    def to_icd9_pcs(code: str):
        raise NotImplementedError

    def parse_admission(self, sorting=False):
        """
        Parse admission data from CSV files
        Creates a mapping of patients to their admission records
        """
        print('parsing the csv file of admission ...')
        filename, cols, converters = self.set_admission()
        admissions = pd.read_csv(str(os.path.join(self.path, filename)),
                                 usecols=list(cols.values()), converters=converters)
        admissions = self._after_read_admission(admissions, cols)
        
        # Use the mapped column names from cols dictionary
        admissions = admissions.sort_values([cols[self.pid_col], cols[self.adm_id_col]], ascending=True).reset_index(drop=True)
        admissions = admissions.astype({cols[self.pid_col]: int, cols[self.adm_id_col]: int})
        
        all_patients = OrderedDict()
        for i, row in admissions.iterrows():
            if i % 100 == 0:
                print('\r\t%d in %d rows' % (i + 1, len(admissions)), end='')
            pid, adm_id, adm_time = row[cols[self.pid_col]], row[cols[self.adm_id_col]], row[cols[self.adm_time_col]]
            disch_time, death_time = row[cols[self.disch_time_col]], row[cols[self.death_time_col]]
            if pid not in all_patients:
                all_patients[pid] = []
            admission = all_patients[pid]
            admission.append({self.adm_id_col: adm_id, self.adm_time_col: adm_time,
                              self.disch_time_col: disch_time, self.death_time_col: death_time})
        print('\r\t%d in %d rows' % (len(admissions), len(admissions)))

        print("Sorting admissions: ", sorting)
        # Filter patients with at least 2 admissions
        patient_admission = OrderedDict() 
        for pid, admissions in all_patients.items(): 
            if len(admissions) >= 2: 
                if sorting:
                    patient_admission[pid] = sorted(admissions, key=lambda adm: adm[self.adm_time_col])
                else:
                    patient_admission[pid] = admissions

        self.patient_admission = patient_admission

    def _after_read_admission(self, admissions, cols):
        """Hook for post-processing admission data"""
        return admissions

    def _parse_concept(self, concept_type):
        """
        Parse medical concepts (diagnoses, procedures, medications)
        Args:
            concept_type: Type of concept ('d', 'p', or 'm')
        """
        assert concept_type in self.parse_fn.keys()
        filename, cols, converters = self.parse_fn[concept_type]()
        concepts = pd.read_csv(str(os.path.join(self.path, filename)),
                               usecols=list(cols.values()), converters=converters)
        concepts = self._after_read_concepts(concepts, concept_type, cols)
        result = OrderedDict()
        for i, row in concepts.iterrows():
            if i % 100 == 0:
                print('\r\t%d in %d rows' % (i + 1, len(concepts)), end='')
            pid = row[cols[self.pid_col]]
            if self.skip_pid_check or pid in self.patient_admission:
                adm_id, code = row[cols[self.adm_id_col]], row[cols[self.cid_col]]
                if code == '' or code == '0':
                    continue
                if adm_id not in result:
                    result[adm_id] = []
                codes = result[adm_id]
                codes.append(code)
        print('\r\t%d in %d rows' % (len(concepts), len(concepts)))
        return result

    def _after_read_concepts(self, concepts, concept_type, cols):
        """Hook for post-processing concept data"""
        return concepts

    def parse_diagnoses(self):
        """Parse diagnosis codes"""
        print('parsing csv file of diagnosis ...')
        self.admission_codes = self._parse_concept('d')

    def parse_procedures(self):
        """Parse procedure codes"""
        print('parsing csv file of procedures ...')
        self.admission_procedures = self._parse_concept('p')

    def parse_medications(self):
        """Parse medication codes"""
        print('parsing csv file of medications ...')
        self.admission_medications = self._parse_concept('m')

    def calibrate_patient_by_admission(self):
        """
        Remove patients whose admissions don't have corresponding diagnosis codes [at least one adm_id of patients not shown in admission_codes]
        Ensures data consistency across different medical concepts
        """
        print('calibrating patients by admission ...')
        del_pids = []
        for pid, admissions in self.patient_admission.items():
            for admission in admissions:
                adm_id = admission[self.adm_id_col]
                if adm_id not in self.admission_codes:
                    break
            else:
                continue
            del_pids.append(pid)
        for pid in del_pids:
            admissions = self.patient_admission[pid]
            for admission in admissions:
                adm_id = admission[self.adm_id_col]
                for concepts in [self.admission_codes]:
                    if adm_id in concepts:
                        del concepts[adm_id]
            del self.patient_admission[pid]

    def _post_calibrate_other_medical_concepts(self, concept_name, adm_id_set):
        """
        Calibrate other medical concepts (procedures, medications) based on valid admission IDs
        """
        admission_concepts = getattr(self, 'admission_%s' % concept_name)
        del_adm_ids = [adm_id for adm_id in admission_concepts if adm_id not in adm_id_set]
        add_adm_ids = [adm_id for adm_id in adm_id_set if adm_id not in admission_concepts]

        for del_adm_id in del_adm_ids:
            del admission_concepts[del_adm_id]

        for add_adm_id in add_adm_ids:
            admission_concepts[add_adm_id] = []

    def calibrate_admission_by_patient(self):
        """
        Remove admission records that don't belong to any patient
        Ensures data consistency between patients and their admissions
        """
        print('calibrating admission by patients ...')
        adm_id_set = set()
        for admissions in self.patient_admission.values():
            for admission in admissions:
                adm_id_set.add(admission[self.adm_id_col])
        del_adm_ids = [adm_id for adm_id in self.admission_codes if adm_id not in adm_id_set]
        for adm_id in del_adm_ids:
            del self.admission_codes[adm_id]

        if self.procedure:
            self._post_calibrate_other_medical_concepts("procedures", adm_id_set)
        if self.medication:
            self._post_calibrate_other_medical_concepts('medications', adm_id_set)

    def sample_patients(self, sample_num, seed):
        """
        Randomly sample a subset of patients
        Args:
            sample_num: Number of patients to sample
            seed: Random seed for reproducibility
        """
        np.random.seed(seed)
        keys = list(self.patient_admission.keys())
        selected_pids = np.random.choice(keys, sample_num, False)
        self.patient_admission = {pid: self.patient_admission[pid] for pid in selected_pids}
        admission_codes = dict()
        for admissions in self.patient_admission.values():
            for admission in admissions:
                adm_id = admission[self.adm_id_col]
                admission_codes[adm_id] = self.admission_codes[adm_id]
        self.admission_codes = admission_codes

    def alignment_admissions_and_patients(self):
        """
        Align admissions across different medical concepts
        Ensures that at least one concept (diagnoses, procedures, medications) has data for each admission
        """
        print("Aligning admissions across different concepts ...")
        valid_adm_id = set(self.admission_codes.keys())
        if self.procedure:
            proc_adm_ids = set([adm_id for adm_id, codes in self.admission_procedures.items() if len(codes) > 0])
            valid_adm_id = valid_adm_id.union(proc_adm_ids)  # procedures are not required for tasks' labels
        if self.medication:
            med_adm_ids = set([adm_id for adm_id, codes in self.admission_medications.items() if len(codes) > 0])
            valid_adm_id = valid_adm_id.intersection(med_adm_ids)
        print("\tnum of total admission: ", len(self.admission_codes))
        print("\tnum of valid admission: ", len(valid_adm_id))
        
        # Update patient admissions to only include valid admissions
        updated_patient_admission = {}
        for pid, admissions in self.patient_admission.items():
            updated_patient_admission[pid] = [
                admission for admission in admissions if admission[self.adm_id_col] in valid_adm_id
            ]
        self.patient_admission = updated_patient_admission
        
        for pid in list(self.patient_admission.keys()):
            if len(self.patient_admission[pid]) == 1:
                valid_adm_id.remove(self.patient_admission[pid][0][self.adm_id_col])
            if len(self.patient_admission[pid]) <= 1:
                del self.patient_admission[pid]

        for adm_id in list(self.admission_codes.keys()):
            if adm_id not in valid_adm_id:
                del self.admission_codes[adm_id]
                if self.procedure:
                    del self.admission_procedures[adm_id]
                if self.medication:
                    del self.admission_medications[adm_id]
        print("\tvalid diagnosis visit num: ", len(valid_adm_id))
        if self.procedure:
            print("\tvalid procedure visit num: ", len(self.admission_procedures))
        if self.medication:
            print("\tvalid medication visit num: ", len(self.admission_medications))

    def parse(self, sample_num=None, sorting=False, alignment=True, seed=6669):
        self.parse_admission(sorting = sorting)
        self.parse_diagnoses()  # self.admission_codes
        if self.procedure:
            self.parse_procedures()  # self.admission_procedures
        if self.medication:
            self.parse_medications()  # self.admission_medications
        self.calibrate_patient_by_admission()
        self.calibrate_admission_by_patient()
        if alignment:
            self.alignment_admissions_and_patients()
        else:
            print('--'*5)
        if sample_num is not None:
            self.sample_patients(sample_num, seed)
        return self.patient_admission, (self.admission_codes, self.admission_procedures, self.admission_medications)


class Mimic3Parser(EHRParser):
    def set_admission(self):
        filename = 'ADMISSIONS.csv'
        cols = {self.pid_col: 'SUBJECT_ID', self.adm_id_col: 'HADM_ID', self.adm_time_col: 'ADMITTIME',
                self.disch_time_col: 'DISCHTIME', self.death_time_col: 'DEATHTIME'}
        converter = {
            'SUBJECT_ID': str,
            'HADM_ID': str,
            'ADMITTIME': lambda cell: datetime.strptime(str(cell), '%Y-%m-%d %H:%M:%S'),
            'DISCHTIME': lambda cell: datetime.strptime(str(cell), '%Y-%m-%d %H:%M:%S'),
            'DEATHTIME': lambda cell: datetime.strptime(str(cell), '%Y-%m-%d %H:%M:%S') if cell else None,
        }
        return filename, cols, converter

    def set_diagnosis(self):
        filename = 'DIAGNOSES_ICD.csv'
        cols = {self.pid_col: 'SUBJECT_ID', self.adm_id_col: 'HADM_ID', self.cid_col: 'ICD9_CODE'}
        converter = {'SUBJECT_ID': int, 'HADM_ID': int, 'ICD9_CODE': self.to_standard_icd9}
        return filename, cols, converter

    def set_procedures(self):
        filename = 'PROCEDURES_ICD.csv'
        cols = {self.pid_col: 'SUBJECT_ID', self.adm_id_col: 'HADM_ID', self.cid_col: 'ICD9_CODE'}
        converter = {'SUBJECT_ID': int, 'HADM_ID': int, 'ICD9_CODE': self.to_icd9_pcs}
        return filename, cols, converter

    def set_medications(self):
        filename = 'PRESCRIPTIONS.csv'
        cols = {self.pid_col: 'SUBJECT_ID', self.adm_id_col: 'HADM_ID', self.cid_col: 'NDC'}
        converter = {'SUBJECT_ID': int, 'HADM_ID': int, 'NDC': str}
        return filename, cols, converter

    @staticmethod
    def to_standard_icd9(code: str):
        code = str(code)
        if code == '':
            return code
        split_pos = 4 if code.startswith('E') else 3
        icd9_code = code[:split_pos] + '.' + code[split_pos:] if len(code) > split_pos else code
        return icd9_code

    @staticmethod
    def to_icd9_pcs(code: str):
        code = str(code)
        if code == '':
            return code
        split_pos = 2
        icd9_code = code[:split_pos] + '.' + code[split_pos:] if len(code) > split_pos else code
        return icd9_code

    def _after_read_concepts(self, concepts, concept_type, cols):

        if concept_type == 'm':
            print('\tmapping NDC to ATC codes...')
            cid_col = cols[self.cid_col]
            concepts.drop(index=concepts[concepts[cid_col] == '0'].index, axis=0, inplace=True)
            concepts.fillna(method='pad', inplace=True)
            concepts.dropna(inplace=True)
            concepts.drop_duplicates(inplace=True)
            concepts = concepts.reset_index(drop=True)
            print('\t', concepts.shape[0], concepts.columns)
            concepts = codeMapping2atc4(concepts, cid_col)
            print('\t', concepts.shape[0], concepts.columns)
            concepts = filter_300_most_med(concepts, cid_col)

        return concepts


class Mimic4Parser(EHRParser):
    def __init__(self, path, procedure=True, medication=True):
        super().__init__(path, procedure, medication)
        self.icd_ver_col = 'icd_version'
        self.icd_map = self._load_icd_map()
        self.patient_year_map = self._load_patient()

    def _load_icd_map(self):  # ICD-10 to ICD-9 map
        print('loading ICD-10 to ICD-9 map ...')
        filename = 'icd10-icd9.csv'
        cols = ['ICD10', 'ICD9']
        converters = {'ICD10': str, 'ICD9': str}
        icd_csv = pd.read_csv(os.path.join(self.path, filename), usecols=cols, converters=converters)
        icd_map = {row['ICD10']: row['ICD9'] for _, row in icd_csv.iterrows()}
        return icd_map

    def _load_patient(self):
        print('loading patients anchor year ...')
        filename = 'patients.csv'
        cols = ['subject_id', 'anchor_year', 'anchor_year_group']
        converters = {'subject_id': int, 'anchor_year': int, 'anchor_year_group': lambda cell: int(str(cell)[:4])}
        patient_csv = pandas.read_csv(os.path.join(self.path, filename), usecols=cols, converters=converters)
        # Consider anchor year as the earliest year of the range
        patient_year_map = {row['subject_id']: row['anchor_year'] - row['anchor_year_group']
                            for i, row in patient_csv.iterrows()}
        return patient_year_map

    def set_admission(self):
        filename = 'admissions.csv'
        cols = {self.pid_col: 'subject_id', self.adm_id_col: 'hadm_id', self.adm_time_col: 'admittime',
                self.disch_time_col: 'dischtime', self.death_time_col: 'deathtime'}
        converter = {
            'subject_id': str,
            'hadm_id': str,
            'admittime': lambda cell: datetime.strptime(str(cell), '%Y-%m-%d %H:%M:%S'),
            'dischtime': lambda cell: datetime.strptime(str(cell), '%Y-%m-%d %H:%M:%S'),
            'deathtime': lambda cell: datetime.strptime(str(cell), '%Y-%m-%d %H:%M:%S') if cell else None,
        }
        return filename, cols, converter

    def set_diagnosis(self):
        filename = 'diagnoses_icd.csv'
        cols = {
            self.pid_col: 'subject_id',
            self.adm_id_col: 'hadm_id',
            self.cid_col: 'icd_code',
            self.icd_ver_col: 'icd_version'
        }
        converter = {'subject_id': int, 'hadm_id': int, 'icd_code': str, 'icd_version': int}
        return filename, cols, converter

    def set_procedures(self):
        filename = 'procedures_icd.csv'
        cols = {
            self.pid_col: 'subject_id',
            self.adm_id_col: 'hadm_id',
            self.cid_col: 'icd_code',
            self.icd_ver_col: 'icd_version'
        }
        converter = {'subject_id': int, 'hadm_id': int, 'icd_code': str, 'icd_version': int}
        return filename, cols, converter

    def set_medications(self):
        filename = 'prescriptions.csv'
        cols = {self.pid_col: 'subject_id', self.adm_id_col: 'hadm_id', self.cid_col: 'ndc'}
        converter = {'subject_id': int, 'hadm_id': int, 'ndc': str}
        return filename, cols, converter

    def _after_read_admission(self, admissions, cols):
        print('\tselecting valid admission ...')
        valid_admissions = []
        n = len(admissions)
        for i, row in admissions.iterrows():
            if i % 100 == 0:
                print('\r\t\t%d in %d rows' % (i + 1, n), end='')
            pid = row[cols[self.pid_col]]
            year = row[cols[self.adm_time_col]].year - self.patient_year_map[pid]
            if year > 2012:  # Avoid the overlapping with MIMIC-III
                valid_admissions.append(i)
        print('\r\t\t%d in %d rows' % (n, n))
        print('\t\tremaining %d rows' % len(valid_admissions))
        return admissions.iloc[valid_admissions]

    def _after_read_concepts(self, concepts, concept_type, cols):

        if concept_type == 'd':
            print('\tmapping ICD-10 to ICD-9 ...')
            n = len(concepts)

            def _10to9(i, row):
                if i % 100 == 0:
                    print('\r\t\t%d in %d rows' % (i + 1, n), end='')
                cid = row[cid_col]
                if row[icd_ver_col] == 10:
                    if cid not in self.icd_map:
                        code = self.icd_map[cid + '1'] if cid + '1' in self.icd_map else ''
                    else:
                        code = self.icd_map[cid]
                    if code == 'NoDx':
                        code = ''
                else:
                    code = cid
                return self.to_standard_icd9(code)

            cid_col, icd_ver_col = cols[self.cid_col], self.icd_ver_col
            col = np.array([_10to9(i, row) for i, row in concepts.iterrows()])
            print('\r\t\t%d in %d rows' % (n, n))
            concepts[cid_col] = col

        if concept_type == 'p':
            print('\tmapping ICD-10 to ICD-9 ...')
            n = len(concepts)

            def _drop10(i, row):
                if i % 100 == 0:
                    print('\r\t\t%d in %d rows' % (i + 1, n), end='')
                cid = row[cid_col]
                if row[icd_ver_col] == 10:  # Drop ICD-10 PCS codes
                    code = ''
                else:
                    code = cid
                return self.to_icd9_pcs(code)

            cid_col, icd_ver_col = cols[self.cid_col], self.icd_ver_col
            col = np.array([_drop10(i, row) for i, row in concepts.iterrows()])
            print('\r\t\t%d in %d rows' % (n, n))
            concepts[cid_col] = col

        if concept_type == 'm':
            print('\tmapping NDC to ATC codes...')
            cid_col = cols[self.cid_col]
            concepts.drop(index=concepts[concepts[cid_col] == '0'].index, axis=0, inplace=True)
            concepts.fillna(method='pad', inplace=True)
            concepts.dropna(inplace=True)
            concepts.drop_duplicates(inplace=True)
            concepts = concepts.reset_index(drop=True)
            concepts = codeMapping2atc4(concepts, cid_col)
            concepts = filter_300_most_med(concepts, cid_col)

        return concepts

    @staticmethod
    def to_standard_icd9(code: str):
        return Mimic3Parser.to_standard_icd9(code)

    @staticmethod
    def to_icd9_pcs(code: str):
        return Mimic3Parser.to_icd9_pcs(code)


class EICUParser(EHRParser):
    def __init__(self, path):
        super().__init__(path)
        self.skip_pid_check = True

    def set_admission(self):
        filename = 'patient.csv'
        cols = {
            self.pid_col: 'patienthealthsystemstayid',
            self.adm_id_col: 'patientunitstayid',
            self.adm_time_col: 'hospitaladmitoffset',
            self.disch_time_col: 'unitdischargeoffset',
            self.death_time_col: 'unitdischargestatus'
        }
        converter = {
            'patienthealthsystemstayid': int,
            'patientunitstayid': int,
            'hospitaladmitoffset': lambda cell: -int(cell),
            'unitdischargeoffset': int,
            'unitdischargestatus': lambda cell: cell == 'Expired' if cell else False
        }
        return filename, cols, converter

    def _after_read_admission(self, admissions, cols):
        print('\tcalculating the real discharge time ...')
        admissions[self.disch_time_col] = admissions[self.adm_time_col] + admissions[self.disch_time_col]
        return admissions

    def set_diagnosis(self):
        filename = 'diagnosis.csv'
        cols = {self.pid_col: 'diagnosisid', self.adm_id_col: 'patientunitstayid', self.cid_col: 'icd9code'}
        converter = {'diagnosisid': int, 'patientunitstayid': int, 'icd9code': EICUParser.to_standard_icd9}
        return filename, cols, converter

    @staticmethod
    def to_standard_icd9(code: str):
        code = str(code)
        if code == '':
            return code
        code = code.split(',')[0]
        c = code[0].lower()
        dot = code.find('.')
        if dot == -1:
            dot = None
        if not c.isalpha():
            prefix = code[:dot]
            if len(prefix) < 3:
                code = ('%03d' % int(prefix)) + code[dot:]
            return code
        if c == 'e':
            prefix = code[1:dot]
            if len(prefix) != 3:
                return ''
        if c != 'e' or code[0] != 'v':
            return ''
        return code

    def parse_diagnoses(self):
        super().parse_diagnoses()
        t = OrderedDict.fromkeys(self.admission_codes.keys())
        for adm_id, codes in self.admission_codes.items():
            t[adm_id] = list(set(codes))
        self.admission_codes = t


def codeMapping2atc4(med_pd, cid_col):
    rxnorm2RXCUI_file = './input/rxnorm2RXCUI.txt'
    RXCUI2atc4_file = './input/RXCUI2atc4.csv'

    with open(rxnorm2RXCUI_file, 'r') as f:
        rxnorm2RXCUI = eval(f.read())
    med_pd['RXCUI'] = med_pd[cid_col].map(rxnorm2RXCUI)
    med_pd.dropna(inplace=True)

    rxnorm2atc4 = pd.read_csv(RXCUI2atc4_file)
    rxnorm2atc4 = rxnorm2atc4.drop(columns=['YEAR', 'MONTH', 'NDC'])
    rxnorm2atc4.drop_duplicates(subset=['RXCUI'], inplace=True)
    med_pd.drop(index=med_pd[med_pd['RXCUI'].isin([''])].index, axis=0, inplace=True)

    med_pd['RXCUI'] = med_pd['RXCUI'].astype('int64')
    med_pd = med_pd.reset_index(drop=True)
    med_pd = med_pd.merge(rxnorm2atc4, on=['RXCUI'])
    med_pd.drop(columns=[cid_col, 'RXCUI'], inplace=True)
    med_pd['ATC4'] = med_pd['ATC4'].map(lambda x: x[:4])
    med_pd = med_pd.rename(columns={'ATC4': cid_col})
    med_pd = med_pd.drop_duplicates()
    med_pd = med_pd.reset_index(drop=True)

    return med_pd


def filter_300_most_med(med_pd, cid_col):  # most common medications
    med_count = med_pd.groupby(by=[cid_col]).size().reset_index().rename(columns={0: 'count'}).sort_values(by=['count'],
                                                                                                           ascending=False).reset_index(
        drop=True)
    med_pd = med_pd[med_pd[cid_col].isin(med_count.loc[:299, cid_col])]

    return med_pd.reset_index(drop=True)


def parse_patient_info(raw_path, parsed_path):
    patient_info = pd.read_csv(os.path.join(raw_path, 'PATIENTS.csv'),
                               usecols=['SUBJECT_ID', 'GENDER', 'DOB'],
                               converters={'SUBJECT_ID': int, 'Gender': str, 'ADMITTIME': \
                                   lambda cell: datetime.strptime(np.str_(cell), '%Y-%m-%d %H:%M:%S')}).dropna()

    patient_info = patient_info.set_index('SUBJECT_ID').to_dict(orient='index')
    for pid, info in patient_info.items():
        if not isinstance(info.get("DOB"), pd.Timestamp):
            info["DOB"] = pd.Timestamp(info["DOB"])
    # print("Total Patients", len(patient_info))

    patient_admission = pickle.load(open(os.path.join(parsed_path, 'patient_admission.pkl'), 'rb'))
    # print("Total Patients in terms of admissions", len(patient_admission))
    patient_info = describe_patient_statistics(patient_info, patient_admission)
    # print("Total Patients in trained data", len(patient_info))

    # Parse demograpic information from admission records
    demos = pd.read_csv(
        os.path.join(raw_path, 'ADMISSIONS.csv'),
        usecols=['SUBJECT_ID', 'HADM_ID', 'ETHNICITY', 'INSURANCE', 'LANGUAGE',
                 'RELIGION', 'MARITAL_STATUS', 'DISCHTIME'],
        converters={
            'SUBJECT_ID': int,
            'HADM_ID': int,
            'ETHNICITY': str,
            'INSURANCE': str,
            'LANGUAGE': str,
            'RELIGION': str,
            'MARITAL_STATUS': str
        },
        parse_dates=['DISCHTIME']
    ).sort_values(by=['SUBJECT_ID', 'DISCHTIME'], ascending=[True, True])

    # Step 1: Construct patient_demo dictionary
    patient_demo = defaultdict(list)
    for _, row in demos.iterrows():
        patient_demo[row['SUBJECT_ID']].append({
            'HADM_ID': row['HADM_ID'],
            'ethnicity': row['ETHNICITY'],
            'insurance': row['INSURANCE'],
            'language': row['LANGUAGE'],
            'religion': row['RELIGION'],
            'marital_status': row['MARITAL_STATUS']
        })

    # Step 2: Filtering patients based on patient_admission
    patient_admission_ids = set(patient_admission.keys())
    patient_demo = {pid: admissions for pid, admissions in patient_demo.items() if
                    pid in patient_admission_ids}

    # Step 3: Assign the demographic information by the most recent
    for pid, admissions in patient_demo.items():
        patient_info[pid].update(admissions[-1])

    for pid, info in patient_info.items():
        info.pop('DOB', None)
        info.pop('HADM_ID', None)

    return patient_info
