import pandas as pd
import dill
import numpy as np

def med_process(med_file):
    """
    :param med_file: prescription med file from MIMIC
    :return: preliminary processed med file (drop unnecessary columns and duplicate rows)
    """
    med_pd = pd.read_csv(med_file, dtype={"NDC": "category"})
    med_pd.drop(
        columns=[
            "ROW_ID",
            "DRUG_TYPE",
            "DRUG_NAME_POE",
            "DRUG_NAME_GENERIC",
            "FORMULARY_DRUG_CD",
            "PROD_STRENGTH",
            "DOSE_VAL_RX",
            "DOSE_UNIT_RX",
            "FORM_VAL_DISP",
            "FORM_UNIT_DISP",
            "GSN",
            "FORM_UNIT_DISP",
            "ROUTE",
            "ENDDATE",
        ],
        axis=1,
        inplace=True,
    )
    med_pd.drop(index=med_pd[med_pd["NDC"] == "0"].index, axis=0, inplace=True)
    med_pd.fillna(method="pad", inplace=True)
    med_pd.dropna(inplace=True)
    med_pd.drop_duplicates(inplace=True)
    med_pd["ICUSTAY_ID"] = med_pd["ICUSTAY_ID"].astype("int64")
    med_pd["STARTDATE"] = pd.to_datetime(
        med_pd["STARTDATE"], format="%Y-%m-%d %H:%M:%S"
    )
    med_pd.sort_values(
        by=["SUBJECT_ID", "HADM_ID", "ICUSTAY_ID", "STARTDATE"], inplace=True
    )
    med_pd = med_pd.reset_index(drop=True)

    med_pd = med_pd.drop(columns=["ICUSTAY_ID"])
    med_pd = med_pd.drop_duplicates()
    med_pd = med_pd.reset_index(drop=True)

    return med_pd


def filter_most_med(med_pd):
    med_count = (
        med_pd.groupby(by=["ATC4"])
        .size()
        .reset_index()
        .rename(columns={0: "count"})
        .sort_values(by=["count"], ascending=False)
        .reset_index(drop=True)
    )
    med_pd = med_pd[med_pd["ATC4"].isin(med_count.loc[:120, "ATC4"])]

    return med_pd.reset_index(drop=True)


def codeMapping2atc4(med_pd):
    """
    :param med_pd: preliminary processed med file with code in NDC format
    :return: med file with code in ATC4 format
    """
    with open(ndc2RXCUI_file, "r") as f:
        ndc2RXCUI = eval(f.read())
    med_pd["RXCUI"] = med_pd["NDC"].map(ndc2RXCUI)
    med_pd.dropna(inplace=True)

    RXCUI2atc4 = pd.read_csv(RXCUI2atc4_file)
    RXCUI2atc4 = RXCUI2atc4.drop(columns=["YEAR", "MONTH", "NDC"])
    RXCUI2atc4.drop_duplicates(subset=["RXCUI"], inplace=True)
    med_pd.drop(index=med_pd[med_pd["RXCUI"].isin([""])].index, axis=0, inplace=True)

    med_pd["RXCUI"] = med_pd["RXCUI"].astype("int64")
    med_pd = med_pd.reset_index(drop=True)
    med_pd = med_pd.merge(RXCUI2atc4, on=["RXCUI"])
    med_pd.drop(columns=["NDC", "RXCUI"], inplace=True)
    med_pd = med_pd.drop_duplicates()
    med_pd = med_pd.reset_index(drop=True)
    return med_pd


def process_visit_lg2(med_pd):
    a = (
        med_pd[["SUBJECT_ID", "HADM_ID"]]
        .groupby(by="SUBJECT_ID")["HADM_ID"]
        .unique()
        .reset_index()
    )
    a["HADM_ID_Len"] = a["HADM_ID"].map(lambda x: len(x))
    a = a[a["HADM_ID_Len"] > 1]
    return a


