import os
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from diffusers import StableDiffusionImg2ImgPipeline
from PIL import Image
import re

TREATMENTS_OF_INTEREST = {
    "L_KneeReplacement": "Left Knee Replacement",
    "R_KneeReplacement": "Right Knee Replacement",
    "L_HipReplacement": "Left Hip Replacement",
    "R_HipReplacement": "Right Hip Replacement",
    "L_Arthroscopy": "Left Knee Arthroscopy",
    "R_Arthroscopy": "Right Knee Arthroscopy",
    "L_Meniscectomy": "Left Meniscectomy",
    "R_Meniscectomy": "Right Meniscectomy",
    "L_Hyl_Injection": "Left Hyaluronic Acid Injection",
    "R_Hyl_Injection": "Right Hyaluronic Acid Injection",
    "L_Steroid_Injection": "Left Steroid Injection",
    "R_Steroid_Injection": "Right Steroid Injection",
    "NSAIDS": "NSAIDs Medication",
    "NSAIDRX": "Prescription NSAIDs",
    "LIDOCAINE": "Lidocaine Treatment",
    "VOLTAREN": "Voltaren Treatment"
}

KNEE_TREATMENTS_OF_INTEREST = {
    "L_KneeReplacement": "Left Knee Replacement",
    "R_KneeReplacement": "Right Knee Replacement",
    "L_HipReplacement": "Left Hip Replacement",
    "R_HipReplacement": "Right Hip Replacement",
    "L_Arthroscopy": "Left Knee Arthroscopy",
    "R_Arthroscopy": "Right Knee Arthroscopy",
    "L_Meniscectomy": "Left Meniscectomy",
    "R_Meniscectomy": "Right Meniscectomy",
    "L_Hyl_Injection": "Left Hyaluronic Acid Injection",
    "R_Hyl_Injection": "Right Hyaluronic Acid Injection",
    "L_Steroid_Injection": "Left Steroid Injection",
    "R_Steroid_Injection": "Right Steroid Injection",
    # "NSAIDS": "NSAIDs Medication",
    # "NSAIDRX": "Prescription NSAIDs",
    # "LIDOCAINE": "Lidocaine Treatment",
    # "VOLTAREN": "Voltaren Treatment"
}

PERCEPTUAL_FEATURES = [
    # "JSN_Lateral",
    "JSN_Medial",
    # "Osteophytes_Femur_Lateral",
    # "Osteophytes_Femur_Medial",
    # "Osteophytes_Tibial_Lateral",
    # "Osteophytes_Tibial_Medial",
    "KLGrade"
]

X_RAY_GRADES = [
    "JSN_Lateral",
    "JSN_Medial",
    "Osteophytes_Femur_Lateral",
    "Osteophytes_Femur_Medial",
    "Osteophytes_Tibial_Lateral",
    "Osteophytes_Tibial_Medial",
    "KLGrade"
]

CLINICAL_INFO = [
    "BMI",
    "AGE"
]
DEMOGRAPHIC_INFO = [
    "sex",
    # "ethnicity",
    # "race"
]

FEATURE_DIM = (len(X_RAY_GRADES) * 2) + len(CLINICAL_INFO) + len(DEMOGRAPHIC_INFO)

TREATMENT_LABELS = {
    "Left Knee Replacement": 0,
    "Right Knee Replacement": 1,
    "Left Hip Replacement": 2,
    "Right Hip Replacement": 3,
    "Left Knee Arthroscopy": 4,
    "Right Knee Arthroscopy": 5,
    "Left Meniscectomy": 6,
    "Right Meniscectomy": 7,
    "Left Hyaluronic Acid Injection": 8,
    "Right Hyaluronic Acid Injection": 9,
    "Left Steroid Injection": 10,
    "Right Steroid Injection": 11,
    "No Treatment": 12
}

NUM_TREATMENTS = len(TREATMENT_LABELS)

