import pandas as pd
import os
from tqdm import tqdm

# Specify types of X-Rays to use 
IMAGE_TYPES = {
    "Bilateral PA Fixed Flexion Knee": "BILATERAL_X_RAY"
}

MONTH_FOLDERS = {
    "00": "BaselineImages",
    "12": "12MonthImages",
    "18": "18MonthImages",
    "24": "24MonthImages",
    "30": "30MonthImages",
    "36": "36MonthImages",
    "48": "48MonthImages",
    "72": "72MonthImages",
    "96": "96MonthImages",
}

# From VisitPrefixDefinitions.pdf in General Documentation
MONTH_ORDER = {
    "00":  "00",
    "12":  "01",
    "18":  "02",
    "24":  "03",
    "30":  "04",
    "36":  "05",
    "48":  "06",
    "60":  "07",
    "72":  "08",
    "84":  "09",
    "96":  "10",
    "108": "11"
}

XRAY_FEATURES_NAMES = {
    "XRJSL" : "JSN_Lateral", 
    "XRJSM" : "JSN_Medial",
    "XROSFL" : "Osteophytes_Femur_Lateral", 
    "XROSFM" : "Osteophytes_Femur_Medial", 
    "XROSTL" : "Osteophytes_Tibial_Lateral", 
    "XROSTM" : "Osteophytes_Tibial_Medial", 
    "XRKL" : "KLGrade"
}                 

# Map AllClinicalXX.txt column names that indicate recent treatment with past 12 months to new dataset column names
TREATMENTS = {
    "ARTL12"  : "L_Arthroscopy",
    "ARTR12"  : "R_Arthroscopy",
    "MENL12"  : "L_Meniscectomy",
    "MENR12"  : "R_Meniscectomy",
    "HYAINJL" : "L_Hyl_Injection",
    "HYAINJR" : "R_Hyl_Injection",
    "STRINJL" : "L_Steroid_Injection",
    "STRINJR" : "R_Steroid_Injection",
    "NSAIDS"  : "NSAIDS",
    "NSAIDRX" : "NSAIDRX"
}

ALL_CLINICAL_INFO = {
    "BMI"     : "BMI", 
    "WOMADLL" : "L_WOMAC_Disability", 
    "WOMADLR" : "R_WOMAC_Disability", 
    "WOMKPL"  : "L_WOMAC_Pain",
    "WOMKPR"  : "R_WOMAC_Pain",
    "WOMSTFL" : "L_WOMAC_Stiffness", 
    "WOMSTFR" : "R_WOMAC_Stiffness",
    "WOMTSL"  : "L_WOMAC_Total",
    "WOMTSR"  : "R_WOMAC_Total",
    "KOOSQOL" : "KOOS", 
    "AGE"     : "AGE", 
}
EXERCISE = {k : k for k in ["PASE" + str(i) for i in range(1,7)] + [f"PASE" + str(i) + "HR" for i in range(1,7)]}
ALL_CLINICAL_INFO.update(EXERCISE)

MIF = {
    "LIDOCAINE" : "LIDOCAINE",
    "DICLOFENAC SODIUM" : "VOLTAREN"
}

OUTCOMES = False
OUTCOME_TREATMENTS = {
    "V99ELKVSAF" : "L_KneeReplacement", #VSAF denotes the closest OAI visit month after the replacment
    "V99ERKVSAF" : "R_KneeReplacement",
    "V99ELHVSAF" : "L_HipReplacement",
    "V99ERHVSAF" : "R_HipReplacement",
}

ethnicity_map = {
    'Not Hispanic or Latino': 0,
    'Hispanic or Latino': 1,
}

race_map = {
    'Asian': 0,
    'Black or African American': 1,
    'Other Non-White': 2,
    'White': 3,
}

SOURCE_PATH = "/local2/acc/OAI"
df = pd.read_csv(os.path.join(SOURCE_PATH, "24MonthImages/oai_enrollee01.txt"), delimiter='\t', skiprows=[1])
outcomes_df = pd.read_csv(f"/local2/acc/OAI/OAI_Text_Data/Outcomes99.txt", delimiter='|')

output_rows = []

