import pandas as pd
from functools import cmp_to_key
import numpy as np

class Event:
    def __init__(self, event_time, event_type):
        self.event_time = event_time
        self.event_type = event_type
        
    def to_dict(self):
        event_dict = {}
        for attr, value in self.__dict__.items():
            if isinstance(value, list):
                event_dict[attr] = [item.to_dict() for item in value]
            else:
                event_dict[attr] = value
        return event_dict
    
    @staticmethod
    def compare_event_by_time(event1, event2):
        if str(event1.event_time) < str(event2.event_time):
            return -1
        elif str(event1.event_time) > str(event2.event_time):
            return 1
        else:
            return 0
    
    @staticmethod
    def compare_event_time(event_time1, event_time2):
        if str(event_time1) < str(event_time2):
            return -1
        elif str(event_time1) > str(event_time2):
            return 1
        else:
            return 0
        
    @staticmethod
    def sort_event_list(event_list):
        event_list.sort(key=cmp_to_key(Event.compare_event_by_time))
        
    @staticmethod
    def get_event_list(conn, hadm_id):
        event_list = []
        event_list.extend(AdmissionEvent.get_admission_event_list(conn, hadm_id))
        event_list.extend(DischargeEvent.get_discharge_event_list(conn, hadm_id))
        event_list.extend(TransferEvent.get_transfer_event_list(conn, hadm_id))
        event_list.extend(ServiceEvent.get_service_event_list(conn, hadm_id))
        event_list.extend(EmarEvent.get_emar_event_list(conn, hadm_id))
        event_list.extend(PrescriptionEvent.get_prescription_event_list(conn, hadm_id))
        event_list.extend(RadiologyEvent.get_radiology_event_list(conn, hadm_id))
        event_list.extend(MicrobiologyEvent.get_microbiology_event_list(conn, hadm_id))
        event_list.extend(LabEvent.get_lab_event_list(conn, hadm_id))
        event_list.extend(ProcedureEvent.get_procedure_event_list(conn, hadm_id))
        
        Event.sort_event_list(event_list)
        
        earliest_event_time = event_list[0].event_time
        latest_event_time = event_list[-1].event_time
        
        omr_event_list = OmrEvent.get_omr_event_list(conn, hadm_id)
        omr_event_list = OmrEvent.filter_omr_event_list(omr_event_list, earliest_event_time, latest_event_time)
        event_list.extend(omr_event_list)
        
        Event.sort_event_list(event_list)
        
        return event_list
    
    @staticmethod
    def _from_dict(event_dict, SubClass):
        event = SubClass(
            event_time=event_dict['event_time'],
            **{k: v for k, v in event_dict.items() if k != 'event_time'}
        )
        return event
        
    
    @staticmethod
    def from_dict(event_dict):
        subclass_names = {sub_class.__name__: sub_class for sub_class in Event.__subclasses__()}
        event_type = event_dict['event_type']
        event_dict.pop('event_type')
        SubClass = subclass_names[event_type]
        event = SubClass._from_dict(event_dict, SubClass)
        return event
        
        
class AdmissionEvent(Event):
    def __init__(self,
                 event_time,
                 admission_type,
                 admission_location,
    ):
        super().__init__(event_time, 'AdmissionEvent')
        self.admission_type = admission_type
        self.admission_location = admission_location
        
    @staticmethod
    def get_admission_event_list(conn, hadm_id):
        query = f"""
        SELECT 
        
        hadm_id,
        admittime,
        dischtime,
        deathtime,
        admission_type,
        admit_provider_id,
        admission_location,
        discharge_location,
        insurance,
        language,
        marital_status,
        race,
        edregtime,
        edouttime,
        hospital_expire_flag
        
        FROM mimiciv_hosp.admissions
        WHERE hadm_id = {hadm_id}
        """
        
        df = pd.read_sql(query, conn)
        df = df.replace({np.nan: None})
        assert len(df) <= 1, f"Expected 0 or 1 row, got {len(df)}"
        if len(df) == 0:
            return []
        admission_event_info = df.iloc[0]
        admission_event = AdmissionEvent(
            event_time=admission_event_info['admittime'],
            admission_type=admission_event_info['admission_type'],
            admission_location=admission_event_info['admission_location'],
        )
        return [admission_event]
        
