import numpy as np
import pandas as pd
import networkx as nx
from pgmpy.readwrite import BIFReader
rename_mapping_alarm = {
    "HISTORY": "History",
    "CVP": "Central Venous Pressure",
    "PCWP": "Pulmonary Capillary Wedge Pressure",
    "HYPOVOLEMIA": "Hypovolemia",
    "LVEDVOLUME": "Left-Ventricular End-Diastolic Volume",
    "LVFAILURE": "Left-Ventricular Failure",
    "STROKEVOLUME": "Stroke Volume",
    "ERRLOWOUTPUT": "Error Low-Voltage",
    "HRBP": "Heart Rate / Blood Pressure",
    "HREKG": "Heart Rate EKG",
    "ERRCAUTER": "Electrocauter",
    "HRSAT": "Heart Rate / O2 Saturation",
    "INSUFFANESTH": "Insufficient Anaesthesia",
    "ANAPHYLAXIS": "Anaphylaxis",
    "TPR": "Total Peripheral Resistance",
    "EXPCO2": "Expelled CO2",
    "KINKEDTUBE": "Kinked Tube",
    "MINVOL": "Minimum Volume",
    "FIO2": "Fraction Inspired O2",
    "PVSAT": "Pulmonary Venous O2 Saturation",
    "SAO2": "Arterial O2 Saturation",
    "PAP": "Pulmonary Artery Pressure",
    "PULMEMBOLUS": "Pulmonary Embolus",
    "SHUNT": "Shunt",
    "INTUBATION": "Intubation Type",
    "PRESS": "Breathing Pressure",
    "DISCONNECT": "Disconnection",
    "MINVOLSET": "Minimum Volume Set",
    "VENTMACH": "Ventilation Machine",
    "VENTTUBE": "Ventilation Tube",
    "VENTLUNG": "Lung Ventilation",
    "VENTALV": "Pulmonary Alveoli Ventilation",
    "ARTCO2": "Arterial CO2",
    "CATECHOL": "Catecholamine",
    "HR": "Heart Rate",
    "CO": "Cardiac Output",
    "BP": "Blood Pressure"
}

