import numpy as np
import pandas as pd
import networkx as nx
from pgmpy.readwrite import BIFReader
variable_description_cancer = {
    "Pollution": "This variable represents the level of air pollution in a certain area. It is a discrete variable with two possible values:\nlow: Indicates a low level of air pollution.\nhigh: Indicates a high level of air pollution.",
    "Smoker": "This variable represents whether an individual is a smoker or not. It is a discrete variable with two possible values:\nTrue: Indicates that the individual is a smoker.\nFalse: Indicates that the individual is not a smoker.",
    "Cancer": "This variable represents whether an individual has been diagnosed with cancer or not. It is a discrete variable with two possible values:\nTrue: Indicates that the individual has cancer.\nFalse: Indicates that the individual does not have cancer.",
    "Xray": "This variable represents the result of an X-ray examination for an individual. It is a discrete variable with two possible values:\npositive: Indicates that the X-ray examination showed signs of cancer.\nnegative: Indicates that the X-ray examination did not show signs of cancer.",
    "Dyspnoea": "This variable represents whether an individual is experiencing dyspnea or not. Dyspnea is a medical term for shortness of breath or difficulty in breathing. It is a discrete variable with two possible values:\nTrue: Indicates that the individual is experiencing dyspnea.\nFalse: Indicates that the individual is not experiencing dyspnea."
}

dataset_description_cancer = "A dataset about the effects of smoking on cancer."

# value mapping: adverse=1, normal=0
value_mappings_cancer = {
    'Pollution': {'low': 0, 'high': 1},
    'Smoker': {'False': 0, 'True': 1},
    'Cancer': {'False': 0, 'True': 1},
    'Xray': {'negative': 0, 'positive': 1},
    'Dyspnoea': {'False': 0, 'True': 1},
}

def fetch_cancer():
    df = pd.read_csv(f'/net/dali/home/mscbio/rul98/CausalLLM/data/cancer_20000.csv')
    
    for col, mapping in value_mappings_cancer.items():
        if col in df.columns:
            df[col] = df[col].astype("str").map(mapping).astype("Int64")

    reader = BIFReader(f'/net/dali/home/mscbio/rul98/CausalLLM/data/cancer.bif')
    G_model = reader.get_model()
    GroundTruth = nx.DiGraph()
    GroundTruth.add_nodes_from(G_model.nodes())
    GroundTruth.add_edges_from(G_model.edges())
    pos_data = nx.spring_layout(GroundTruth)
    # print(set(GroundTruth.nodes()) - set(df.columns), set(df.columns) - set(GroundTruth.nodes()))
    return df, GroundTruth, pos_data