class DischargeEvent(Event):
    def __init__(self,
                 event_time,
                 discharge_location,
    ):
        super().__init__(event_time, 'DischargeEvent')
        self.discharge_location = discharge_location
        
    @staticmethod
    def get_discharge_event_list(conn, hadm_id):
        query = f"""
        SELECT 
        
        hadm_id,
        admittime,
        dischtime,
        deathtime,
        admission_type,
        admit_provider_id,
        admission_location,
        discharge_location,
        insurance,
        language,
        marital_status,
        race,
        edregtime,
        edouttime,
        hospital_expire_flag
        
        FROM mimiciv_hosp.admissions
        WHERE hadm_id = {hadm_id}
        """
        
        df = pd.read_sql(query, conn)
        df = df.replace({np.nan: None})
        assert len(df) <=1 , f"Expected 0 or 1 row, got {len(df)}"
        if len(df) == 0:
            return []
        admission_event_info = df.iloc[0]
        discharge_event = DischargeEvent(
            event_time=admission_event_info['dischtime'],
            discharge_location=admission_event_info['discharge_location'],
        )
        return [discharge_event]

class TransferEvent(Event):
    def __init__(self,
                 event_time,
                 transfer_location,
    ):
        super().__init__(event_time, 'TransferEvent')
        self.transfer_location = transfer_location
        
    @staticmethod
    def get_transfer_event_list(conn, hadm_id):
        query = f"""
        SELECT 
        
        transfer_id,
        eventtype,
        careunit,
        intime,
        outtime
        
        FROM mimiciv_hosp.transfers
        WHERE hadm_id = {hadm_id}
        """
        
        df = pd.read_sql(query, conn)
        transfer_event_list = []
        for _, transfer_event_info in df.iterrows():
            transfer_event = TransferEvent(
                event_time=transfer_event_info['intime'],
                transfer_location=transfer_event_info['careunit'],
            )
            transfer_event_list.append(transfer_event)
        return transfer_event_list
    
class ServiceEvent(Event):
    def __init__(self,
                 event_time,
                 prev_service,
                curr_service,
    ):
        super().__init__(event_time, 'ServiceEvent')
        self.prev_service = prev_service
        self.curr_service = curr_service
        
    @staticmethod
    def get_service_event_list(conn, hadm_id):
        query = f"""
        SELECT 
        
        transfertime,
        prev_service,
        curr_service
        
        FROM mimiciv_hosp.services
        WHERE hadm_id = {hadm_id}
        """
        
        df = pd.read_sql(query, conn)
        df = df.replace({np.nan: None})
        service_event_list = []
        for _, service_event_info in df.iterrows():
            service_event = ServiceEvent(
                event_time=service_event_info['transfertime'],
                prev_service=service_event_info['prev_service'],
                curr_service=service_event_info['curr_service'],
            )
            service_event_list.append(service_event)
        return service_event_list
    
