# pretrain_propensity_model_temporal.py (v4 - using delta_t)

import os
import argparse
from functools import partial
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from torch.utils.data import DataLoader, WeightedRandomSampler
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
from sklearn.metrics import roc_auc_score
import timm
import torchvision.transforms as T


from data.pairs_dataset.dataset import TemporalKneeFeatureConditioningDataset
# Import other constants needed
from data.pairs_dataset.dataset import (
        KNEE_TREATMENTS_OF_INTEREST, TREATMENT_LABELS, NUM_TREATMENTS, FEATURE_DIM
)
from PIL import Image

TREATMENT_NAMES = list(TREATMENT_LABELS.keys())

# --- Helper functions (compute_distribution, log_distribution - same as before) ---
def compute_distribution(dataset, label_key="interval_labels"):
    counter = {i: 0 for i in range(NUM_TREATMENTS)} ; total_samples = 0
    if not hasattr(dataset, 'samples'): return counter, {i: 0.0 for i in range(NUM_TREATMENTS)}
    for sample in dataset.samples:
        if label_key in sample:
            label_vector_np = sample[label_key].cpu().numpy()
            for i in range(NUM_TREATMENTS):
                if label_vector_np[i] == 1: counter[i] += 1
            total_samples += 1
    if total_samples == 0: return counter, {i: 0.0 for i in range(NUM_TREATMENTS)}
    prevalence = {cls: count / total_samples for cls, count in counter.items()}
    total_dataset_size = len(dataset)
    print(f"Dist based on {total_samples} samples w/ '{label_key}' out of {total_dataset_size}.")
    return counter, prevalence

def log_distribution(log_file, split_name, dataset, label_key="interval_labels"):
    counter, prevalence = compute_distribution(dataset, label_key)
    with open(log_file, "a") as f:
        f.write(f"\nDist for {split_name} set ('{label_key}'):\n")
        num_samples_with_key = len([s for s in dataset.samples if label_key in s])
        f.write(f"  Samples analyzed: {num_samples_with_key}\n")
        if num_samples_with_key > 0:
            for i in range(NUM_TREATMENTS): name = TREATMENT_NAMES[i] if i<len(TREATMENT_NAMES) else f"Unk_{i}"; cnt=counter.get(i,0); prev=prevalence.get(i,0.0); f.write(f"  Cls {i} ({name}): {cnt}, prev {prev:.4f}\n")
        else: f.write("  No samples.\n")

def get_normalization_params(dataset, key="delta_t"):
    """Calculate mean/std dev for a scalar key (e.g., delta_t)."""
    values = []
    if not hasattr(dataset, 'samples'): return 0.0, 1.0
    for sample in dataset.samples:
        if key in sample:
            try: values.append(float(sample[key]))
            except (ValueError, TypeError): continue
    if not values: return 0.0, 1.0
    values_tensor = torch.tensor(values, dtype=torch.float)
    mean = values_tensor.mean().item(); std = values_tensor.std().item()
    if std < 1e-6: std = 1.0
    return mean, std

def collate_fn_temporal_conditioning(batch, include_images=False):
    """ Collates data, including delta_t (float, shape [B, 1]). """
    covs_hist = [b["cov_seq"] for b in batch]
    trts_hist = [b["trt_seq"] for b in batch]
    lengths_hist = torch.tensor([seq.size(0) for seq in covs_hist], dtype=torch.long)
    covs_hist_padded = pad_sequence(covs_hist, batch_first=True, padding_value=0.0)
    trts_hist_padded = pad_sequence(trts_hist, batch_first=True, padding_value=0.0)
    interval_labels = torch.stack([b["interval_labels"] for b in batch])

    # << Get delta_t as float [B, 1] >>
    delta_t = torch.tensor([b["delta_t"] for b in batch], dtype=torch.float).unsqueeze(1)

    # << Convert side list into numeric tensor [B,1]: 0=Left, 1=Right >>
    side_tensor = torch.tensor([1.0 if s.upper().startswith("R") else 0.0 for s in [b["side"] for b in batch]], dtype=torch.float).unsqueeze(1)

    images_hist_padded = None # Image handling remains same
    if include_images:
        if all("image_seq" in b and b["image_seq"] is not None for b in batch):
            img_seqs_hist=[b["image_seq"] for b in batch]; images_hist_padded=pad_sequence(img_seqs_hist, batch_first=True, padding_value=0.0)
        else: images_hist_padded = None

    return {
        "cov_seq": covs_hist_padded, "trt_seq": trts_hist_padded, "lengths": lengths_hist,
        "image_seq": images_hist_padded, "delta_t": delta_t, # << Include delta_t
        "interval_labels": interval_labels,
        "subject_id": [b["subject_id"] for b in batch], "side": side_tensor,
        "earlier_tp": [b["earlier_tp"] for b in batch], "later_tp": [b["later_tp"] for b in batch]
    }