# Precompute available timepoint prefixes from full dataset columns
_FULL_DATASET_DF = pd.read_csv("./data/dataset.csv")
_FULL_DATASET_DF.replace(['-1', '.', '-1.0', -1, '.: Missing Form/Incomplete Workbook'], np.nan, inplace=True)
_FULL_DATASET_DF.replace(['0: Never'], 0, inplace=True)
_FULL_DATASET_DF.replace(['1: Less than 1 hour', '1: Seldom (1-2 days)'], 1, inplace=True)
_FULL_DATASET_DF.replace(['2: 1 hour but less than 2 hours', '2: Sometimes (3-4 days)'], 2, inplace=True)
_FULL_DATASET_DF.replace(['3: 2-4 hours', '3: Often (5-7 days)'], 3, inplace=True)
_FULL_DATASET_DF.replace(['4: More than 4 hours'], 4, inplace=True)
_TIMEPOINT_PREFIXES = sorted({
    int(m.group(1))
    for col in _FULL_DATASET_DF.columns
    if (m := re.match(r'^(\d+)_', col))
})

# Map subject_id to its full record for quick lookup
_FULL_DATA_DICT = {
    str(r["src_subject_id"]): r
    for _, r in _FULL_DATASET_DF.iterrows()
    if not pd.isna(r["src_subject_id"])
}


# -----------------------------
# Helper function to generate a cropped image path.
# -----------------------------
def get_crop_path(original_path, side):
    """
    Given an original bilateral scan path, returns the corresponding cropped image path.
    It replaces "OAI" with "YOLO_OAI" and appends _{side} before the file extension.
    """
    out_path = original_path.replace("OAI", "YOLO_OAI")
    base, ext = os.path.splitext(out_path)
    return base + f"_{side}" + ext

def encode_feature(feat, raw_val):
    """
    Convert raw_val to a float.
    For 'JSN_Medial', apply the conversion:
        <= 0.2 -> 0.0, <= 1.5 -> 1.0, <= 2.5 -> 2.0, otherwise -> 3.0.
    Otherwise, return the float value or -1.0 on error.
    """
    try:
        val = float(raw_val)
    except Exception:
        return -1.0
    if val == -1:
        return -1.0
    if feat == "JSN_Medial":
        if val <= 0.2:
            return 0.0
        elif val <= 1.5:
            return 1.0
        elif val <= 2.5:
            return 2.0
        else:
            return 3.0
    return val

def build_earlier_feature_vector(row):
    """
    Construct a combined feature vector for the earlier scan by concatenating:
      - Left and right X_RAY_GRADES (using encode_feature),
      - CLINICAL_INFO (for earlier, with 'AGE' normalized),
      - DEMOGRAPHIC_INFO.
    """
    feat_list = []
    # Left knee x-ray grades.
    for feat in X_RAY_GRADES:
        col = f"earlier_x_ray_grades_L_{feat}"
        raw_val = row.get(col, -1)
        feat_list.append(encode_feature(feat, raw_val))
    # Right knee x-ray grades.
    for feat in X_RAY_GRADES:
        col = f"earlier_x_ray_grades_R_{feat}"
        raw_val = row.get(col, -1)
        feat_list.append(encode_feature(feat, raw_val))
    # Clinical information.
    for feat in CLINICAL_INFO:
        col = f"earlier_clinical_info_{feat}"
        try:
            val = float(row.get(col, -1))
        except Exception:
            val = -1.0
        if feat == "AGE" and val != -1:
            val = val / 100.0  # Normalize AGE.
        feat_list.append(val)
    # Demographic information.
    for feat in DEMOGRAPHIC_INFO:
        val = row.get(feat, None)
        try:
            feat_list.append(float(val))
        except Exception:
            feat_list.append(-1.0)
    return feat_list

def build_later_feature_vector(row, side, features_list=PERCEPTUAL_FEATURES):
    """
    For a given side ('L' or 'R'), construct the later feature vector using the new nomenclature.
    Columns are expected to be named: "later_x_ray_grades_{side}_{feat}".
    """
    feat_list = []
    for feat in features_list:
        col = f"later_x_ray_grades_{side}_{feat}"
        try:
            val = float(row.get(col, -1))
        except Exception:
            val = -1.0
        # For perceptual features like JSN_Medial, apply encoding.
        if feat == "JSN_Medial":
            val = encode_feature("JSN_Medial", val)
        feat_list.append(val)
    return feat_list