def diag_process(diag_file):
    diag_pd = pd.read_csv(diag_file)
    diag_pd.dropna(inplace=True)
    diag_pd.drop(columns=["SEQ_NUM", "ROW_ID"], inplace=True)
    diag_pd.drop_duplicates(inplace=True)
    diag_pd.sort_values(by=["SUBJECT_ID", "HADM_ID"], inplace=True)
    diag_pd = diag_pd.reset_index(drop=True)

    def filter_2000_most_diag(diag_pd):
        diag_count = (
            diag_pd.groupby(by=["ICD9_CODE"])
            .size()
            .reset_index()
            .rename(columns={0: "count"})
            .sort_values(by=["count"], ascending=False)
            .reset_index(drop=True)
        )
        diag_pd = diag_pd[diag_pd["ICD9_CODE"].isin(diag_count.loc[:1999, "ICD9_CODE"])]
        return diag_pd.reset_index(drop=True)

    diag_pd = filter_2000_most_diag(diag_pd)
    return diag_pd


def procedure_process(procedure_file):
    pro_pd = pd.read_csv(procedure_file, dtype={"ICD9_CODE": "category"})
    pro_pd.drop(columns=["ROW_ID"], inplace=True)
    pro_pd.drop_duplicates(inplace=True)
    pro_pd.sort_values(by=["SUBJECT_ID", "HADM_ID", "SEQ_NUM"], inplace=True)
    pro_pd.drop(columns=["SEQ_NUM"], inplace=True)
    pro_pd.drop_duplicates(inplace=True)
    pro_pd.reset_index(drop=True, inplace=True)

    return pro_pd


def combine_process(med_pd, diag_pd, pro_pd):
    med_pd_key = med_pd[["SUBJECT_ID", "HADM_ID"]].drop_duplicates()
    diag_pd_key = diag_pd[["SUBJECT_ID", "HADM_ID"]].drop_duplicates()
    pro_pd_key = pro_pd[["SUBJECT_ID", "HADM_ID"]].drop_duplicates()

    combined_key = med_pd_key.merge(
        diag_pd_key, on=["SUBJECT_ID", "HADM_ID"], how="inner"
    )
    combined_key = combined_key.merge(
        pro_pd_key, on=["SUBJECT_ID", "HADM_ID"], how="inner"
    )

    diag_pd = diag_pd.merge(combined_key, on=["SUBJECT_ID", "HADM_ID"], how="inner")
    med_pd = med_pd.merge(combined_key, on=["SUBJECT_ID", "HADM_ID"], how="inner")
    pro_pd = pro_pd.merge(combined_key, on=["SUBJECT_ID", "HADM_ID"], how="inner")

    # flatten and merge
    diag_pd = (
        diag_pd.groupby(by=["SUBJECT_ID", "HADM_ID"])["ICD9_CODE"]
        .unique()
        .reset_index()
    )
    med_pd = med_pd.groupby(by=["SUBJECT_ID", "HADM_ID"])["ATC4"].unique().reset_index()
    pro_pd = (
        pro_pd.groupby(by=["SUBJECT_ID", "HADM_ID"])["ICD9_CODE"]
        .unique()
        .reset_index()
        .rename(columns={"ICD9_CODE": "PRO_CODE"})
    )
    med_pd["ATC4"] = med_pd["ATC4"].map(lambda x: list(x))
    pro_pd["PRO_CODE"] = pro_pd["PRO_CODE"].map(lambda x: list(x))
    data = diag_pd.merge(med_pd, on=["SUBJECT_ID", "HADM_ID"], how="inner")
    data = data.merge(pro_pd, on=["SUBJECT_ID", "HADM_ID"], how="inner")
    data["ATC4_num"] = data["ATC4"].map(lambda x: len(x))

    return data


# get vocabulary
class Voc(object):
    def __init__(self):
        self.idx2word = {}
        self.word2idx = {}

    def add_sentence(self, sentence):
        for word in sentence:
            if word not in self.word2idx:
                self.idx2word[len(self.word2idx)] = word
                self.word2idx[word] = len(self.word2idx)


