{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "import sys\nimport os\n# Add the project root to the path\nproject_root = os.path.abspath(os.path.join(os.getcwd(), \"../..\"))\nsys.path.append(project_root)\nsys.path.append(os.path.join(project_root, \"src\"))\nsys.path.append(os.path.join(project_root, \"src/data_ncb\"))\n\nfrom src.config.config_dataclass import TrainerConfig, DatasetConfig, ModelConfig, BooleanConfig, FeatureSelectorConfig\nfrom tests.feature_analysis.feature_analysis import FeatureAnalyzer"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "# Create configurations: These values should be the same as the ones used in training\ntrainer_config = TrainerConfig(\n    epochs=1,\n    approach=\"sfw\",  # or \"sfw\"\n    seed=42,\n    lr=0,\n)\n\ndataset_config = DatasetConfig(\n    data_dir=\"path/to/your/datasets/clevr-hans/images/confounded/CLEVR_Hans3_4\",\n    batch_size=32,\n    num_workers=4,\n)\n\nmodel_config = ModelConfig(\n    model=\"settransformer\",  # or \"linear\", \"mlp\"\n    pretrained_model=True,\n    pretrained_path=\"path/to/your/merlinarthur-ncb-results/checkpoints/sfw/confounded/sfw_SetTransformer_on_one_hot_padded_seed42_bright-rain-3348/best_model.pth\",\n    n_heads=4,\n    set_transf_hidden=128,\n    #hidden_dim=256, # MLP only\n    #dropout=0.1 # MLP only\n)\n\nbool_config = BooleanConfig()\n\nfeature_selector_config = FeatureSelectorConfig(\n    mask_size=6, \n    lr_fs=0.1, # SFW only\n    l1_penalty_coefficient=5, # SFW only\n    sfw_max_iterations=350, # SFW only\n    sfw_patience=10, # SFW only\n    #fs_model=\"settransformer\",  # or \"mlp\"\n    #fs_hidden_dim=256,\n    #fs_dropout=0.3,\n    #fs_n_heads=4,\n)"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "# Initialize the analyzer\nanalyzer = FeatureAnalyzer(\n    trainer_config=trainer_config,\n    dataset_config=dataset_config,\n    model_config=model_config,\n    bool_config=bool_config,\n    feature_selector_config=feature_selector_config,\n    num_slots=4,\n    num_blocks=16,\n    input_dim=4224,\n    image_folder='path/to/your/models/ncb/trainedmodels_NCBrepo/CLEVR-4/retbind_seed_0/clustered_exemplars',\n)\n"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "analyzer.setup_ncb()\nanalyzer.setup_data()\nanalyzer.setup_model()"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "bs_encs, cs_encs, fnames, labels, preds_merlin, preds_morgana, features_merlin, features_morgana = analyzer.feature_analysis(split='val')"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": "analyzer.plot_results(0) # first image in the batch"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}