ils_desc = {
    "HISTORY": "Measurement variable. This variable represents whether left ventricular failure will be writen to the disease history, only determined by the diagnose variable of left ventricular failure.",
    "CVP": "Measurement variable. (Central Venous Pressure): This variable represents the pressure within the thoracic vena cava, near the patient's heart.",
    "PCWP": "Measurement variable. (Pulmonary Capillary Wedge Pressure): This variable represents the pressure in the pulmonary capillaries, reflecting the pressure in the left atrium.",
    "HYPOVOLEMIA": "Diagnostic variable. This variable indicates whether the patient is experiencing low blood volume.",
    "LVEDVOLUME": "Intermediate variable. (Left Ventricular End-Diastolic Volume), This variable represents the volume of blood in the left ventricle at the end of diastole, before the ventricle contracts, indicated by test results.",
    "LVFAILURE": "Diagnostic variable. This variable indicates whether the patient is experiencing left ventricular failure.",
    "STROKEVOLUME": "Intermediate variable. This variable represents the volume of blood pumped out of the heart with each beat.",
    "ERRLOWOUTPUT": "Sensor-fault flag (top-level troubleshooting node). True when the arterial-pressure transducer or its amplifier is detected to be in a low-output (low-voltage) error state—i.e. the monitor is malfunctioning, not the patient.",
    "HRBP": "Measurement variable. (Heart Rate Blood Pressure): Combined monitor channel that normally reports the patient’s heart-rate and blood-pressure.",
    "HREKG": "Measurement variable. (Heart Rate Electrocardiogram): This variable represents the heart rate as measured by an electrocardiogram.",
    "ERRCAUTER": "Intermediate variable and top level. (Error Cauterization):  This variable indicates whether there is an error related to a cauterization procedure.",
    "HRSAT": "Measurement variable. (Heart Rate Saturation): This variable represents the patient's oxygen saturation in relation to their heart rate.",
    "INSUFFANESTH": "Diagnostic variable. (Insufficient Anesthesia):  This variable indicates whether the patient is experiencing insufficient anesthesia during a procedure.",
    "ANAPHYLAXIS": "Diagnostic variable. This variable indicates whether the patient is experiencing an anaphylactic reaction.",
    "TPR": "Measurement variable. (Total Peripheral Resistance): This variable represents the resistance to blood flow within the circulatory system.",
    "EXPCO2": "Measurement variable. (Exhaled Carbon Dioxide):  This variable represents the amount of carbon dioxide exhaled by the patient.",
    "KINKEDTUBE": "Diagnostic variable. This variable indicates whether there is a kink in a tube, such as an endotracheal or ventilator tube, which may impede airflow.",
    "MINVOL": "Measurement variable. (Minute Volume):  This variable represents the minute volume.",
    "FIO2": "Measurement variable. (Fraction of Inspired Oxygen): This variable represents the concentration of oxygen in the inspired air, preset with fixed value.",
    "PVSAT": "Intermediate variable. Pulmonary artery oxygen saturation",
    "SAO2": "Measurement variable. (Arterial Oxygen Saturation): This variable represents the percentage of hemoglobin in the arterial blood that is saturated with oxygen.",
    "PAP": "Measurement variable. (Pulmonary Artery Pressure): This variable represents the pressure in the pulmonary artery.",
    "PULMEMBOLUS": "Diagnostic variable. (Pulmonary Embolism):  This variable indicates whether the patient is experiencing a pulmonary embolism.",
    "SHUNT": "Intermediate variable. This variable represents the presence and degree of a shunt, which is an abnormal connection between blood vessels or heart chambers, allowing blood to bypass the normal circulatory pathway.",
    "INTUBATION": "Diagnostic variable. This variable represents the placement of an endotracheal tube, with NORMAL indicating proper placement in the trachea, ESOPHAGEAL indicating accidental placement in the esophagus, and ONESIDED indicating that the tube is only in one lung.",
    "PRESS": "Measurement variable. (Pressure):  This variable represents the pressure in various parts of the respiratory system or related equipment.",
    "DISCONNECT": "Diagnostic variable. This variable indicates whether there is a disconnection in the patient's ventilator circuit, potentially causing ventilation issues.",
    "MINVOLSET": "Measurement variable and top level. Minute ventilation",
    "VENTMACH": "Intermediate variable. (Ventilator Machine):  This variable represents the operational setting of the ventilator machine, off, low or high.",
    "VENTTUBE": "Intermediate variable. (Ventilator Tube):  This variable represents the airflow through the ventilator tube.",
    "VENTLUNG": "Intermediate variable. (Ventilator Lung):  This variable indicates the state or capacity of the lungs to receive and process the air supplied by the ventilation machine.",
    "VENTALV": "Intermediate variable. (Ventilator Alveoli): Indicates the state or capacity of the alveoli (tiny air sacs in the lungs where oxygen and carbon dioxide exchange takes place).",
    "ARTCO2": "Intermediate variable. (Arterial Carbon Dioxide): This variable represents the partial pressure of carbon dioxide in the arterial blood.",
    "CATECHOL": "Intermediate variable. (Catecholamines):  This variable represents the levels of catecholamines, which are hormones released in response to stress, such as adrenaline and noradrenaline.",
    "HR": "Intermediate variable. (Heart Rate): This variable represents the patient's heart rate.",
    "CO": "Measurement variable. (Cardiac Output): This variable represents the amount of blood the heart pumps per minute.",
    "BP": "Measurement variable. (Blood Pressure): This variable represents the patient's blood pressure, which is the force exerted by the blood against the walls of the blood vessels."
}

variable_description_alarm = {
    rename_mapping_alarm[k]: v for k, v in ils_desc.items() if k in rename_mapping_alarm
}