def ATC_process(med_voc):
    all_atc4 = set(med_voc.word2idx.keys())
    atc_cid_df = pd.read_excel('./input/ATC_CID.xlsx')
    drug_info_df = pd.read_csv('./input/DrugInfo.csv')
    atc_cid_df['ATC4'] = atc_cid_df['ATC4'].str[:5]


    filtered_atc_cid = atc_cid_df[(atc_cid_df['ATC4'].isin(all_atc4)) & (atc_cid_df['CID'] != -1)]
    atc4_to_atc5_mapping = filtered_atc_cid.groupby('ATC4')['ATC5'].apply(list).to_dict()
    merged_df = filtered_atc_cid.merge(drug_info_df, on='CID', how='inner')
    atc4_to_smiles_mapping = merged_df.groupby('ATC4')['isosmiles'].apply(list).to_dict()

    valid_atc4_keys = set(atc4_to_atc5_mapping.keys()) & set(atc4_to_smiles_mapping.keys())
    filtered_atc4_to_atc5_mapping = {k: atc4_to_atc5_mapping[k] for k in valid_atc4_keys}
    filtered_atc4_to_smiles_mapping = {k: atc4_to_smiles_mapping[k] for k in valid_atc4_keys}

    final_mapping = {
        'ATC4_to_ATC5': atc4_to_atc5_mapping,
        'ATC4_to_SMILES': atc4_to_smiles_mapping
    }
    with open('./output/ATC4_mappings.pkl', 'wb') as f:
        dill.dump(final_mapping, f)
    print("get ATC4 mapping done")
    return valid_atc4_keys


def create_str_token_mapping(df):
    diag_voc = Voc()
    med_voc = Voc()
    pro_voc = Voc()
    for index, row in df.iterrows():
        diag_voc.add_sentence(row["ICD9_CODE"])
        med_voc.add_sentence(row["ATC4"])
        pro_voc.add_sentence(row["PRO_CODE"])
    valid_atc4_keys = ATC_process(med_voc)
    med_voc = filter_voc_by_valid_keys(med_voc, valid_atc4_keys)
    dill.dump(
        obj={"diag_voc": diag_voc, "med_voc": med_voc, "pro_voc": pro_voc},
        file=open(vocabulary_file, "wb"),
    )
    return diag_voc, med_voc, pro_voc, valid_atc4_keys


def filter_voc_by_valid_keys(voc, valid_keys):
    filtered_voc = Voc()
    for word in voc.word2idx:
        if word in valid_keys:
            filtered_voc.add_sentence([word])

    return filtered_voc


# create final records
def create_patient_record(df, diag_voc, med_voc, pro_voc):
    records = []  # (patient, code_kind:3, codes)  code_kind:diag, proc, med
    for subject_id in df["SUBJECT_ID"].unique():
        item_df = df[df["SUBJECT_ID"] == subject_id]
        patient = []
        for index, row in item_df.iterrows():
            admission = []
            admission.append([diag_voc.word2idx[i] for i in row["ICD9_CODE"]])
            admission.append([pro_voc.word2idx[i] for i in row["PRO_CODE"]])
            admission.append([med_voc.word2idx[i] for i in row["ATC4"]])
            patient.append(admission)
        records.append(patient)
    dill.dump(obj=records, file=open(ehr_sequence_file, "wb"))
    return records

