{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "adc3fe93-ef6f-4cf6-a08c-95bbec8f179e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from graphgen.config_utils import generate_and_save_static_config\n",
    "from graphgen.runner import batch_sample_and_save\n",
    "from embedding.pipeline import run_embedding_and_prediction_pipeline\n",
    "from embedding.plot import plot_tsne, plot_raw\n",
    "from evaluation.runner import run_evaluation_pipeline\n",
    "from evaluation.plot import plot_multi_modes\n",
    "from evaluation.p_values import compute_p_values\n",
    "\n",
    "from sklearn.metrics import normalized_mutual_info_score as NMI\n",
    "from sklearn.metrics import adjusted_rand_score as ARI\n",
    "from pathlib import Path\n",
    "import numpy as np\n",
    "from itertools import product"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e2dc89c-32e2-44f3-86ca-052584d44aee",
   "metadata": {},
   "outputs": [],
   "source": [
    "p_values = [1.0, 0.9, 0.8, 0.7, 0.6, 0.55, 0.5, 0.45, 0.4, 0.35, 0.3, 0.25, 0.2, 0.1]\n",
    "q = 0.2\n",
    "pq_list = [(p, q) for p in p_values]\n",
    "runs_per_pq = 20\n",
    "\n",
    "K = 3\n",
    "label_mappings = [\n",
    "                 {0: 0, 1: 1, 2: 2},\n",
    "                 {0: 0, 1: 0, 2: 1},\n",
    "                 {0: 0, 1: 1, 2: 1}\n",
    "                 ]\n",
    "clusters_per_time = [3, 2, 2]\n",
    "n_nodes = 200\n",
    "connected = False\n",
    "\n",
    "emb_type_list = [\n",
    "                {\"rep_type\": \"UASE\", \"regularized\": 0},\n",
    "                {\"rep_type\": \"ULSE-n1\", \"regularized\": 0.1},\n",
    "                {\"rep_type\": \"ULSE-n2\", \"regularized\": 0.1},\n",
    "                ]\n",
    "metrics = {\"NMI\": NMI, \"ARI\": ARI}\n",
    "\n",
    "pi = np.array([0.3, 0.4, 0.3]) # 1\n",
    "# pi = np.array([0.4, 0.5, 0.1]) # 2\n",
    "base_dir = Path(\"../data/synthetic/1\")\n",
    "emb_dir = Path(\"../emb/synthetic/1\")\n",
    "true_label_path =  base_dir / \"config/node2label.txt\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b101f0c3-4369-460c-aa6b-a9ddd6004709",
   "metadata": {},
   "outputs": [],
   "source": [
    "# graph sampling\n",
    "static_labels, dynamic_labels = generate_and_save_static_config(\n",
    "    n=n_nodes,\n",
    "    K=K,\n",
    "    pi=pi,\n",
    "    label_mappings=label_mappings,\n",
    "    config_dir= base_dir / \"config\",\n",
    "    seed=42,\n",
    ")\n",
    "batch_sample_and_save(\n",
    "        static_labels=static_labels,\n",
    "        dynamic_node_labels=dynamic_labels,\n",
    "        output_dir=base_dir,\n",
    "        pq_list=pq_list,\n",
    "        runs_per_pq=runs_per_pq,\n",
    "        base_seed=42,\n",
    "        connected=False\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fcdc3399-e191-4c93-bb2b-bb45ae8f2e80",
   "metadata": {},
   "outputs": [],
   "source": [
    "# embedding and clustering\n",
    "run_embedding_and_prediction_pipeline(\n",
    "    base_dir = base_dir,\n",
    "    pq_list = pq_list,\n",
    "    runs_per_pq = runs_per_pq,\n",
    "    clusters_per_time = clusters_per_time,\n",
    "    emb_size = K,\n",
    "    random_state = 42,\n",
    "    emb_type_list = emb_type_list,\n",
    "    emb_dir = emb_dir\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aed22bbf-ba91-42d6-b620-d80f3cb45ba0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# evaluation\n",
    "result = run_evaluation_pipeline(\n",
    "    base_dir = emb_dir,\n",
    "    pq_list = pq_list,\n",
    "    runs_per_pq = runs_per_pq,\n",
    "    n_nodes = n_nodes,\n",
    "    true_label_path = true_label_path,\n",
    "    emb_type_list = emb_type_list,\n",
    "    metrics = metrics,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4919af7b-e92b-4a4e-a7e8-2ec34665606c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# accuracy plot\n",
    "plot_multi_modes(\n",
    "    pq_list = pq_list,\n",
    "    base_dir = emb_dir,\n",
    "    emb_type_list = emb_type_list,\n",
    "    metrics = metrics,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82929ec3-2fbf-40e2-affd-821e55d977dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# compute p values\n",
    "compute_p_values(\n",
    "    base_dir=emb_dir,\n",
    "    emb_type_list=emb_type_list,\n",
    "    metrics=metrics,\n",
    "    sample_size=runs_per_pq,\n",
    "    emb_low_key=\"UASE_reg0\",\n",
    "    emb_high_key=\"ULSE-n1_reg0.1\",\n",
    ")\n",
    "compute_p_values(\n",
    "    base_dir=emb_dir,\n",
    "    emb_type_list=emb_type_list,\n",
    "    metrics=metrics,\n",
    "    sample_size=runs_per_pq,\n",
    "    emb_low_key=\"UASE_reg0\",\n",
    "    emb_high_key=\"ULSE-n2_reg0.1\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b647ef2-7972-483c-b885-153d3cda930a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# embedding plot (raw)\n",
    "method_list = [\"USE_UASE\", \"USE_ULSE-n1\", \"USE_ULSE-n2\", \"TemporalCut_fast_normalized_1_0\"]\n",
    "label_path = Path(\"../data/synthetic/2/config/static_node2label.txt\")\n",
    "for method in method_list:\n",
    "    for t in [2,3]:\n",
    "        emb_path = Path(\"../emb/synthetic_2/synthetic_2_\"+method+\".emb\")\n",
    "        save_path = Path(\"./plots/synthetic_2_raw_\"+ str(t) +\"timesteps_\"+method+\".png\")\n",
    "\n",
    "        if method==\"USE_ULSE-n1\":\n",
    "            plot_raw(\n",
    "                embedding_path=emb_path,\n",
    "                label_path=label_path,\n",
    "                save_path= save_path,\n",
    "                method_name = method,\n",
    "                columns=[1,2],\n",
    "                n_times = t\n",
    "            )\n",
    "        else:\n",
    "            plot_raw(\n",
    "                embedding_path=emb_path,\n",
    "                label_path=label_path,\n",
    "                save_path= save_path,\n",
    "                method_name = method,\n",
    "                columns=[0,1],\n",
    "                n_times = t\n",
    "            )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84958c5d-9a01-473a-9b54-8d217cdc5f6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# embedding plot (tsne)\n",
    "method_list = [\"USE_UASE\", \"USE_ULSE-n1\", \"USE_ULSE-n2\"]\n",
    "label_path = Path(\"../data/synthetic/2/config/static_node2label.txt\")\n",
    "for method in method_list:\n",
    "    emb_path = Path(\"../emb/synthetic_2/synthetic_2_\"+method+\".emb\")\n",
    "    save_path = Path(\"./plots/synthetic_2_tsne_3timesteps_\"+method+\".png\")\n",
    "\n",
    "    plot_tsne(\n",
    "        embedding_path=emb_path,\n",
    "        label_path=label_path,\n",
    "        save_path= save_path,\n",
    "        method_name = method,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8db00e2e-f17e-4348-a531-49bbadfaab26",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
