{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import numpy as np\n",
    "import os\n",
    "import pickle\n",
    "import random\n",
    "import sympy\n",
    "import uuid\n",
    "\n",
    "from collections import namedtuple\n",
    "from datasets.base import KnownEquation\n",
    "from datasets.sampling import DefaultSampling\n",
    "from eq.conversion import sympy2sequence, sequence2model\n",
    "from eq.eval import convert_pred_sequence_to_eqs\n",
    "from eq.vocabulary import SOS_TOKEN\n",
    "from sympy.core import numbers\n",
    "from sympy.core import power\n",
    "from torchdistill.common.file_util import get_file_path_list, make_parent_dirs\n",
    "from torchdistill.common.yaml_util import load_yaml_file\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Variables"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "DATASET_CONFIG_FILE_PATH = os.path.expanduser('./configs/datasets/random/feynman_lm.yaml')\n",
    "REF_EQ_DIR_PATH = os.path.expanduser('~/dataset/symbolic_regression/proposed/random_split/full_set/true_eq/')\n",
    "NGRAM = 2\n",
    "NULL_TOKEN = 'Null'\n",
    "TOKEN_DELIMITER = ' '\n",
    "MAX_SEQ_LENGTH = 30\n",
    "MAX_NUM_EQS = 5000\n",
    "NUM_VARIATIONS = 10\n",
    "Node = namedtuple('Node', ['value', 'op_str'])\n",
    "\n",
    "\n",
    "TRAIN_RATIO, VAL_RATIO, TEST_RATIO = 8, 1, 1\n",
    "total = TRAIN_RATIO + VAL_RATIO + TEST_RATIO\n",
    "TRAIN_RATIO /= total\n",
    "VAL_RATIO /= total\n",
    "TEST_RATIO /= total"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_eq_tree_seqs(eq_dir_path):\n",
    "    eq_tree_seq_list = list()\n",
    "    for eq_file_path in get_file_path_list(eq_dir_path, is_sorted=True):\n",
    "        with open(eq_file_path, 'rb') as fp:\n",
    "            eq_sympy = pickle.load(fp)\n",
    "        eq_tree_sequence = sympy2sequence(eq_sympy.evalf(), returns_binary_tree=True)\n",
    "        eq_tree_seq_list.append(eq_tree_sequence)\n",
    "    return eq_tree_seq_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "ref_eq_tree_seqs = load_eq_tree_seqs(REF_EQ_DIR_PATH)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Build n-gram language model using naive Bayes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def update_freq_dict(eq_tree_seq, ngram, freq_dict):\n",
    "    ngram_list = [NULL_TOKEN] * (ngram - 1) + [SOS_TOKEN]\n",
    "    for token in eq_tree_seq:\n",
    "        ngram_list.pop(0)\n",
    "        given_str = TOKEN_DELIMITER.join(ngram_list)\n",
    "        ngram_list.append(token)\n",
    "        random_str = token\n",
    "        if given_str not in freq_dict:\n",
    "            freq_dict[given_str] = dict()\n",
    "            \n",
    "        given_dict = freq_dict[given_str]\n",
    "        if random_str not in given_dict:\n",
    "            given_dict[random_str] = 0\n",
    "        given_dict[random_str] += 1\n",
    "\n",
    "\n",
    "def build_ngram_lm(eq_tree_seqs, ngram):\n",
    "    freq_dict = dict()\n",
    "    for eq_tree_seq in eq_tree_seqs:\n",
    "        update_freq_dict(eq_tree_seq, ngram, freq_dict)\n",
    "        \n",
    "    ngram_lm = dict()\n",
    "    for given_str, sub_dict in freq_dict.items():\n",
    "        denominator = sum(sub_dict.values())\n",
    "        pairs = [(v / denominator, k) for k, v in sub_dict.items()]\n",
    "        pairs = sorted(pairs, key=lambda x: x[0], reverse=True)\n",
    "        base_value = 0\n",
    "        node_list = list()\n",
    "        for (v, k) in pairs:\n",
    "            base_value += v\n",
    "            node = Node(base_value, k)\n",
    "            node_list.append(node)\n",
    "        ngram_lm[given_str] = node_list\n",
    "    return ngram_lm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "ngram_lm = build_ngram_lm(ref_eq_tree_seqs, NGRAM)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Randomly generate equations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_random_equation(ngram_lm, ngram, max_seq_length, token_delimiter):\n",
    "    ngram_list = [NULL_TOKEN] * (ngram - 1) + [SOS_TOKEN]\n",
    "    op_list = list()\n",
    "    for i in range(max_seq_length):\n",
    "        ngram_list.pop(0)\n",
    "        given_str = token_delimiter.join(ngram_list)\n",
    "        if given_str not in ngram_lm:\n",
    "            return None\n",
    "        \n",
    "        random_nodes = ngram_lm[given_str]\n",
    "        random_value = random.random()\n",
    "        op_found = False\n",
    "        for random_node in random_nodes:\n",
    "            if random_value < random_node.value:\n",
    "                op_list.append(random_node.op_str)\n",
    "                ngram_list.append(random_node.op_str)\n",
    "                op_found = True\n",
    "                break\n",
    "        \n",
    "        if not op_found:\n",
    "            print(f'random value: {random_value}, random nodes: {random_nodes}')\n",
    "            \n",
    "        try:\n",
    "            random_sr_model, parent_stack = sequence2model(op_list, returns_parent_stack=True)\n",
    "            if len(parent_stack) == 0:\n",
    "                sympy_eq_str = random_sr_model.sympy_str()\n",
    "                random_eq = sympy.sympify(sympy_eq_str)\n",
    "                return random_eq, op_list\n",
    "        except:\n",
    "            pass\n",
    "    return None\n",
    "\n",
    "\n",
    "def reindex_variables(op_list):\n",
    "    numbers = [int(op[1:]) for op in op_list if op.startswith('x') and op[1:].isdigit()]\n",
    "    numbers = sorted(numbers)\n",
    "    max_num = len(numbers) - 1\n",
    "    var_dict = {f'x{number}': f'x{i}' for i, number in enumerate(numbers)}\n",
    "    op_list = [var_dict.get(op, op) for op in op_list]\n",
    "    return op_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████| 5000/5000 [00:32<00:00, 156.21it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4501 random equations generated, and 1841 of them are unique w.r.t. their equation tree\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "random_eq_list = list()\n",
    "random_eq_set = set()\n",
    "for _ in tqdm(range(MAX_NUM_EQS)):\n",
    "    output = generate_random_equation(ngram_lm, NGRAM, MAX_SEQ_LENGTH, TOKEN_DELIMITER)\n",
    "    if output is not None:\n",
    "        random_eq, op_list = output\n",
    "        random_eq_list.append(random_eq)\n",
    "        op_list = reindex_variables(op_list)\n",
    "        eq_key = '\\t'.join(op_list)\n",
    "        random_eq_set.add(eq_key)\n",
    "\n",
    "print(f'{len(random_eq_list)} random equations generated, and {len(random_eq_set)} of them are unique w.r.t. their equation tree')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create datasets using the randomly generated equations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "def generate_random_sampling_objs(num_vars):\n",
    "    sampling_obj_list = list()\n",
    "    for _ in range(num_vars):\n",
    "        random_int = random.randint(-32, 32)\n",
    "        uses_negative = random.random() < 0.5\n",
    "        sampling_obj = \\\n",
    "            DefaultSampling(np.power(10.0, random_int - 1), np.power(10.0, random_int + 1), uses_negative=uses_negative)\n",
    "        sampling_obj_list.append(sampling_obj)\n",
    "    return sampling_obj_list\n",
    "\n",
    "\n",
    "def random_init_constants(random_eq, sub_eq=None, parent_op=None):\n",
    "    if sub_eq is None:\n",
    "        sub_eq = random_eq\n",
    "        \n",
    "    if isinstance(sub_eq, numbers.Float):\n",
    "        const_value = random.random() * math.pow(10, random.uniform(-32, 32))\n",
    "        if isinstance(parent_op, power.Pow):\n",
    "            const_value = random.randint(2, 5)\n",
    "            if random.random() < 0.5:\n",
    "                const_value *= -1\n",
    "            const_value = float(const_value)\n",
    "        random_eq = random_eq.subs(sub_eq, const_value)\n",
    "    \n",
    "    for i in range(len(sub_eq.args)):\n",
    "        random_eq = random_init_constants(random_eq, sub_eq.args[i], sub_eq)\n",
    "    return random_eq\n",
    "\n",
    "\n",
    "def split_dataset(dataset, train_ratio, val_ratio, test_ratio):\n",
    "    total = train_ratio + val_ratio + test_ratio\n",
    "    train_ratio /= total\n",
    "    val_ratio /= total\n",
    "    num_samples = len(dataset)\n",
    "    num_train_samples = int(train_ratio * num_samples)\n",
    "    num_val_samples = int(val_ratio * num_samples)\n",
    "    num_test_samples = num_samples - (num_train_samples + num_val_samples)\n",
    "    train_dataset = dataset[:num_train_samples] if num_train_samples > 0 else None\n",
    "    val_dataset = dataset[num_train_samples:num_train_samples + num_val_samples] if num_val_samples > 0 else None\n",
    "    test_dataset = dataset[-num_test_samples:] if num_test_samples > 0 else None\n",
    "    return train_dataset, val_dataset, test_dataset\n",
    "\n",
    "\n",
    "def generate_dataset(eq_instance, eq_name, dataset_config, default_train_ratio, default_val_ratio, default_test_ratio):\n",
    "#     print('\\n====================================')\n",
    "#     print(f'Generating dataset `{eq_name}` ...')\n",
    "#     print(dataset_config)\n",
    "\n",
    "    # Generate tabular dataset\n",
    "    try:\n",
    "        dataset = eq_instance.create_dataset(dataset_config['sample_size'])\n",
    "    except:\n",
    "#         print(f'{eq_instance.sympy_eq} could not create a dataset')\n",
    "        return False\n",
    "    \n",
    "    train_ratio = dataset_config.get('train_ratio', default_train_ratio)\n",
    "    val_ratio = dataset_config.get('val_ratio', default_val_ratio)\n",
    "    test_ratio = dataset_config.get('test_ratio', default_test_ratio)\n",
    "    train_dataset, val_dataset, test_dataset = split_dataset(dataset, train_ratio, val_ratio, test_ratio)\n",
    "\n",
    "    # Write out each split\n",
    "    output_dir_path = os.path.expanduser(dataset_config['output_dir'])\n",
    "    output_ext = dataset_config['output_ext']\n",
    "    delimiter = dataset_config.get('output_delim', '\\t' if output_ext == '.tsv' else ' ')\n",
    "    for sub_dataset, split_name in zip((train_dataset, val_dataset, test_dataset), ('train', 'val', 'test')):\n",
    "        if sub_dataset is None:\n",
    "            continue\n",
    "\n",
    "#         print(f'Writing out {len(sub_dataset)} samples for {split_name} split')\n",
    "        output_file_path = os.path.join(output_dir_path, split_name, eq_name + output_ext)\n",
    "        make_parent_dirs(output_file_path)\n",
    "        # Save tabular dataset\n",
    "        np.savetxt(output_file_path, sub_dataset, delimiter=delimiter)\n",
    "\n",
    "    # Save ground-truth sympy expression\n",
    "    pickle_file_path = os.path.join(output_dir_path, 'true_eq', eq_name + '.pkl')\n",
    "    make_parent_dirs(pickle_file_path)\n",
    "    with open(pickle_file_path, 'wb') as fp:\n",
    "        pickle.dump(eq_instance.sympy_eq, fp)\n",
    "    return True\n",
    "\n",
    "\n",
    "def generate_datasets_from_eq(random_eq, num_trials, base_random_eq_name, dataset_config, train_ratio, val_ratio, test_ratio):\n",
    "    success_count = 0\n",
    "    num_vars = len(random_eq.free_symbols)\n",
    "    for i in range(num_trials):\n",
    "        random_eq_name = f'{base_random_eq_name}-{i}'\n",
    "        random_eq = random_init_constants(random_eq)\n",
    "#         print(random_eq)\n",
    "        sampling_objs = generate_random_sampling_objs(num_vars)\n",
    "        random_eq_instance = KnownEquation.from_sympy_eq(random_eq, sampling_objs, reindexes=True)\n",
    "        success = generate_dataset(random_eq_instance, random_eq_name, dataset_config, \n",
    "                                   train_ratio, val_ratio, test_ratio)\n",
    "        if success:\n",
    "            success_count += 1\n",
    "    return success_count"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_config = load_yaml_file(DATASET_CONFIG_FILE_PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████| 4501/4501 [1:51:55<00:00,  1.49s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "24232 datasets were created.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "success_count = 0\n",
    "for i, random_eq in enumerate(tqdm(random_eq_list)):\n",
    "    success_count += generate_datasets_from_eq(random_eq, NUM_VARIATIONS, f'random-{i}', dataset_config, TRAIN_RATIO, VAL_RATIO, TEST_RATIO)\n",
    "\n",
    "print(f'{success_count} datasets were created.')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "symbolic-regression-for-mis",
   "language": "python",
   "name": "symbolic-regression-for-mis"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