# Loop through each subject in the enrollee dataframe
for _, subject in tqdm(df.iterrows(), total=len(df)):
    id = str(subject['src_subject_id'])
    subject_data = {
        "src_subject_id": id,
        "sex": subject['sex'],
        "ethnicity": ethnicity_map.get(subject.get('ethnicity'), -1),
        "race": race_map.get(subject.get('race'), -1),
        "cohort": "Progression" if subject['e_cohort'] == 1 else "Incidence" if subject['e_cohort'] == 2 else "Control"
    }

    # Extract all Outcomes (Optional)
    if OUTCOMES:
        outcomes_row = outcomes_df.loc[outcomes_df['id'] == int(id)]
        if not outcomes_row.empty:
            for column in outcomes_df.columns:
                if column != 'id' or column != 'version':
                    subject_data[column] = outcomes_row[column]

    # BEGIN loop through all months
    for month, month_order in MONTH_ORDER.items():
        all_clinical = pd.read_csv(f"/local2/acc/OAI/OAI_Text_Data/AllClinical{month_order}.txt", delimiter='|')

        # Extract X-Ray paths from IMAGE_TYPES
        if month in MONTH_FOLDERS:
            folder = MONTH_FOLDERS[month]
            img_df = pd.read_csv(f"/local2/acc/OAI/{folder}/image03.txt", delimiter="\t", skiprows=[1], low_memory=False)
            for x_ray_type, x_raw_col_name in IMAGE_TYPES.items():
                imgs = img_df[(img_df.image_description == x_ray_type) & (img_df.src_subject_id == int(id))].image_thumbnail_file.values
                if len(imgs)>0:
                    subject_data[f"{month}_{x_raw_col_name}"] = f"{SOURCE_PATH}/{folder}/image03/{month}m/" + "/".join(imgs[0].split('/')[5:])
                else:
                    subject_data[f"{month}_{x_raw_col_name}"] = "-1"

        #Extract KL, JSN, and Osteophytes Grades
        if month_order in ["00", "01", "03", "05", "06", "08", "10"]:
            xray_f_df = pd.read_csv(f'/local2/acc/OAI/OAI_Text_Data/kxr_sq_bu{month_order}.txt', delimiter='|')
            if int(id) in xray_f_df.ID.values:
                for f_col, f_name in XRAY_FEATURES_NAMES.items():
                    prefix = "v" if month_order == "05" and (f_col == "XRKL" or f_col == "XRJSL" or f_col == "XRJSM") else "V" 
                    for side_val, side in enumerate(['R', 'L']): 
                        f_row = xray_f_df[(xray_f_df.ID == int(id)) & (xray_f_df.SIDE == side_val + 1)][prefix + month_order + f_col]
                        if not f_row.isna().all():
                            subject_data[f"{month}_{f_name}_{side}"] = f_row.values[0]
        
        # Extract Treatments from TREATMENTS and info from ALL_CLINICAL_INFO
        if month not in ["00","18", "30", "60", "84", "108"]:
            clinical_row = all_clinical.loc[all_clinical.ID == int(id)]
            for t_key, t_name in TREATMENTS.items():
                t_val = clinical_row[f'V{month_order}{t_key}'].values[0]
                subject_data[f"{month}_{t_name}"] = 1 if "1" in str(t_val) else 0 if "0" in str(t_val) else -1
            for i_key, i_name in ALL_CLINICAL_INFO.items():
                i_val = clinical_row[f'V{month_order}{i_key}'].values[0]
                subject_data[f"{month}_{i_name}"] = i_val

        # Extract MIF data
        if month_order not in ["07", "09", "11"]: 
            mif = pd.read_csv(f"/local2/acc/OAI/OAI_Text_Data/MIF{month_order}.txt", delimiter='|', encoding='latin1')
            for med, med_name in MIF.items():
                freq = mif[(mif.ID == int(id)) & (mif[f"V{month_order}INGNAME"] == med)][f"V{month_order}MIFFREQ"].values
                if len(freq)>0:
                    subject_data[f"{month}_{med_name}"] = freq[0]
                else:
                    subject_data[f"{month}_{med_name}"] = "-1"
       
        # Extract Treatments from TREATMENTS_OUTCOME
        if month != "00":
            for t_key, t_name in OUTCOME_TREATMENTS.items(): 
                subject_data[f"{month}_{t_name}"] = 0
                patient_treatment_date = outcomes_df[outcomes_df.id == int(id)][t_key].values[0].split(':')[0]
                if patient_treatment_date != "." and int(patient_treatment_date) == int(month_order):
                    subject_data[f"{month}_{t_name}"] = 1

    # After processing all months, append the subject's data to the output list
    output_rows.append(subject_data)

# Create the final DataFrame
output_df = pd.DataFrame(output_rows)
output_df.to_csv("dataset.csv", index=False)


