import os
import pandas as pd
from PIL import Image
from sklearn.model_selection import train_test_split
import numpy as np

# This script generates an enriched image-pair dataset from patient x-ray data.
# It iterates through each patient, sorts their available x-rays chronologically,
# and examines all pairs of scans. For each pair (earlier_scan, later_scan),
# it identifies treatments that occurred strictly between these two scans.
#
# Specifically, the dataset includes:
#   - the earlier x-ray image,
#   - the month each treatment occurred between the two scans (or -1 if not occurred),
#   - the later x-ray image,
#   - radiological features (x_ray_grades) recorded at both scans for both knees,
#   - clinical information (e.g., BMI, AGE) recorded at both scans,
#   - demographic information (e.g., sex, ethnicity, race) recorded once.
#
# If no treatments occur between two scans, the dataset includes:
#   - both x-ray images,
#   - all treatment months as -1,
#   - radiological features (x_ray_grades) at both scans,
#   - clinical information for both scans, and demographic data.

# Constants
MONTHS = ["00", "12", "18", "24", "30", "36", "48", "60", "72", "84", "96", "108"]

TREATMENTS = [
    "L_Arthroscopy", "R_Arthroscopy", "L_Meniscectomy", "R_Meniscectomy",
    "L_Hyl_Injection", "R_Hyl_Injection", "L_Steroid_Injection", "R_Steroid_Injection",
    "NSAIDS", "NSAIDRX", "L_KneeReplacement", "R_KneeReplacement",
    "L_HipReplacement", "R_HipReplacement", "LIDOCAINE", "VOLTAREN"
]

X_RAY_GRADES = [
    "JSN_Lateral", "JSN_Medial", "Osteophytes_Femur_Lateral",
    "Osteophytes_Femur_Medial", "Osteophytes_Tibial_Lateral",
    "Osteophytes_Tibial_Medial", "KLGrade"
]

CLINICAL_INFO = [
    "BMI", 
    "AGE"
]
DEMOGRAPHIC_INFO = [
    "sex", 
    "ethnicity", 
    "race"
]

def get_xray_path(row, month, df_columns):
    col = f"{month}_BILATERAL_X_RAY"
    return row[col] if col in df_columns and isinstance(row[col], str) and row[col] != '-1' else None

def extract_xray_grades(row, month, knee_side, df_columns):
    """Extracts the radiological features (x_ray_grades) for the given knee side at a given timepoint."""
    grades = {}
    for feat in X_RAY_GRADES:
        col = f"{month}_{feat}_{knee_side}"
        grades[f"{knee_side}_{feat}"] = row[col] if col in df_columns and not pd.isna(row[col]) else -1
    return grades

def extract_clinical_info(row, month, df_columns):
    """Extracts clinical information (for example, BMI and AGE) for a given timepoint (no left/right)."""
    info = {}
    for feature in CLINICAL_INFO:
        col = f"{month}_{feature}"
        info[f"{feature}"] = row[col] if col in df_columns and not pd.isna(row[col]) else -1
    return info

def extract_demographic(row, df_columns):
    """Extracts demographic information (sex, ethnicity, race) once per patient."""
    demo = {}
    for feature in DEMOGRAPHIC_INFO:
        demo[feature] = row[feature] if feature in df_columns and not pd.isna(row[feature]) else None
    return demo

