{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb7302df",
   "metadata": {},
   "outputs": [],
   "source": [
    "from calflops import calculate_flops\n",
    "from model import InterpretableResnet2, InterpretableViT, CBM, ViTConceptModel\n",
    "from processing.utils import get_info_from_lattice\n",
    "from argparse import Namespace"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d4e0fdb",
   "metadata": {},
   "outputs": [],
   "source": [
    "args = Namespace()\n",
    "args.dataset = \"awa2\"  # \"awa2\" / \"cifar100\" / \"inet100\"\n",
    "args.lattice_levels = [2, 1] # Change with lattice levels\n",
    "args.backbone_layer_ids = [3, 4] # change with backbone positions\n",
    "args.pretrained_clfs = False\n",
    "\n",
    "# Change the remaining accordingly\n",
    "if args.dataset == \"awa2\":\n",
    "    args.data_path = '././DATA/Animals_with_Attributes2'\n",
    "    args.concept_file = '././DATA/concepts/awa2_concepts.json'\n",
    "    args.lattice_path = '././DATA/lattices/awa2_context.pkl'\n",
    "    args.num_classes = 50\n",
    "    args.num_attrs = 85\n",
    "    args.backbone = 'resnet18'\n",
    "elif args.dataset == \"inet100\":\n",
    "    args.data_path = '././DATA/inet100'\n",
    "    args.concept_file = '././DATA/concepts/inet100_concepts.json'\n",
    "    args.lattice_path = '././DATA/lattices/inet100_context.pkl'\n",
    "    args.num_classes = 100\n",
    "    args.num_attrs = 700\n",
    "    args.backbone = 'resnet50'\n",
    "elif args.dataset == \"cifar100\":\n",
    "    args.data_path = '././DATA/cifar100'\n",
    "    args.concept_file = '././DATA/concepts/cifar100_concepts.json'\n",
    "    args.lattice_path = '././DATA/lattices/cifar100_context.pkl'\n",
    "    args.num_classes = 100\n",
    "    args.num_attrs = 700\n",
    "    args.backbone = 'resnet50'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d330eff",
   "metadata": {},
   "outputs": [],
   "source": [
    "perlevel_intents, perlevel_fcs = get_info_from_lattice(args.lattice_path, args.lattice_levels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63808f92",
   "metadata": {},
   "outputs": [],
   "source": [
    "# model = InterpretableResnet2(\n",
    "#     intent_list=perlevel_intents,\n",
    "#     fc_list=perlevel_fcs,\n",
    "#     backbone_layer_ids=args.backbone_layer_ids,\n",
    "#     num_classes=args.num_classes,\n",
    "#     backbone_name=args.backbone\n",
    "# )\n",
    "\n",
    "# model = InterpretableViT(\n",
    "#     intent_list=perlevel_intents,\n",
    "#     fc_list=perlevel_fcs,\n",
    "#     backbone_layer_ids=args.backbone_layer_ids,\n",
    "#     num_classes=args.num_classes,\n",
    "#     model_name=args.backbone\n",
    "# )\n",
    "\n",
    "model = CBM(\n",
    "    model_name=args.backbone,\n",
    "    num_classes=args.num_classes,\n",
    "    num_attrs=args.num_attrs,\n",
    ")\n",
    "\n",
    "# model = ViTConceptModel(\n",
    "#     model_name=args.backbone,\n",
    "#     num_classes=args.num_classes,\n",
    "#     num_concepts=args.num_attrs,\n",
    "# )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf21f783",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 1\n",
    "input_shape = (batch_size, 3, 224, 224)\n",
    "flops, macs, params = calculate_flops(model=model, \n",
    "                                      input_shape=input_shape,\n",
    "                                      output_as_string=True,\n",
    "                                      output_precision=4)\n",
    "\n",
    "print(\"FoCA-CBM FLOPs:%s   MACs:%s   Params:%s \\n\" %(flops, macs, params))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fdda525b",
   "metadata": {},
   "outputs": [],
   "source": [
    "args.backbone = 'vit_base_patch16_224'\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5894948e",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 1\n",
    "input_shape = (batch_size, 3, 224, 224)\n",
    "flops, macs, params = calculate_flops(model=model, \n",
    "                                      input_shape=input_shape,\n",
    "                                      output_as_string=True,\n",
    "                                      output_precision=4)\n",
    "\n",
    "print(\"FoCA-CBM FLOPs:%s   MACs:%s   Params:%s \\n\" %(flops, macs, params))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2884f185",
   "metadata": {},
   "outputs": [],
   "source": [
    "from cem.models.cem import ConceptEmbeddingModel\n",
    "from cem.train.utils import wrap_pretrained_model\n",
    "from torchvision.models import resnet18, resnet50\n",
    "from calflops import calculate_flops"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94ce2c47",
   "metadata": {},
   "outputs": [],
   "source": [
    "cem_model = ConceptEmbeddingModel(\n",
    "  n_concepts=700, # Number of training-time concepts\n",
    "  n_tasks=100, # Number of output labels\n",
    "  emb_size=16,\n",
    "  concept_loss_weight=0.1,\n",
    "  learning_rate=1e-3,\n",
    "  optimizer=\"adam\",\n",
    "  c_extractor_arch=wrap_pretrained_model(resnet50), # Replace this appropriately\n",
    "  training_intervention_prob=0.25, # RandInt probability\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d67dc283",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 1\n",
    "input_shape = (batch_size, 3, 224, 224)\n",
    "flops, macs, params = calculate_flops(model=cem_model, \n",
    "                                      input_shape=input_shape,\n",
    "                                      output_as_string=True,\n",
    "                                      output_precision=4)\n",
    "\n",
    "print(\"FoCA-CBM FLOPs:%s   MACs:%s   Params:%s \\n\" %(flops, macs, params))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f24bdc27",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "fca4nn",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