# -----------------------------
# Custom Dataset for pretraining.
# -----------------------------
class KneeFeatureDataset(Dataset):
    def __init__(self, csv_file, split, feature, transform=None):
        """
        csv_file: path to CSV (train, val, or test)
        split: "train", "val", or "test" (affects which image column prefix is used)
        feature: one of the FEATURES (e.g., "JSN_Lateral")
        transform: torchvision transforms to apply.
        """
        self.data = pd.read_csv(csv_file)
        self.feature = feature
        self.transform = transform
        self.samples = []  # list of tuples (image_path, label)

        # For each row, process both "earlier" and "later" images.
        for prefix in ["image_earlier", "image_later"]:
            # Determine the corresponding time prefix: "earlier" or "later"
            time_prefix = prefix.split("_")[1]
            for side in ["L", "R"]:
                label_col = f"{time_prefix}_{side}_{feature}"
                if label_col not in self.data.columns:
                    continue
                for idx, row in self.data.iterrows():
                    # Only use a datapoint if the label is not -1.
                    label_val = row[label_col]
                    try:
                        label_val = float(label_val)
                    except:
                        continue
                    if label_val == -1:
                        continue
                    
                    # If training on JSN_Medial, convert continuous value to discrete grade.
                    if feature in ["JSN_Medial"]:
                        label_val = encode_feature("JSN_Medial", label_val)
                    
                    # Get the original bilateral image path.
                    img_path = row[prefix]
                    # Create the cropped image path.
                    crop_path = get_crop_path(img_path, side)
                    if os.path.exists(crop_path):
                        self.samples.append((crop_path, label_val))

        # Remove duplicates (if the same cropped image appears more than once)
        self.samples = list({s[0]: s for s in self.samples}.values())

        # Map the unique label values to integer class indices.
        labels = [s[1] for s in self.samples]
        self.unique_labels = sorted(set(labels))
        self.label_to_idx = {val: i for i, val in enumerate(self.unique_labels)}
        # Update samples to have integer labels.
        self.samples = [(path, self.label_to_idx[label]) for (path, label) in self.samples]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        try:
            image = Image.open(img_path).convert("RGB")
        except Exception as e:
            print(f"Error loading {img_path}: {e}")
            image = Image.new("RGB", (224, 224))
        if self.transform:
            image = self.transform(image)
        return image, label


# -----------------------------
# New Dataset class that creates datapoints based on treatment information.
# -----------------------------
class KneePairsTreatmentDataset(Dataset):
    def __init__(self, csv_file, image_transform=None, prompt_prefix=""):
        """
        csv_file: Path to CSV (train, val, or test).
        image_transform: torchvision transforms to apply.
        prompt_prefix: Optional text prefix to include in the prompt.

        For each row, this dataset checks for valid treatments on the left side (columns starting with "L_")
        and the right side (columns starting with "R_") among TREATMENTS_OF_INTEREST. If a side has any valid
        treatment (i.e. its value is not -1), a datapoint is created using the corresponding cropped images.
        """
        self.df = pd.read_csv(csv_file)
        self.image_transform = image_transform
        self.prompt_prefix = prompt_prefix
        self.samples = []  # List of dictionaries, one per datapoint.

        for idx, row in self.df.iterrows():
            # Calculate month difference.
            month_diff = row["month_later"] - row["month_earlier"]

            # For left side treatments.
            left_treatments = []
            for col, treatment_str in KNEE_TREATMENTS_OF_INTEREST.items():
                if col.startswith("L_"):
                    val = row.get(col, -1)
                    try:
                        val = float(val)
                    except:
                        continue
                    if val != -1:
                        left_treatments.append(treatment_str)

            # For right side treatments.
            right_treatments = []
            for col, treatment_str in KNEE_TREATMENTS_OF_INTEREST.items():
                if col.startswith("R_"):
                    val = row.get(col, -1)
                    try:
                        val = float(val)
                    except:
                        continue
                    if val != -1:
                        right_treatments.append(treatment_str)

            # Create datapoint for left side if treatments exist.
            if left_treatments:
                prompt = f"{self.prompt_prefix}A knee X-ray {month_diff} months after " + ", ".join(left_treatments)
                input_img_path = get_crop_path(row["image_earlier"], "L")
                target_img_path = get_crop_path(row["image_later"], "L")
                if os.path.exists(input_img_path) and os.path.exists(target_img_path):
                    self.samples.append({
                        "input_image_path": input_img_path,
                        "target_image_path": target_img_path,
                        "prompt": prompt
                    })
            # Create datapoint for right side if treatments exist.
            if right_treatments:
                prompt = f"{self.prompt_prefix}A knee X-ray {month_diff} months after " + ", ".join(right_treatments)
                input_img_path = get_crop_path(row["image_earlier"], "R")
                target_img_path = get_crop_path(row["image_later"], "R")
                if os.path.exists(input_img_path) and os.path.exists(target_img_path):
                    self.samples.append({
                        "input_image_path": input_img_path,
                        "target_image_path": target_img_path,
                        "prompt": prompt
                    })

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        try:
            input_image = Image.open(sample["input_image_path"]).convert("RGB")
        except Exception as e:
            print(f"Error loading {sample['input_image_path']}: {e}")
            input_image = Image.new("RGB", (256, 256))
        try:
            target_image = Image.open(sample["target_image_path"]).convert("RGB")
        except Exception as e:
            print(f"Error loading {sample['target_image_path']}: {e}")
            target_image = Image.new("RGB", (256, 256))

        if self.image_transform:
            input_image = self.image_transform(input_image)
            target_image = self.image_transform(target_image)

        return {
            "input_image": input_image,
            "target_image": target_image,
            "prompt": sample["prompt"]
        }