# create ddi matrix
def create_ddi_matrix(med_voc, ddi_file):
    # get atc4_cid dataframe
    all_atc4 = set(med_voc.word2idx.keys())
    atc_cid_df = pd.read_excel('./input/ATC_CID.xlsx')
    atc_cid_df['ATC4'] = atc_cid_df['ATC4'].str[:5]
    filtered_atc_cid = atc_cid_df[(atc_cid_df['ATC4'].isin(all_atc4)) & (atc_cid_df['CID'] != -1)].reset_index(drop=True)
    atc4_cid_df = filtered_atc_cid[['ATC4', 'CID']]

    # get cid-cid interaction dataframe
    TOPK = 40
    ddi_df = pd.read_csv(ddi_file)
    ddi_most_pd = (
        ddi_df.groupby(by=["Polypharmacy Side Effect", "Side Effect Name"])
        .size()
        .reset_index()
        .rename(columns={0: "count"})
        .sort_values(by=["count"], ascending=False)
        .reset_index(drop=True)
    )
    ddi_most_pd = ddi_most_pd.iloc[-TOPK:, :]
    filter_ddi_df = ddi_df.merge(ddi_most_pd[["Side Effect Name"]], how="inner", on=["Side Effect Name"])
    ddi_df = (filter_ddi_df[["STITCH 1", "STITCH 2"]].drop_duplicates().reset_index(drop=True))
    ddi_df['CID1'] = ddi_df['STITCH 1'].str.replace('CID','').astype(int)
    ddi_df['CID2'] = ddi_df['STITCH 2'].str.replace('CID', '').astype(int)

    # get ddi_matrix
    matrix_size = len(med_voc.word2idx)
    ddi_matrix = np.zeros((matrix_size, matrix_size), dtype=int)
    atc4_cid_dict = atc4_cid_df.groupby('ATC4')['CID'].apply(list).to_dict()
    cnt = 0
    for a1, idx1 in med_voc.word2idx.items():
        for a2, idx2 in med_voc.word2idx.items():
            cids1 = set(atc4_cid_dict.get(a1, []))
            cids2 = set(atc4_cid_dict.get(a2, []))

            if not cids1 or not cids2:
                continue

            ddi_exists = ddi_df[
                (ddi_df['CID1'].isin(cids1) & ddi_df['CID2'].isin(cids2)) |
                (ddi_df['CID1'].isin(cids2) & ddi_df['CID2'].isin(cids1))
            ]

            if not ddi_exists.empty:
                ddi_matrix[idx1, idx2] = 1
                ddi_matrix[idx2, idx1] = 1
    dill.dump(obj=ddi_matrix, file=open(ddi_matrix_file, "wb"))

if __name__ == "__main__":
    # MIMIC dataset
    med_file = "./input/PRESCRIPTIONS.csv"
    diag_file = "./input/DIAGNOSES_ICD.csv"
    procedure_file = "./input/PROCEDURES_ICD.csv"

    # auxiliary files
    RXCUI2atc4_file = "./input/RXCUI2atc4.csv"
    ndc2RXCUI_file = "./input/ndc2RXCUI.txt"
    ddi_file = './input/drug-DDI.csv'

    # output files
    vocabulary_file = "./output/voc.pkl"
    ehr_sequence_file = "./output/records.pkl"
    ddi_matrix_file = "./output/ddi_matrix.pkl"

    # process of med
    med_pd = med_process(med_file)
    med_pd_lg2 = process_visit_lg2(med_pd).reset_index(drop=True)
    med_pd = med_pd.merge(
        med_pd_lg2[["SUBJECT_ID"]], on="SUBJECT_ID", how="inner"
    ).reset_index(drop=True)
    med_pd = codeMapping2atc4(med_pd)
    med_pd = filter_most_med(med_pd)
    print("complete medication processing")

    # process of diagnosis
    diag_pd = diag_process(diag_file)
    print("complete diagnosis processing")

    # process procedure
    pro_pd = procedure_process(procedure_file)
    print("complete procedure processing")

    # combine
    data = combine_process(med_pd, diag_pd, pro_pd)
    print("complete combining")

    # create vocab
    diag_voc, med_voc, pro_voc, valid_atc4 = create_str_token_mapping(data)
    print("obtain voc")

    create_ddi_matrix(med_voc, ddi_file)
    print("obtain ddi")
    data = data[data["ATC4"].apply(lambda atc_list: all(atc in valid_atc4 for atc in atc_list))]

    # create ehr sequence data
    records = create_patient_record(data, diag_voc, med_voc, pro_voc)
    print("obtain ehr sequence data")
