from sacred import Experiment
from dataclasses import dataclass
import numpy as np

ex = Experiment("METER", save_git_info=False)


# Define constants
LAB_TEST_BINS = {
    "creatinine": [-np.inf, 1.2, 2.0, 3.5, 5, np.inf],
    "platelets": [-np.inf, 20, 50, 100, 150, np.inf],
    "wbc": [-np.inf, 4, 12, np.inf],
    "hb": [-np.inf, 8, 10, 12, np.inf],
    "bicarbonate": [-np.inf, 22, 29, np.inf],
    "sodium": [-np.inf, 135, 145, np.inf]
}
LABELS = {
    "creatinine": [0, 1, 2, 3, 4],
    "platelets": [4, 3, 2, 1, 0],
    "wbc": [0, 1, 2],
    "hb": [0, 1, 2, 3],
    "bicarbonate": [0, 1, 2],
    "sodium": [0, 1, 2]
}

TASK2STRING = {
    "mimiciv":{
        "creatinine": ['Creatinine'],
        "platelets": ['Platelet Count'],
        "wbc": ['WBC Count', 'White Blood Cells'],
        "hb": ['Hemoglobin'],
        "bicarbonate": ['Bicarbonate'],
        "sodium": ['Sodium'],
    },
    "eicu":{
        "creatinine": ['creatinine'],
        "platelets": ['platelets x 1000'],
        "wbc": ['WBC x 1000'],
        "hb": ['Hgb'],
        "bicarbonate": ['bicarbonate'],
        "sodium": ['sodium'],
    },
}

TASK2STRING = {key: {sub_key: [v.lower() for v in sub_values] for sub_key, sub_values in values.items()} for key, values in TASK2STRING.items()}

EHR_CONFIG = {
    "eicu": {
        "lab": {
            "table_name": "lab",
            "itemid_col": "labname",
            "value_col": "labresult",
        },
        "input": {
            "table_name": "infusiondrug",
            "itemid_col": "drugname",
            "task": [
                "norepinephrine",
                "propofol"
            ],
        },
        "med": {
            "table_name": "medication",
            "itemid_col": "drugname",
            "task": [
                "magnesium sulfate",
                "heparin",
                "potassium chloride|kcl"
            ],
        },
    },
    "mimiciv": {
        "lab": {
            "table_name": "labevents",
            "itemid_col": "itemid",
            "value_col": "valuenum",
        },
        "input": {
            "table_name": "inputevents",
            "itemid_col": "itemid",
            "task": [
                "norepinephrine",
                "propofol"
            ],
        },
        "med": {
            "table_name": "prescriptions",
            "itemid_col": "drug",
            "task": [
                "magnesium sulfate",
                "heparin",
                "potassium chloride|kcl"
            ],
        },
    },
}

@dataclass
class Task:
    name: str
    num_classes: int
    property: str

def get_task(pred_task):
    return {
        'creatinine': Task('creatinine', 5, 'multi-class'), 
        'platelets': Task('platelets', 5, 'multi-class'),
        'wbc': Task('wbc', 3, 'multi-class'),
        'hb': Task('hb', 4, 'multi-class'),
        'bicarbonate': Task('bicarbonate', 3, 'multi-class'),
        'sodium': Task('sodium', 3, 'multi-class'),
        'magnesium sulfate': Task('magnesium sulfate', 1, 'binary'),
        'heparin': Task('heparin', 1, 'binary'),
        'potassium chloride|kcl': Task('potassium chloride|kcl', 1, 'binary'),
        'norepinephrine': Task('norepinephrine', 1, 'binary'),
        'propofol': Task('propofol', 1, 'binary'),
        'heartrate': Task('heartrate', 1, 'binary'),
        'resprate': Task('resprate', 1, 'binary'),
        'morphine': Task('morphine', 1, 'binary'),
        'ondansetron': Task('ondansetron', 1, 'binary'),
        'detect': Task('detect', 1, 'binary'),
    }[pred_task]
    
    
@ex.config
def config():
    ehr = "mimiciv"
    obs_size = 12
    
    max_event_size = {
        "mimiciv": {6: 165, 12: 243, 24: 366},
        "eicu": {6: 79, 12: 114, 24: 179}
    }[ehr][obs_size]

    input_index_size = {
        "mimiciv": {6: 2216, 12: 2328, 24: 2386},
        "eicu": {6: 1328, 12: 1369, 24: 1389}
    }[ehr][obs_size]

    table_names = {
        "mimiciv": ["labevents", "inputevents", "prescriptions"],
        "eicu": ["lab", "infusiondrug", "medication"],
    }[ehr]

    table_names = table_names
    pid_column = "stay_id"
    seed = 0
    input_index_size = 28996
    embed_dim = 128
    pred_dim = 64
    max_event_token_len = 128
    dropout = 0.1
    n_heads = 4
    n_layers = 2
    max_event_size = max_event_size
    embed_list = ["input"]
    num_epochs = 50
    batch_size = 64
    lr = 5e-5
    patience = 10
    pred_tasks = ['creatinine', 'platelets', 'wbc', 'hb', 'bicarbonate', 'sodium',
                    'magnesium sulfate', 'heparin', 'potassium chloride|kcl', 
                    'norepinephrine', 'propofol']
    lab_test_bins = LAB_TEST_BINS
    lab_labels = LABELS
    task_string = TASK2STRING[ehr]
    ehr_config = EHR_CONFIG[ehr]
    predef_vocab = f"{ehr}_predef_vocab.pickle"
    col_type = f"{ehr}_col_dtype.pickle"     
    process_time = "filter"
    recovery = True
    recovery_save = True
    use_multiprocessing = True
    num_workers = 32
    input_file_name = f'{ehr}_hi_input.npy'
    type_file_name = f'{ehr}_hi_type.npy'
    time_file_name = f'{ehr}_hi_time.npy'
    real_time_file_name = f'{ehr}_hi_num_time.npy'
    reduce_vocab = False
    threshold = 0.5
    
    postprocess_steps = [1,2,3]
    real_data_root = ""
    syn_data_root = ""
    suffix = ""
    create_privacy_data = False
    sample = False
    

    
