import numpy as np
import pandas as pd
import networkx as nx
from pgmpy.readwrite import BIFReader
rename_mapping_child = {
    'BirthAsphyxia': 'Birth Asphyxia (Known)',
    'HypDistrib': 'Blood Flow Distribution',
    'HypoxiaInO2': 'Hypoxia Severity (O2 Levels)',
    'CO2': 'CO2 Levels',
    'ChestXray': 'Chest X-ray Findings',
    'Grunting': 'Grunting Present',
    'LVHreport': 'LVH Report',
    'LowerBodyO2': 'Lower Body Oxygen Saturation',
    'RUQO2': 'Right Upper Quadrant Oxygen Saturation',
    'CO2Report': 'CO2 Report',
    'XrayReport': 'Chest X-ray Report',
    'Disease': 'Diagnosed Disease',
    'GruntingReport': 'Grunting Report',
    'Age': 'Age When Sick Enough',
    'LVH': 'Left Ventricular Hypertrophy (LVH)',
    'DuctFlow': 'Ductus Arteriosus Blood Flow',
    'CardiacMixing': 'Cardiac Blood Mixing',
    'LungParench': 'Lung Parenchyma Appearance',
    'LungFlow': 'Lung Blood Flow',
    'Sick': 'Sickness Status'
}

ils_desc = {
    "BirthAsphyxia": "This variable indicates whether the newborn experienced birth asphyxia or not, which is already known.",
    "HypDistrib": "This variable represents the distribution of blood flow in the body.\nEqual: Blood flow is evenly distributed throughout the body.\nUnequal: Blood flow is unevenly distributed in the body.",
    "HypoxiaInO2": "This variable indicates the severity of hypoxia (low oxygen levels) in the body.\nMild: Mild level of hypoxia.\nModerate: Moderate level of hypoxia.\nSevere: Severe level of hypoxia.",
    "CO2": "This variable represents the carbon dioxide levels in the body.\nNormal: Normal CO2 levels.\nLow: Low CO2 levels.\nHigh: High CO2 levels.",
    "ChestXray": "This variable describes the appearance of the chest X-ray findings.\nNormal: Normal chest X-ray.\nOligaemic: Oligaemic chest X-ray indicating reduced blood flow.\nPlethoric: Plethoric chest X-ray indicating increased blood flow.\nGrd_Glass: Ground-glass opacity in the chest X-ray.\nAsy/Patch: Asymmetric or patchy appearance in the chest X-ray.",
    "Grunting": "This variable indicates the presence of grunting in the newborn.\nyes: The newborn is grunting.\nno: The newborn is not grunting.",
    "LVHreport": "This variable represents the report of left ventricular hypertrophy (LVH), only directly dependent on the LVH variable.",
    "LowerBodyO2": "This variable indicates the oxygen saturation levels in the lower body.\n<5: Oxygen saturation levels below 5%.\n5-12: Oxygen saturation levels between 5% and 12%.\n12+: Oxygen saturation levels above 12%.",
    "RUQO2": "This variable indicates the oxygen saturation levels in the right upper quadrant of the body.\n<5: Oxygen saturation levels below 5%.\n5-12: Oxygen saturation levels between 5% and 12%.\n12+: Oxygen saturation levels above 12%.",
    "CO2Report": "This variable represents the report of CO2 levels, only directly dependent on the CO2 variable.",
    "XrayReport": "This variable describes the appearance of the chest X-ray findings in the medical report.\nNormal: Normal chest X-ray in the medical report, only directly dependent on the Chest Xray appearance.\nOligaemic: Oligaemic chest X-ray in the medical report.\nPlethoric: Plethoric chest X-ray in the medical report.\nGrd_Glass: Ground-glass opacity in the chest X-ray in the medical report.\nAsy/Patchy: Asymmetric or patchy appearance in the chest X-ray in the medical report.",
    "Disease": "This variable represents the type of birth-determined diagnoses among six congenital cardiopulmonary defects (the structural lesion present when the baby is born)\nPFC: Persistent fetal circulation.\nTGA: Transposition of the great arteries.\nFallot: Tetralogy of Fallot.\nPAIVS: Pulmonary atresia with intact ventricular septum.\nTAPVD: Total anomalous pulmonary venous drainage.\nLung: Lung disease.",
    "GruntingReport": "This variable indicates the report of presence of grunting, only directly dependent on the Grunting variable.",
    "Age": "This variable represents an observed time-of-presentation indicator that captures when the baby became sick enough to be referred to Great Ormond Street Hospital",
    "LVH": "This variable indicates the presence of left ventricular hypertrophy (LVH) in the newborn.\nyes: LVH is present.\nno: LVH is not present.",
    "DuctFlow": "This variable represents the direction of blood flow in the ductus arteriosus.\nLt_to_Rt: Blood flow from left to right.\nNone: No blood flow inthe ductus arteriosus.\nRt_to_Lt: Blood flow from right to left.",
    "CardiacMixing": "This variable describes the degree of mixing between oxygenated and deoxygenated blood in the heart.\nNone: No mixing of blood.\nMild: Mild mixing of blood.\nComplete: Complete mixing of blood.\nTransp.: Mixing in transposed circulation.",
    "LungParench": "This variable represents the appearance of the lung parenchyma.\nNormal: Normal lung parenchyma.\nCongested: Congested lung parenchyma.\nAbnormal: Abnormal lung parenchyma.",
    "LungFlow": "This variable indicates the blood flow in the lungs.\nNormal: Normal blood flow in the lungs.\nLow: Low blood flow in the lungs.\nHigh: High blood flow in the lungs.",
    "Sick": "This variable indicates whether the newborn is sick or not.\nyes: The newborn is sick.\nno: The newborn is not sick." 
}