value_mappings_alarm = {
    # ── three-level LOW / NORMAL / HIGH ─────────────────────────────
    "Central Venous Pressure": {
        "LOW": 0, "Low": 0,
        "NORMAL": 1, "Normal": 1,
        "HIGH": 2, "High": 2,
    },
    "Pulmonary Capillary Wedge Pressure": {
        "LOW": 0, "Low": 0,
        "NORMAL": 1, "Normal": 1,
        "HIGH": 2, "High": 2,
    },
    "Total Peripheral Resistance": {
        "LOW": 0, "Low": 0,
        "NORMAL": 1, "Normal": 1,
        "HIGH": 2, "High": 2,
    },
    "Blood Pressure": {
        "LOW": 0, "Low": 0,
        "NORMAL": 1, "Normal": 1,
        "HIGH": 2, "High": 2,
    },
    "Cardiac Output": {
        "LOW": 0, "Low": 0,
        "NORMAL": 1, "Normal": 1,
        "HIGH": 2, "High": 2,
    },
    "Heart Rate / Blood Pressure": {
        "LOW": 0, "Low": 0,
        "NORMAL": 1, "Normal": 1,
        "HIGH": 2, "High": 2,
    },
    "Heart Rate EKG": {
        "LOW": 0, "Low": 0,
        "NORMAL": 1, "Normal": 1,
        "HIGH": 2, "High": 2,
    },
    "Heart Rate / O2 Saturation": {
        "LOW": 0, "Low": 0,
        "NORMAL": 1, "Normal": 1,
        "HIGH": 2, "High": 2,
    },
    "Pulmonary Artery Pressure": {
        "LOW": 0, "Low": 0,
        "NORMAL": 1, "Normal": 1,
        "HIGH": 2, "High": 2,
    },
    "Arterial O2 Saturation": {
        "LOW": 0, "Low": 0,
        "NORMAL": 1, "Normal": 1,
        "HIGH": 2, "High": 2,
    },
    "Heart Rate": {
        "LOW": 0, "Low": 0,
        "NORMAL": 1, "Normal": 1,
        "HIGH": 2, "High": 2,
    },
    "Left-Ventricular End-Diastolic Volume": {
        "LOW": 0, "Low": 0,
        "NORMAL": 1, "Normal": 1,
        "HIGH": 2, "High": 2,
    },
    "Stroke Volume": {
        "LOW": 0, "Low": 0,
        "NORMAL": 1, "Normal": 1,
        "HIGH": 2, "High": 2,
    },
    "Pulmonary Venous O2 Saturation": {
        "LOW": 0, "Low": 0,
        "NORMAL": 1, "Normal": 1,
        "HIGH": 2, "High": 2,
    },
    "Arterial CO2": {
        "LOW": 0, "Low": 0,
        "NORMAL": 1, "Normal": 1,
        "HIGH": 2, "High": 2,
    },

    # ── binary TRUE / FALSE (code 0-1) ─────────────────────────────
    "History":                    {"FALSE": 0, "False": 0, "TRUE": 1, "True": 1},
    "Hypovolemia":                {"FALSE": 0, "False": 0, "TRUE": 1, "True": 1},
    "Left-Ventricular Failure":   {"FALSE": 0, "False": 0, "TRUE": 1, "True": 1},
    "Anaphylaxis":                {"FALSE": 0, "False": 0, "TRUE": 1, "True": 1},
    "Insufficient Anaesthesia":   {"FALSE": 0, "False": 0, "TRUE": 1, "True": 1},
    "Pulmonary Embolus":          {"FALSE": 0, "False": 0, "TRUE": 1, "True": 1},
    "Kinked Tube":                {"FALSE": 0, "False": 0, "TRUE": 1, "True": 1},
    "Disconnection":              {"FALSE": 0, "False": 0, "TRUE": 1, "True": 1},
    "Error Low-Voltage":           {"FALSE": 0, "False": 0, "TRUE": 1, "True": 1},
    "Electrocauter":              {"FALSE": 0, "False": 0, "TRUE": 1, "True": 1},

    # ── two-level NORMAL / HIGH ────────────────────────────────────
    "Catecholamine":              {"NORMAL": 0, "Normal": 0, "HIGH": 1, "High": 1},
    "Shunt":                      {"NORMAL": 0, "Normal": 0, "HIGH": 1, "High": 1},

    # ── LOW / NORMAL (Fraction Inspired O2) ────────────────────────
    "Fraction Inspired O2":       {"LOW": 0, "Low": 0, "NORMAL": 1, "Normal": 1},

    # ── ZERO / LOW / NORMAL / HIGH (four-level) ───────────────────
    "Breathing Pressure":         {"ZERO": 0, "Zero": 0, "LOW": 1, "Low": 1,
                                   "NORMAL": 2, "Normal": 2, "HIGH": 3, "High": 3},
    "Expelled CO2":               {"ZERO": 0, "Zero": 0, "LOW": 1, "Low": 1,
                                   "NORMAL": 2, "Normal": 2, "HIGH": 3, "High": 3},
    "Minimum Volume":             {"ZERO": 0, "Zero": 0, "LOW": 1, "Low": 1,
                                   "NORMAL": 2, "Normal": 2, "HIGH": 3, "High": 3},
    "Pulmonary Alveoli Ventilation": {
                                   "ZERO": 0, "Zero": 0, "LOW": 1, "Low": 1,
                                   "NORMAL": 2, "Normal": 2, "HIGH": 3, "High": 3},
    "Lung Ventilation":           {"ZERO": 0, "Zero": 0, "LOW": 1, "Low": 1,
                                   "NORMAL": 2, "Normal": 2, "HIGH": 3, "High": 3},
    "Ventilation Tube":           {"ZERO": 0, "Zero": 0, "LOW": 1, "Low": 1,
                                   "NORMAL": 2, "Normal": 2, "HIGH": 3, "High": 3},
    "Ventilation Machine":        {"ZERO": 0, "Zero": 0, "LOW": 1, "Low": 1,
                                   "NORMAL": 2, "Normal": 2, "HIGH": 3, "High": 3},

    # ── Minimum Volume Set (LOW / NORMAL / HIGH) ──────────────────
    "Minimum Volume Set":         {"LOW": 0, "Low": 0,
                                   "NORMAL": 1, "Normal": 1,
                                   "HIGH": 2, "High": 2},

    # ── Intubation Type (NORMAL / ESOPHAGEAL / ONE-SIDED) ─────────
    "Intubation Type":            {"NORMAL": 0, "Normal": 0,
                                   "ESOPHAGEAL": 1, "Esophageal": 1,
                                   "ONESIDED": 2, "OneSided": 2},
}

