{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "from pytorch_lightning.utilities.seed import seed_everything\n",
    "import clustering_utils\n",
    "import data_utils\n",
    "import model_utils\n",
    "import persistence_utils\n",
    "import visualisation_utils\n",
    "import models\n",
    "import numpy as np\n",
    "import wandb\n",
    "\n",
    "\n",
    "from experiments_clevr_extra import (train_graph_class,\n",
    "                                     plot_samples, save_centroids,\n",
    "                                     print_near_example, test_retrieval,\n",
    "                                     test_missing_modality)\n",
    "visualisation_utils.set_rc_params()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "wandb.init()"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# define some variables\n",
    "seed = 0\n",
    "seed_everything(seed)\n",
    "DATASET_NAME = 'clevr'\n",
    "MODE = 'SHARCS'\n",
    "path = os.path.join(\"output\", DATASET_NAME, MODE, f\"seed_{seed}\")\n",
    "data_utils.create_path(path)\n",
    "MODEL_NAME = f\"{DATASET_NAME}_{MODE}\"\n",
    "NUM_CLASSES = 2\n",
    "TRAIN_TEST_SPLIT = 0.8\n",
    "VOCAB_DIM = 22\n",
    "CLUSTER_ENCODING_SIZE = 24\n",
    "EPOCHS = 10\n",
    "LR = 0.001\n",
    "BATCH_SIZE = 5\n",
    "global dev\n",
    "if torch.cuda.is_available():\n",
    " dev = \"cuda\"\n",
    "else:\n",
    " dev = \"cpu\"\n",
    "device = torch.device(dev)\n",
    "LAYER_NUM = 0"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# load data\n",
    "print(\"Reading dataset\")\n",
    "\n",
    "train_loader, train, test_loader, test = data_utils.create_clevr()\n",
    "\n",
    "full_train_loader = DataLoader(train, batch_size=int(len(train) * 0.1), shuffle=True)\n",
    "full_test_loader = DataLoader(test, batch_size=max(int(len(test) * 0.1), 1))\n",
    "\n",
    "print('Done!')"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# train model\n",
    "model = models.CLEVR_SHARCS(VOCAB_DIM, CLUSTER_ENCODING_SIZE, NUM_CLASSES)\n",
    "interpretable = True\n",
    "model.to(dev)\n",
    "\n",
    "model_to_return = model\n",
    "model = model_utils.register_hooks(model)\n",
    "\n",
    "# train\n",
    "train_acc, test_acc, train_loss, test_loss = train_graph_class(model, train_loader, test_loader,\n",
    "                                                               EPOCHS, LR,\n",
    "                                                               if_interpretable_model=interpretable,\n",
    "                                                               mode=MODE)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# additional experiments setup\n",
    "train_input_image, train_input_text, train_questions, train_img_index, train_y, _, _, train_txt_aux = next(iter(full_train_loader))\n",
    "train_input_image = train_input_image.to(device)\n",
    "train_input_text = train_input_text.to(device)\n",
    "train_y = train_y.to(device)\n",
    "\n",
    "train_node_concepts, _, _, _ = model(train_input_text, train_input_image, train_y)\n",
    "train_graph_concepts = model.gnn_graph_shared_concepts.cpu()\n",
    "train_graph_local_concepts = model.gnn_graph_local_concepts.cpu()\n",
    "train_tab_local_concepts = model.x_tab_local_concepts.cpu()\n",
    "train_tab_concepts = model.tab_shared_concepts.cpu()\n",
    "\n",
    "test_input_image, test_input_text, test_questions, test_img_index, test_y, _, _, test_txt_aux = next(iter(full_test_loader))\n",
    "test_input_image = test_input_image.to(device)\n",
    "test_input_text = test_input_text.to(device)\n",
    "test_y = test_y.to(device)\n",
    "test_node_concepts, _, _, _ = model(test_input_text, test_input_image, test_y)\n",
    "test_graph_concepts = model.gnn_graph_shared_concepts.cpu()\n",
    "test_graph_local_concepts = model.gnn_graph_local_concepts.cpu()\n",
    "test_tab_local_concepts = model.x_tab_local_concepts.cpu()\n",
    "test_tab_concepts = model.tab_shared_concepts.cpu()\n",
    "\n",
    "graph_concepts = torch.vstack([train_graph_concepts, test_graph_concepts])\n",
    "graph_local_concepts = torch.vstack([train_graph_local_concepts, test_graph_local_concepts])\n",
    "tab_local_concepts = torch.vstack([train_tab_local_concepts, test_tab_local_concepts])\n",
    "tab_concepts = torch.vstack([train_tab_concepts, test_tab_concepts])\n",
    "\n",
    "q_train = train_input_text.cpu()\n",
    "q_test = test_input_text.cpu()\n",
    "q = torch.cat((q_train, q_test))\n",
    "\n",
    "idx_train = train_img_index.cpu()\n",
    "idx_test = test_img_index.cpu()\n",
    "\n",
    "img_train = train_input_image.cpu()\n",
    "img_test = test_input_image.cpu()\n",
    "img = torch.cat((img_train, img_test))\n",
    "\n",
    "y_train = train_y.cpu()\n",
    "y_test = test_y.cpu()\n",
    "y = torch.cat((y_train, y_test))\n",
    "\n",
    "questions = train_questions + test_questions\n",
    "txt_aux = train_txt_aux + test_txt_aux\n",
    "idx = torch.cat((train_img_index, test_img_index)).numpy()\n",
    "\n",
    "train_mask = np.zeros(y.shape[0], dtype=bool)\n",
    "train_mask[:y_train.shape[0]] = True\n",
    "test_mask = ~train_mask"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# local concepts text\n",
    "print(\"\\n_____________THIS IS FOR TEXT____________\")\n",
    "concepts_g_local = torch.Tensor(graph_local_concepts).detach()\n",
    "centroids, centroid_labels, used_centroid_labels = clustering_utils.find_centroids(concepts_g_local, y)\n",
    "print(f\"Number of graph cenroids: {len(centroids)}\")\n",
    "\n",
    "cluster_counts = visualisation_utils.print_cluster_counts(used_centroid_labels)\n",
    "classifier = models.ActivationClassifierConcepts(y, used_centroid_labels, train_mask, test_mask)\n",
    "\n",
    "print(f\"Classifier Concept completeness score: {classifier.accuracy}\")\n",
    "concept_metrics = [('cluster_count', cluster_counts)]\n",
    "\n",
    "visualisation_utils.plot_clustering(seed, concepts_g_local, y, centroids, centroid_labels,\n",
    "                                    used_centroid_labels,\n",
    "                                    MODEL_NAME, LAYER_NUM, path, task=\"graph local\", id_path=\"_graph\")\n",
    "\n",
    "g_concepts = concepts_g_local.detach().cpu().numpy()\n",
    "\n",
    "print('TEXT CONCEPTS')\n",
    "\n",
    "sample_graphs, sample_feat = plot_samples(None, g_concepts, 5, questions, concepts_g_local, path,\n",
    "                                          concepts=centroids, task='local')"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# local concepts image\n",
    "print(\"\\n_____________THIS IS FOR IMAGE____________\")\n",
    "concepts_g_img_local = torch.Tensor(tab_local_concepts).detach()\n",
    "# find centroids for both modalities\n",
    "centroids, centroid_labels, used_centroid_labels = clustering_utils.find_centroids(concepts_g_img_local, y)\n",
    "print(f\"Number of graph cenroids: {len(centroids)}\")\n",
    "persistence_utils.persist_experiment(centroids, path, 'centroids_g.z')\n",
    "persistence_utils.persist_experiment(centroid_labels, path, 'centroid_labels_g.z')\n",
    "persistence_utils.persist_experiment(used_centroid_labels, path, 'used_centroid_labels_g.z')\n",
    "\n",
    "# calculate cluster sizing\n",
    "cluster_counts = visualisation_utils.print_cluster_counts(used_centroid_labels)\n",
    "classifier = models.ActivationClassifierConcepts(y, used_centroid_labels, train_mask, test_mask)\n",
    "\n",
    "print(f\"Classifier Concept completeness score: {classifier.accuracy}\")\n",
    "concept_metrics = [('cluster_count', cluster_counts)]\n",
    "persistence_utils.persist_experiment(concept_metrics, path, 'image_concept_metrics.z')\n",
    "# wandb.log({'local graph completeness': classifier.accuracy, 'num clusters local graph': len(centroids)})\n",
    "\n",
    "# plot concept heatmaps\n",
    "# visualisation_utils.plot_concept_heatmap(centroids, concepts, y_double, used_centroid_labels, MODEL_NAME, LAYER_NUM, path, id_title=\"Graph \", id_path=\"graph_\")\n",
    "\n",
    "# plot clustering\n",
    "visualisation_utils.plot_clustering(seed, concepts_g_img_local, y, centroids, centroid_labels,\n",
    "                                    used_centroid_labels, MODEL_NAME, LAYER_NUM, path, task=\"graph image local\",\n",
    "                                    id_path=\"_graph\")\n",
    "g_img_concepts = concepts_g_img_local.detach().cpu().numpy()\n",
    "\n",
    "print('IMAGE CONCEPTS')\n",
    "sample_graphs, sample_feat = plot_samples(None, g_img_concepts, 5, idx, concepts_g_img_local, path,\n",
    "                                          concepts=centroids, task='local', mod='image')"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# shared space concepts\n",
    "print(\"\\n_____________THIS IS FOR GRAPHS AND TABLE____________\")\n",
    "concepts = torch.vstack([torch.Tensor(graph_concepts), torch.Tensor(tab_concepts)]).detach()\n",
    "y_double = torch.cat((y, y), dim=0)\n",
    "# find centroids for both modalities\n",
    "centroids, centroid_labels, used_centroid_labels = clustering_utils.find_centroids(concepts, y_double)\n",
    "print(f\"Number of graph cenroids: {len(centroids)}\")\n",
    "\n",
    "cluster_counts = visualisation_utils.print_cluster_counts(used_centroid_labels)\n",
    "train_mask_double = np.concatenate((train_mask, train_mask), axis=0)\n",
    "test_mask_double = np.concatenate((test_mask, test_mask), axis=0)\n",
    "classifier = models.ActivationClassifierConcepts(y_double, used_centroid_labels, train_mask_double,\n",
    "                                                 test_mask_double)\n",
    "\n",
    "print(f\"Classifier Concept completeness score: {classifier.accuracy}\")\n",
    "concept_metrics = [('cluster_count', cluster_counts)]\n",
    "persistence_utils.persist_experiment(concept_metrics, path, 'shared_concept_metrics.z')\n",
    "wandb.log({'shared completeness': classifier.accuracy, 'num clusters shared': len(centroids)})\n",
    "\n",
    "tab_or_graph = torch.cat((torch.ones(int(concepts.shape[0] / 2)), torch.zeros(int(concepts.shape[0] / 2))),\n",
    "                         dim=0)\n",
    "visualisation_utils.plot_clustering(seed, concepts, tab_or_graph, centroids, centroid_labels,\n",
    "                                    used_centroid_labels, MODEL_NAME, LAYER_NUM, path, task=\"shared\",\n",
    "                                    id_path=\"_graph\")"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# text shared concepts\n",
    "print('TEXT CONCEPTS')\n",
    "g_concepts = graph_concepts.detach().cpu().numpy()\n",
    "top_plot, top_concepts = plot_samples(None, g_concepts, 5, questions, concepts, path, concepts=centroids, task='local')"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# image shared concepts\n",
    "print('Image CONCEPTS')\n",
    "t_concepts = tab_concepts.detach().numpy()\n",
    "top_plot_images, top_concepts_img = plot_samples(None, t_concepts, 5, idx, concepts, path, concepts=centroids, task='global', mod='image')"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "print('------SHARED SPACE-----')\n",
    "top_concepts_both = np.array(top_concepts + top_concepts_img)\n",
    "\n",
    "top_plot_both = top_plot + top_plot_images\n",
    "\n",
    "if len(top_concepts + top_concepts_img) > 0:\n",
    "    visualisation_utils.plot_clustering_images_inside(seed, concepts, top_concepts_both, top_plot_both,\n",
    "                                                      used_centroid_labels, path,\n",
    "                                                      'shared space with images')"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# combined concepts\n",
    "print(\"\\n_____________THIS IS FOR COMBINED CONCEPTS____________\")\n",
    "union_concepts = torch.cat([torch.Tensor(graph_concepts), torch.Tensor(tab_concepts)], dim=-1).detach()\n",
    "\n",
    "# find centroids for both modalities\n",
    "centroids, centroid_labels, used_centroid_labels = clustering_utils.find_centroids(union_concepts, y)\n",
    "print(f\"Number of graph cenroids: {len(centroids)}\")\n",
    "\n",
    "cluster_counts = visualisation_utils.print_cluster_counts(used_centroid_labels)\n",
    "\n",
    "classifier = models.ActivationClassifierConcepts(y, used_centroid_labels, train_mask, test_mask)\n",
    "\n",
    "save_centroids(centroids, y, used_centroid_labels, union_concepts,\n",
    "               g_concepts, questions,\n",
    "               t_concepts, idx,\n",
    "               path)\n",
    "classifier.plot2(path)\n",
    "\n",
    "print(f\"Classifier Concept completeness score: {classifier.accuracy}\")\n",
    "concept_metrics = [('cluster_count', cluster_counts)]\n",
    "persistence_utils.persist_experiment(concept_metrics, path, 'graph_concept_metrics.z')\n",
    "wandb.log({'combined completeness': classifier.accuracy, 'num clusters combined': len(centroids)})"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# plot clustering\n",
    "visualisation_utils.plot_clustering(seed, union_concepts, y, centroids, centroid_labels,\n",
    "                                    used_centroid_labels,\n",
    "                                    MODEL_NAME, LAYER_NUM, path, task=\"combined\", id_path=\"_graph\",\n",
    "                                    extra=True, train_mask=train_mask, test_mask=test_mask,\n",
    "                                    n_classes=NUM_CLASSES)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "print_near_example(g_concepts, questions, t_concepts, idx, path)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "import experiments_clevr_extra\n",
    "experiments_clevr_extra.dev = dev\n",
    "missing_accuracy = test_missing_modality(full_test_loader, model, t_concepts, g_concepts)\n",
    "print(missing_accuracy)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "retrieval_acc = test_retrieval(full_test_loader, model, g_concepts, questions, t_concepts, txt_aux)\n",
    "print(retrieval_acc)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
