{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Table: Dataset stats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "\n",
    "from src import data, models, functional\n",
    "import torch\n",
    "import os\n",
    "import pandas as pd\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "############################################################\n",
    "model_name = \"gptj\"\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "results_dir = f\"../results/known_samples/\"\n",
    "############################################################\n",
    "os.makedirs(results_dir, exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "mt = models.load_model(model_name, fp16=True, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = data.load_dataset()\n",
    "filtered = functional.filter_dataset_samples(\n",
    "    mt=mt,\n",
    "    dataset=dataset,\n",
    "    n_icl_lm=functional.DEFAULT_N_ICL_LM,\n",
    "    n_trials=3,\n",
    "    batch_size=1,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "relations_by_name = {r.name: r for r in dataset.relations}\n",
    "filtered_by_name = {r.name: r for r in filtered.relations}\n",
    "\n",
    "samples_known = {}\n",
    "for name in relations_by_name:\n",
    "    relation_samples = set(relations_by_name[name].samples)\n",
    "    filtered_samples = set(filtered_by_name[name].samples) if name in filtered_by_name else set()\n",
    "    samples_known[name] = {\n",
    "        \"known\": len(filtered_samples),\n",
    "        \"total\": len(relation_samples),\n",
    "        \"known_samples\": [\n",
    "            {\n",
    "                \"subject\": sample.subject,\n",
    "                \"object\": sample.object,\n",
    "            }\n",
    "            for sample in filtered_samples\n",
    "        ],\n",
    "    }\n",
    "\n",
    "with open(f\"{results_dir}/{model_name}.json\", \"w\") as f:\n",
    "    json.dump(samples_known, f, indent=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_json = [\n",
    "    {\n",
    "        \"relation\": key,\n",
    "        \"total\": value[\"total\"],\n",
    "        model_name: value[\"known\"],\n",
    "    } for key, value in samples_known.items()]\n",
    "\n",
    "df = pd.DataFrame(df_json)\n",
    "df.to_excel(f\"{results_dir}/{model_name}.xlsx\", index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset[0].properties.__dict__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Literal\n",
    "property_key: Literal[\"relation_type\", \"fn_type\", \"disambiguating\", \"symmetric\"] = \"disambiguating\"\n",
    "\n",
    "category_wise = {}\n",
    "for name in relations_by_name:\n",
    "    property_value = relations_by_name[name].properties.__dict__[property_key]\n",
    "    relation_samples = set(relations_by_name[name].samples)\n",
    "    filtered_samples = set(filtered_by_name[name].samples) if name in filtered_by_name else set()\n",
    "    if property_value not in category_wise:\n",
    "        category_wise[property_value] = {\n",
    "            \"known\": 0,\n",
    "            \"total\": 0,\n",
    "        }\n",
    "    category_wise[property_value][\"known\"] += len(filtered_samples)\n",
    "    category_wise[property_value][\"total\"] += len(relation_samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_json = [\n",
    "    {\n",
    "        property_key: key,\n",
    "        \"total\": value[\"total\"],\n",
    "        model_name: value[\"known\"],\n",
    "    } for key, value in category_wise.items()\n",
    "]\n",
    "\n",
    "pd.DataFrame(df_json)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "relations",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