variable_description_child = {
    rename_mapping_child[k]: v for k, v in ils_desc.items() if k in rename_mapping_child
}

value_mappings_child = {
    # Binary: missing → 0, negative → 0, positive → 1
    'Birth Asphyxia (Known)': {
        "None":      0,
        'no':   0,
        'yes':  1,
    },
    'Grunting Present': {
        "None":      0,
        'no':   0,
        'yes':  1,
    },
    'LVH Report': {
        "None":      0,
        'no':   0,
        'yes':  1,
    },
    'Grunting Report': {
        "None":      0,
        'no':   0,
        'yes':  1,
    },
    'Left Ventricular Hypertrophy (LVH)': {
        "None":      0,
        'no':   0,
        'yes':  1,
    },
    'Sickness Status': {
        "None":      0,
        'no':   0,
        'yes':  1,
    },

    # Two‐level ordinal
    'CO2 Report': {
        "None":      0,
        '<7.5':  0,
        '>=7.5': 1,
    },

    # Three‐level ordinal (severity or amount)
    'Hypoxia Severity (O2 Levels)': {
        "None":      0,
        'Mild':     0,
        'Moderate': 1,
        'Severe':   2,
    },
    'CO2 Levels': {
        "None":      0,
        'Low':   0,
        'Normal':1,
        'High':  2,
    },
    'Lower Body Oxygen Saturation': {
        "None":      0,
        '<5':   0,
        '5-12': 1,
        '12+':  2,
    },
    'Right Upper Quadrant Oxygen Saturation': {
        "None":      0,
        '<5':   0,
        '5-12': 1,
        '12+':  2,
    },
    'Newborn Age': {
        "None":      0,
        '0-3_days': 0,
        '4-10_days':1,
        '11-30_days':2,
    },
    'Lung Blood Flow': {
        "None":      0,
        'Low':  0,
        'Normal':1,
        'High': 2,
    },

    # Four‐level nominal (no inherent order) & five‐level nominal
    'Blood Flow Distribution': {
        "None":      0,
        'Equal':   0,
        'Unequal': 1,
    },
    'Chest X-ray Findings': {
        "None":      0,
        'Oligaemic':  0,
        'Normal':     1,
        'Asy/Patch':  2,
        'Plethoric':  3,
        'Grd_Glass':  4,
    },
    'Chest X-ray Report': {
        "None":      0,
        'Oligaemic':  0,
        'Normal':     1,
        'Asy/Patchy': 2,
        'Plethoric':  3,
        'Grd_Glass':  4,
    },
    'Diagnosed Disease': {
        "None":      0,
        'Fallot':1,
        'TGA':   2,
        'PAIVS': 3,
        'Lung':  4,
        'TAPVD': 5,
        'PFC':   6,
    },

    # Mixed (use domain semantics for order)
    'Ductus Arteriosus Blood Flow': {
        "None":      0,
        'Lt_to_Rt':  0,
        'Rt_to_Lt':  1,
    },
    'Cardiac Blood Mixing': {
        "None":      0,
        'Mild':    0,
        'Complete':1,
        'Transp.': 2,
    },
    'Lung Parenchyma Appearance': {
        "None":      0,
        'Normal':   0,
        'Abnormal': 1,
        'Congested':2,
    },
}

dataset_description_child = "A Bayesian network modeling pediatric congenital heart disease diagnoses. It captures physiological and diagnostic relationships in newborns, including symptoms, test results, and disease states."

def fetch_child():
    df = pd.read_csv(f'/net/dali/home/mscbio/rul98/CausalLLM/data/child_20000.csv')

    # Normalize all missing values (including string "<NA>")
    df = df.replace(["<NA>", "nan", pd.NA], "None")

    # Rename columns
    df = df.rename(columns=rename_mapping_child)

    # Apply value mappings
    for col, mapping in value_mappings_child.items():
        if col in df.columns:
            df[col] = df[col].astype("str").map(mapping).astype("Int64")  

    reader = BIFReader(f'/net/dali/home/mscbio/rul98/CausalLLM/data/child.bif')
    G_model = reader.get_model()

    # Create a directed graph from the edges
    GroundTruth = nx.DiGraph()
    GroundTruth.add_nodes_from(G_model.nodes())
    GroundTruth.add_edges_from(G_model.edges())
    GroundTruth = nx.relabel_nodes(GroundTruth, rename_mapping_child)
    pos_data = nx.spring_layout(GroundTruth)
    # print(set(GroundTruth.nodes()) - set(df.columns), set(df.columns) - set(GroundTruth.nodes()))
    return df, GroundTruth, pos_data