class XRayPairsDataset(Dataset):
    def __init__(self, csv_file, image_transform=None, prompt_prefix=""):
        self.df = pd.read_csv(csv_file)
        self.image_transform = image_transform
        self.prompt_prefix = prompt_prefix

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        input_image = Image.open(row["image_earlier"]).convert("RGB")
        target_image = Image.open(row["image_later"]).convert("RGB")

        treatment_str = ", ".join([
            TREATMENTS_OF_INTEREST[t] for t in TREATMENTS_OF_INTEREST if row[t] != -1
        ]) or "No Treatment"
        month_diff = row["month_later"] - row["month_earlier"]
        prompt = f"{self.prompt_prefix}A knee X-ray {month_diff} months after {treatment_str}"

        if self.image_transform:
            input_image = self.image_transform(input_image)
            target_image = self.image_transform(target_image)

        return {
            "input_image": input_image,
            "target_image": target_image,
            "prompt": prompt
        }

class TimeAgnosticPropensityDataset(Dataset):
    def __init__(self, csv_path, image_transform=None):
        """
        For each CSV row (using "image_earlier" as the earlier image), this dataset:
          - Builds a feature vector by concatenating:
              • Left and right X_RAY_GRADES from earlier scans,
              • CLINICAL_INFO from the earlier scan,
              • DEMOGRAPHIC_INFO.
          - Creates a multi-hot label vector of length NUM_TREATMENTS.
            Now, the label is built by combining:
              • The treatments available in this pairs CSV (using KNEE_TREATMENTS_OF_INTEREST), and
              • Every single treatment the patient receives as recorded in ./data/dataset.csv.
            If no valid treatment is found from either source, then "No Treatment" is set to 1.
        """
        self.df = pd.read_csv(csv_path)
        self.image_transform = image_transform
        self.samples = []

        # Load the full dataset CSV that contains every treatment over all timepoints.
        full_df = pd.read_csv("./data/dataset.csv")
        full_df.replace(['-1', '.', '-1.0', -1, '.: Missing Form/Incomplete Workbook'], np.nan, inplace=True)
        full_df.replace(['0: Never'], 0, inplace=True)
        full_df.replace(['1: Less than 1 hour', '1: Seldom (1-2 days)'], 1, inplace=True)
        full_df.replace(['2: 1 hour but less than 2 hours', '2: Sometimes (3-4 days)'], 2, inplace=True)
        full_df.replace(['3: 2-4 hours', '3: Often (5-7 days)'], 3, inplace=True)
        full_df.replace(['4: More than 4 hours'], 4, inplace=True)
        # Create a dictionary mapping subject id (as a string) to the full record (a pandas Series)
        full_dict = {
            str(row["src_subject_id"]): row for _, row in full_df.iterrows()
            if "src_subject_id" in row and not pd.isna(row["src_subject_id"])
        }

        for idx, row in self.df.iterrows():
            img_earlier = row.get("image_earlier", None)
            if not isinstance(img_earlier, str) or img_earlier == '-1' or not os.path.exists(img_earlier):
                continue

            # Build the feature vector using your existing helper.
            features = build_earlier_feature_vector(row)
            features_tensor = torch.tensor(features, dtype=torch.float)

            # Build multi-hot treatment labels using the pairs CSV first.
            label_vector = [0] * NUM_TREATMENTS
            found = False
            for key, treat_name in KNEE_TREATMENTS_OF_INTEREST.items():
                if key in row: # Check if column exists in pairs CSV
                    try:
                        val = float(row[key])
                    except Exception:
                        val = -1
                    if val != -1:
                        idx_treat = TREATMENT_LABELS[treat_name]
                        label_vector[idx_treat] = 1
                        found = True

            # Now, update the label_vector by also checking the full dataset record.
            # We assume that the current CSV row contains a "src_subject_id" column.
            subject_id = row.get("src_subject_id", None)
            if subject_id is not None and str(subject_id) in full_dict:
                full_row = full_dict[str(subject_id)]
                # Now include *all* treatments ever received, using the code keys.
                for code, treat_name in KNEE_TREATMENTS_OF_INTEREST.items():
                    idx_treat = TREATMENT_LABELS[treat_name]
                    # search any timepoint column like "12_L_KneeReplacement"
                    for col in full_row.index:
                        # match "<number>_<code>" at start of column
                        if re.match(rf'^\d+_{code}(_|$)', col):
                            try:
                                val = float(full_row[col])
                            except Exception:
                                val = -1
                            if val > 0:
                                label_vector[idx_treat] = 1
                                break

            # If no treatment was flagged from either source, mark as "No Treatment".
            if not any(label_vector):
                label_vector[TREATMENT_LABELS["No Treatment"]] = 1

            label_tensor = torch.tensor(label_vector, dtype=torch.float)

            self.samples.append((img_earlier, features_tensor, label_tensor))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, features, label = self.samples[idx]
        image = Image.open(img_path).convert("RGB")
        if self.image_transform:
            image = self.image_transform(image)
        return image, features, label


