{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "VKuaeT-2F6H9"
   },
   "source": [
    "## Reasoning Scaling Law Example Code for Training a Small Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note this is not the full research code. Only for people to get a idea of what the core code looks like. Default to run on GPU. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 36238,
     "status": "ok",
     "timestamp": 1751815271572,
     "user": {
      "displayName": "Xinyi Wang",
      "userId": "00375459372959943015"
     },
     "user_tz": -480
    },
    "id": "dEUhaFfkF6H_"
   },
   "outputs": [],
   "source": [
    "### import relevant packages\n",
    "\n",
    "import networkx as nx\n",
    "import numpy as np\n",
    "import random\n",
    "from collections import defaultdict\n",
    "import os, json\n",
    "import copy\n",
    "import torch\n",
    "import transformers\n",
    "import matplotlib.pyplot as plt\n",
    "import itertools\n",
    "from transformers import Trainer, TrainingArguments\n",
    "from torch.utils.data import IterableDataset, get_worker_info, Dataset\n",
    "from typing import Dict, Optional, Sequence\n",
    "from sklearn.utils import shuffle\n",
    "from dataclasses import dataclass\n",
    "import torch.nn.functional as F\n",
    "from torch.nn import CrossEntropyLoss\n",
    "from transformers import LlamaForCausalLM\n",
    "from transformers.modeling_outputs import CausalLMOutputWithPast\n",
    "from transformers.cache_utils import Cache\n",
    "from typing import List, Optional, Tuple, Union"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "gA1eQp9T731U"
   },
   "source": [
    "Helper founctions:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 32,
     "status": "ok",
     "timestamp": 1751815271602,
     "user": {
      "displayName": "Xinyi Wang",
      "userId": "00375459372959943015"
     },
     "user_tz": -480
    },
    "id": "MtrdtzEUHtgs"
   },
   "outputs": [],
   "source": [
    "def set_seed(seed):\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed(seed)\n",
    "    torch.cuda.manual_seed_all(seed)\n",
    "    np.random.seed(seed)\n",
    "    random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "\n",
    "def add_edge(G, h, t, r):\n",
    "    num_edges = 0\n",
    "    if G.has_edge(h, t):\n",
    "        if r not in G[h][t]['id']:\n",
    "            G[h][t]['id'].append(r)\n",
    "            num_edges += 1\n",
    "        else:\n",
    "            print('edge already exists')\n",
    "    else:\n",
    "        G.add_edge(h, t, id=[r])\n",
    "        num_edges += 1\n",
    "    print('add edge: ', (h, r, t), 'num edges: ', num_edges)\n",
    "    return num_edges\n",
    "\n",
    "\n",
    "def generate_rules(relations, num_rules, L_min, L_max, weighted=False, temperature=0.25):\n",
    "    # Generate K acyclic logic rules with varying lengths\n",
    "    dependency_graph = defaultdict(set)\n",
    "    rules = []\n",
    "    weights = []\n",
    "    if weighted:\n",
    "        for l in range(L_min, L_max + 1):\n",
    "            weights.append(np.exp(-temperature*l))\n",
    "        probs = np.array([w / sum(weights) for w in weights])\n",
    "    else:\n",
    "        weights = [1] * (L_max - L_min + 1)\n",
    "\n",
    "    def has_cycle(start, visited, stack):\n",
    "        \"\"\"Detects if adding a new dependency introduces a cycle.\"\"\"\n",
    "        if start not in visited:\n",
    "            visited.add(start)\n",
    "            stack.add(start)\n",
    "            print('visited: ', visited)\n",
    "            print('stack: ', stack)\n",
    "            for neighbor in dependency_graph[start]:\n",
    "                if neighbor in stack:\n",
    "                    return True\n",
    "                elif has_cycle(neighbor, visited, stack):\n",
    "                    return True\n",
    "        if start in stack:\n",
    "            stack.remove(start)\n",
    "        return False\n",
    "\n",
    "    for _ in range(num_rules):\n",
    "        while True:\n",
    "            if weighted:\n",
    "                length = random.choices(range(L_min, L_max + 1), weights=weights)[0]\n",
    "            else:\n",
    "                length = random.randint(L_min, L_max)\n",
    "            rule_relations = random.choices(relations, k = length + 1) # the first element is the implied relation\n",
    "            valid_rule = True\n",
    "            for i in range(1, len(rule_relations)):\n",
    "                dependency_graph[rule_relations[0]].add(rule_relations[i])\n",
    "\n",
    "                # Check for cycles\n",
    "                if has_cycle(rule_relations[i], set(), set()):\n",
    "                    valid_rule = False\n",
    "                    for j in range(1, i + 1):\n",
    "                        dependency_graph[rule_relations[0]].remove(rule_relations[j])\n",
    "                    break\n",
    "\n",
    "            if valid_rule:\n",
    "                rules.append(tuple(rule_relations))\n",
    "                break\n",
    "\n",
    "    print('rules: ', rules)\n",
    "    return rules\n",
    "\n",
    "def get_node_types(rules, max_num_relations_per_node=3):\n",
    "    # map node types to out relations\n",
    "    node_types = {}\n",
    "    # map out relations to node types\n",
    "    r2node_types = defaultdict(list)\n",
    "    for rule in rules:\n",
    "        for i in range(len(rule)):\n",
    "            node_type = len(node_types)\n",
    "            if i == 0:\n",
    "                node_types[node_type] = [rule[i], rule[1]]\n",
    "                r2node_types[rule[i]].append(node_type)\n",
    "                r2node_types[rule[1]].append(node_type)\n",
    "            elif i == len(rule) - 1:\n",
    "                node_types[node_type] = ['-' + rule[i], '-' + rule[0]]\n",
    "                r2node_types['-' + rule[i]].append(node_type)\n",
    "                r2node_types['-' + rule[0]].append(node_type)\n",
    "            else:\n",
    "                node_types[node_type] = ['-' + rule[i], rule[i+1]]\n",
    "                r2node_types['-' + rule[i]].append(node_type)\n",
    "                r2node_types[rule[i+1]].append(node_type)\n",
    "\n",
    "    print(node_types)\n",
    "    print(r2node_types)\n",
    "\n",
    "    for num_rs in range(2, max_num_relations_per_node):\n",
    "        possible_new_node_types = []\n",
    "        for r in r2node_types:\n",
    "            alt_rs = []\n",
    "            for node_type in r2node_types[r]:\n",
    "                for _r in node_types[node_type]:\n",
    "                    if _r != r:\n",
    "                        alt_rs.append(_r)\n",
    "            alt_rs = list(set(alt_rs))\n",
    "            for node_type in r2node_types[r]:\n",
    "                if len(node_types[node_type]) == num_rs:\n",
    "                    for _r in alt_rs:\n",
    "                        if _r not in node_types[node_type]:\n",
    "                            possible_new_node_types.append(tuple(sorted([_r] + list(node_types[node_type]))))\n",
    "            print(possible_new_node_types)\n",
    "            possible_new_node_types += list(set(possible_new_node_types))\n",
    "        possible_new_node_types = list(set(possible_new_node_types))\n",
    "        print(possible_new_node_types)\n",
    "\n",
    "        for rs in possible_new_node_types:\n",
    "            new_node_type = len(node_types)\n",
    "            node_types[new_node_type] = list(rs)\n",
    "            for _r in rs:\n",
    "                r2node_types[_r].append(new_node_type)\n",
    "\n",
    "    return node_types\n",
    "\n",
    "def get_adj_out_relations(rules):\n",
    "    adj = defaultdict(list)\n",
    "    for rule in rules:\n",
    "        for i in range(len(rule)):\n",
    "            if i == 0:\n",
    "                adj[rule[i]].append(rule[1])\n",
    "                adj[rule[1]].append(rule[i])\n",
    "            elif i == len(rule) - 1:\n",
    "                adj['-' + rule[i]].append('-' + rule[0])\n",
    "                adj['-' + rule[0]].append('-' + rule[i])\n",
    "            else:\n",
    "                adj['-' + rule[i]].append(rule[i+1])\n",
    "                adj[rule[i+1]].append('-' + rule[i])\n",
    "    return adj"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "E-JuCbm_774w"
   },
   "source": [
    "Synthetic graph generation:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 114,
     "status": "ok",
     "timestamp": 1751815271726,
     "user": {
      "displayName": "Xinyi Wang",
      "userId": "00375459372959943015"
     },
     "user_tz": -480
    },
    "id": "58tty8NWf4D0"
   },
   "outputs": [],
   "source": [
    "def latent_rule_graph(num_rules=50, L_min=2, L_max=4, n=10000, m=10, n_r=200,\n",
    "                      num_test=1000, num_train=150000, check_frequency=100,\n",
    "                      power_law=False, initial_graph=None,\n",
    "                      length_weighted=False, mcmc=0.2, temperature=0.25,\n",
    "                      deductible_ratio=0.5):\n",
    "\n",
    "    relations = ['P' + str(i) for i in range(n_r)]\n",
    "    all_rules = generate_rules(relations, max(n_r//L_min, num_rules), L_min, L_max)\n",
    "    r2rules = {}\n",
    "    for rule in all_rules:\n",
    "        if rule[0] not in r2rules:\n",
    "            r2rules[rule[0]] = []\n",
    "        r2rules[rule[0]].append(rule[1:])\n",
    "    num_triples = 0\n",
    "    repeated_entities = defaultdict(list) # map in relation to entities\n",
    "    child_relations = []\n",
    "    for rule in all_rules:\n",
    "        child_relations += rule[1:]\n",
    "    child_relations = list(set(child_relations))\n",
    "    child_relations += ['-' + r for r in child_relations]\n",
    "    deductible_rules = random.sample(all_rules, num_rules)\n",
    "    if length_weighted:\n",
    "        weights = [int(100*np.exp(-temperature*len(rule))) for rule in all_rules]\n",
    "    else:\n",
    "        weights = [1 for _ in all_rules]\n",
    "    repeated_rules = []\n",
    "    for rule, weight in zip(all_rules, weights):\n",
    "        for _ in range(weight):\n",
    "            repeated_rules.append(rule)\n",
    "    random.shuffle(repeated_rules)\n",
    "    adj = get_adj_out_relations(repeated_rules)\n",
    "    all_deductibles = {}\n",
    "\n",
    "    if initial_graph is None:\n",
    "        # Default initial graph\n",
    "        G = nx.DiGraph()\n",
    "        node_id = 0\n",
    "        min_repeated_entities = 0\n",
    "        while min_repeated_entities < m:\n",
    "            for rule in all_rules:\n",
    "                source = 'Q' + str(node_id)\n",
    "                node_id += 1\n",
    "                h = source\n",
    "                for r in rule[1:]:\n",
    "                    t = 'Q' + str(node_id)\n",
    "                    node_id += 1\n",
    "                    num_triples += add_edge(G, h, t, r)\n",
    "                    repeated_entities[r].append(t)\n",
    "                    repeated_entities['-' + r].append(h)\n",
    "                    h = t\n",
    "                num_triples += add_edge(G, source, t, rule[0])\n",
    "                repeated_entities[rule[0]].append(t)\n",
    "                repeated_entities['-' + rule[0]].append(source)\n",
    "\n",
    "            min_repeated_entities = min([len(set(repeated_entities[r])) for r in child_relations])\n",
    "    else:\n",
    "        if len(initial_graph) < m or len(initial_graph) > n:\n",
    "            raise nx.NetworkXError(\n",
    "                f\"Initial graph needs between m={m} and n={n} nodes\"\n",
    "            )\n",
    "        G = initial_graph.copy()\n",
    "        node_id = len(G)\n",
    "\n",
    "    if not power_law:\n",
    "        repeated_entities = {r: list(set(repeated_entities[r])) for r in repeated_entities}\n",
    "\n",
    "    # adding nodes\n",
    "    while node_id < n:\n",
    "        source = 'Q' + str(node_id)\n",
    "        node_id += 1\n",
    "        possible_relations = [_r for _r in adj if _r in child_relations]\n",
    "        if len(possible_relations) == 0:\n",
    "            print('no adj relations')\n",
    "            break\n",
    "        print('add child edge')\n",
    "        chosen_edges = []\n",
    "        stop = False\n",
    "        for _ in range(m):\n",
    "            it = 0\n",
    "            while (r, t) in chosen_edges:\n",
    "                r = random.choice(possible_relations)\n",
    "                t = random.choice(repeated_entities[r])\n",
    "                it += 1\n",
    "                if it > 100:\n",
    "                    print('failed to find edge')\n",
    "                    stop = True\n",
    "                    break\n",
    "            if stop or len(possible_relations) == 0:\n",
    "                break\n",
    "\n",
    "            possible_relations = [_r for _r in adj[r] if _r in child_relations]\n",
    "            chosen_edges.append((r, t))\n",
    "            if r[0] == '-':\n",
    "                num_triples += add_edge(G, t, source, r[1:])\n",
    "                repeated_entities[r[1:]].append(source)\n",
    "            else:\n",
    "                num_triples += add_edge(G, source, t, r)\n",
    "                repeated_entities['-' + r].append(source)\n",
    "            repeated_entities[r].append(t)\n",
    "            if len(possible_relations) == 0:\n",
    "                print('no adj relations')\n",
    "                break\n",
    "\n",
    "        if not power_law:\n",
    "            repeated_entities = {r: list(set(repeated_entities[r])) for r in repeated_entities}\n",
    "\n",
    "        if node_id % check_frequency == 0 or node_id == n-1:\n",
    "            # add deductibles\n",
    "            all_nodes = list(G.nodes)\n",
    "            random.shuffle(all_nodes)\n",
    "            for h in all_nodes:\n",
    "                for rule in deductible_rules:\n",
    "                    head_list = [h]\n",
    "                    r = rule[0]\n",
    "\n",
    "                    for _r in rule[1:]:\n",
    "                        next_head_list = []\n",
    "                        for e_h in head_list:\n",
    "                            if e_h not in G.nodes:\n",
    "                                continue\n",
    "                            for e_t in G[e_h]:\n",
    "                                if _r in G[e_h][e_t]['id']:\n",
    "                                    if random.random() < mcmc:\n",
    "                                        next_head_list.append(e_t)\n",
    "                        head_list = next_head_list\n",
    "\n",
    "                    for t in head_list:\n",
    "                        if (h, r, t) not in all_deductibles:\n",
    "                            all_deductibles[(h, r, t)] = [rule]\n",
    "                        elif rule not in all_deductibles[(h, r, t)]:\n",
    "                            all_deductibles[(h, r, t)].append(rule)\n",
    "                        if not G.has_edge(h, t) or r not in G[h][t]['id']:\n",
    "                            print('add deductible edge')\n",
    "                            add_edge(G, h, t, r)\n",
    "                            num_triples += 1\n",
    "                            repeated_entities[r].append(t)\n",
    "                            repeated_entities['-' + r].append(h)\n",
    "\n",
    "    atomic_triples = []\n",
    "    deductible_triples = []\n",
    "    for h, t in G.edges:\n",
    "        for r in G[h][t]['id']:\n",
    "            if (h, r, t) not in all_deductibles:\n",
    "                atomic_triples.append((h, r, t))\n",
    "            else:\n",
    "                deductible_triples.append((h, r, t))\n",
    "    random.shuffle(atomic_triples)\n",
    "    random.shuffle(deductible_triples)\n",
    "    assert len(atomic_triples) >= int(num_train * (1-deductible_ratio))\n",
    "    assert len(deductible_triples) >= int(num_train * deductible_ratio) + 2 * num_test\n",
    "\n",
    "    remove_triples = []\n",
    "    train_atomic_triples = atomic_triples[:int(num_train * (1-deductible_ratio))]\n",
    "    remove_triples += atomic_triples[int(num_train * (1-deductible_ratio)):]\n",
    "    train_deductible_triples = deductible_triples[:int(num_train * deductible_ratio)]\n",
    "    remove_triples += deductible_triples[int(num_train * deductible_ratio):]\n",
    "\n",
    "    for h, r, t in remove_triples:\n",
    "        _t = t\n",
    "        rs = G[h][_t]['id']\n",
    "        if r in rs:\n",
    "            if len(rs) == 1:\n",
    "                G.remove_edge(h, _t)\n",
    "            else:\n",
    "                G[h][_t]['id'].remove(r)\n",
    "\n",
    "    train_triples = train_deductible_triples + train_atomic_triples\n",
    "    random.shuffle(train_triples)\n",
    "    print(\"num train triples: \", len(train_triples))\n",
    "\n",
    "    r2rule = {}\n",
    "    for rule in deductible_rules:\n",
    "        if rule[0] in r2rule:\n",
    "            r2rule[rule[0]].append(rule[1:])\n",
    "        else:\n",
    "            r2rule[rule[0]] = [rule[1:]]\n",
    "\n",
    "    def check_deductible(triple):\n",
    "        h, r, t = triple\n",
    "        alt_ts = []\n",
    "        for rule in r2rule[r]:\n",
    "            head_list = [h]\n",
    "            for _r in rule:\n",
    "                next_head_list = []\n",
    "                for e_h in head_list:\n",
    "                    for e_t in G[e_h]:\n",
    "                        if _r in G[e_h][e_t]['id']:\n",
    "                            next_head_list.append(e_t)\n",
    "                head_list = next_head_list\n",
    "            alt_ts += head_list\n",
    "        if t in alt_ts:\n",
    "            return True\n",
    "        return False\n",
    "\n",
    "    id_test_triples = []\n",
    "    for i in range(int(num_train * deductible_ratio), len(deductible_triples)):\n",
    "        if check_deductible(deductible_triples[i]):\n",
    "            id_test_triples.append(deductible_triples[i])\n",
    "        if len(id_test_triples) == num_test:\n",
    "            break\n",
    "\n",
    "    id_test_rules = [all_deductibles[triple] for triple in id_test_triples]\n",
    "    print(\"num id test triples: \", len(id_test_triples))\n",
    "\n",
    "    rule2triples = defaultdict(list)\n",
    "    for triple in deductible_triples[i+1:]:\n",
    "        for rule in all_deductibles[triple]:\n",
    "            rule2triples[rule].append(triple)\n",
    "\n",
    "    # uniformly sample testing triples from each rule\n",
    "    uniform_test_triples = []\n",
    "    for rule in rule2triples:\n",
    "        triples = []\n",
    "        for triple in rule2triples[rule]:\n",
    "            if check_deductible(triple):\n",
    "                triples.append(triple)\n",
    "\n",
    "        if len(triples) > num_test//len(rule2triples):\n",
    "            uniform_test_triples += random.sample(triples, num_test//len(rule2triples))\n",
    "        else:\n",
    "            uniform_test_triples += triples\n",
    "\n",
    "    random.shuffle(uniform_test_triples)\n",
    "    uniform_test_rules = [all_deductibles[triple] for triple in uniform_test_triples]\n",
    "    print(\"num uniform test triples: \", len(uniform_test_triples))\n",
    "\n",
    "    return G, deductible_rules, train_triples, id_test_triples, id_test_rules, uniform_test_triples, uniform_test_rules"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "SxUZzyra8m7k"
   },
   "source": [
    "Data class for synthetic graph:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 25,
     "status": "ok",
     "timestamp": 1751815271754,
     "user": {
      "displayName": "Xinyi Wang",
      "userId": "00375459372959943015"
     },
     "user_tz": -480
    },
    "id": "z-tRvGo6SJbY"
   },
   "outputs": [],
   "source": [
    "class LatentRuleGraph:\n",
    "    def __init__(self,\n",
    "                 n=1000, n_r=40, m=5, n_rules=30, n_triples=10000,\n",
    "                 num_test=1000, L_min=2, L_max=4, power_law=False,\n",
    "                 length_weighted=False, mcmc=1.0,\n",
    "                 temperature=0.25, deductible_ratio=0.5, seed=42):\n",
    "        self.n = n\n",
    "        self.n_r = n_r\n",
    "        self.n_triples = n_triples\n",
    "        self.n_rules = n_rules\n",
    "        self.num_test = num_test\n",
    "        self.L_min = L_min\n",
    "        self.L_max = L_max\n",
    "        self.power_law = power_law\n",
    "        self.m = m\n",
    "        self.length_weighted = length_weighted\n",
    "        self.mcmc = mcmc\n",
    "        self.temperature = temperature\n",
    "        self.deductible_ratio = deductible_ratio\n",
    "        self.seed = seed\n",
    "        random.seed(seed)\n",
    "        self.G = nx.DiGraph()\n",
    "        self.load_data()\n",
    "        self.all_es = list(self.G.nodes)\n",
    "        self.all_rs = set()\n",
    "        for h, t, r_dict in self.G.edges(data=True):\n",
    "            for r in r_dict['id']:\n",
    "                self.all_rs.add(r)\n",
    "        self.triple_complet_file = None\n",
    "\n",
    "    def load_data(self):\n",
    "        self.triples = []\n",
    "        self.id_test_triples = []\n",
    "        self.uniform_test_triples = []\n",
    "        self.id_alt_ts = []\n",
    "        self.uniform_alt_ts = []\n",
    "        self.rules = []\n",
    "        self.id_test_rules = []\n",
    "        self.uniform_test_rules = []\n",
    "\n",
    "\n",
    "        self.G, self.rules, self.triples, \\\n",
    "        self.id_test_triples, self.id_test_rules, \\\n",
    "        self.uniform_test_triples, self.uniform_test_rules = latent_rule_graph(\n",
    "            num_rules=self.n_rules, L_min=self.L_min, L_max=self.L_max,\n",
    "            n=self.n, n_r=self.n_r, m=self.m,\n",
    "            num_test=self.num_test, num_train=self.n_triples,\n",
    "            power_law=self.power_law,\n",
    "            length_weighted=self.length_weighted, mcmc=self.mcmc,\n",
    "            deductible_ratio=self.deductible_ratio, temperature=self.temperature)\n",
    "\n",
    "        r2rule = {}\n",
    "        for rule in self.rules:\n",
    "            if rule[0] in r2rule:\n",
    "                r2rule[rule[0]].append(rule[1:])\n",
    "            else:\n",
    "                r2rule[rule[0]] = [rule[1:]]\n",
    "\n",
    "        def get_alt_ts(h, r, t):\n",
    "            alt_ts = []\n",
    "            for rule in r2rule[r]:\n",
    "                head_list = [h]\n",
    "                for _r in rule:\n",
    "                    next_head_list = []\n",
    "                    for e_h in head_list:\n",
    "                        for e_t in self.G[e_h]:\n",
    "                            if _r in self.G[e_h][e_t]['id']:\n",
    "                                next_head_list.append(e_t)\n",
    "                    head_list = next_head_list\n",
    "                alt_ts += head_list\n",
    "            return alt_ts\n",
    "\n",
    "        for h, r, t in self.id_test_triples:\n",
    "            alt_ts = get_alt_ts(h, r, t)\n",
    "            self.id_alt_ts.append(alt_ts)\n",
    "\n",
    "        for h, r, t in self.uniform_test_triples:\n",
    "            alt_ts = get_alt_ts(h, r, t)\n",
    "            self.uniform_alt_ts.append(alt_ts)\n",
    "\n",
    "        self.mem_triples = random.sample(self.triples, k=self.num_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "VwfKAW2q_bF9"
   },
   "source": [
    "Create a new synthetic graph:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 12931,
     "status": "ok",
     "timestamp": 1751815284686,
     "user": {
      "displayName": "Xinyi Wang",
      "userId": "00375459372959943015"
     },
     "user_tz": -480
    },
    "id": "u_l4knpm_YIH",
    "outputId": "0c9d6138-ea3b-4427-d33d-5f6c86bd5866"
   },
   "outputs": [],
   "source": [
    "graph = LatentRuleGraph(\n",
    "        n=2000,\n",
    "        n_r=50,\n",
    "        n_triples=10000,\n",
    "        n_rules=20,\n",
    "        L_min=2,\n",
    "        L_max=4,\n",
    "        power_law=True,\n",
    "        deductible_ratio=0.5,\n",
    "        length_weighted=False,\n",
    "        m=6,\n",
    "        num_test=1000,\n",
    "        temperature=0.25,\n",
    "        mcmc=1.0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0Usl6GWN8s0g"
   },
   "source": [
    "Training data class for synthetic graph:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 6,
     "status": "ok",
     "timestamp": 1751815284702,
     "user": {
      "displayName": "Xinyi Wang",
      "userId": "00375459372959943015"
     },
     "user_tz": -480
    },
    "id": "iAKmfbSrR5in"
   },
   "outputs": [],
   "source": [
    "class TrainDataset(IterableDataset):\n",
    "    \"\"\"\n",
    "    Iterable dataset that returns constant length chunks of tokens from stream of text files.\n",
    "        Args:\n",
    "            tokenizer (Tokenizer): The processor used for proccessing the data.\n",
    "            dataset (dataset.Dataset): Dataset with text files.\n",
    "            infinite (bool): If True the iterator is reset after dataset reaches end else stops.\n",
    "            seq_length (int): Length of token sequences to return.\n",
    "            num_of_sequences (int): Number of token sequences to keep in buffer.\n",
    "            chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer.\n",
    "            tokenized (bool): If true we use a pretokenized dataset.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        graph, # generated graph\n",
    "        tokenizer,\n",
    "        seq_length=256,\n",
    "        num_of_sequences=1024,\n",
    "        chars_per_token=3.6,\n",
    "        seed=42,\n",
    "    ):\n",
    "        super(TrainDataset, self).__init__()\n",
    "\n",
    "        self.tokenizer = tokenizer\n",
    "        self.seq_length = seq_length\n",
    "        self.epoch = 0\n",
    "        self.current_size = 0\n",
    "        self.num_buffer_sequences = num_of_sequences\n",
    "        self.max_buffer_size = seq_length * chars_per_token * num_of_sequences\n",
    "        self.seed = seed\n",
    "        self.data = graph\n",
    "\n",
    "        print(\"max buffer size: \", self.max_buffer_size)\n",
    "\n",
    "    def set_epoch(self, worker_id):\n",
    "        set_seed(self.seed + self.epoch + worker_id) # int(time.time())\n",
    "\n",
    "    def triple2str(self, triple):\n",
    "        if type(triple[0]) == int or type(triple[1]) == int or type(triple[2]) == int:\n",
    "            return f'Q{triple[0]} P{triple[1]} Q{triple[2]}'\n",
    "        else:\n",
    "            return ' '.join(list(triple))\n",
    "\n",
    "    def iter_fun(self, worker_id=0):\n",
    "        num_sents = len(self.data.triples)\n",
    "        while True:\n",
    "            i = random.randint(0, num_sents-1)\n",
    "            triple = self.data.triples[i]\n",
    "            text = self.triple2str(triple) + '\\n'\n",
    "            if text is None:\n",
    "                print(\"cannot translate \", triple, \" into text.\")\n",
    "                continue\n",
    "            yield text\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.data.triples)\n",
    "\n",
    "    def __iter__(self):\n",
    "        more_examples = True\n",
    "        try:\n",
    "            worker_info = get_worker_info()\n",
    "            print(worker_info)\n",
    "            worker_id = worker_info.id\n",
    "        except:\n",
    "            worker_id = 0\n",
    "        self.set_epoch(worker_id)\n",
    "        iterator = self.iter_fun(worker_id=worker_id)\n",
    "        print(\"worker id: \", )\n",
    "\n",
    "        while more_examples:\n",
    "            buffer, buffer_len = [], 0\n",
    "            while True:\n",
    "                if buffer_len >= self.max_buffer_size:\n",
    "                    print(\"data buffer full\")\n",
    "                    break\n",
    "                try:\n",
    "                    buffer.append(next(iterator))\n",
    "                    buffer_len += len(buffer[-1])\n",
    "                except StopIteration:\n",
    "                    self.epoch += 1\n",
    "                    self.set_epoch(worker_id)\n",
    "                    iterator = self.iter_fun()\n",
    "                    print(f\"Dataset epoch: {self.epoch}\")\n",
    "            # print(buffer[:3])\n",
    "\n",
    "            input_lens = []\n",
    "            random.shuffle(buffer)\n",
    "            tokenized_inputs = self.tokenizer(buffer,\n",
    "                                padding=False,\n",
    "                                max_length=self.seq_length,\n",
    "                                truncation=True)[\"input_ids\"]\n",
    "            for tokenized_input in tokenized_inputs:\n",
    "                input_ids = tokenized_input + [self.tokenizer.eos_token_id]\n",
    "                input_lens.append(len(input_ids))\n",
    "                self.current_size += 1\n",
    "                yield dict(input_ids=torch.tensor(input_ids), labels=torch.tensor(input_ids))\n",
    "            print(\"average example length: \", np.mean(input_lens))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 8,
     "status": "ok",
     "timestamp": 1751815284717,
     "user": {
      "displayName": "Xinyi Wang",
      "userId": "00375459372959943015"
     },
     "user_tz": -480
    },
    "id": "B4sdu4gs_Zyq",
    "outputId": "e36eb113-b35c-49f7-ba4b-e1476a37450c"
   },
   "outputs": [],
   "source": [
    "train_dataset = TrainDataset(\n",
    "        graph,\n",
    "        tokenizer=None,\n",
    "        seq_length=128,\n",
    "        num_of_sequences=1024,\n",
    "        chars_per_token=3.6,\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "aY8D6o8N8vlQ"
   },
   "source": [
    "Tokenizer class:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 46,
     "status": "ok",
     "timestamp": 1751815284764,
     "user": {
      "displayName": "Xinyi Wang",
      "userId": "00375459372959943015"
     },
     "user_tz": -480
    },
    "id": "JWZC4rdF8oGZ"
   },
   "outputs": [],
   "source": [
    "class BaseTokenizer:\n",
    "    def __init__(self, n=1, vocab=None, padding_side='right', add_special_tokens=False):\n",
    "        self.n = n\n",
    "        if vocab is None:\n",
    "            self.vocab = self.build_vocab()\n",
    "        else:\n",
    "            self.vocab = vocab\n",
    "        self.rev_vocab = {v: k for k, v in self.vocab.items()}\n",
    "        self.padding_side = padding_side\n",
    "        self.add_special_tokens = add_special_tokens\n",
    "        self.bos_token = '<BOS>'\n",
    "        self.bos_token_id = self.vocab['<BOS>']\n",
    "        self.eos_token = '<EOS>'\n",
    "        self.eos_token_id = self.vocab['<EOS>']\n",
    "        self.pad_token = '<PAD>'\n",
    "        self.pad_token_id = self.vocab['<PAD>']\n",
    "        self.unk_token = '<UNK>'\n",
    "        self.unk_token_id = self.vocab['<UNK>']\n",
    "        self.all_special_ids = [self.bos_token_id, self.eos_token_id,\n",
    "                                self.pad_token_id, self.unk_token_id]\n",
    "        self.all_special_tokens = self.all_special_tokens_extended = [\n",
    "            self.bos_token, self.eos_token,\n",
    "            self.pad_token, self.unk_token]\n",
    "\n",
    "    def build_vocab(self):\n",
    "        pass\n",
    "\n",
    "    def tokenize(self, text: str, max_length: int):\n",
    "        pass\n",
    "\n",
    "    def encode(self, text, padding=False, max_length=1024, return_tensors=None, truncation=True):\n",
    "        if type(text) == str:\n",
    "            ids = [self.tokenize(text, max_length)]\n",
    "        else:\n",
    "            ids = []\n",
    "            lens = []\n",
    "            for t in text:\n",
    "                _ids = self.tokenize(t, max_length)\n",
    "                ids.append(_ids)\n",
    "                lens.append(len(_ids))\n",
    "\n",
    "            if padding:\n",
    "                max_length = max(lens)\n",
    "                for _ids in ids:\n",
    "                    if len(_ids) < max_length:\n",
    "                        if self.padding_side == 'left':\n",
    "                            _ids = [self.pad_token_id] * (max_length - len(_ids)) + _ids\n",
    "                        elif self.padding_side == 'right':\n",
    "                            _ids += [self.pad_token_id] * (max_length - len(_ids))\n",
    "                        else:\n",
    "                            raise NotImplementedError\n",
    "\n",
    "        if return_tensors == 'pt':\n",
    "            ids = torch.tensor(ids)\n",
    "\n",
    "        return ids\n",
    "\n",
    "    def __call__(self, text, padding=False, max_length=1024, return_tensors=None, truncation=True, device='cpu'):\n",
    "        if type(text) == str:\n",
    "            ids = [self.tokenize(text, max_length)]\n",
    "            attns = [[1] * len(ids[0])]\n",
    "        else:\n",
    "            ids = []\n",
    "            attns = []\n",
    "            lens = []\n",
    "            for t in text:\n",
    "                _ids = self.tokenize(t, max_length)\n",
    "                ids.append(_ids)\n",
    "                lens.append(len(_ids))\n",
    "                attns.append([1] * len(_ids))\n",
    "\n",
    "            if padding:\n",
    "                max_length = max(lens)\n",
    "                padded_ids = []\n",
    "                padded_attns = []\n",
    "                for _ids, attn in zip(ids, attns):\n",
    "                    num_pad = max_length - len(_ids)\n",
    "                    if self.padding_side == 'left':\n",
    "                        padded_ids.append([self.pad_token_id] * num_pad + _ids)\n",
    "                        padded_attns.append([0] * num_pad + attn)\n",
    "                    elif self.padding_side == 'right':\n",
    "                        padded_ids.append(_ids + [self.pad_token_id] * num_pad)\n",
    "                        padded_attns.append(attn + [0] * num_pad)\n",
    "                    else:\n",
    "                        raise NotImplementedError\n",
    "                ids = padded_ids\n",
    "                attns = padded_attns\n",
    "\n",
    "        if return_tensors == 'pt':\n",
    "            ids = torch.tensor(ids).to(device)\n",
    "            attns = torch.tensor(attns).to(device)\n",
    "\n",
    "        return {\"input_ids\": ids, 'attention_mask': attns}\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.vocab)\n",
    "\n",
    "    def decode(self, token_ids, skip_special_tokens=False):\n",
    "        if type(token_ids) == int:\n",
    "            return self.rev_vocab[token_ids]\n",
    "        else:\n",
    "            out = ''\n",
    "            for i in token_ids:\n",
    "                if i == self.eos_token_id:\n",
    "                    if not skip_special_tokens:\n",
    "                        out += self.eos_token\n",
    "                    break\n",
    "                if skip_special_tokens and i in self.all_special_ids:\n",
    "                    continue\n",
    "                out += self.rev_vocab[i]\n",
    "            return out\n",
    "\n",
    "    def batch_decode(self, sequences, skip_special_tokens=False):\n",
    "        out = []\n",
    "        for token_ids in sequences:\n",
    "            out.append(self.decode(token_ids, skip_special_tokens))\n",
    "        return out\n",
    "\n",
    "    def save_pretrained(self, output_dir):\n",
    "        with open(f'{output_dir}/tokenizer.json', 'w') as wf:\n",
    "            json.dump(self.vocab, wf, indent = 4)\n",
    "\n",
    "    @classmethod\n",
    "    def from_pretrained(cls, pretrained_model_name_or_path, padding_side='right', trust_remote_code=False, revision=None):\n",
    "        vocab_path = f\"{pretrained_model_name_or_path}/tokenizer.json\"\n",
    "        if os.path.exists(vocab_path):\n",
    "            vocab = json.load(open(vocab_path))\n",
    "            n = 1\n",
    "            for token in vocab:\n",
    "                if token not in ['<BOS>', '<EOS>', '<PAD>', '<UNK>']:\n",
    "                    if '_' in token:\n",
    "                        n = max(n, int(token.split('_')[1]) + 1)\n",
    "                    else:\n",
    "                        n = max(n, len(token))\n",
    "\n",
    "            return cls(n, vocab, padding_side=padding_side)\n",
    "        else:\n",
    "            return cls(padding_side=padding_side)\n",
    "\n",
    "class CharTokenizer(BaseTokenizer):\n",
    "    def __init__(self, n=1, vocab=None, padding_side='right', add_special_tokens=False):\n",
    "        super().__init__(n, vocab, padding_side, add_special_tokens)\n",
    "\n",
    "    def build_vocab(self):\n",
    "        vocab = {'Q':0, 'P':1}\n",
    "        for i in range(10):\n",
    "            vocab[str(i)] = i+2\n",
    "        vocab_size = 12\n",
    "        vocab['\\n'] = vocab_size\n",
    "        vocab_size += 1\n",
    "        vocab[' '] = vocab_size\n",
    "        vocab_size += 1\n",
    "        vocab['-'] = vocab_size\n",
    "        vocab_size += 1\n",
    "        vocab['?'] = vocab_size\n",
    "        vocab_size += 1\n",
    "        vocab['<BOS>'] = vocab_size\n",
    "        vocab_size += 1\n",
    "        vocab['<EOS>'] = vocab_size\n",
    "        vocab_size += 1\n",
    "        vocab['<PAD>'] = vocab_size\n",
    "        vocab_size += 1\n",
    "        vocab['<UNK>'] = vocab_size\n",
    "\n",
    "        return vocab\n",
    "\n",
    "    def tokenize(self, text: str, max_length: int):\n",
    "        ids = []\n",
    "        for l in text.split('\\n'):\n",
    "            if len(l) == 0:\n",
    "                continue\n",
    "            for w in l.split():\n",
    "                for c in w.strip():\n",
    "                    if c not in self.vocab:\n",
    "                        ids.append(self.unk_token_id)\n",
    "                    else:\n",
    "                        ids.append(self.vocab[c])\n",
    "                ids.append(self.vocab[' '])\n",
    "            ids.append(self.vocab['\\n'])\n",
    "\n",
    "        if self.add_special_tokens:\n",
    "            ids.append(self.vocab['<EOS>'])\n",
    "        else:\n",
    "            ids = ids[:-2]\n",
    "        # print(ids)\n",
    "        if max_length < len(ids):\n",
    "            return ids[:max_length]\n",
    "        else:\n",
    "            return ids"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Sdn-qc9783PF"
   },
   "source": [
    "Helper functions for training and evaluation:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 22,
     "status": "ok",
     "timestamp": 1751815284797,
     "user": {
      "displayName": "Xinyi Wang",
      "userId": "00375459372959943015"
     },
     "user_tz": -480
    },
    "id": "5j7oSdDU-mdE"
   },
   "outputs": [],
   "source": [
    "IGNORE_INDEX = -100\n",
    "DEFAULT_PAD_TOKEN = \"[PAD]\"\n",
    "DEFAULT_EOS_TOKEN = \"</s>\"\n",
    "DEFAULT_BOS_TOKEN = \"</s>\"\n",
    "DEFAULT_UNK_TOKEN = \"<unk>\"\n",
    "\n",
    "def model_path_map(model_name):\n",
    "    return '../llms/' + model_name\n",
    "\n",
    "def count_params(model):\n",
    "    params: int = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "    return params\n",
    "\n",
    "def compute_llama_param(l, h, v):\n",
    "    d = 64 * h\n",
    "    embd = d * v\n",
    "    atten = 4*d*d\n",
    "    mlp = 2*d*d*3\n",
    "    ln = d\n",
    "    return l * (atten + mlp + 2*ln) + ln + 2*embd\n",
    "\n",
    "@dataclass\n",
    "class DataCollatorForSupervisedDataset(object):\n",
    "    \"\"\"Collate examples for supervised fine-tuning.\"\"\"\n",
    "\n",
    "    tokenizer: transformers.PreTrainedTokenizer\n",
    "\n",
    "    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:\n",
    "        input_ids, labels = tuple([instance[key] for instance in instances]\n",
    "                                  for key in (\"input_ids\", \"labels\"))\n",
    "        input_ids = torch.nn.utils.rnn.pad_sequence(\n",
    "            input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id\n",
    "        )\n",
    "        labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True,\n",
    "                                                 padding_value=IGNORE_INDEX)\n",
    "\n",
    "        attn_mask = input_ids.ne(self.tokenizer.pad_token_id)\n",
    "\n",
    "        # print(\"input_ids: \", input_ids)\n",
    "        # print(\"labels: \", labels)\n",
    "        # print(\"atten mask: \", attn_mask)\n",
    "\n",
    "        return dict(\n",
    "            input_ids=input_ids,\n",
    "            labels=labels,\n",
    "            attention_mask=attn_mask,\n",
    "        )\n",
    "\n",
    "\n",
    "def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer, max_length: int) -> Dict:\n",
    "    \"\"\"Tokenize a list of strings.\"\"\"\n",
    "    tokenized_list = [\n",
    "        tokenizer(\n",
    "            text,\n",
    "            return_tensors=\"pt\",\n",
    "            padding=\"longest\",\n",
    "            max_length=max_length,\n",
    "            truncation=True,\n",
    "            # pad_to_multiple_of=8,\n",
    "        )[\"input_ids\"]\n",
    "        for text in strings\n",
    "    ]\n",
    "    input_ids = labels = [tokenized[0] for tokenized in tokenized_list]\n",
    "    input_ids_lens = labels_lens = [\n",
    "        tokenized.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list\n",
    "    ]\n",
    "    return dict(\n",
    "        input_ids=input_ids,\n",
    "        labels=labels,\n",
    "        input_ids_lens=input_ids_lens,\n",
    "        labels_lens=labels_lens,\n",
    "    )\n",
    "\n",
    "\n",
    "def prepare_data(\n",
    "    sources: Sequence[str],\n",
    "    targets: Sequence[str],\n",
    "    tokenizer: transformers.PreTrainedTokenizer,\n",
    "    max_length: int,\n",
    ") -> Dict:\n",
    "    \"\"\"Preprocess the data by tokenizing.\"\"\"\n",
    "    examples = [s + t for s, t in zip(sources, targets)]\n",
    "    examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer, max_length)\n",
    "                                             for strings in (examples, sources)]\n",
    "    eos = torch.tensor([tokenizer.eos_token_id])\n",
    "    input_ids = [torch.cat((ids, eos)) for ids in examples_tokenized[\"input_ids\"]]\n",
    "    labels = copy.deepcopy(input_ids)\n",
    "    for label, source_len in zip(labels, sources_tokenized[\"input_ids_lens\"]):\n",
    "        label[:source_len] = IGNORE_INDEX\n",
    "    return dict(input_ids=input_ids, labels=labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "_4SzSxLe8_Dh"
   },
   "source": [
    "Training function (set bf16 to False if you are not using GPUs):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 82,
     "status": "ok",
     "timestamp": 1751815284880,
     "user": {
      "displayName": "Xinyi Wang",
      "userId": "00375459372959943015"
     },
     "user_tz": -480
    },
    "id": "2PaLLB1bR5si"
   },
   "outputs": [],
   "source": [
    "def train(train_dataset, model_name_or_path='llama-2-2', random_initialize=True,\n",
    "          output_dir='.', bf16=True, device='cuda'):\n",
    "\n",
    "    set_seed(42) # make sure use the same model initialization\n",
    "    l, h, v = None, None, None\n",
    "\n",
    "    if random_initialize:\n",
    "        print(\"Random initializing...\")\n",
    "        model_name, l, h = model_name_or_path.split('-')\n",
    "        l, h = int(l), int(h)\n",
    "        d = 64 * h\n",
    "        if model_name == 'llama':\n",
    "            config = transformers.LlamaConfig(hidden_size=d,\n",
    "                                            intermediate_size=2*d,\n",
    "                                            num_attention_heads=h,\n",
    "                                            num_hidden_layers=l)\n",
    "        else:\n",
    "            raise NotImplemented\n",
    "\n",
    "\n",
    "        tokenizer = CharTokenizer()\n",
    "        config.vocab_size = len(tokenizer.vocab)\n",
    "        config.bos_token_id = tokenizer.bos_token_id\n",
    "        config.eos_token_id = tokenizer.eos_token_id\n",
    "        print(\"vocab size: \", len(tokenizer.vocab))\n",
    "        print(\"new config: \", config)\n",
    "\n",
    "        v = config.vocab_size\n",
    "        model = transformers.AutoModelForCausalLM.from_config(config, torch_dtype=torch.bfloat16)\n",
    "        print(\"embedding size: \", model.get_input_embeddings().weight.data.shape)\n",
    "    else:\n",
    "        print(\"Using pre-trained model weights...\")\n",
    "\n",
    "        tokenizer = CharTokenizer.from_pretrained(model_name_or_path)\n",
    "\n",
    "        model = transformers.AutoModelForCausalLM.from_pretrained(\n",
    "            model_name_or_path\n",
    "        ).to(device)\n",
    "\n",
    "    if l is not None and h is not None and v is not None:\n",
    "        print(\"theoretical # params: \", compute_llama_param(l, h, v))\n",
    "    print(\"actual # params: \", count_params(model))\n",
    "\n",
    "    train_dataset.tokenizer = tokenizer\n",
    "\n",
    "    data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)\n",
    "\n",
    "    train_args = TrainingArguments(bf16=bf16, max_steps=1000,\n",
    "                      per_device_train_batch_size=32, eval_strategy=\"no\",\n",
    "                      save_steps=1000, save_total_limit=1, learning_rate=1e-4,\n",
    "                      weight_decay=0.0, warmup_ratio=0.2, lr_scheduler_type=\"cosine\",\n",
    "                      logging_steps=1, output_dir=output_dir, report_to=\"none\")\n",
    "\n",
    "    trainer = Trainer(model=model, tokenizer=tokenizer, args=train_args,\n",
    "                    train_dataset=train_dataset, data_collator=data_collator,\n",
    "                    eval_dataset=None)\n",
    "\n",
    "    if not random_initialize:\n",
    "        print(\"resume training from: \", model_name_or_path)\n",
    "        trainer.train(model_name_or_path)\n",
    "    else:\n",
    "        trainer.train()\n",
    "    trainer.save_state()\n",
    "    trainer.save_model(output_dir=output_dir)\n",
    "    return model, tokenizer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "8NWaGVfcH0uu"
   },
   "source": [
    "Train a 2-layer language model on the generated synthetic graph:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "collapsed": true,
    "id": "HJjCPy8zBts8",
    "outputId": "7eb080e0-5998-4976-9ea9-d01c5e815b87"
   },
   "outputs": [],
   "source": [
    "model, tokenizer = train(train_dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "bLkdCY9CHUvr"
   },
   "source": [
    "Evaluation data class:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Tps9kqoIFuf8"
   },
   "outputs": [],
   "source": [
    "class EvalDataset(Dataset):\n",
    "\n",
    "    def __init__(self,\n",
    "                graph,\n",
    "                tokenizer,\n",
    "                split=\"id\", # or \"uniform\", or \"mem\"\n",
    "                num_options=10,\n",
    "                use_rule_length=False,\n",
    "                seed=42):\n",
    "\n",
    "        super(EvalDataset, self).__init__()\n",
    "        self.split = split\n",
    "        self.tokenizer = tokenizer\n",
    "        self.eos_token = self.tokenizer.eos_token\n",
    "        self.num_options = num_options\n",
    "        self.num_test = graph.num_test\n",
    "        set_seed(seed)\n",
    "        self.path_length = []\n",
    "\n",
    "        self.data = graph\n",
    "        if split == 'id':\n",
    "            self.triples = self.data.id_test_triples\n",
    "            self.alt_ts = self.data.id_alt_ts\n",
    "        elif split == 'uniform':\n",
    "            self.triples = self.data.uniform_test_triples\n",
    "            self.alt_ts = self.data.uniform_alt_ts\n",
    "        elif split == \"mem\":\n",
    "            self.triples = self.mem_triples\n",
    "        else:\n",
    "            print(\"no such split: \", split)\n",
    "            raise NotImplementedError\n",
    "\n",
    "        if use_rule_length:\n",
    "            print(\"using rule length\")\n",
    "            self.path_length = [min([len(rule) - 1 for rule in rules]) for rules in self.data.test_rules]\n",
    "\n",
    "        self.get_data()\n",
    "        if len(self.path_length) == 0:\n",
    "            self.get_path_length()\n",
    "\n",
    "    def get_path_length(self):\n",
    "        for h, r, t in self.input_triples:\n",
    "            try:\n",
    "                l = nx.shortest_path_length(self.data.G, source=h, target=t)\n",
    "            except:\n",
    "                l = 0\n",
    "                print(f'cannot find shortest path between {h} and {t}')\n",
    "            self.path_length.append(l)\n",
    "        print(\"avg path length: \", np.mean(self.path_length))\n",
    "\n",
    "    def get_data(self):\n",
    "        self.input_text = []\n",
    "        self.input_triples = []\n",
    "        self.seen_ts = []\n",
    "        self.options = []\n",
    "\n",
    "        for idx, triple in enumerate(self.triples):\n",
    "            h, r, t = triple\n",
    "\n",
    "            if self.split == \"mem\":\n",
    "                seen_ts = []\n",
    "                if h in self.data.G:\n",
    "                    for e in self.data.G[h]:\n",
    "                        if r in self.data.G[h][e]['id']:\n",
    "                            seen_ts.append(e)\n",
    "            else:\n",
    "                seen_ts = self.alt_ts[idx]\n",
    "            self.seen_ts.append(seen_ts)\n",
    "\n",
    "            question = h + ' ' + r + ' '\n",
    "            ans = t\n",
    "\n",
    "            options = [ans]\n",
    "            for i in range(self.num_options-1):\n",
    "                neg_e = random.choice(self.data.all_es)\n",
    "                while neg_e == ans or neg_e in seen_ts:\n",
    "                    neg_e = random.choice(self.data.all_es)\n",
    "                options.append(neg_e)\n",
    "\n",
    "            self.input_text.append(question)\n",
    "            random.shuffle(options)\n",
    "            self.options.append(options)\n",
    "\n",
    "            self.input_triples.append(triple)\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.input_text)\n",
    "\n",
    "    def __getitem__(self, i):\n",
    "        example = [self.input_text[i], self.input_triples[i], self.seen_ts[i]]\n",
    "        example += [self.options[i]]\n",
    "        example.append(self.path_length[i])\n",
    "        return example"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "WlBrHd-iHZMd"
   },
   "source": [
    "Create a evaluation dataset based on the synthetic graph generated before:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ksgtknd-PpkC"
   },
   "outputs": [],
   "source": [
    "eval_dataset = EvalDataset(graph, tokenizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "zp1GYdroJrho"
   },
   "source": [
    "The evaluation function (set device to \"cuda\" if you have GPUs):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "7bJvX4JlJWA0"
   },
   "outputs": [],
   "source": [
    "def eval(eval_dataset, model, batch_size=16, max_length=64, num_test=1000, device=\"cuda\"):\n",
    "\n",
    "    model.eval()\n",
    "\n",
    "    def collect_data(instances, device='cuda'):\n",
    "        input_ids = instances[\"input_ids\"]\n",
    "        labels = instances[\"labels\"]\n",
    "        input_ids = torch.nn.utils.rnn.pad_sequence(\n",
    "            input_ids, batch_first=True, padding_value=tokenizer.pad_token_id\n",
    "        )\n",
    "        labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True,\n",
    "                                                padding_value=IGNORE_INDEX)\n",
    "\n",
    "        attn_mask = input_ids.ne(tokenizer.pad_token_id)\n",
    "\n",
    "        return dict(\n",
    "            input_ids=input_ids.to(device),\n",
    "            labels=labels.to(device),\n",
    "            attention_mask=attn_mask.to(device),\n",
    "        )\n",
    "\n",
    "    num_choices = eval_dataset.num_options\n",
    "    print(\"number of choices: \", num_choices)\n",
    "\n",
    "    input_texts = []\n",
    "    output_texts = []\n",
    "    gts = []\n",
    "    example_ids = []\n",
    "    id = 0\n",
    "    num_correct = 0\n",
    "    num_all = 0\n",
    "    losses = []\n",
    "\n",
    "    for q, triple, seen_t, opts, l in eval_dataset:\n",
    "        id += 1\n",
    "        if id > min(num_test, eval_dataset.num_test):\n",
    "            break\n",
    "        print(q, triple, seen_t, opts, l)\n",
    "        label = triple[-1]\n",
    "        input_text = q\n",
    "        for op in opts:\n",
    "            input_texts.append(input_text)\n",
    "            output_texts.append(op)\n",
    "            gts.append(op == label)\n",
    "            example_ids.append(id)\n",
    "\n",
    "        if len(input_texts) >= batch_size or id == min(len(eval_dataset)-1, num_test):\n",
    "            data_dict = prepare_data(input_texts, output_texts, tokenizer, max_length)\n",
    "            input_data = collect_data(data_dict, device)\n",
    "            print(\"input data shape: \", input_data['input_ids'].shape)\n",
    "            logits = model(**input_data).logits.detach().cpu()\n",
    "            print(\"logits shape: \", logits.shape)\n",
    "            labels = input_data['labels']\n",
    "            shift_logits = logits[..., :-1, :].contiguous()\n",
    "            shift_labels = labels[..., 1:].contiguous()\n",
    "            # Flatten the tokens\n",
    "            loss_fct = CrossEntropyLoss(reduction='none')\n",
    "            shift_logits = shift_logits.view(-1, len(tokenizer.vocab))\n",
    "            shift_labels = shift_labels.view(-1)\n",
    "            # Enable model parallelism\n",
    "            shift_labels = shift_labels.to(shift_logits.device)\n",
    "            loss = loss_fct(shift_logits, shift_labels)\n",
    "            loss = loss.view([labels.size(0), labels.size(1) - 1])\n",
    "            loss = loss.sum(-1)\n",
    "            for i in range(len(input_texts)//num_choices):\n",
    "                pred = torch.argmin(loss[i*num_choices: (i+1)*num_choices]).item()\n",
    "                gt = np.arange(num_choices)[gts[i*num_choices: (i+1)*num_choices]][0]\n",
    "                losses.append(loss[i*num_choices: (i+1)*num_choices][gt])\n",
    "                if pred == gt:\n",
    "                    num_correct += 1\n",
    "                num_all += 1\n",
    "\n",
    "            acc = num_correct/num_all\n",
    "            print(\"Accuracy: \", acc)\n",
    "            mean_loss = np.mean(losses)\n",
    "            print(\"Loss: \", mean_loss)\n",
    "\n",
    "            input_texts = []\n",
    "            output_texts = []\n",
    "            gts = []\n",
    "            example_ids = []\n",
    "\n",
    "    return acc, mean_loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "VQRbkUPfOj46"
   },
   "source": [
    "Evaluate the previously trained language model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "lWS9yJEaHTBO"
   },
   "outputs": [],
   "source": [
    "acc, loss = eval(eval_dataset, model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "D_FCtkTJPI7e"
   },
   "source": [
    "Sweeping function to train and evaluate a range of model sizes and plot the loss curve and acc curve (Need GPUs to run):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "pDvf2laCOiLW"
   },
   "outputs": [],
   "source": [
    "def scaling_sweep(train_dataset, eval_dataset, y_axis='loss'):\n",
    "    accs = []\n",
    "    losses = []\n",
    "    error_losses = []\n",
    "    sizes = []\n",
    "    for i in range(1, 8):\n",
    "        for j in [i-1, i, i+1]:\n",
    "            if j > 0 and j < 8:\n",
    "                num_p = compute_llama_param(2**i, 2**j, 19)\n",
    "                model, tokenizer = train(train_dataset, f'llama-{2**i}-{2**j}')\n",
    "                acc, loss = eval(eval_dataset, model)\n",
    "                accs.append(acc)\n",
    "                losses.append(loss)\n",
    "                sizes.append(num_p)\n",
    "\n",
    "    idx = np.argsort(sizes)\n",
    "    accs = np.array(accs)[idx]\n",
    "    losses = np.array(losses)[idx]\n",
    "    sizes = np.array(sizes)[idx]\n",
    "    log_sizes = np.log(sizes)\n",
    "    print('log sizes: ', log_sizes)\n",
    "    print('accs: ', accs)\n",
    "    print('losses: ', losses)\n",
    "\n",
    "    # plot acc\n",
    "    plt.plot(log_sizes, accs, 'o-', linewidth=2)\n",
    "    plt.xticks(np.log(sizes), (sizes/1000000).round(1), size=8)\n",
    "    plt.xlabel(f'Llama model size (M)', size=12)\n",
    "    plt.ylabel(f'Accuracy', size=12)\n",
    "    plt.legend(fontsize=12)\n",
    "    plt.show()\n",
    "    plt.savefig('acc.png')\n",
    "    plt.close()\n",
    "\n",
    "    # plot loss\n",
    "    plt.plot(log_sizes, losses, 'o-', linewidth=2)\n",
    "    plt.xticks(np.log(sizes), (sizes/1000000).round(1), size=8)\n",
    "    plt.xlabel(f'Llama model size (M)', size=12)\n",
    "    plt.ylabel(f'Loss', size=12)\n",
    "    plt.legend(loc='upper left', fontsize=12)\n",
    "    plt.show()\n",
    "    plt.savefig('loss.png')\n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "l7h6QycrSe82"
   },
   "outputs": [],
   "source": [
    "scaling_sweep(train_dataset, eval_dataset)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "nanotron",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
