# PRISM: Prototype-Regularized Information-Splitting Model

This repository contains the official PyTorch Lightning implementation for PRISM, a deep generative model designed for disentangled representation learning.

The primary goal of PRISM is to learn a latent space that is split into two distinct subspaces:
1.  An **identity subspace** ($z_1$) that captures the core, class-conditional factors of variation (e.g., the digit '8' in MNIST, a person's identity in a face dataset).
2.  A **residual subspace** ($z_0$) that captures all other "style" variations (e.g., handwriting style, lighting conditions, pose).

It achieves this through a combination of prototype-based regularization on the identity subspace and information-theoretic splitting of the latent code, enforced by an adversarial training scheme.

## Requirements

The project uses Python and several common deep learning libraries. All dependencies are managed in the `environment.yml` file.

1.  **Install Conda:** If you don't have it, install [Miniconda](https://docs.conda.io/en/latest/miniconda.html) or Anaconda.

2.  **Create the Conda Environment:** Create and activate the environment from the project root directory.
    ```bash
    conda env create -f environment.yml
    conda activate torchenv
    ```

## Dataset Setup

The model is configured to run on several datasets, including Morpho-MNIST and CelebA.

1.  **Download the Datasets:**
    *   **Morpho-MNIST:** Download the "global" variant from the [official repository](https://github.com/dccastro/Morpho-MNIST).
    *   **CelebA:** Download from the [official website](https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html).

2.  **Configure Paths:**
    **IMPORTANT:** Before running any training, you must update the dataset paths in the configuration files to point to their locations on your local machine.

    For example, in `configs/base.yaml`, you will find entries like this:
    ```yaml
    morpho-mnist:
      dir: "<path_to_datasets>/morpho-mnist-global" # <-- CHANGE THIS
    celeba:
      dir: "<path_to_datasets>/CelebA"             # <-- CHANGE THIS
    ```
    Replace the value of `dir` with the correct path. The key configuration files to check are:
    *   `configs/base.yaml`
    *   `configs/celeba_config.yaml`
    *   `configs/baselines/fader_morphomnist.yaml`

3.  **Expected Directory Structure:** The code expects the following structure within your dataset directories:
    *   **Morpho-MNIST:**
        ```
        <your_path>/morpho-mnist-global/
        ├── train-images-idx3-ubyte.gz
        ├── train-labels-idx1-ubyte.gz
        ├── train-morpho.csv
        ├── t10k-images-idx3-ubyte.gz
        ├── t10k-labels-idx1-ubyte.gz
        └── t10k-morpho.csv
        ```
    *   **CelebA:**
        ```
        <your_path>/CelebA/
        ├── Img/
        │   ├── img_align_celeba/
        │   │   ├── 000001.jpg
        │   │   └── ...
        └── Anno/
            ├── list_attr_celeba.txt
            ├── list_eval_partition.txt
            └── ...
        ```

## Usage

The framework is controlled via the `main.py` script and YAML configuration files.

### Training the Main PRISM Model

To train the main model, use the `train` command and specify a configuration file.

*   **Train on Morpho-MNIST (default):**
    ```bash
    python main.py train --config configs/base.yaml
    ```

*   **Train on CelebA:**
    ```bash
    python main.py train --config configs/celeba_config.yaml
    ```

*   **Resume Training:** To resume the latest run for a given experiment, use the `--resume` flag. You can also specify a version number.
    ```bash
    # Resume the latest run defined in the config's sweep_name
    python main.py train --config configs/base.yaml --resume

    # Resume a specific version (e.g., version_0)
    python main.py train --config configs/base.yaml --resume 0
    ```
    Logs and checkpoints will be saved to the directory specified by `run.log_dir` in the config file (default is `../runs`).

### Post-Hoc Analysis

After a model has been trained, you can run evaluation and visualization scripts on the saved artifacts.

*   **Run Post-Hoc Metrics:**
    ```bash
    python main.py evaluate --run_dir path/to/your/runs/morpho/version_0
    ```

*   **Generate Post-Hoc Visualizations:**
    ```bash
    python main.py visualize --run_dir path/to/your/runs/morpho/version_0
    ```

### Running Baselines

The repository includes an implementation of the Fader Network baseline. It uses a separate training script.

*   **Train Fader Network on Morpho-MNIST:**
    ```bash
    python -m user_extensions.baselines.main --config configs/baselines/fader_morphomnist.yaml
    ```

### Running Experiments (Reproducibility)

Scripts are provided to reproduce the hyperparameter optimization (HPO) and ablation studies.

*   **Run the Ablation Study:**
    This script runs an HPO search for each ablation condition and then trains the best model for several seeds.
    ```bash
    python -m user_extensions.experiments.run_ablation \
      --study_config configs/experiment/additive_ablation.yaml \
      --hpo_config configs/experiment/hpo_loss.yaml \
      --num_gpus 2
    ```

*   **Run a Standalone HPO Study:**
    This script performs a broad HPO search for a single model configuration.
    ```bash
    python -m user_extensions.experiments.run_hpo \
      --study_config configs/experiment/hpo_architecture.yaml \
      --num_gpus 2
    ```

## Codebase Structure

The codebase is organized into a core framework and user-specific extensions.

*   `prism/`: Contains the core, reusable components of the framework.
    *   `core/`: Base classes and registries for modularity.
    *   `systems/`: The main `PrismSystem` LightningModule.
    *   `models/`: Model definitions (Encoder, Generator, backbones, heads).
    *   `losses/`: Custom loss function implementations.
    *   `callbacks/`: Callbacks for logging, artifact saving, and visualizations.
    *   `evaluation/`: Metric calculation and visualization logic.
*   `user_extensions/`: Contains project-specific implementations.
    *   `datasets/`: Custom `LightningDataModule` for each dataset.
    *   `baselines/`: Code for baseline models (e.g., Fader Network).
    *   `experiments/`: Scripts for running ablation and HPO studies.
*   `configs/`: Contains all YAML configuration files for experiments.

## Citation

Citation information will be provided upon publication.