class OmrEvent(Event):
    def __init__(self,
                 event_time,
                 seq_num,
                 result_name,
                 result_value,
    ):
        super().__init__(event_time, 'OmrEvent')
        self.seq_num = seq_num
        self.result_name = result_name
        self.result_value = result_value
        
    @staticmethod
    def get_omr_event_list(conn, hadm_id):
        query = f"""
        SELECT 
        
        chartdate,
        seq_num,
        result_name,
        result_value
        
        FROM mimiciv_hosp.omr
        JOIN mimiciv_hosp.admissions
        ON mimiciv_hosp.omr.subject_id = mimiciv_hosp.admissions.subject_id
        WHERE mimiciv_hosp.admissions.hadm_id = {hadm_id}
        """
        
        df = pd.read_sql(query, conn)
        df = df.replace({np.nan: None})
        omr_event_list = []
        for _, omr_event_info in df.iterrows():
            omr_event = OmrEvent(
                event_time=omr_event_info['chartdate'],
                seq_num=omr_event_info['seq_num'],
                result_name=omr_event_info['result_name'],
                result_value=omr_event_info['result_value'],
            )
            omr_event_list.append(omr_event)
        return omr_event_list
    
    @staticmethod
    def filter_omr_event_list(omr_event_list, earliest_event_time, latest_event_time):
        filtered_omr_event_list = []
        for omr_event in omr_event_list:
            if Event.compare_event_time(omr_event.event_time, earliest_event_time) >= 0 and \
               Event.compare_event_time(omr_event.event_time, latest_event_time) <= 0:
                filtered_omr_event_list.append(omr_event)
        return filtered_omr_event_list

class EmarEntry:
    def __init__(self,
                 emar_id,
                 emar_seq,
                 medication,
                 event_txt,
    ):
        self.emar_id = emar_id
        self.emar_seq = emar_seq
        self.medication = medication
        self.event_txt = event_txt
        
    def to_dict(self):
        return self.__dict__
        
        

class EmarEvent(Event):
    def __init__(self,
                 event_time,
    ):
        super().__init__(event_time, 'EmarEvent')
        self.emar_list = []
        
    @staticmethod
    def get_emar_event_list(conn, hadm_id):
        query = f"""
        SELECT 
        
        emar_id,
        emar_seq,
        poe_id,
        pharmacy_id,
        enter_provider_id,
        charttime,
        medication,
        event_txt,
        scheduletime,
        storetime
        
        FROM mimiciv_hosp.emar
        WHERE hadm_id = {hadm_id}
        """
        
        df = pd.read_sql(query, conn)
        df = df.replace({np.nan: None})
        emar_event_dict = {}
        for _, emar_event_info in df.iterrows():
            if emar_event_info['charttime'] not in emar_event_dict:
                emar_event_dict[emar_event_info['charttime']] = EmarEvent(
                    event_time=emar_event_info['charttime'],
                )
            emar_event_dict[emar_event_info['charttime']].emar_list.append(
                EmarEntry(
                    emar_id=emar_event_info['emar_id'],
                    emar_seq=emar_event_info['emar_seq'],
                    medication=emar_event_info['medication'],
                    event_txt=emar_event_info['event_txt'],
                )
            )
        emar_event_list = list(emar_event_dict.values())
        return emar_event_list
    
    @staticmethod
    def _from_dict(event_dict, SubClass):
        emar_list = event_dict.pop('emar_list')
        event = SubClass(**event_dict)
        event.emar_list = [EmarEntry(**entry) for entry in emar_list]
        return event
        
    
class PrescriptionEntry:
    def __init__(self,
                 starttime,
                 stoptime,
                 drug_type,
                 drug,
                 formulary_drug_cd,
                 gsn,
                 ndc,
                 prod_strength,
                 form_rx,
                 dose_val_rx,
                 dose_unit_rx,
                 form_val_disp,
                 form_unit_disp,
                 doses_per_24_hrs,
                 route,
    ):
        self.starttime = starttime
        self.stoptime = stoptime
        self.drug_type = drug_type
        self.drug = drug
        self.formulary_drug_cd = formulary_drug_cd
        self.gsn = gsn
        self.ndc = ndc
        self.prod_strength = prod_strength
        self.form_rx = form_rx
        self.dose_val_rx = dose_val_rx
        self.dose_unit_rx = dose_unit_rx
        self.form_val_disp = form_val_disp
        self.form_unit_disp = form_unit_disp
        self.doses_per_24_hrs = doses_per_24_hrs
        self.route = route
        
    def to_dict(self):
        return self.__dict__
    