class TemporalPropensityModel(nn.Module):
    def __init__(self, cov_dim, hidden_dim, num_layers, dropout,
                 model_type="rnn", include_images=False,
                 image_model_name="efficientnet_b0", img_feat_dim=128,
                 delta_t_feat_dim=16): # << Dimension for processed delta_t feature
        super().__init__()
        self.include_images = include_images; self.model_type = model_type
        self.img_feat_dim = img_feat_dim; self.cov_dim = cov_dim
        self.num_treatments = NUM_TREATMENTS
        self.delta_t_feat_dim = delta_t_feat_dim # << NEW

        # Sequence Input Dim Calc (same as before)
        seq_input_dim = self.cov_dim + self.num_treatments
        if self.include_images:
            print(f"Loading img encoder: {image_model_name}"); img_model_full = timm.create_model(image_model_name, pretrained=True, num_classes=0); self.img_encoder = img_model_full
            try: encoder_feat_dim = self.img_encoder.feature_info.info[-1]['num_chs']
            except AttributeError:
                 try: encoder_feat_dim = self.img_encoder.num_features
                 except AttributeError:
                     if hasattr(self.img_encoder,'classifier'): encoder_feat_dim=self.img_encoder.classifier.in_features
                     elif hasattr(self.img_encoder,'fc'): encoder_feat_dim=self.img_encoder.fc.in_features
                     else: raise ValueError(f"Cannot get feat dim for {image_model_name}")
            self.fc_img = nn.Linear(encoder_feat_dim, self.img_feat_dim); seq_input_dim += self.img_feat_dim
        print(f"Seq input dim: {seq_input_dim}")

        # Sequence Encoder (same as before)
        if model_type.lower() == "rnn":
            print("Using RNN (LSTM)"); self.encoder = nn.LSTM(seq_input_dim, hidden_dim, num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0); encoder_output_dim = hidden_dim
        elif model_type.lower() == "transformer":
            print("Using Transformer"); nhead=3; # Check dim % nhead
            if seq_input_dim % nhead != 0: print(f"Warn: seq_dim {seq_input_dim} not div by nhead {nhead}.")
            encoder_layer = nn.TransformerEncoderLayer(d_model=seq_input_dim, nhead=nhead, dim_feedforward=hidden_dim, dropout=dropout, batch_first=True, activation='relu'); self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers); encoder_output_dim = seq_input_dim
        else: raise ValueError(f"Unsupported type: {model_type}")
        print(f"Seq encoder output dim: {encoder_output_dim}")

        # --- Network to process delta_t << REVERTED/MODIFIED >> ---
        self.delta_t_processor = nn.Sequential(
            nn.Linear(1, self.delta_t_feat_dim), # Input dim 1 for delta_t
            nn.ReLU()
        )
        print(f"Delta_t feature dim: {self.delta_t_feat_dim}")

        # --- Process side feature (0=Left, 1=Right)
        self.side_feat_dim = 4
        self.side_processor = nn.Sequential(
            nn.Linear(1, self.side_feat_dim),
            nn.ReLU()
        )
        print(f"Side feature dim: {self.side_feat_dim}")

        # --- Output Head (Input dim adjusted) 
        final_head_input_dim = encoder_output_dim + self.delta_t_feat_dim + self.side_feat_dim
        self.treatment_head = nn.Linear(final_head_input_dim, self.num_treatments)
        print(f"Final head input dim: {final_head_input_dim}")
        print(f"Output dim: {self.num_treatments}")

    def forward(self, cov_seq, trt_seq, lengths, delta_t, image_seq=None, side=None):
        """ Forward pass processing normalized delta_t and side. """
        # History Processing 
        combined_seq = torch.cat([cov_seq, trt_seq], dim=2)
        if self.include_images and image_seq is not None:
            B, T, C, H, W = image_seq.shape; mask = torch.arange(T, device=image_seq.device)[None, :] < lengths[:, None]; imgs_flat = image_seq[mask]
            if imgs_flat.numel() > 0:
                 with torch.no_grad() if not self.img_encoder.training else torch.enable_grad(): img_feats_flat = self.img_encoder(imgs_flat)
                 img_feats_proj = F.relu(self.fc_img(img_feats_flat)); img_feats = torch.zeros(B, T, self.img_feat_dim, device=combined_seq.device, dtype=img_feats_proj.dtype); img_feats[mask] = img_feats_proj
                 combined_seq = torch.cat([combined_seq, img_feats], dim=2)

        # History Encoding
        if self.model_type.lower() == "rnn":
            packed_input = pack_padded_sequence(combined_seq, lengths.cpu(), batch_first=True, enforce_sorted=False)
            try: 
                self.encoder.flatten_parameters() 
            except AttributeError: 
                pass
            packed_output, (h_n, c_n) = self.encoder(packed_input)
            encoder_output = h_n[-1]
        elif self.model_type.lower() == "transformer":
            max_len = combined_seq.size(1); src_key_padding_mask = torch.arange(max_len, device=combined_seq.device)[None, :] >= lengths.to(combined_seq.device)[:, None]; transformer_output = self.encoder(combined_seq, src_key_padding_mask=src_key_padding_mask); last_step_indices = (lengths - 1).clamp(min=0).to(combined_seq.device); encoder_output = transformer_output[torch.arange(combined_seq.size(0), device=combined_seq.device), last_step_indices]

        # --- Process delta_t feature ---
        # Assumes delta_t input is already normalized, shape [B, 1]
        delta_t_features = self.delta_t_processor(delta_t) # (B, delta_t_feat_dim)

        # --- Process side feature ---
        side_features = self.side_processor(side)  # (B, side_feat_dim)

        # --- Combine features and predict ---
        combined_final_features = torch.cat([encoder_output, delta_t_features, side_features], dim=1)
        pred_trts_logits = self.treatment_head(combined_final_features)

        return pred_trts_logits

