# SKCD: Semiparametric Kernel Conditional Distributional Treatment Effect Test

Code for reproducing experiments in the paper.

## Requirements

```
numpy
pandas
torch
scikit-learn
lightgbm
optuna
scipy
matplotlib
```

Install via: `pip install numpy pandas torch scikit-learn lightgbm optuna scipy matplotlib`

## Repository Structure

```
├── run_simulation.py      # Main entry point for MNIST simulation experiments
├── config.py              # Configuration class for simulation parameters
├── proposed_test.py       # Proposed SKCD-MMD and SKCD-Wald tests
├── codite_mmd_test.py     # Baseline KCD test (Park et al., 2021)
├── utils.py               # Kernels, nuisance estimation, matrix computations
├── witness_analysis.py    # Witness function analysis (401k real data)
├── pension.csv            # SIPP 401(k) dataset
└── sim_data/              # Pre-processed MNIST embeddings
    ├── null_unedited.npy          # Covariate embeddings X
    ├── null_bright_rand.npy       # Outcome embeddings Y (null)
    ├── alternate_bright_rand_rot_det.npy  # Outcome embeddings Y (alternative)
    ├── trt_assign.npy             # Treatment assignments A
    └── propensities.npy           # True propensity scores
```

## Running Experiments

### MNIST Simulation (Section 5.1)

```bash
# Full experiment (all configurations)
python run_simulation.py

# Specific configuration
python run_simulation.py \
    --test_types mmd wald \
    --scenarios null alternate \
    --sample_sizes 500 1000 \
    --n_replicates 100 \
    --n_bootstrap 1000

# Dry run (check configuration without running)
python run_simulation.py --dry_run
```

Key arguments:
- `--test_types`: `mmd`, `wald`, `codite` (baseline)
- `--scenarios`: `null`, `alternate`
- `--sample_sizes`: e.g., `250 500 750 1000`
- `--misspec_combinations`: e.g., `"False,False" "True,False"` (propensity, outcome)
- `--n_replicates`: Monte Carlo replicates per configuration
- `--n_bootstrap`: Bootstrap samples for p-value computation
- `--use_oracle_propensity`: Use true propensity scores

Results are saved to `simulation_results_<timestamp>/`.

### 401(k) Real Data Analysis (Section 5.2)

```bash
python witness_analysis.py
```

Outputs:
- `witness_results_slice.pkl`: Computed witness functions and confidence bands
- `witness_functions_slice.png`: Visualization (Figure 3 in paper)

## Code Overview

**Proposed Tests** (`proposed_test.py`):
- `proposed_mmd_test()`: SKCD-MMD test with bootstrap inference
- `proposed_wald_test()`: SKCD-Wald test with regularized covariance
- `proposed_mmd_wald_test()`: Unified interface for both tests

**Baseline** (`codite_mmd_test.py`):
- `codite_mmd_test()`: KCD test with permutation inference

**Fast Implementation** (`utils.py`):
- `proposed_fast_mmd_setup()`: Precompute MMD bootstrap operator (Algorithm 2, line 3)
- `proposed_fast_wald_setup()`: Precompute Wald operators and LU factorization (Algorithm 2, lines 5-6)
- `proposed_fast_wald_step()`: O(n²) bootstrap iteration (Algorithm 2, lines 15-16)

## Hardware

Experiments were run on compute nodes with NVIDIA T4 GPU and 32GB RAM. The code automatically falls back to CPU if CUDA is unavailable.

## License

MIT