class PrescriptionEvent(Event):
    def __init__(self,
                 event_time,
    ):
        super().__init__(event_time, 'PrescriptionEvent')
        self.prescription_list = []
        
    @staticmethod
    def get_prescription_event_list(conn, hadm_id):
        query = f"""
        SELECT
        
        pharmacy_id,
        poe_id,
        poe_seq,
        order_provider_id,
        starttime,
        stoptime,
        drug_type,
        drug,
        formulary_drug_cd,
        gsn,
        ndc,
        prod_strength,
        form_rx,
        dose_val_rx,
        dose_unit_rx,
        form_val_disp,
        form_unit_disp,
        doses_per_24_hrs,
        route
        
        FROM mimiciv_hosp.prescriptions
        WHERE hadm_id = {hadm_id}
        """
        
        df = pd.read_sql(query, conn)
        df = df.replace({np.nan: None})
        prescription_event_dict = {}
        for _, prescription_event_info in df.iterrows():
            if prescription_event_info['starttime'] not in prescription_event_dict:
                prescription_event_dict[prescription_event_info['starttime']] = PrescriptionEvent(
                    event_time=prescription_event_info['starttime'],
                )
            prescription_event_dict[prescription_event_info['starttime']].prescription_list.append(
                PrescriptionEntry(
                    starttime=prescription_event_info['starttime'],
                    stoptime=prescription_event_info['stoptime'],
                    drug_type=prescription_event_info['drug_type'],
                    drug=prescription_event_info['drug'],
                    formulary_drug_cd=prescription_event_info['formulary_drug_cd'],
                    gsn=prescription_event_info['gsn'],
                    ndc=prescription_event_info['ndc'],
                    prod_strength=prescription_event_info['prod_strength'],
                    form_rx=prescription_event_info['form_rx'],
                    dose_val_rx=prescription_event_info['dose_val_rx'],
                    dose_unit_rx=prescription_event_info['dose_unit_rx'],
                    form_val_disp=prescription_event_info['form_val_disp'],
                    form_unit_disp=prescription_event_info['form_unit_disp'],
                    doses_per_24_hrs=prescription_event_info['doses_per_24_hrs'],
                    route=prescription_event_info['route'],
                )
            )
        prescription_event_list = list(prescription_event_dict.values())
        return prescription_event_list
    
    @staticmethod
    def _from_dict(event_dict, SubClass):
        presciption_list = event_dict.pop('prescription_list')
        event = SubClass(**event_dict)
        event.prescription_list = [PrescriptionEntry(**entry) for entry in presciption_list]
        return event
    
class RadiologyEvent(Event):
    def __init__(self,
                 event_time,
                 note_id,
                 note_type,
                 note_seq,
                 text,
                 modality=None,
    ):
        super().__init__(event_time, 'RadiologyEvent')
        self.note_id = note_id
        self.note_type = note_type
        self.note_seq = note_seq
        self.text = text
        if modality is not None:
            self.modality = modality
        
    @staticmethod
    def get_radiology_event_list(conn, hadm_id):
        query = f"""
        SELECT 
        
        note_id,
        note_type,
        note_seq,
        charttime,
        storetime,
        text
        
        FROM mimiciv_note.radiology
        WHERE hadm_id = {hadm_id}
        """
        
        df = pd.read_sql(query, conn)
        df = df.replace({np.nan: None})
        radiology_event_list = []
        for _, radiology_event_info in df.iterrows():
            radiology_event = RadiologyEvent(
                event_time=radiology_event_info['charttime'],
                note_id=radiology_event_info['note_id'],
                note_type=radiology_event_info['note_type'],
                note_seq=radiology_event_info['note_seq'],
                text=radiology_event_info['text'],
            )
            radiology_event_list.append(radiology_event)
        return radiology_event_list
    
    def add_modality(self, modality):
        self.modality = modality
    