class KneeFeatureConditioningDataset(Dataset):
    def __init__(self, csv_file, image_transform=None, prompt_prefix=""):
        """
        Reads data from csv_file (e.g., "data/pairs_dataset/train.csv").
        For each CSV row and each knee side ("L" and "R") where at least one later feature (from PERCEPTUAL_FEATURES)
        is valid (i.e. not -1), a datapoint is created.

        Each datapoint returns:
          - "input_image": Cropped image from image_earlier for that side.
          - "target_image": Cropped image from image_later for that side.
          - "prompt": A treatment prompt based on treatment columns.
          - "earlier_features": A tensor of shape [FEATURE_DIM] built from the earlier scan.
          - "later_features": A tensor of shape [len(PERCEPTUAL_FEATURES)] built from later_x_ray_grades for that side.
          - "treatment_label": An integer label (0 to 12) computed based on the valid treatments.
        """
        self.df = pd.read_csv(csv_file)
        self.image_transform = image_transform
        self.prompt_prefix = prompt_prefix
        self.samples = []
        df_columns = self.df.columns

        for idx, row in self.df.iterrows():
            month_diff = row["month_later"] - row["month_earlier"]
            for side in ["L", "R"]:
                # Build later features from the perceptual subset using the new column naming.
                later_feats = build_later_feature_vector(row, side, features_list=PERCEPTUAL_FEATURES)
                # Require at least one later feature to be valid.
                if all(v == -1 for v in later_feats):
                    continue
                # Build the full earlier feature vector (concatenated bilateral X_RAY_GRADES, clinical info, demographic info).
                earlier_feats = build_earlier_feature_vector(row)
                earlier_tensor = torch.tensor(earlier_feats, dtype=torch.float)
                later_tensor = torch.tensor(later_feats, dtype=torch.float)

                # Build the treatment prompt and multi-label treatment vector.
                treatment_list = []
                for col, treat_name in KNEE_TREATMENTS_OF_INTEREST.items():
                    if col.startswith(side + "_"):
                        try:
                            val = float(row.get(col, -1))
                        except Exception:
                            continue
                        if val != -1:
                            treatment_list.append(treat_name)
                treatment_str = ", ".join(treatment_list) if treatment_list else "No Treatment"
                prompt = f"{self.prompt_prefix}A knee X-ray {month_diff} months after {treatment_str}."
                # Build multi-label binary treatment vector
                label_vec = [0] * NUM_TREATMENTS
                for treat in treatment_list:
                    idx_label = TREATMENT_LABELS.get(treat, None)
                    if idx_label is not None:
                        label_vec[idx_label] = 1
                # If no treatments, mark 'No Treatment'
                if not any(label_vec):
                    label_vec[TREATMENT_LABELS["No Treatment"]] = 1
                label_tensor = torch.tensor(label_vec, dtype=torch.float)

                # Get cropped image paths.
                input_img_path = get_crop_path(row["image_earlier"], side)
                target_img_path = get_crop_path(row["image_later"], side)
                if not (os.path.exists(input_img_path) and os.path.exists(target_img_path)):
                    continue

                self.samples.append({
                    "input_image_path": input_img_path,
                    "target_image_path": target_img_path,
                    "prompt": prompt,
                    "earlier_features": earlier_tensor,
                    "later_features": later_tensor,
                    "treatment_label": label_tensor
                })

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        try:
            input_image = Image.open(sample["input_image_path"]).convert("RGB")
        except Exception as e:
            print(f"Error loading {sample['input_image_path']}: {e}")
            input_image = Image.new("RGB", (224, 224))
        try:
            target_image = Image.open(sample["target_image_path"]).convert("RGB")
        except Exception as e:
            print(f"Error loading {sample['target_image_path']}: {e}")
            target_image = Image.new("RGB", (224, 224))
        if self.image_transform:
            input_image = self.image_transform(input_image)
            target_image = self.image_transform(target_image)
        return {
            "input_image": input_image,
            "target_image": target_image,
            "prompt": sample["prompt"],
            "earlier_features": sample["earlier_features"],
            "later_features": sample["later_features"],
            "treatment_label": sample["treatment_label"]
        }

