import numpy as np
import pandas as pd
import networkx as nx
from pgmpy.readwrite import BIFReader
rename_mapping_asia = {
    'D': "dyspnoea",
    'T': "tuberculosis",
    'L': "lung_cancer",
    'B': "bronchitis",
    'A': "visit_to_asia",
    'S': "smoking",
    'X': "positive_chest_xray",
    'E': "either"
}

variable_description_asia = {
    "visit_to_asia": "This variable indicates whether the person has visited Asia or not in 1980s, where is a large portion of the tuberculosis cases. It is a discrete variable with two possible values: yes (the person has visited Asia) or no (the person has not visited Asia).",
    "tuberculosis": "This variable indicates whether the person has tuberculosis or not. It is a discrete variable with two possible values: yes (the person has tuberculosis) or no (the person does not have tuberculosis).",
    "smoking": "This variable indicates whether the person is a smoker or not. It is a discrete variable with two possible values: yes (the person is a smoker) or no (the person is not a smoker).",
    "lung_cancer": "This variable indicates whether the person has lung cancer or not. It is a discrete variable with two possible values: yes (the person has lung cancer) or no (the person does not have lung cancer).",
    "bronchitis": "This variable indicates whether the person has bronchitis or not. It is a discrete variable with two possible values: yes (the person has bronchitis) or no (the person does not have bronchitis).",
    "either": "This variable indicates whether the person has either tuberculosis or lung cancer or not. It is a discrete variable with two possible values: yes (the person has either tuberculosis or lung cancer) or no (the person does not have either tuberculosis or lung cancer).",
    "positive_chest_xray": "This variable indicates whether the person has a positive X-ray result or not. It is a discrete variable with two possible values: yes (the person has a positive X-ray result) or no (the person does not have a positive X-ray result).",
    "dyspnoea": "This variable indicates whether the person is experiencing shortness of breath (dyspnoea) or not. It is a discrete variable with two possible values: yes (the person is experiencing dyspnoea) or no (the person is not experiencing dyspnoea)."
}

dataset_description_asia = "A dataset about lung diseases (tuberculosis, lung cancer or bronchitis) and visits to Asia. Shortness-of-breath (dyspnoea) may be due to tuberculosis, lung cancer or bronchitis, or none of them, or more than one of them. A recent visit to Asia increases the chances of tuberculosis, while smoking is known to be a risk factor for both lung cancer and bronchitis. The results of a single chest X-ray do not discriminate between lung cancer and tuberculosis, as neither does the presence or absence of dyspnoea."

def fetch_asia():
    df = pd.read_csv(f'/net/dali/home/mscbio/rul98/CausalLLM/data/asia.csv')

    # Rename columns
    df = df.rename(columns=rename_mapping_asia)

    for col in df.columns:
        df[col] = df[col].astype("str").map({'no': 0, 'yes':1}).astype("Int64")  

    reader = BIFReader(f'/net/dali/home/mscbio/rul98/CausalLLM/data/asia.bif')
    G_model = reader.get_model()
    bif_mapping_asia = {
        'dysp': "dyspnoea",
        'tub': "tuberculosis",
        'lung': "lung_cancer",
        'bronc': "bronchitis",
        'asia': "visit_to_asia",
        'smoke': "smoking",
        'xray': "positive_chest_xray",
        'either': "either"
    }
    # Create a directed graph from the edges
    GroundTruth = nx.DiGraph()
    GroundTruth.add_nodes_from(G_model.nodes())
    GroundTruth.add_edges_from(G_model.edges())
    GroundTruth = nx.relabel_nodes(GroundTruth, bif_mapping_asia)
    pos_data = nx.spring_layout(GroundTruth)
    # print(set(GroundTruth.nodes()) - set(df.columns), set(df.columns) - set(GroundTruth.nodes()))
    return df, GroundTruth, pos_data