class MicrobiologyEntry:
    def __init__(self,
                 micro_event_id,
                 test_seq,
                 org_itemid,
                 org_name,
                 isolate_num,
                 quantity,
                 ab_itemid,
                 ab_name,
                 dilution_text,
                 dilution_comparison,
                 dilution_value,
                 interpretation,
                 comments,
    ):
        self.micro_event_id = micro_event_id
        self.test_seq = test_seq
        self.org_itemid = org_itemid
        self.org_name = org_name
        self.isolate_num = isolate_num
        self.quantity = quantity
        self.ab_itemid = ab_itemid
        self.ab_name = ab_name
        self.dilution_text = dilution_text
        self.dilution_comparison = dilution_comparison
        self.dilution_value = dilution_value
        self.interpretation = interpretation
        self.comments = comments
        
    def to_dict(self):
        return self.__dict__
    
class MicrobiologyEvent(Event):
    def __init__(self,
                 event_time,
                 spec_type_desc,
                 test_name,
    ):
        super().__init__(event_time, 'MicrobiologyEvent')
        self.spec_type_desc = spec_type_desc
        self.test_name = test_name
        self.microbiology_list = []
        
    @staticmethod
    def get_microbiology_event_list(conn, hadm_id):
        query = f"""
        SELECT
        
        microevent_id,
        micro_specimen_id,
        order_provider_id,
        chartdate,
        charttime,
        spec_itemid,
        spec_type_desc,
        test_seq,
        storedate,
        storetime,
        test_itemid,
        test_name,
        org_itemid,
        org_name,
        isolate_num,
        quantity,
        ab_itemid,
        ab_name,
        dilution_text,
        dilution_comparison,
        dilution_value,
        interpretation,
        comments
        
        FROM mimiciv_hosp.microbiologyevents
        WHERE hadm_id = {hadm_id}
        """
        
        df = pd.read_sql(query, conn)
        df = df.replace({np.nan: None})
        microbiology_event_dict = {}
        for _, microbiology_event_info in df.iterrows():
            handle = (
                microbiology_event_info['charttime'],
                microbiology_event_info['spec_type_desc'],
                microbiology_event_info['test_name']
            )
            if handle not in microbiology_event_dict:
                microbiology_event_dict[handle] = MicrobiologyEvent(
                    event_time=microbiology_event_info['charttime'],
                    spec_type_desc=microbiology_event_info['spec_type_desc'],
                    test_name=microbiology_event_info['test_name'],
                )
            microbiology_event_dict[handle].microbiology_list.append(
                MicrobiologyEntry(
                    micro_event_id=microbiology_event_info['microevent_id'],
                    test_seq=microbiology_event_info['test_seq'],
                    org_itemid=microbiology_event_info['org_itemid'],
                    org_name=microbiology_event_info['org_name'],
                    isolate_num=microbiology_event_info['isolate_num'],
                    quantity=microbiology_event_info['quantity'],
                    ab_itemid=microbiology_event_info['ab_itemid'],
                    ab_name=microbiology_event_info['ab_name'],
                    dilution_text=microbiology_event_info['dilution_text'],
                    dilution_comparison=microbiology_event_info['dilution_comparison'],
                    dilution_value=microbiology_event_info['dilution_value'],
                    interpretation=microbiology_event_info['interpretation'],
                    comments=microbiology_event_info['comments'],
                )
            )
        microbiology_event_list = list(microbiology_event_dict.values())
        return microbiology_event_list
    
    @staticmethod
    def _from_dict(event_dict, SubClass):
        microbiology_list = event_dict.pop('microbiology_list')
        event = SubClass(**event_dict)
        event.microbiology_list = [MicrobiologyEntry(**entry) for entry in microbiology_list]
        return event
    
class LabEntry:
    def __init__(self,
                 lab_event_id,
                 itemid,
                 value,
                 valuenum,
                 valueuom,
                 ref_range_lower,
                 ref_range_upper,
                 flag,
                 priority,
                 comments,
                 label,
    ):
        self.lab_event_id = lab_event_id
        self.itemid = itemid
        self.value = value
        self.valuenum = valuenum
        self.valueuom = valueuom
        self.ref_range_lower = ref_range_lower
        self.ref_range_upper = ref_range_upper
        self.flag = flag
        self.priority = priority
        self.comments = comments
        self.label = label
        
    def to_dict(self):
        return self.__dict__
    
