{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES']= '0'\n",
    "import sys\n",
    "sys.path.append(\"../\")\n",
    "from matplotlib import pyplot as plt\n",
    "import random\n",
    "import torch\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from src.utils.logging import get_logger\n",
    "from src.configs.defaults import get_cfg_defaults\n",
    "from src.env.build import build_env\n",
    "import src.env.binary_tree_env\n",
    "import src.env.binary_tree_env_one_step\n",
    "from src.gfn.gfn_evaluation import GFNEvaluator\n",
    "from src.gfn.rollout_worker_phylo import RolloutWorker\n",
    "from src.gfn.training_data_loader import TrainingDataLoader\n",
    "from src.gfn.build import build_gfn\n",
    "import src.utils.plot_utils as plot_utils\n",
    "from src.utils.utils import load_sequences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "sequences_path = '../dataset/benchmark_datasets/DS1.pickle'\n",
    "cfg_path = '../phylo_conditional_used_for_ds1/conditional_exp/config.yaml'\n",
    "paup_result_path = '../dataset/benchmark_datasets/paup_results/ds1.tre'\n",
    "\n",
    "# load sequences\n",
    "all_seqs = load_sequences(sequences_path)\n",
    "\n",
    "# load config\n",
    "cfg = get_cfg_defaults()\n",
    "cfg.merge_from_file(cfg_path)\n",
    "\n",
    "# select GPU or CPU\n",
    "# all_device = [torch.device(f'cuda:{i}') for i in range(1)]\n",
    "all_device = [torch.device('cpu')]\n",
    "env, state2input = build_env(cfg, all_seqs)\n",
    "rollout_worker = RolloutWorker(cfg.GFN, env, state2input)\n",
    "generator = build_gfn(cfg, state2input, env, all_device)\n",
    "\n",
    "\n",
    "# evaluator\n",
    "if cfg.GFN.CONDITION_ON_SCALE:\n",
    "    scales_set = cfg.GFN.SCALES_SET\n",
    "    assert scales_set is not None, 'must specify \"SCALES_SET\" in config when \"CONDITION_ON_SCALE\" is on'\n",
    "else:\n",
    "    scales_set = None\n",
    "    \n",
    "# reducing the evaluation batch size because I'm using my windows machine\n",
    "cfg.GFN.MODEL.EVALUATION.BATCH_SIZE=8\n",
    "gfn_evaluator = GFNEvaluator(\n",
    "    cfg.GFN.MODEL.EVALUATION, rollout_worker, generator, states='placeholder', verbose=True, scales_set=scales_set)\n",
    "\n",
    "model_checkpoint_path =  '../phylo_conditional_used_for_ds1/conditional_exp/checkpoints/checkpoint_000399.pt'\n",
    "generator.load(model_checkpoint_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### reproducing results for DS1 conditional model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "all_scales = np.linspace(0.5, 8.0, 100)\n",
    "with torch.no_grad():\n",
    "    all_log_z = generator.compute_log_Z(torch.tensor(all_scales, dtype=torch.float32)[:, None].to(all_device[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib\n",
    "from importlib import reload\n",
    "reload(matplotlib)\n",
    "plt.clf()\n",
    "plt.figure(figsize=(7, 7))\n",
    "plt.plot(all_scales, all_log_z.cpu().numpy())\n",
    "plt.xlabel('statistical temperature', fontsize=18)\n",
    "plt.ylabel('GFlowNet log partition', fontsize=18)\n",
    "plt.xticks(fontsize=18)\n",
    "plt.yticks(fontsize=18)\n",
    "plt.tight_layout()\n",
    "plt.savefig('conditional_logz.pdf', dpi=350)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "scale=8.0\n",
    "# sampling 10,000 trees from the generator\n",
    "gfn_sample_result_8 = {'states': [], 'mutations': []}\n",
    "for _ in range(10):\n",
    "    gfn_sample_result_ = gfn_evaluator.evaluate_gfn_samples(scale)\n",
    "    gfn_sample_result_8['states'].extend(gfn_sample_result_['states'])\n",
    "    gfn_sample_result_8['mutations'].extend(gfn_sample_result_['mutations'])\n",
    "\n",
    "unique_mutations = np.unique(gfn_sample_result_8['mutations'])\n",
    "num_unique_mutations = len(unique_mutations)\n",
    "mutations = np.random.choice(unique_mutations, min(200, num_unique_mutations), replace=False)\n",
    "eval_states_scale_8 = [np.array(gfn_sample_result_8['states'])[np.array(gfn_sample_result_8['mutations']) == mut][0] for mut in mutations]\n",
    "gfn_evaluator.states = eval_states_scale_8\n",
    "ret_scale_8 = gfn_evaluator.evaluate_gfn_quality(scale)\n",
    "\n",
    "fig = plt.figure(figsize=(10, 10))\n",
    "plt.scatter(ret_scale_8['log_prob_reward'][0], -np.array([state.subtrees[0].total_mutations for state in eval_states_scale_8]), \n",
    "            label='PearsonR: %.3f' % (ret_scale_8['log_pearsonr']))\n",
    "plt.xlabel('GFlowNet log probability\\n T=8.0', fontsize=40)\n",
    "plt.ylabel('Negative parsimony score', fontsize=40)\n",
    "plt.xticks(fontsize=40, rotation=45, ha='right')\n",
    "plt.yticks(fontsize=40)\n",
    "plt.legend(loc='lower right', fontsize=40)\n",
    "plt.tight_layout()\n",
    "plt.savefig('scatter_scale_8.pdf', dpi=350)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "scale=4.0\n",
    "# sampling 10,000 trees from the generator\n",
    "gfn_sample_result_4 = {'states': [], 'mutations': []}\n",
    "for _ in range(10):\n",
    "    gfn_sample_result_ = gfn_evaluator.evaluate_gfn_samples(scale)\n",
    "    gfn_sample_result_4['states'].extend(gfn_sample_result_['states'])\n",
    "    gfn_sample_result_4['mutations'].extend(gfn_sample_result_['mutations'])\n",
    "\n",
    "unique_mutations = np.unique(gfn_sample_result_4['mutations'])\n",
    "num_unique_mutations = len(unique_mutations)\n",
    "mutations = np.random.choice(unique_mutations, min(200, num_unique_mutations), replace=False)\n",
    "eval_states_scale_4 = [np.array(gfn_sample_result_4['states'])[np.array(gfn_sample_result_4['mutations']) == mut][0] for mut in mutations]\n",
    "\n",
    "if len(eval_states_scale_4) < 200:\n",
    "    for state in eval_states_scale_8:\n",
    "        if state.subtrees[0].total_mutations not in unique_mutations:\n",
    "            eval_states_scale_4.append(state)\n",
    "            if len(eval_states_scale_4) >= 200:\n",
    "                break\n",
    "print(len(eval_states_scale_4))\n",
    "\n",
    "gfn_evaluator.states = eval_states_scale_4\n",
    "ret_scale_4 = gfn_evaluator.evaluate_gfn_quality(scale)\n",
    "\n",
    "fig = plt.figure(figsize=(10, 10))\n",
    "plt.scatter(ret_scale_4['log_prob_reward'][0], -np.array([state.subtrees[0].total_mutations for state in eval_states_scale_4]), \n",
    "            label='PearsonR: %.3f' % (ret_scale_4['log_pearsonr']))\n",
    "plt.xlabel('GFlowNet log probability\\n T=4.0', fontsize=40)\n",
    "plt.ylabel('Negative parsimony score', fontsize=40)\n",
    "plt.xticks(fontsize=40, rotation=45, ha='right')\n",
    "plt.yticks(fontsize=40)\n",
    "plt.legend(loc='lower right', fontsize=40)\n",
    "plt.tight_layout()\n",
    "plt.savefig('scatter_scale_4.pdf', dpi=350)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scale=2.0\n",
    "# sampling 10,000 trees from the generator\n",
    "gfn_sample_result_2 = {'states': [], 'mutations': []}\n",
    "for _ in range(10):\n",
    "    gfn_sample_result_ = gfn_evaluator.evaluate_gfn_samples(scale)\n",
    "    gfn_sample_result_2['states'].extend(gfn_sample_result_['states'])\n",
    "    gfn_sample_result_2['mutations'].extend(gfn_sample_result_['mutations'])\n",
    "\n",
    "unique_mutations = np.unique(gfn_sample_result_2['mutations'])\n",
    "num_unique_mutations = len(unique_mutations)\n",
    "mutations = np.random.choice(unique_mutations, min(200, num_unique_mutations), replace=False)\n",
    "eval_states_scale_2 = [np.array(gfn_sample_result_2['states'])[np.array(gfn_sample_result_2['mutations']) == mut][0] for mut in mutations]\n",
    "\n",
    "if len(eval_states_scale_2) < 200:\n",
    "    for state in eval_states_scale_4:\n",
    "        if state.subtrees[0].total_mutations not in unique_mutations:\n",
    "            eval_states_scale_2.append(state)\n",
    "            if len(eval_states_scale_2) >= 200:\n",
    "                break\n",
    "print(len(eval_states_scale_2))\n",
    "\n",
    "gfn_evaluator.states = eval_states_scale_2\n",
    "ret_scale_2 = gfn_evaluator.evaluate_gfn_quality(scale)\n",
    "\n",
    "fig = plt.figure(figsize=(10, 10))\n",
    "plt.scatter(ret_scale_2['log_prob_reward'][0], -np.array([state.subtrees[0].total_mutations for state in eval_states_scale_2]), \n",
    "            label='PearsonR: %.3f' % (ret_scale_2['log_pearsonr']))\n",
    "plt.xlabel('GFlowNet log probability\\n T=4.0', fontsize=40)\n",
    "plt.ylabel('Negative parsimony score', fontsize=40)\n",
    "plt.xticks(fontsize=40, rotation=45, ha='right')\n",
    "plt.yticks(fontsize=40)\n",
    "plt.legend(loc='lower right', fontsize=40)\n",
    "plt.tight_layout()\n",
    "plt.savefig('scatter_scale_2.pdf', dpi=350)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "scale=1.0\n",
    "# sampling 10,000 trees from the generator\n",
    "gfn_sample_result_1 = {'states': [], 'mutations': []}\n",
    "for _ in range(10):\n",
    "    gfn_sample_result_ = gfn_evaluator.evaluate_gfn_samples(scale)\n",
    "    gfn_sample_result_1['states'].extend(gfn_sample_result_['states'])\n",
    "    gfn_sample_result_1['mutations'].extend(gfn_sample_result_['mutations'])\n",
    "\n",
    "unique_mutations = np.unique(gfn_sample_result_1['mutations'])\n",
    "num_unique_mutations = len(unique_mutations)\n",
    "mutations = np.random.choice(unique_mutations, min(200, num_unique_mutations), replace=False)\n",
    "eval_states_scale_1 = [np.array(gfn_sample_result_1['states'])[np.array(gfn_sample_result_1['mutations']) == mut][0] for mut in mutations]\n",
    "\n",
    "if len(eval_states_scale_1) < 200:\n",
    "    for state in eval_states_scale_2:\n",
    "        if state.subtrees[0].total_mutations not in unique_mutations:\n",
    "            eval_states_scale_1.append(state)\n",
    "            if len(eval_states_scale_1) >= 200:\n",
    "                break\n",
    "print(len(eval_states_scale_1))\n",
    "\n",
    "gfn_evaluator.states = eval_states_scale_1\n",
    "ret_scale_1 = gfn_evaluator.evaluate_gfn_quality(scale)\n",
    "\n",
    "fig = plt.figure(figsize=(10, 10))\n",
    "plt.scatter(ret_scale_1['log_prob_reward'][0], -np.array([state.subtrees[0].total_mutations for state in eval_states_scale_1]), \n",
    "            label='PearsonR: %.3f' % (ret_scale_1['log_pearsonr']))\n",
    "plt.xlabel('GFlowNet log probability\\n T=1.0', fontsize=40)\n",
    "plt.ylabel('Negative parsimony score', fontsize=40)\n",
    "plt.xticks(fontsize=40, rotation=45, ha='right')\n",
    "plt.yticks(fontsize=40)\n",
    "plt.legend(loc='lower right', fontsize=40)\n",
    "plt.tight_layout()\n",
    "plt.savefig('scatter_scale_1.pdf', dpi=350)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.clf()\n",
    "plt.figure(figsize=(12, 8))\n",
    "plt.hist(gfn_sample_result_8['mutations'], label=f'GFlowNet samples at T=8.0')\n",
    "plt.hist(gfn_sample_result_4['mutations'], label=f'GFlowNet samples at T=4.0')\n",
    "plt.hist(gfn_sample_result_2['mutations'], label=f'GFlowNet samples at T=2.0')\n",
    "plt.hist(gfn_sample_result_1['mutations'], label=f'GFlowNet samples at T=1.0')\n",
    "plt.legend(fontsize=18)\n",
    "plt.xticks(fontsize=18)\n",
    "plt.yticks(fontsize=18)\n",
    "plt.xlabel('parsimony score', fontsize=18)\n",
    "plt.ylabel('counts', fontsize=18)\n",
    "plt.savefig('conditional_hist_v2.pdf', dpi=350)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