dataset_description_alarm = "A medical diagnostic network used to simulate and monitor ICU patient conditions. Each edge reflects how one clinical variable influences another. Three types of variables are presented in a causal network. Diagnose variables and other qualitative information are at the top level of the network. These variables have no predecessors (no variable causes a diagnose variable). Measurement variables represent any available quantitative information. Intermediate variables are inferred entities that cannot be measured directly."

alarm_df_mapping = {
    'CVP': 'CVP',
    'PCWP': 'PCWP',
    'HIST': 'HISTORY',
    'TPR': 'TPR',
    'BP': 'BP',
    'CO': 'CO',
    'HRBP': 'HRBP',
    'HREK': 'HREKG',
    'HRSA': 'HRSAT',
    'PAP': 'PAP',
    'SAO2': 'SAO2',
    'FIO2': 'FIO2',
    'PRSS': 'PRESS',
    'ECO2': 'EXPCO2',
    'MINV': 'MINVOL',
    'MVS': 'MINVOLSET',
    'HYP': 'HYPOVOLEMIA',
    'LVF': 'LVFAILURE',
    'APL': 'ANAPHYLAXIS',
    'ANES': 'INSUFFANESTH',
    'PMB': 'PULMEMBOLUS',
    'INT': 'INTUBATION',
    'KINK': 'KINKEDTUBE',
    'DISC': 'DISCONNECT',
    'LVV': 'LVEDVOLUME',
    'STKV': 'STROKEVOLUME',
    'CCHL': 'CATECHOL',
    'ERLO': 'ERRLOWOUTPUT',
    'HR': 'HR',
    'ERCA': 'ERRCAUTER',
    'SHNT': 'SHUNT',
    'PVS': 'PVSAT',
    'ACO2': 'ARTCO2',
    'VALV': 'VENTALV',
    'VLNG': 'VENTLUNG',
    'VTUB': 'VENTTUBE',
    'VMCH': 'VENTMACH',
}

def fetch_alarm():
    df = pd.read_csv(f'/net/dali/home/mscbio/rul98/CausalLLM/data/alarm.csv')

    # Normalize all missing values (including string "<NA>")
    df = df.replace(["<NA>", "nan", pd.NA], "None")

    # Rename columns
    df = df.rename(columns=alarm_df_mapping)
    df = df.rename(columns=rename_mapping_alarm)

    # Apply value mappings
    for col, mapping in value_mappings_alarm.items():
        if col in df.columns:
            df[col] = df[col].astype("str").map(mapping).astype("Int64")

    reader = BIFReader(f'/net/dali/home/mscbio/rul98/CausalLLM/data/alarm.bif')
    G_model = reader.get_model()

    # Create a directed graph from the edges
    GroundTruth = nx.DiGraph()
    GroundTruth.add_nodes_from(G_model.nodes())
    GroundTruth.add_edges_from(G_model.edges())
    GroundTruth = nx.relabel_nodes(GroundTruth, rename_mapping_alarm)
    pos_data = nx.spring_layout(GroundTruth)
    # print(set(GroundTruth.nodes()) - set(df.columns), set(df.columns) - set(GroundTruth.nodes()))
    return df, GroundTruth, pos_data