def prepare_dataset(df):
    rows = []
    df_columns = df.columns

    for _, row in df.iterrows():
        subject_id = row["src_subject_id"]
        # Extract demographic information once per patient.
        demographic_info = extract_demographic(row, df_columns)
        
        xray_months = sorted([(int(m), get_xray_path(row, m, df_columns)) for m in MONTHS if get_xray_path(row, m, df_columns)])
        for i in range(len(xray_months) - 1):
            month_earlier, img_earlier = xray_months[i]
            for j in range(i + 1, len(xray_months)):
                month_later, img_later = xray_months[j]

                # Create treatment month mapping.
                treatments_month = {treatment: -1 for treatment in TREATMENTS}
                for month in range(month_earlier + 1, month_later + 1):
                    str_month = f"{month:02d}"
                    for treatment in TREATMENTS:
                        col = f"{str_month}_{treatment}"
                        if col in df_columns and row[col] == 1 and treatments_month[treatment] == -1:
                            treatments_month[treatment] = month

                # Verify images can be opened.
                try:
                    Image.open(img_earlier).convert("RGB")
                except Exception as e:
                    print(f"Cannot open earlier image: {img_earlier}. Error: {e}")
                    continue

                try:
                    Image.open(img_later).convert("RGB")
                except Exception as e:
                    print(f"Cannot open later image: {img_later}. Error: {e}")
                    continue

                row_data = {
                    "src_subject_id": subject_id,
                    "month_earlier": month_earlier,
                    "image_earlier": img_earlier,
                    "month_later": month_later,
                    "image_later": img_later,
                    **treatments_month,
                }
                # Add X_RAY_GRADES for both knees (earlier and later).
                row_data.update({f"earlier_x_ray_grades_{k}": v 
                                 for k, v in extract_xray_grades(row, f"{month_earlier:02d}", 'L', df_columns).items()})
                row_data.update({f"earlier_x_ray_grades_{k}": v 
                                 for k, v in extract_xray_grades(row, f"{month_earlier:02d}", 'R', df_columns).items()})
                row_data.update({f"later_x_ray_grades_{k}": v 
                                 for k, v in extract_xray_grades(row, f"{month_later:02d}", 'L', df_columns).items()})
                row_data.update({f"later_x_ray_grades_{k}": v 
                                 for k, v in extract_xray_grades(row, f"{month_later:02d}", 'R', df_columns).items()})
                # Add CLINICAL_INFO for each time point.
                row_data.update({f"earlier_clinical_info_{k}": v 
                                 for k, v in extract_clinical_info(row, f"{month_earlier:02d}", df_columns).items()})
                row_data.update({f"later_clinical_info_{k}": v 
                                 for k, v in extract_clinical_info(row, f"{month_later:02d}", df_columns).items()})
                # Add demographic info (only once).
                row_data.update(demographic_info)
                rows.append(row_data)
    return pd.DataFrame(rows)

if __name__ == "__main__":
    df = pd.read_csv("../dataset.csv")

    # Data cleaning.
    df.replace(['-1', '.', '-1.0', -1, '.: Missing Form/Incomplete Workbook'], np.nan, inplace=True)
    df.replace(['0: Never'], 0, inplace=True)
    df.replace(['1: Less than 1 hour', '1: Seldom (1-2 days)'], 1, inplace=True)
    df.replace(['2: 1 hour but less than 2 hours', '2: Sometimes (3-4 days)'], 2, inplace=True)
    df.replace(['3: 2-4 hours', '3: Often (5-7 days)'], 3, inplace=True)
    df.replace(['4: More than 4 hours'], 4, inplace=True)
    
    dataset = prepare_dataset(df)
    dataset.to_csv("pairs-dataset.csv", index=False)

    # Get unique subject ids.
    unique_subjects = dataset['src_subject_id'].unique()

    # Split the unique subjects into train (80%) and temporary (20%) groups.
    from sklearn.model_selection import train_test_split
    train_subjects, temp_subjects = train_test_split(unique_subjects, train_size=0.8, random_state=42)

    # Split the temporary subjects evenly into validation and test groups (each 10% of the total).
    val_subjects, test_subjects = train_test_split(temp_subjects, test_size=0.5, random_state=42)

    # Subset the dataset based on these subject splits.
    train = dataset[dataset['src_subject_id'].isin(train_subjects)]
    val = dataset[dataset['src_subject_id'].isin(val_subjects)]
    test = dataset[dataset['src_subject_id'].isin(test_subjects)]

    # Save the splits to CSV files.
    train.to_csv("train.csv", index=False)
    val.to_csv("val.csv", index=False)
    test.to_csv("test.csv", index=False)

    print(f"Train size: {len(train)}, Val size: {len(val)}, Test size: {len(test)}")