# --- EVALUATION FUNCTION (Using delta_t) ---
def evaluate(model, loader, device, criterion_trt, delta_t_mean, delta_t_std): # << Pass delta_t norm params
    """ Evaluates the model using delta_t feature. """
    model.eval(); total_loss_trt = 0.0; all_preds_trt, all_labels_trt, all_probs_trt = [], [], []
    with torch.no_grad():
        for batch in tqdm(loader, desc="Evaluating"):
            cov_seq=batch["cov_seq"].to(device); trt_seq=batch["trt_seq"].to(device)
            lengths=batch["lengths"]; delta_t=batch["delta_t"].to(device) # << Get delta_t
            image_seq=batch["image_seq"].to(device) if model.include_images and batch["image_seq"] is not None else None
            target_trts=batch["interval_labels"].to(device)
            side = batch["side"].to(device)

            # Normalize delta_t << NEW >>
            delta_t_norm = (delta_t - delta_t_mean) / delta_t_std

            # Forward pass << MODIFIED >>
            pred_trts_logits = model(cov_seq, trt_seq, lengths, delta_t_norm, image_seq, side)

            loss_trt = criterion_trt(pred_trts_logits, target_trts); total_loss_trt += loss_trt.item() * cov_seq.size(0)
            probs=torch.sigmoid(pred_trts_logits).cpu().numpy(); preds=(probs >= 0.5).astype(float)
            all_probs_trt.append(probs); all_preds_trt.append(preds); all_labels_trt.append(target_trts.cpu().numpy())

    # Metric Calculation (same as before)
    num_samples = len(loader.dataset) if hasattr(loader,'dataset') else sum(len(b['cov_seq']) for b in loader)
    if num_samples == 0: return {"loss":0.0,"loss_trt":0.0,"accuracy":0.0,"auc":0.0,"per_class_accuracy":{},"per_class_recall":{}}
    avg_loss_trt = total_loss_trt / num_samples
    all_preds_trt=np.concatenate(all_preds_trt,axis=0); all_labels_trt=np.concatenate(all_labels_trt,axis=0); all_probs_trt=np.concatenate(all_probs_trt,axis=0)
    preds_flat = all_preds_trt.flatten()
    labels_flat = all_labels_trt.flatten()
    total_tp = ((preds_flat == 1) & (labels_flat == 1)).sum()
    total_fn = ((preds_flat == 0) & (labels_flat == 1)).sum()
    total_recall = float(total_tp / (total_tp + total_fn)) if (total_tp + total_fn) > 0 else 0.0
    acc = (all_preds_trt == all_labels_trt).mean(); per_class_acc = (all_preds_trt == all_labels_trt).mean(axis=0)
    per_class_recall={}; aucs=[]
    for i in range(NUM_TREATMENTS):
        tp_i=((all_preds_trt[:,i]==1)&(all_labels_trt[:,i]==1)).sum(); fn_i=((all_preds_trt[:,i]==0)&(all_labels_trt[:,i]==1)).sum()
        per_class_recall[i]=float(round(tp_i/(tp_i+fn_i),3)) if (tp_i+fn_i)>0 else 0.0 # Keep per-class
        y_true=all_labels_trt[:,i]; y_score=all_probs_trt[:,i]
        if len(np.unique(y_true))>1:
            try:
                auc=roc_auc_score(y_true,y_score)
                aucs.append(auc)
            except ValueError:
                pass # Keep handling specific class AUC errors

    macro_auc=float(round(np.mean(aucs),3)) if aucs else 0.0

    metrics={
        "loss":avg_loss_trt,
        "loss_trt":avg_loss_trt,
        "accuracy":float(round(acc,3)), # Note: accuracy calculated this way might differ slightly from total_recall if TNs are counted differently, but conceptually similar for overall performance.
        "total_recall": float(round(total_recall, 3)), # << ADDED
        "auc":macro_auc, # Macro AUC
        "per_class_accuracy":{TREATMENT_NAMES[i]:float(round(acc_i,3)) for i,acc_i in enumerate(per_class_acc)},
        "per_class_recall":{TREATMENT_NAMES[i]:rec for i,rec in per_class_recall.items()} # Keep per-class
    }
    return metrics