class TemporalKneeFeatureConditioningDataset(Dataset):
    """
    Dataset including history sequences and delta_t.
    Pre-computes samples in __init__.
    """
    def __init__(self, csv_file, image_transform=T.Compose([T.Resize((224, 224)), T.ToTensor()]), include_images=False):
        """
        Args:
            csv_file (str): Path to the pairs CSV file.
            image_transform (callable, optional): A complete torchvision transform
                to apply to images (input, target, and sequence if include_images=True).
                Should include ToTensor and Resize if needed. Defaults to basic ToTensor+Resize.
            include_images (bool, optional): Whether to load and return historical image sequences. Defaults to False.
        """
        self.pairs_df = pd.read_csv(csv_file)
        self.include_images = include_images
        self.image_transform = image_transform
        self.samples = self._create_samples()

    def _build_covariates(self, row, t):
        # --- Build Covariate Vector 
        feats = []
        for feat in X_RAY_GRADES: # Requires X_RAY_GRADES list
            for side in ("L", "R"): raw = row.get(f"{t}_{feat}_{side}", -1); feats.append(encode_feature(feat, raw))
        for feat in CLINICAL_INFO: # Requires CLINICAL_INFO list
            raw = row.get(f"{t}_{feat}", row.get(feat, -1)); feats.append(float(raw) if pd.notna(raw) else -1.0)
        for feat in DEMOGRAPHIC_INFO: # Requires DEMOGRAPHIC_INFO list
            raw = row.get(feat, -1); val = -1.0
            if pd.notna(raw):
                if feat.lower() == "sex": r = str(raw).strip().lower(); val = 0.0 if r in ("m", "male") else 1.0 if r in ("f", "female") else -1.0
                else: 
                    try: 
                        val = float(raw)
                    except (ValueError, TypeError): 
                        val = -1.0
            feats.append(val)
        assert len(feats) == FEATURE_DIM
        cov = torch.tensor(feats, dtype=torch.float)
        return torch.nan_to_num(cov, nan=-1.0, posinf=0.0, neginf=0.0)

    def _build_treatment_vector(self, row, t):
        # --- Build Cumulative Treatment Vector (Keep helper method) ---
        vec = [0] * NUM_TREATMENTS
        for tp in _TIMEPOINT_PREFIXES:
            if tp > t: break
            for code, name in KNEE_TREATMENTS_OF_INTEREST.items(): 
                 if name in TREATMENT_LABELS:
                     raw = row.get(f"{tp}_{code}", -1)
                     try: val = float(raw)
                     except Exception: val = -1
                     if val > 0: vec[TREATMENT_LABELS[name]] = 1
        if "No Treatment" in TREATMENT_LABELS:
            no_treat_idx = TREATMENT_LABELS["No Treatment"]
            # Set No Treatment only if no OTHER treatments are active
            other_treatments_active = any(vec[i] for i in range(NUM_TREATMENTS) if i != no_treat_idx)
            vec[no_treat_idx] = 0.0 if other_treatments_active else 1.0
        return torch.tensor(vec, dtype=torch.float)

    def _get_history_sequences(self, full_row, tps_history):
        """Builds covariate and treatment history sequences."""
        cov_seq_full, trt_seq_full = [], []
        for t in tps_history:
            cov_seq_full.append(self._build_covariates(full_row, t))
            trt_seq_full.append(self._build_treatment_vector(full_row, t))
        if not cov_seq_full: # Should not happen if tps_history is valid
            return None, None
        return torch.stack(cov_seq_full), torch.stack(trt_seq_full)

    def _get_interval_labels_and_delta_t(self, full_row, earlier_tp, later_tp):
        """Calculates interval labels and delta_t."""
        interval_tps = [t for t in _TIMEPOINT_PREFIXES if earlier_tp < t <= later_tp]
        interval_label = [0] * NUM_TREATMENTS
        for t in interval_tps:
            for code, name in KNEE_TREATMENTS_OF_INTEREST.items():
                try: val = float(full_row.get(f"{t}_{code}", -1))
                except (ValueError, TypeError): val = -1
                if val > 0 and name in TREATMENT_LABELS:
                    interval_label[TREATMENT_LABELS[name]] = 1

        no_treat_idx = TREATMENT_LABELS.get("No Treatment")
        other_treatments_active = any(interval_label[i] for i in range(NUM_TREATMENTS) if i != no_treat_idx)
        interval_label[no_treat_idx] = 0 if other_treatments_active else 1
        interval_label_tensor = torch.tensor(interval_label, dtype=torch.float)
        delta_t = float(later_tp - earlier_tp)
        return interval_label_tensor, delta_t

    def _get_image_sequence(self, full_row, tps_history, side):
        """Loads and transforms historical image sequence (if include_images is True)."""
        if not self.include_images:
            return None

        img_seq_side = []
        img_transform = self.image_transform

        for t in tps_history:
            orig_img_path = full_row.get(f"{t}_BILATERAL_X_RAY", None)
            cropped_img_path = get_crop_path(orig_img_path, side) if isinstance(orig_img_path, str) else None # Requires get_crop_path helper
            try:
                if cropped_img_path and os.path.exists(cropped_img_path):
                    img = Image.open(cropped_img_path).convert("RGB")
                    img = img_transform(img) # Apply the single transform
                else:
                    # Create placeholder consistent with transform output (usually tensor)
                    img = torch.zeros(3, 224, 224) # Assumes 3 channels, 224x224 post-transform
            except Exception as e:
                print(f"Error loading/transforming hist img {cropped_img_path}: {e}. Placeholder.")
                img = torch.zeros(3, 224, 224)
            img_seq_side.append(img)

        return torch.stack(img_seq_side) if img_seq_side else None

    def _create_samples(self):
        """Iterates through pairs dataframe and pre-computes sample data."""
        samples_list = []
        print(f"Processing {len(self.pairs_df)} rows from input CSV...")
        for idx, row in tqdm(self.pairs_df.iterrows(), total=len(self.pairs_df), desc="Generating Samples"):
            subj = str(row.get("src_subject_id"))
            earlier_tp = int(row["month_earlier"])
            later_tp = int(row["month_later"])
            full_row = _FULL_DATA_DICT.get(subj) 
            if full_row is None: 
                print(f"No full row for {subj}")
                continue

            tps_history = [t for t in _TIMEPOINT_PREFIXES if t <= earlier_tp] # Requires _TIMEPOINT_PREFIXES
            if not tps_history: 
                print(f"No tps_history for {subj}")
                continue

            # Get history sequences
            cov_seq_tensor, trt_seq_tensor = self._get_history_sequences(full_row, tps_history)
            if cov_seq_tensor is None:
                print(f"No cov_seq_tensor for {subj}")
                continue
            earlier_features = cov_seq_tensor[-1]

            # Get interval labels and delta_t
            interval_label_tensor, delta_t = self._get_interval_labels_and_delta_t(full_row, earlier_tp, later_tp)

            # Generate prompt
            month_diff = later_tp - earlier_tp
            interval_treatment_names = [name for name, idx in TREATMENT_LABELS.items() if interval_label_tensor[idx] == 1 and name != "No Treatment"]
            prompt_treatment_str = ", ".join(interval_treatment_names) if interval_treatment_names else "No Treatment"
            prompt = f"A knee X-ray {month_diff} months after {prompt_treatment_str}."

            # Process sides
            for side in ["L", "R"]:
                later_feats_vec = build_later_feature_vector(row, side, features_list=PERCEPTUAL_FEATURES) 
                if later_feats_vec is None or all(v == -1 for v in later_feats_vec): 
                    # print(f"No later_feats_vec for {subj}") # happens a lot
                    continue
                later_tensor = torch.tensor(later_feats_vec, dtype=torch.float)

                input_img_path = get_crop_path(row["image_earlier"], side) 
                target_img_path = get_crop_path(row["image_later"], side)

                image_seq_tensor = None 
                if self.include_images:
                    image_seq_tensor = self._get_image_sequence(full_row, tps_history, side)
                    if image_seq_tensor is None:
                        print(f"No image_seq_tensor for {subj}")
                        continue

                samples_list.append({
                    "input_image_path": input_img_path, "target_image_path": target_img_path,
                    "prompt": prompt, "later_features": later_tensor,
                    "interval_labels": interval_label_tensor, "delta_t": delta_t,
                    "earlier_features": earlier_features, "cov_seq": cov_seq_tensor,
                    "trt_seq": trt_seq_tensor, "image_seq": image_seq_tensor,
                    "subject_id": subj, "side": side, "earlier_tp": earlier_tp, "later_tp": later_tp
                })
        print(f"Generated {len(samples_list)} samples.")
        return samples_list

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        """Loads images on demand for the pre-computed sample."""
        sample = self.samples[idx]

        # Determine transform to use for input/target images
        img_transform = self.image_transform
        input_image = Image.open(sample["input_image_path"]).convert("RGB")
        input_image = img_transform(input_image)

        target_image = Image.open(sample["target_image_path"]).convert("RGB")
        target_image = img_transform(target_image)
        return {
            "input_image": input_image,
            "target_image": target_image,
            "prompt": sample["prompt"],
            "earlier_features": sample["earlier_features"],
            "later_features": sample["later_features"],
            "cov_seq": sample["cov_seq"],
            "trt_seq": sample["trt_seq"],
            "interval_labels": sample["interval_labels"],
            "delta_t": sample["delta_t"],
            "subject_id": sample["subject_id"],
            "side": sample["side"],
            "earlier_tp": sample["earlier_tp"],
            "later_tp": sample["later_tp"],
            "image_seq": sample.get("image_seq", None) 
        }