class LabEvent(Event):
    def __init__(self,
                 event_time,
                 specimen_id,
                 fluid,
                 category,
    ):
        super().__init__(event_time, 'LabEvent')
        self.specimen_id = specimen_id
        self.fluid = fluid
        self.category = category
        self.lab_list = []
        
    @staticmethod
    def get_lab_event_list(conn, hadm_id):
        query = f"""
        SELECT
        
        labevent_id,
        specimen_id,
        labevents.itemid as itemid,
        order_provider_id,
        charttime,
        storetime,
        value,
        valuenum,
        valueuom,
        ref_range_lower,
        ref_range_upper,
        flag,
        priority,
        comments,
        label,
        fluid,
        category
        
        FROM mimiciv_hosp.labevents
        JOIN mimiciv_hosp.d_labitems
        ON labevents.itemid = d_labitems.itemid
        WHERE hadm_id = {hadm_id}
        """
        
        df = pd.read_sql(query, conn)
        df = df.replace({np.nan: None})
        lab_event_dict = {}
        for _, lab_event_info in df.iterrows():
            handle = (
                lab_event_info['charttime'],
                lab_event_info['specimen_id'],
                lab_event_info['fluid'],
                lab_event_info['category'],
            )
            if handle not in lab_event_dict:
                lab_event_dict[handle] = LabEvent(
                    event_time=lab_event_info['charttime'],
                    specimen_id=lab_event_info['specimen_id'],
                    fluid=lab_event_info['fluid'],
                    category=lab_event_info['category'],
                )
            lab_event_dict[handle].lab_list.append(
                LabEntry(
                    lab_event_id=lab_event_info['labevent_id'],
                    itemid=lab_event_info['itemid'],
                    value=lab_event_info['value'],
                    valuenum=lab_event_info['valuenum'],
                    valueuom=lab_event_info['valueuom'],
                    ref_range_lower=lab_event_info['ref_range_lower'],
                    ref_range_upper=lab_event_info['ref_range_upper'],
                    flag=lab_event_info['flag'],
                    priority=lab_event_info['priority'],
                    comments=lab_event_info['comments'],
                    label=lab_event_info['label'],
                )
            )
        lab_event_list = list(lab_event_dict.values())
        return lab_event_list
    
    @staticmethod
    def _from_dict(event_dict, SubClass):
        lab_list = event_dict.pop('lab_list')
        event = SubClass(**event_dict)
        event.lab_list = [LabEntry(**entry) for entry in lab_list]
        return event
    
class ProcedureEvent(Event):
    def __init__(self,
                 event_time,
                 seq_num,
                 icd_code,
                 icd_version,
                 long_title,
    ):
        super().__init__(event_time, 'ProcedureEvent')
        self.seq_num = seq_num
        self.icd_code = icd_code
        self.icd_version = icd_version
        self.long_title = long_title
        
    @staticmethod
    def get_procedure_event_list(conn, hadm_id):
        query = f"""
        SELECT
        
        seq_num,
        chartdate,
        procedures_icd.icd_code,
        procedures_icd.icd_version,
        long_title
        
        FROM mimiciv_hosp.procedures_icd
        JOIN mimiciv_hosp.d_icd_procedures
        ON procedures_icd.icd_code = d_icd_procedures.icd_code
        WHERE hadm_id = {hadm_id}
        """
        
        df = pd.read_sql(query, conn)
        df = df.replace({np.nan: None})
        procedure_event_list = []
        for _, procedure_event_info in df.iterrows():
            procedure_event = ProcedureEvent(
                event_time=procedure_event_info['chartdate'],
                seq_num=procedure_event_info['seq_num'],
                icd_code=procedure_event_info['icd_code'],
                icd_version=procedure_event_info['icd_version'],
                long_title=procedure_event_info['long_title'],
            )
            procedure_event_list.append(procedure_event)
        return procedure_event_list