# --- MAIN TRAINING FUNCTION ---
def main(args):
    os.makedirs(os.path.dirname(args.log_file), exist_ok=True)
    with open(args.log_file, "w") as f: f.write(f"Temporal Propensity Model (Delta_t)\nStart: {pd.Timestamp.now()}\nArgs: {vars(args)}\n")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu"); print(f"Device: {device}")

    image_transform = T.Compose([
            T.Resize((224, 224)),
            T.ToTensor()
        ])
    # Dataset Loading
    train_csv=os.path.join(args.pairs_dir,"train.csv"); val_csv=os.path.join(args.pairs_dir,"val.csv")
    print("Loading Train DS..."); train_ds = TemporalKneeFeatureConditioningDataset(csv_file=train_csv, image_transform=image_transform, include_images=args.include_images); print(f"Train size: {len(train_ds)}")
    print("Loading Val DS..."); val_ds = TemporalKneeFeatureConditioningDataset(csv_file=val_csv, image_transform=image_transform, include_images=args.include_images); print(f"Val size: {len(val_ds)}")

    log_distribution(args.log_file, "Train", train_ds); log_distribution(args.log_file, "Validation", val_ds)

    # --- Calc Delta_t Normalization Params ---
    print("Calculating delta_t norm params..."); delta_t_mean, delta_t_std = get_normalization_params(train_ds, key="delta_t")
    print(f"  Mean={delta_t_mean:.4f}, Std={delta_t_std:.4f}")
    with open(args.log_file,"a") as f: 
        f.write(f"Delta_t Norm: Mean={delta_t_mean:.4f}, Std={delta_t_std:.4f}\n")

    sampler = None
    if args.use_sampler:
        print("Calculating sampler weights..."); no_t_idx=TREATMENT_NAMES.index("No Treatment"); weights=[]; n_down=0
        for s in train_ds.samples: l=s["interval_labels"]; is_nt=(l[no_t_idx]==1 and int(l.sum().item())==1); weights.append(args.no_treatment_down_weight if is_nt else 1.0); n_down+=is_nt
        if len(weights)==len(train_ds): print(f"Down-weighting {n_down} NT samples."); sampler=WeightedRandomSampler(weights,len(weights),True)
        else: print("Warn: Sampler failed.")

    # DataLoaders (uses updated collate_fn)
    train_loader = DataLoader(train_ds, batch_size=args.batch_size, sampler=sampler, shuffle=sampler is None, collate_fn=partial(collate_fn_temporal_conditioning, include_images=args.include_images), num_workers=args.num_workers, pin_memory=True, drop_last=True)
    val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, collate_fn=partial(collate_fn_temporal_conditioning, include_images=args.include_images), num_workers=args.num_workers, pin_memory=True)

    if len(train_ds)==0: raise ValueError("Train DS empty!")
    try: cov_dim = train_ds[0]["cov_seq"].size(-1)
    except(IndexError, KeyError): raise ValueError("Cannot get cov_dim.")
    print(f"Covariate dim: {cov_dim}")

    model = TemporalPropensityModel( # Pass delta_t_feat_dim
        cov_dim=cov_dim, hidden_dim=args.hidden_dim, num_layers=args.num_layers, dropout=args.dropout,
        model_type=args.model_type, include_images=args.include_images, image_model_name=args.image_model,
        img_feat_dim=args.img_feat_dim, delta_t_feat_dim=args.delta_t_feat_dim # << Use new arg
    ).to(device)

    # Loss (pos_weight calc is same)
    print("Calculating pos_weight..."); pos_counts=torch.zeros(NUM_TREATMENTS,device=device); n_train=0
    for s in train_ds.samples:
        if "interval_labels" in s: pos_counts+=s["interval_labels"].to(device); n_train+=1
    if n_train>0: 
        neg=n_train-pos_counts
        safe_pos=torch.clamp(pos_counts,min=1e-6); pos_w=neg/safe_pos
        pos_w=torch.clamp(pos_w,min=1.0,max=args.pos_weight_max_clamp)
        print(f"pos_w: {pos_w.cpu().numpy()}")
        with open(args.log_file,"a") as f: 
            f.write(f"pos_w: {pos_w.cpu().numpy().tolist()}\n")
    else: 
        print("Warn: Using default pos_w."); pos_w=torch.ones(NUM_TREATMENTS,device=device)
    criterion_trt = nn.BCEWithLogitsLoss(pos_weight=pos_w)

    # Optimizer & Scheduler (same)
    optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=args.scheduler_patience, verbose=True)

    # Training Loop << MODIFIED >>
    best_val_auc=-1.0; best_state=None; epochs_no_improve=0
    for epoch in range(1, args.epochs + 1):
        model.train(); epoch_loss = 0.0
        prog_bar = tqdm(train_loader, desc=f"Epoch {epoch} Train", leave=False)
        for batch in prog_bar:
            cov=batch["cov_seq"].to(device); trt=batch["trt_seq"].to(device); lens=batch["lengths"]
            delta_t=batch["delta_t"].to(device) # << Get delta_t
            targets=batch["interval_labels"].to(device)
            imgs=batch["image_seq"].to(device) if args.include_images and batch["image_seq"] is not None else None
            side = batch["side"].to(device)

            # Normalize delta_t 
            delta_t_norm = (delta_t - delta_t_mean) / delta_t_std

            optimizer.zero_grad()
            # Pass normalized delta_t and side 
            logits = model(cov, trt, lens, delta_t_norm, imgs, side)
            loss = criterion_trt(logits, targets)

            if torch.isnan(loss): print(f"Warn: NaN loss ep {epoch} batch {prog_bar.n}. Skip."); continue
            loss.backward()
            if args.grad_clip > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
            optimizer.step()
            epoch_loss += loss.item() * cov.size(0); prog_bar.set_postfix(loss=loss.item())

        n_samp = len(train_loader.sampler) if sampler else len(train_loader.dataset)
        avg_loss = epoch_loss / n_samp if n_samp > 0 else 0

        # Evaluation >>
        print(f"\n--- Evaluating Epoch {epoch} ---")
        train_metrics = evaluate(model, train_loader, device, criterion_trt, delta_t_mean, delta_t_std)
        val_metrics = evaluate(model, val_loader, device, criterion_trt, delta_t_mean, delta_t_std)

        # Logging, Saving, Early Stopping (same logic)
        log = ( f"Epoch {epoch}/{args.epochs}:\n"
                f"  Trn | Loss={train_metrics['loss_trt']:.4f}, Acc={train_metrics['accuracy']:.3f}, AUC={train_metrics['auc']:.3f}, TotRecall={train_metrics['total_recall']:.3f}\n" # Added TotRecall
                f"  Val | Loss={val_metrics['loss_trt']:.4f}, Acc={val_metrics['accuracy']:.3f}, AUC={val_metrics['auc']:.3f}, TotRecall={val_metrics['total_recall']:.3f}\n" # Added TotRecall
                f"  Val Per-Class Recall: { {k: f'{v:.3f}' for k, v in val_metrics['per_class_recall'].items()} }\n" ) # Formatted recall slightly
        print(log)
        with open(args.log_file, "a") as f: 
            f.write(log)
        cur_auc = val_metrics['auc']; scheduler.step(cur_auc)
        if cur_auc > best_val_auc:
            best_val_auc=cur_auc; best_state=model.state_dict(); epochs_no_improve=0; print(f"*** Best model (AUC: {best_val_auc:.3f}) ***")
            save_pth=os.path.join(args.checkpoint_dir,"propensity_model_temporal_best.pth")
            torch.save(best_state,save_pth)
            with open(args.log_file,"a") as f: 
                f.write(f"*** Saved best model ep {epoch} to {save_pth} ***\n")
        else: epochs_no_improve+=1; print(f"Val AUC no improve {epochs_no_improve} epochs.")
        if args.early_stopping_patience>0 and epochs_no_improve>=args.early_stopping_patience: 
            print("Stopping early.")
            with open(args.log_file,"a") as f: 
                f.write(f"Early stop ep {epoch}.\n")
                break

    # Save Final Model
    final_pth=os.path.join(args.checkpoint_dir,"propensity_model_temporal_final.pth"); torch.save(model.state_dict(),final_pth)
    print(f"Saved final model to {final_pth}")
    with open(args.log_file,"a") as f: 
        f.write(f"Saved final model: {final_pth}\nEnd: {pd.Timestamp.now()}\n")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train Temporal Propensity Model (using delta_t)")
    # Data Args
    parser.add_argument("--pairs_dir", type=str, default="data/pairs_dataset", help="Dir with train.csv/val.csv")
    # Model Args
    parser.add_argument("--model_type", type=str, choices=["rnn", "transformer"], default="transformer")
    parser.add_argument("--hidden_dim", type=int, default=128)
    parser.add_argument("--num_layers", type=int, default=2)
    parser.add_argument("--dropout", type=float, default=0.2)
    parser.add_argument("--include_images", action="store_true")
    parser.add_argument("--image_model", type=str, default="efficientformerv2_s0.snap_dist_in1k")
    parser.add_argument("--img_feat_dim", type=int, default=128)
    parser.add_argument("--delta_t_feat_dim", type=int, default=32, help="Dimension for processed delta_t feature")
    # Training Args
    parser.add_argument("--epochs", type=int, default=100)
    parser.add_argument("--batch_size", type=int, default=64) 
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--weight_decay", type=float, default=1e-2)
    parser.add_argument("--grad_clip", type=float, default=100.0)
    parser.add_argument("--pos_weight_max_clamp", type=float, default=500.0)
    parser.add_argument("--use_sampler", action="store_true")
    parser.add_argument("--no_treatment_down_weight", type=float, default=0.3)
    parser.add_argument("--scheduler_patience", type=int, default=5)
    parser.add_argument("--early_stopping_patience", type=int, default=12) 
    # System Args
    parser.add_argument("--num_workers", type=int, default=8)
    parser.add_argument("--log_file", type=str, default="checkpoints/Propensity_Model_Temporal/train_propensity.log") 
    parser.add_argument("--checkpoint_dir", type=str, default="checkpoints/Propensity_Model_Temporal")

    args = parser.parse_args()
    os.makedirs(args.checkpoint_dir, exist_ok=True)
    main(args)