{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ba56140a",
   "metadata": {},
   "source": [
    "# SOME CODE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "70cb0f39",
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.stats import pearsonr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "29b651d3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "_StoreAction(option_strings=['--gradMethod'], dest='gradMethod', nargs=None, const=None, default='vimco', type=<class 'str'>, choices=None, required=False, help=' vimco | rws ', metavar=None)"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import warnings \n",
    "warnings. filterwarnings('ignore')\n",
    "import logging\n",
    "logger = logging.getLogger()\n",
    "logger.setLevel(logging.CRITICAL)\n",
    "import argparse\n",
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES']='0'\n",
    "\n",
    "from dataManipulation import *\n",
    "from utils import summary, summary_raw, mcmc_treeprob, get_support_from_mcmc, BitArray, tree_process\n",
    "from vbpi import VBPI\n",
    "import time\n",
    "import numpy as np\n",
    "import datetime\n",
    "\n",
    "parser = argparse.ArgumentParser()\n",
    "\n",
    "parser.add_argument('--dataset',default='DS1', help=' DS1 | DS2 | DS3 | DS4 | DS5 | DS6 | DS7 | DS8 ')\n",
    "parser.add_argument('--supportType', type=str, default='ufboot', help=' ufboot | mcmc ')\n",
    "parser.add_argument('--empFreq', default=False, action='store_true', help='emprical frequence for KL computation')\n",
    "\n",
    "\n",
    "######### Model arguments\n",
    "parser.add_argument('--psp', default=False, action='store_true', help=' turn on psp branch length feature')\n",
    "parser.add_argument('--nf', type=int, default=2, help=' branch length feature embedding dimension ')\n",
    "parser.add_argument('--hdim', type=int, default=100, help='hidden dimension for node embedding net')\n",
    "parser.add_argument('--hL', type=int, default=2, help='number of hidden layers for node embedding net')\n",
    "parser.add_argument('--brlen_model', type=str, default='gnn', help='branch length models')\n",
    "parser.add_argument('--gnn_type', type=str, default='gcn', help='gcn | sage | gin | ggnn')\n",
    "parser.add_argument('--aggr', type=str, default='sum', help='sum | mean | max')\n",
    "parser.add_argument('--proj', default=False, action='store_true', help='use projection first in SAGEConv')\n",
    "parser.add_argument('--test', default=False, action='store_true', help='turn on the test mode')\n",
    "parser.add_argument('--datetime', type=str, default='2022-01-01', help=' 2020-04-01 | 2020-04-02 | ...... ')\n",
    "\n",
    "\n",
    "######### Optimizer arguments\n",
    "parser.add_argument('--stepszTree', type=float, default=0.001, help=' step size for tree topology parameters ')\n",
    "parser.add_argument('--stepszBranch', type=float, default=0.001, help=' stepsz for branch length parameters ')\n",
    "parser.add_argument('--maxIter', type=int, default=200000, help=' number of iterations for training, default=400000')\n",
    "parser.add_argument('--invT0', type=float, default=0.001, help=' initial inverse temperature for annealing schedule, default=0.001')\n",
    "parser.add_argument('--nwarmStart', type=float, default=100000, help=' number of warm start iterations, default=100000')\n",
    "parser.add_argument('--nParticle', type=int, default=10, help='number of particles for variational objectives, default=10')\n",
    "parser.add_argument('--ar', type=float, default=0.75, help='step size anneal rate, default=0.75')\n",
    "parser.add_argument('--af', type=int, default=20000, help='step size anneal frequency, default=20000')\n",
    "parser.add_argument('--tf', type=int, default=1000, help='monitor frequency during training, default=1000')\n",
    "parser.add_argument('--lbf', type=int, default=5000, help='lower bound test frequency, default=5000')\n",
    "parser.add_argument('--gradMethod', type=str, default='vimco', help=' vimco | rws ')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "9b588180",
   "metadata": {},
   "outputs": [],
   "source": [
    "from vector_sbnModel import *\n",
    "from utils import namenum\n",
    "import torch \n",
    "def sample_tree(rooted=False, random_p=None):\n",
    "    root = Tree()\n",
    "    node_split_stack = [(root, '0'*tree_model.ntaxa + '1'*tree_model.ntaxa)]\n",
    "    for i in range(tree_model.ntaxa-1):\n",
    "        node, split_bitarr = node_split_stack.pop()\n",
    "        parent_clade_bitarr = bitarray(split_bitarr[tree_model.ntaxa:])\n",
    "        node.clade_bitarr = parent_clade_bitarr\n",
    "        node.split_bitarr = min([parent_clade_bitarr, ~parent_clade_bitarr]).to01()\n",
    "        if node.is_root():\n",
    "            split_prob = tree_model.rs_CPDs\n",
    "            # split = self.rs_reverse_map[np.random.choice(len(split_prob), p=split_prob)]\n",
    "\n",
    "            split = tree_model.rs_reverse_map[torch.multinomial(split_prob, 1).item()]\n",
    "            if random_p is not None:\n",
    "                eps = np.random.random()\n",
    "                if eps < random_p:\n",
    "                    idx = np.random.randint(0, len(tree_model.rs_CPDs))\n",
    "                    split = tree_model.rs_reverse_map[idx]\n",
    "        else:\n",
    "            split_prob = tree_model.get_subsplit_CPDs(split_bitarr)\n",
    "            # split = self.ss_reverse_map[split_bitarr][np.random.choice(len(split_prob), p=split_prob)]\n",
    "            split = tree_model.ss_reverse_map[split_bitarr][torch.multinomial(split_prob, 1).item()]\n",
    "            if random_p is not None:\n",
    "                eps = np.random.random()\n",
    "                if eps < random_p:\n",
    "                    idx = np.random.randint(0, len(split_prob))\n",
    "                    split = tree_model.ss_reverse_map[split_bitarr][idx]\n",
    "                    \n",
    "        comp_split = (parent_clade_bitarr ^ bitarray(split)).to01()\n",
    "\n",
    "        c1 = node.add_child()\n",
    "        c2 = node.add_child()\n",
    "        if split.count('1') > 1:\n",
    "            node_split_stack.append((c1, comp_split + split))\n",
    "        else:\n",
    "            c1.name = tree_model.taxa[split.find('1')]\n",
    "            c1.clade_bitarr = bitarray(split)\n",
    "            c1.split_bitarr = min([c1.clade_bitarr, ~c1.clade_bitarr]).to01()\n",
    "        if comp_split.count('1') > 1:\n",
    "            node_split_stack.append((c2, split + comp_split))\n",
    "        else:\n",
    "            c2.name = tree_model.taxa[comp_split.find('1')]\n",
    "            c2.clade_bitarr = bitarray(comp_split)\n",
    "            c2.split_bitarr = min([c2.clade_bitarr, ~c2.clade_bitarr]).to01()\n",
    "\n",
    "    if not rooted:\n",
    "        root.unroot()\n",
    "    return root\n",
    "\n",
    "import networkx as nx\n",
    "\n",
    "def generate_nx_trees(samp_trees, samp_log_branch, n_leaves):\n",
    "    nx_trees = []\n",
    "\n",
    "\n",
    "    for tree, log_branch in zip(samp_trees, samp_log_branch):\n",
    "        branch_length = log_branch.exp().detach().cpu().numpy()\n",
    "        for idx, node in enumerate(tree.traverse(\"postorder\")):\n",
    "            if not node.is_root():\n",
    "                node.dist = branch_length[node.name]\n",
    "\n",
    "        n_nodes = 2 * n_leaves - 2\n",
    "\n",
    "        nx_tree = nx.Graph()\n",
    "        node_names = [n for n in range(n_nodes)]\n",
    "        leaves = node_names[:n_leaves]\n",
    "\n",
    "        # Add nodes to the graph\n",
    "        for node in leaves:\n",
    "            nx_tree.add_node(node, type='leaf')\n",
    "        for node in node_names[n_leaves:-1]:\n",
    "            nx_tree.add_node(node, type='internal')\n",
    "        nx_tree.add_node(2 * n_leaves - 3, type='root')\n",
    "\n",
    "\n",
    "        for idx, node in enumerate(tree.traverse(\"preorder\")):    \n",
    "            if not node.is_leaf():\n",
    "                children = []\n",
    "                for child_node in node.children:\n",
    "                    t = child_node.dist\n",
    "                    nx_tree.add_edge(node.name, child_node.name, t=t)\n",
    "                    nx_tree.add_node(child_node.name, parent=node.name)\n",
    "\n",
    "                    children.append(child_node.name)\n",
    "                nx_tree.add_node(node.name, children=children)\n",
    "\n",
    "        nx_trees.append(nx_tree)\n",
    "    return nx_trees\n",
    "\n",
    "def generate_results(random_p):\n",
    "    \n",
    "    samp_trees = [sample_tree(random_p=random_p) for particle in range(128)]\n",
    "    [namenum(tree, model.taxa) for tree in samp_trees]    \n",
    "    samp_log_branch, logq_branch = model.branch_model(samp_trees)\n",
    "\n",
    "    logll = torch.stack([model.phylo_model.loglikelihood(log_branch, tree) for log_branch, tree in zip(*[samp_log_branch, samp_trees])])\n",
    "    logp_prior = model.phylo_model.logprior(samp_log_branch)\n",
    "    logq_tree = torch.stack([model.logq_tree(tree) for tree in samp_trees])       \n",
    "    \n",
    "    \n",
    "    sampling_p = (logq_tree + logq_branch -samp_log_branch.sum(-1) ).detach().cpu().numpy()\n",
    "    prior_likelihood = (logp_prior + logll+ model.log_p_tau -samp_log_branch.sum(-1)).detach().cpu().numpy()\n",
    "    \n",
    "    vbpi_pearson_r = pearsonr(sampling_p,prior_likelihood).statistic\n",
    "    \n",
    "    trees = generate_nx_trees(samp_trees, samp_log_branch, len(model.taxa))\n",
    "    \n",
    "    r = {\n",
    "        'vbpi_sampling_p': sampling_p,\n",
    "        'vbpi_prior_likelihood': prior_likelihood,\n",
    "        'vbpi_pearson_r': vbpi_pearson_r,\n",
    "        'nx_trees': trees\n",
    "    }\n",
    "    return r\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "397a8599",
   "metadata": {},
   "source": [
    "# LOAD MODEL "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "01dbe9f9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Loading Data set: DS1 ......\n",
      "Support loaded in 91.0 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/user/vbpi-gnn/phyloModel.py:32: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at /opt/conda/conda-bld/pytorch_1682343970094/work/torch/csrc/utils/tensor_new.cpp:245.)\n",
      "  self.L, self.site_counts = map(torch.FloatTensor, self.initialCLV(data, unique_site=True))\n",
      "/user/vbpi-gnn/vector_sbnModel.py:121: UserWarning: masked_scatter_ received a mask with dtype torch.uint8, this behavior is now deprecated,please use a mask with dtype torch.bool instead. (Triggered internally at /opt/conda/conda-bld/pytorch_1682343970094/work/aten/src/ATen/native/TensorAdvancedIndexing.cpp:2279.)\n",
      "  temp_mat.masked_scatter_(self.ss_mask, self.CPD_params[self.rs_len:])\n",
      "/user/vbpi-gnn/vector_sbnModel.py:122: UserWarning: masked_fill_ received a mask with dtype torch.uint8, this behavior is now deprecated,please use a mask with dtype torch.bool instead. (Triggered internally at /opt/conda/conda-bld/pytorch_1682343970094/work/aten/src/ATen/native/TensorAdvancedIndexing.cpp:1772.)\n",
      "  masked_temp_mat = temp_mat.masked_fill(1-self.ss_mask, -float('inf'))\n",
      "/user/vbpi-gnn/vector_sbnModel.py:125: UserWarning: masked_select received a mask with dtype torch.uint8, this behavior is now deprecated,please use a mask with dtype torch.bool instead. (Triggered internally at /opt/conda/conda-bld/pytorch_1682343970094/work/aten/src/ATen/native/TensorAdvancedIndexing.cpp:1855.)\n",
      "  return masked_CPDs.masked_select(self.ss_mask), masked_CPDs\n"
     ]
    }
   ],
   "source": [
    "args = parser.parse_args('')\n",
    "\n",
    "args.dataset = 'DS1'\n",
    "args.brlen_model = 'gnn'\n",
    "args.gnn_type = 'edge'\n",
    "args.hdim = 100\n",
    "args.maxIter = 400000\n",
    "args.empFreq = True \n",
    "args.psp = True\n",
    "ufboot_support_path = 'data/ufboot_data_DS1-11/'\n",
    "data_path = 'data/hohna_datasets_fasta/'\n",
    "ground_truth_path, samp_size = 'data/raw_data_DS1-11/', 750001\n",
    "\n",
    "###### Load Data\n",
    "print('\\nLoading Data set: {} ......'.format(args.dataset))\n",
    "run_time = -time.time()\n",
    "\n",
    "if args.supportType == 'ufboot':\n",
    "    tree_dict_support, tree_names_support = summary_raw(args.dataset, ufboot_support_path)\n",
    "elif args.supportType == 'mcmc':\n",
    "    tree_dict_support, tree_names_support, _ = mcmc_treeprob(mcmc_support_path + args.dataset + '.trprobs', 'nexus', taxon='keep')\n",
    "\n",
    "data, taxa = loadData(data_path + args.dataset + '.fasta', 'fasta')\n",
    "\n",
    "run_time += time.time()\n",
    "print('Support loaded in {:.1f} seconds'.format(run_time))\n",
    "\n",
    "emp_tree_freq = None\n",
    "\n",
    "rootsplit_supp_dict, subsplit_supp_dict = get_support_from_mcmc(taxa, tree_dict_support, tree_names_support)\n",
    "del tree_dict_support, tree_names_support\n",
    "model = VBPI(taxa, rootsplit_supp_dict, subsplit_supp_dict, data, pden=np.ones(4)/4., subModel=('JC', 1.0),\n",
    "                 emp_tree_freq=emp_tree_freq, feature_dim=args.nf, psp=args.psp, hidden_dim=args.hdim, num_layers=args.hL, branch_model=args.brlen_model, gnn_type=args.gnn_type, aggr=args.aggr, project=args.proj)\n",
    "model = model.cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "4bb45fca",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/user/vbpi-gnn/vector_sbnModel.py:121: UserWarning: masked_scatter_ received a mask with dtype torch.uint8, this behavior is now deprecated,please use a mask with dtype torch.bool instead. (Triggered internally at /opt/conda/conda-bld/pytorch_1682343970094/work/aten/src/ATen/native/TensorAdvancedIndexing.cpp:2279.)\n",
      "  temp_mat.masked_scatter_(self.ss_mask, self.CPD_params[self.rs_len:])\n",
      "/user/vbpi-gnn/vector_sbnModel.py:122: UserWarning: masked_fill_ received a mask with dtype torch.uint8, this behavior is now deprecated,please use a mask with dtype torch.bool instead. (Triggered internally at /opt/conda/conda-bld/pytorch_1682343970094/work/aten/src/ATen/native/TensorAdvancedIndexing.cpp:1772.)\n",
      "  masked_temp_mat = temp_mat.masked_fill(1-self.ss_mask, -float('inf'))\n",
      "/user/vbpi-gnn/vector_sbnModel.py:125: UserWarning: masked_select received a mask with dtype torch.uint8, this behavior is now deprecated,please use a mask with dtype torch.bool instead. (Triggered internally at /opt/conda/conda-bld/pytorch_1682343970094/work/aten/src/ATen/native/TensorAdvancedIndexing.cpp:1855.)\n",
      "  return masked_CPDs.masked_select(self.ss_mask), masked_CPDs\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "-7108.3955078125"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.load_from('results/DS1/gnn/ufboot_vimco_10_edge_sum_psp_2023-09-14 20:14:33.048146.pt')\n",
    "tree_model = model.tree_model\n",
    "model.lower_bound(n_particles=1000, n_runs=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7147529f",
   "metadata": {},
   "source": [
    "# GENERATE TREES "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "0c8617bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "r_ds1_0 = generate_results(0)\n",
    "r_ds1_30 = generate_results(0.3)\n",
    "r_ds1_50 = generate_results(0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "b853608c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.9531419904261849, 0.5160804107615207, 0.32124576970252716)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "r_ds1_0['vbpi_pearson_r'], r_ds1_30['vbpi_pearson_r'], r_ds1_50['vbpi_pearson_r']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "61fada94",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "pickle.dump(r_ds1_0, open('vbpi_gnn_ds1_0_trees_result.p', 'wb'))\n",
    "pickle.dump(r_ds1_30, open('vbpi_gnn_ds1_30_trees_result.p', 'wb'))\n",
    "pickle.dump(r_ds1_50, open('vbpi_gnn_ds1_50_trees_result.p', 'wb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "734fb9fe",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:torch] *",
   "language": "python",
   "name": "conda-env-torch"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
