{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "### Discrete CoT Model, MNNS Task"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "e4b6f22056f671bb"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "import torch\n",
    "import random\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from transformers import GPT2Config, GPT2LMHeadModel\n",
    "from torch.optim import AdamW\n",
    "import torch.nn as nn\n",
    "import numpy as np\n",
    "\n",
    "# seed everything for reproducibility\n",
    "SEED = 42\n",
    "random.seed(SEED)\n",
    "torch.manual_seed(SEED)\n",
    "torch.cuda.manual_seed(SEED)\n",
    "torch.cuda.manual_seed_all(SEED)  # for multi-GPU\n",
    "torch.backends.cudnn.deterministic = True\n",
    "torch.backends.cudnn.benchmark = False"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "7026ad87a5345b9c"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "def tokenize_line(line, token2id_map):\n",
    "    tokens = line.strip().split()\n",
    "    return [token2id_map[t] for t in tokens]\n",
    "\n",
    "class MathExpressionDataset(Dataset):\n",
    "    def __init__(self, tokenized_samples, max_len, token2id):\n",
    "        self.samples = tokenized_samples\n",
    "        self.max_len = max_len\n",
    "        self.token2id = token2id\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.samples)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        token_ids = self.samples[idx]\n",
    "        # Truncate if longer than max_len\n",
    "        if len(token_ids) > self.max_len:\n",
    "            token_ids = token_ids[:self.max_len]\n",
    "\n",
    "        # Create attention mask\n",
    "        attention_mask = [1] * len(token_ids)\n",
    "\n",
    "        # Pad if shorter\n",
    "        while len(token_ids) < self.max_len:\n",
    "            token_ids.append(self.token2id[\"<PAD>\"])\n",
    "            attention_mask.append(0)\n",
    "\n",
    "        input_ids = torch.tensor(token_ids, dtype=torch.long)\n",
    "        attention_mask = torch.tensor(attention_mask, dtype=torch.long)\n",
    "\n",
    "        # For a causal LM, labels are the same as input_ids\n",
    "        return {\n",
    "            \"input_ids\": input_ids,\n",
    "            \"attention_mask\": attention_mask,\n",
    "            \"labels\": input_ids.clone()\n",
    "        }"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "bf6d35d52085955f"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "import itertools\n",
    "\n",
    "def generate_vocab():\n",
    "    special_tokens = [\"<PAD>\", \"<BOS>\", \"<EOS>\", \"->\", \"+\", \"-\"]\n",
    "\n",
    "    # Digits 0..9\n",
    "    digit_tokens = [f\"D{i}\" for i in range(10)]\n",
    "\n",
    "    # Partial sums from -36..36\n",
    "    sum_tokens = [f\"S{i}\" for i in range(-36, 37)]\n",
    "\n",
    "    # Combine them into a single list\n",
    "    vocab = special_tokens + digit_tokens + sum_tokens\n",
    "\n",
    "    # Create mapping from token to ID and back\n",
    "    token2id = {token: idx for idx, token in enumerate(vocab)}\n",
    "    id2token = {idx: token for token, idx in token2id.items()}\n",
    "\n",
    "    return vocab, token2id, id2token\n",
    "\n",
    "def generate_text_dataset(digit_range=range(1, 6), seq_length=4):\n",
    "    dataset_text = []\n",
    "\n",
    "    for seq in itertools.product(digit_range, repeat=seq_length):\n",
    "        best_final_sum = None\n",
    "        best_partial_sums = None\n",
    "\n",
    "        # Try all sign patterns for the 4 digits (2^4 = 16),\n",
    "        # starting partial_sum = 0, then apply +/- for each digit in seq.\n",
    "        for signs in itertools.product([\"+\", \"-\"], repeat=seq_length):\n",
    "            partial_sum = 0\n",
    "            partial_sums = [partial_sum]  # [0, x, y, ...]\n",
    "\n",
    "            for i in range(seq_length):\n",
    "                if signs[i] == \"+\":\n",
    "                    partial_sum += seq[i]\n",
    "                else:\n",
    "                    partial_sum -= seq[i]\n",
    "                partial_sums.append(partial_sum)\n",
    "\n",
    "            # Check if final sum is >= 0\n",
    "            if partial_sum >= 0:\n",
    "                # If this is the first non-negative final sum found\n",
    "                # or if it's smaller than our current best\n",
    "                if best_final_sum is None or partial_sum < best_final_sum:\n",
    "                    best_final_sum = partial_sum\n",
    "                    best_partial_sums = partial_sums\n",
    "\n",
    "        # If we found at least one sign pattern that yields a non-negative sum,\n",
    "        # record the best partial sums in textual form.\n",
    "        if best_final_sum is not None:\n",
    "            digit_seq_tokens = [f\"D{d}\" for d in seq]\n",
    "            # omit the initial partial sum (index 0) so only 4 sums remain\n",
    "            # best_partial_sums has length 5 => skip best_partial_sums[0]\n",
    "            sum_seq_tokens = [f\"S{ps}\" for ps in best_partial_sums[1:]]\n",
    "\n",
    "            line_tokens = [\"<BOS>\"] + digit_seq_tokens + [\"->\"] + sum_seq_tokens + [\"<EOS>\"]\n",
    "            line_text = \" \".join(line_tokens)\n",
    "            dataset_text.append(line_text)\n",
    "    \n",
    "    return dataset_text"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "862dfac05b8ea74b"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "def encode_prompt(digit_seq, token2id):\n",
    "    # We'll build a prompt like: <BOS> D5 D3 D2 D4 ->\n",
    "    tokens = [\"<BOS>\"] + [f\"D{d}\" for d in digit_seq] + [\"->\"]\n",
    "    return [token2id[t] for t in tokens]\n",
    "\n",
    "def decode_tokens(token_ids, id2token):\n",
    "    return [id2token[i] for i in token_ids]"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "95156cd7bd05ee64"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "initial_id",
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def get_token_loss(outputs, labels, seq_length):\n",
    "    # Minimal additional code to get token-level losses:\n",
    "    logits = outputs.logits\n",
    "    # Shift the logits and labels by one for causal LM:\n",
    "    shifted_logits = logits[..., :-1, :].contiguous()\n",
    "    shifted_labels = labels[..., 1:].contiguous()\n",
    "\n",
    "    loss_fct = nn.CrossEntropyLoss(reduction=\"none\")\n",
    "    per_token_loss = loss_fct(\n",
    "        shifted_logits.view(-1, shifted_logits.size(-1)),\n",
    "        shifted_labels.view(-1)\n",
    "    )\n",
    "    # Reshape into (batch_size, sequence_length - 1) for easy interpretation\n",
    "    per_token_loss = per_token_loss.view(shifted_labels.size())\n",
    "    return per_token_loss[:, seq_length+1:2*seq_length+1].sum(dim=0).cpu().detach().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "\n",
    "def permutation_train_val_split(dataset_text, train_ratio=0.8):\n",
    "    \"\"\"\n",
    "    This function takes a dataset of text lines, groups them by the sorted digits\n",
    "    contained in each line, shuffles the groups, and splits them into training\n",
    "    and validation sets based on a specified ratio.\n",
    "    \"\"\"\n",
    "    # Group lines by sorted digits\n",
    "    groups = defaultdict(list)\n",
    "    for line in dataset_text:\n",
    "        tokens = line.split()\n",
    "        arrow_idx = tokens.index(\"->\")\n",
    "        digit_tokens = tokens[1:arrow_idx]  # ignoring <BOS>\n",
    "        # parse digits from lines like \"D5\"\n",
    "        digits = tuple(sorted(int(dt[1:]) for dt in digit_tokens))\n",
    "        groups[digits].append(line)\n",
    "\n",
    "    # Shuffle the group keys\n",
    "    group_keys = list(groups.keys())\n",
    "    random.shuffle(group_keys)\n",
    "\n",
    "    # Split group keys 80/20\n",
    "    split_idx = int(train_ratio * len(group_keys))\n",
    "    train_keys = group_keys[:split_idx]\n",
    "    val_keys = group_keys[split_idx:]\n",
    "\n",
    "    # Gather lines\n",
    "    train_lines = []\n",
    "    val_lines = []\n",
    "    for k in train_keys:\n",
    "        train_lines.extend(groups[k])\n",
    "    for k in val_keys:\n",
    "        val_lines.extend(groups[k])\n",
    "\n",
    "    return train_lines, val_lines"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "7c7bb3c2a6c91b8e"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "def parse_line(line):\n",
    "    tokens = line.split()  # e.g. [\"<BOS>\", \"D5\", \"D3\", \"D2\", \"D4\", \"->\", \"S5\", \"S2\", \"S4\", \"S0\", \"<EOS>\"]\n",
    "    arrow_idx = tokens.index(\"->\")  # location of '->'\n",
    "\n",
    "    # digits: everything after <BOS> up to '->'\n",
    "    digit_tokens = tokens[1:arrow_idx]  # ignore <BOS>\n",
    "    # partial sums: everything after '->' up to <EOS>\n",
    "    sum_tokens = tokens[arrow_idx + 1:-1]  # ignore <EOS>\n",
    "\n",
    "    # Convert \"D5\" -> 5, \"S5\" -> 5, etc.\n",
    "    digits = [int(dt[1:]) for dt in digit_tokens]  # strip the first char 'D'\n",
    "    partial_sums = [int(st[1:]) for st in sum_tokens]  # strip the first char 'S'\n",
    "    return digits, partial_sums\n",
    "\n",
    "def generate_partial_sums_step_by_step(\n",
    "    model,\n",
    "    digits,\n",
    "    token2id,\n",
    "    id2token,\n",
    "    device,\n",
    "    max_sums=4,\n",
    "    do_sample=False,\n",
    "    temperature=1.0\n",
    "):\n",
    "    # Build the initial prompt (no partial sums yet).\n",
    "    # Example: \"<BOS> D5 D3 D2 D4 ->\"\n",
    "    prompt_tokens = [\"<BOS>\"] + [f\"D{d}\" for d in digits] + [\"->\"]\n",
    "\n",
    "    # Convert each token string to its ID.\n",
    "    input_ids = torch.tensor([[token2id[t] for t in prompt_tokens]], dtype=torch.long).to(device)\n",
    "\n",
    "    predicted_sums = []\n",
    "    for _ in range(max_sums):\n",
    "        # Generate exactly 1 token from the model (greedy).\n",
    "        # pad_token_id is important to avoid warnings if the sequence grows.\n",
    "        out = model.generate(\n",
    "            input_ids=input_ids,\n",
    "            max_new_tokens=1,\n",
    "            do_sample=do_sample,\n",
    "            pad_token_id=token2id[\"<PAD>\"],\n",
    "            temperature=temperature\n",
    "        )\n",
    "        # The last generated token is out[0, -1].\n",
    "        new_token_id = out[0, -1].item()\n",
    "        new_token_str = id2token[new_token_id]\n",
    "\n",
    "        # If it looks like \"Sxxx\", parse out the integer value, else store None.\n",
    "        if new_token_str.startswith(\"S\"):\n",
    "            val = int(new_token_str[1:])  # e.g. \"S5\" -> 5\n",
    "        else:\n",
    "            val = None\n",
    "\n",
    "        predicted_sums.append(val)\n",
    "\n",
    "        # Update input_ids to include the newly generated token\n",
    "        input_ids = out\n",
    "\n",
    "    return predicted_sums\n",
    "\n",
    "\n",
    "def evaluate_model(model, test_dataset_text, token2id, id2token, device):\n",
    "    \"\"\"\n",
    "    For each line in the test dataset:\n",
    "      - Parse out the digits and the final 'ground-truth' partial sums.\n",
    "      - Generate partial sums step by step.\n",
    "      - Check if the final predicted sum == the final ground-truth sum:\n",
    "    \"\"\"\n",
    "    correct_count = 0\n",
    "\n",
    "    for line in test_dataset_text:\n",
    "        digits, gt_sums = parse_line(line)\n",
    "        # ground_truth_final = minimal non-negative sum (or whatever is in the dataset)\n",
    "        ground_truth_final = gt_sums[-1]\n",
    "\n",
    "        # Model's predicted partial sums\n",
    "        predicted_sums = generate_partial_sums_step_by_step(\n",
    "            model, digits, token2id, id2token, device, max_sums=len(gt_sums)\n",
    "        )\n",
    "\n",
    "        # Check validity of partial sums and final sum\n",
    "        # if is_valid_path(digits, predicted_sums) and (predicted_sums[-1] == ground_truth_final):\n",
    "        #     correct_count += 1\n",
    "        \n",
    "        if predicted_sums[-1] == ground_truth_final:  # For fair comparison, compare the last token only instead\n",
    "            correct_count += 1\n",
    "\n",
    "    accuracy = correct_count / len(test_dataset_text)\n",
    "    return accuracy, correct_count"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "ed81a99d5f080d53"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "SEQ_LENGTH = 4\n",
    "MAX_SEQ_LEN = 2 * SEQ_LENGTH + 3\n",
    "EMBEDDING_DIM = 24\n",
    "DIGIT_RANGE = range(1, 10) # 1..9 digits\n",
    "BATCH_SIZE = 16\n",
    "SPLIT_METHOD = \"random_permutation\"  \n",
    "NUM_EPOCHS = 1000\n",
    "OUTPUT_DIR = \"moss-test\"\n",
    "NUM_LAYERS = 2\n",
    "NUM_HEADS = 2\n",
    "\n",
    "print(f\"Max Seq Len: {MAX_SEQ_LEN}, \"\n",
    "          f\"Embedding Dim: {EMBEDDING_DIM}, \"\n",
    "          f\"Digit Range: {DIGIT_RANGE}, \"\n",
    "          f\"Batch Size: {BATCH_SIZE}, \"\n",
    "          f\"Seq Length: {SEQ_LENGTH}, \"\n",
    "          f\"Split Method: {SPLIT_METHOD}, \"\n",
    "          f\"Num Epochs: {NUM_EPOCHS}, \"\n",
    "          f\"Output Dir: {OUTPUT_DIR}, \"\n",
    "          f\"Num Layers: {NUM_LAYERS}, \"\n",
    "          f\"Num Heads: {NUM_HEADS}\")"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "2c168d1b3f1d183a"
  },
  {
   "cell_type": "markdown",
   "source": [
    "#### Train for 1000 epochs on MNNS task, we'll compare final validation accuracies."
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "59229e5dae3c631c"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "# Generate vocab, dataset\n",
    "vocab, token2id, id2token = generate_vocab()\n",
    "dataset_text = generate_text_dataset(DIGIT_RANGE, SEQ_LENGTH)\n",
    "vocab_size = len(vocab)\n",
    "\n",
    "print(f\"Vocab size = {vocab_size}\")\n",
    "print(f\"Number of valid sequences in dataset: {len(dataset_text)}\")\n",
    "print(\"Sample line:\", dataset_text[0])\n",
    "\n",
    "train_lines, val_lines = permutation_train_val_split(dataset_text, train_ratio=0.8)\n",
    "train_data = [tokenize_line(line, token2id) for line in train_lines]\n",
    "val_data = [tokenize_line(line, token2id) for line in val_lines]\n",
    "random.shuffle(train_data)\n",
    "random.shuffle(val_data)\n",
    "\n",
    "# Create the model configuration\n",
    "config = GPT2Config(\n",
    "    vocab_size=vocab_size,\n",
    "    n_positions=MAX_SEQ_LEN,\n",
    "    n_embd=EMBEDDING_DIM,\n",
    "    n_layer=NUM_LAYERS,\n",
    "    n_head=NUM_HEADS\n",
    ")\n",
    "\n",
    "# Instantiate the model\n",
    "model = GPT2LMHeadModel(config)\n",
    "\n",
    "# Create datasets and loaders\n",
    "train_dataset = MathExpressionDataset(train_data, max_len=MAX_SEQ_LEN, token2id=token2id)\n",
    "val_dataset = MathExpressionDataset(val_data, max_len=MAX_SEQ_LEN, token2id=token2id)\n",
    "\n",
    "# Create DataLoader for training and validation\n",
    "train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)\n",
    "val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)\n",
    "\n",
    "optimizer = AdamW(model.parameters(), lr=1e-4)\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "model.to(device)\n",
    "\n",
    "train_losses = []\n",
    "train_losses_tokens = []\n",
    "val_losses = []\n",
    "val_losses_tokens = []\n",
    "val_accuracies = []\n",
    "for epoch in range(NUM_EPOCHS):\n",
    "    model.train()\n",
    "    total_loss = 0.\n",
    "    train_loss_tokens = np.zeros((SEQ_LENGTH, ))\n",
    "    for batch in train_loader:\n",
    "        input_ids = batch[\"input_ids\"].to(device)\n",
    "        attention_mask = batch[\"attention_mask\"].to(device)\n",
    "        labels = batch[\"labels\"].to(device)\n",
    "\n",
    "        outputs = model(\n",
    "            input_ids=input_ids,\n",
    "            attention_mask=attention_mask,\n",
    "            labels=labels\n",
    "        )\n",
    "        loss = outputs.loss\n",
    "\n",
    "        #Calculate token-level loss\n",
    "        train_loss_tokens += get_token_loss(outputs, labels, SEQ_LENGTH)\n",
    "        \n",
    "        # Backpropagation\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        total_loss += loss.item()\n",
    "\n",
    "    avg_train_loss = total_loss / len(train_loader)\n",
    "    train_loss_tokens /= len(train_loader)\n",
    "    train_loss_tokens /= BATCH_SIZE\n",
    "    train_losses.append(avg_train_loss)\n",
    "    train_losses_tokens.append(train_loss_tokens.tolist())\n",
    "\n",
    "    # Validation\n",
    "    model.eval()\n",
    "    val_loss = 0.0\n",
    "    val_loss_tokens = np.zeros((SEQ_LENGTH, )) \n",
    "    with torch.no_grad():\n",
    "        for batch in val_loader:\n",
    "            input_ids = batch[\"input_ids\"].to(device)\n",
    "            attention_mask = batch[\"attention_mask\"].to(device)\n",
    "            labels = batch[\"labels\"].to(device)\n",
    "\n",
    "            outputs = model(\n",
    "                input_ids=input_ids,\n",
    "                attention_mask=attention_mask,\n",
    "                labels=labels\n",
    "            )\n",
    "            val_loss += outputs.loss.item()\n",
    "\n",
    "            # Calculate token-level validation loss\n",
    "            val_loss_tokens += get_token_loss(outputs, labels, SEQ_LENGTH)\n",
    "\n",
    "    avg_val_loss = val_loss / len(val_loader)\n",
    "    val_loss_tokens /= len(val_loader)\n",
    "    val_loss_tokens /= BATCH_SIZE\n",
    "    val_losses_tokens.append(val_loss_tokens.tolist())\n",
    "    val_losses.append(avg_val_loss)\n",
    "\n",
    "    val_accuracy, _ = evaluate_model(model, val_lines, token2id, id2token, device)\n",
    "    val_accuracies.append(val_accuracy)\n",
    "\n",
    "    print(f\"Epoch {epoch + 1} | \"\n",
    "          f\"Train Loss: {avg_train_loss:.4f} | \"\n",
    "          f\"Val Loss: {avg_val_loss:.4f} | \"\n",
    "          f\"Val Accuracy: {val_accuracy:.2%} | \"\n",
    "          f\"Val Loss Tokens: \" + \"-\".join([f\"{val_loss_tokens[i]:.4f}\" for i in range(len(val_loss_tokens))]) + \" | \"\n",
    "          f\"Train Loss Tokens: \" + \"-\".join([f\"{train_loss_tokens[i]:.4f}\" for i in range(len(train_loss_tokens))]) + \" | \")\n",
    "\n",
    "# Plot losses and accuracies\n",
    "epochs_range = range(1, NUM_EPOCHS+1)\n",
    "file_name = (\n",
    "    f\"Digit{DIGIT_RANGE.start}-{DIGIT_RANGE.stop}\"\n",
    "    f\"_Seq{SEQ_LENGTH}_Emb{EMBEDDING_DIM}_Split{SPLIT_METHOD}\"\n",
    "    f\"_Batch{BATCH_SIZE}_Epochs{NUM_EPOCHS}\"\n",
    ")"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "77e2ca43e0c590fa"
  },
  {
   "cell_type": "markdown",
   "source": [
    "### CoT2 Model, MNNS Task\n",
    "We will compare the final performance of this model against the discrete model."
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "3e2f2a4028e391c1"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "from torch.utils.data import Dataset\n",
    "\n",
    "class FourDigitsSoftDataset(Dataset):\n",
    "    \"\"\"\n",
    "    For each 4-digit sequence in [1..5], we:\n",
    "      1) Start from partial_sum=0.\n",
    "      2) For i in [1..4], we expand partial sums by +/- seq[i],\n",
    "         resulting in 2^i partial sums, each with uniform probability 1/(2^i).\n",
    "         Build a dist_vec for step i, setting dist_vec[S{ps}] = 1/(2^i) if \"S{ps}\" in vocab.\n",
    "      3) Among those 16 final sums, pick the min non-negative sum as final_hard_label = S{best_sum}.\n",
    "      4) If no non-negative sum is found, skip the sequence.\n",
    "    \"\"\"\n",
    "    def __init__(self, token2id, digit_range=range(1, 6), seq_length=4):\n",
    "        super().__init__()\n",
    "        self.token2id = token2id\n",
    "        self.vocab_size = len(token2id)\n",
    "        self.examples = []\n",
    "\n",
    "        for seq in itertools.product(digit_range, repeat=seq_length):\n",
    "            # We'll accumulate partial sums at each step,\n",
    "            # always starting from 0 for step 0.\n",
    "            partial_sums_at_step = []\n",
    "            current_partial_sums = [0]  # step 0\n",
    "            partial_sums_at_step.append(current_partial_sums)\n",
    "\n",
    "            # Build dist_steps (4 steps, each distribution over S{ps})\n",
    "            dist_steps = []\n",
    "\n",
    "            for i in range(seq_length):\n",
    "                # Expand to 2^i+1 partial sums\n",
    "                new_sums = []\n",
    "                digit = seq[i]\n",
    "                for ps in current_partial_sums:\n",
    "                    new_sums.append(ps + digit)\n",
    "                    new_sums.append(ps - digit)\n",
    "                current_partial_sums = new_sums\n",
    "                partial_sums_at_step.append(current_partial_sums)\n",
    "\n",
    "                # Now build a distribution vector: each sum has probability 1/2^(i+1)\n",
    "                dist_vec = torch.zeros(self.vocab_size)\n",
    "                prob = 1.0 / len(current_partial_sums)  # = 1/(2^(i+1))\n",
    "                for ps_val in current_partial_sums:\n",
    "                    key = f\"S{ps_val}\"\n",
    "                    if key in self.token2id:\n",
    "                        dist_vec[self.token2id[key]] += prob\n",
    "                # dist_steps.append(torch.sqrt(dist_vec))\n",
    "                dist_steps.append(dist_vec) \n",
    "\n",
    "            # current_partial_sums now has 16 final sums (2^4)\n",
    "            # pick smallest non-negative final sum\n",
    "            final_sums = current_partial_sums\n",
    "            best_sum = None\n",
    "            for candidate in sorted(final_sums):\n",
    "                if candidate >= 0:\n",
    "                    best_sum = candidate\n",
    "                    break\n",
    "\n",
    "            if best_sum is None:\n",
    "                # skip if no non-negative sum\n",
    "                continue\n",
    "\n",
    "            # final label => S{best_sum}\n",
    "            final_label_str = f\"S{best_sum}\"\n",
    "            if final_label_str not in self.token2id:\n",
    "                # skip if not in vocab\n",
    "                continue\n",
    "            final_label_id = self.token2id[final_label_str]\n",
    "\n",
    "            # Build the prompt: <BOS> + digits\n",
    "            prompt_tokens = [\"<BOS>\"] + [f\"D{d}\" for d in seq]\n",
    "            prompt_ids = []\n",
    "            for pt in prompt_tokens:\n",
    "                if pt in self.token2id:\n",
    "                    prompt_ids.append(self.token2id[pt])\n",
    "\n",
    "            ex = {\n",
    "                \"prompt_ids\": torch.tensor(prompt_ids, dtype=torch.long),\n",
    "                \"dist_steps\": dist_steps,  # 4 distributions\n",
    "                \"final_hard_label\": final_label_id\n",
    "            }\n",
    "            self.examples.append(ex)\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.examples)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.examples[idx]\n",
    "\n",
    "\n",
    "def collate_fn(batch):\n",
    "    max_len = max(len(ex[\"prompt_ids\"]) for ex in batch)\n",
    "    prompt_list = []\n",
    "    attn_list = []\n",
    "    dist_steps_list = []\n",
    "    final_labels_list = []\n",
    "\n",
    "    for ex in batch:\n",
    "        p = ex[\"prompt_ids\"]\n",
    "        pad_len = max_len - len(p)\n",
    "        padded = torch.cat([p, torch.full((pad_len,), 0, dtype=torch.long)])\n",
    "        attn = torch.cat([torch.ones(len(p)), torch.zeros(pad_len)])\n",
    "\n",
    "        prompt_list.append(padded.unsqueeze(0))\n",
    "        attn_list.append(attn.unsqueeze(0))\n",
    "        dist_steps_list.append(ex[\"dist_steps\"])\n",
    "        final_labels_list.append(ex[\"final_hard_label\"])\n",
    "\n",
    "    prompt_ids = torch.cat(prompt_list, dim=0)     # (B, max_len)\n",
    "    attention_mask = torch.cat(attn_list, dim=0)  # (B, max_len)\n",
    "\n",
    "    return {\n",
    "        \"prompt_ids\": prompt_ids,\n",
    "        \"attention_mask\": attention_mask,\n",
    "        \"dist_steps\": dist_steps_list,      # list of lists\n",
    "        \"final_labels\": final_labels_list\n",
    "    }"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "771bd897f04aaf07"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "from torch.utils.data import Subset\n",
    "from collections import defaultdict\n",
    "\n",
    "def permutation_train_val_split_continuous(dataset, id2token, seq_length, train_ratio=0.8):\n",
    "    \"\"\"\n",
    "    Ensures that all permutations of the same digit sequence go entirely into\n",
    "    train or val. We do this by grouping examples based on sorted digits in\n",
    "    their prompt, then performing a group-level random split.\n",
    "\n",
    "    Args:\n",
    "      dataset: A PyTorch Dataset whose __getitem__ returns a dict like:\n",
    "               {\n",
    "                 \"prompt_ids\": Tensor,  # shape [prompt_len]\n",
    "                 \"dist_steps\": ...,\n",
    "                 \"final_hard_label\": ...\n",
    "               }\n",
    "      id2token: mapping from token ID to string token, e.g. \"D5\"\n",
    "      seq_length: how many 'D' tokens in each prompt (e.g. 4)\n",
    "      train_ratio: fraction of groups to go to train (e.g. 0.8 => 80% train)\n",
    "\n",
    "    Returns:\n",
    "      train_data, val_data: Subset objects pointing to the train/val samples.\n",
    "    \"\"\"\n",
    "    # Build groups: canonical sorted digits -> list of example indices\n",
    "    groups = defaultdict(list)\n",
    "\n",
    "    for idx in range(len(dataset)):\n",
    "        ex = dataset[idx]     # e.g. {\"prompt_ids\": ..., \"dist_steps\":..., \"final_hard_label\":...}\n",
    "        prompt_ids = ex[\"prompt_ids\"]\n",
    "\n",
    "        # The digits are typically in the tokens from index 1..(1+seq_length) ignoring <BOS>.\n",
    "        digit_ids = prompt_ids[1: 1 + seq_length].tolist()\n",
    "        digit_strs = [id2token[d] for d in digit_ids]  # e.g. [\"D5\", \"D1\", ...]\n",
    "        digits = [int(s[1:]) for s in digit_strs]      # strip off the \"D\", e.g. [5, 1, ...]\n",
    "        canon_digits = tuple(sorted(digits))           # canonical form, e.g. (1,5,5,4)\n",
    "\n",
    "        groups[canon_digits].append(idx)\n",
    "\n",
    "    # Shuffle group keys\n",
    "    group_keys = list(groups.keys())\n",
    "    random.shuffle(group_keys)\n",
    "\n",
    "    # Split group keys based on train_ratio\n",
    "    split_idx = int(train_ratio * len(group_keys))\n",
    "    train_keys = group_keys[:split_idx]\n",
    "    val_keys = group_keys[split_idx:]\n",
    "\n",
    "    # Gather indices\n",
    "    train_indices = []\n",
    "    val_indices = []\n",
    "    for k in train_keys:\n",
    "        train_indices.extend(groups[k])\n",
    "    for k in val_keys:\n",
    "        val_indices.extend(groups[k])\n",
    "\n",
    "    # Create Subset objects for train and val\n",
    "    train_subset = Subset(dataset, train_indices)\n",
    "    val_subset = Subset(dataset, val_indices)\n",
    "\n",
    "    return train_subset, val_subset"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "31a3ff7c7a7915f4"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "import torch.nn.functional as F\n",
    "\n",
    "def cross_entropy_distribution_batch(logits, target_dist):\n",
    "    \"\"\"\n",
    "    Batch version\n",
    "    logits:  (B, vocab_size)\n",
    "    dist:    (B, vocab_size)  (already on same device as logits)\n",
    "    Returns (B,) => one scalar CE loss per example, same formula as above.\n",
    "    \"\"\"\n",
    "    EPS = 1e-8\n",
    "    log_probs = F.log_softmax(logits, dim=-1)  # (B, vocab_size)\n",
    "    return -(target_dist * log_probs).sum(dim=-1) + (target_dist*torch.log(target_dist + EPS)).sum(dim=-1)"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "ffa3c133d157c919"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "def train_soft_steps_batch(\n",
    "    model,\n",
    "    data_loader,\n",
    "    optimizer,\n",
    "    device,\n",
    "    loss_steps_to_include,\n",
    "    supervision_type,\n",
    "    seq_length,\n",
    "    num_soft_steps\n",
    "):\n",
    "    model.train()\n",
    "    total_loss = 0.0\n",
    "    steps = 0\n",
    "\n",
    "    ce_loss_fn = nn.CrossEntropyLoss()\n",
    "\n",
    "    # keep track of partial losses\n",
    "    train_loss_tokens = np.zeros(seq_length, dtype=np.float32)\n",
    "    embedding_matrix = model.transformer.wte.weight  # (vocab_size, n_embd)\n",
    "\n",
    "    for batch in data_loader:\n",
    "        # Move prompt and attention mask to device\n",
    "        prompt_ids = batch[\"prompt_ids\"].to(device)        # (B, max_len)\n",
    "        attention_mask = batch[\"attention_mask\"].to(device)  # (B, max_len)\n",
    "        final_labels = torch.tensor(batch[\"final_labels\"], device=device)  # (B,)\n",
    "\n",
    "        # Convert batch[\"dist_steps\"] from list-of-lists to a single Tensor\n",
    "        # shape => (B, T, vocab_size).\n",
    "        # batch[\"dist_steps\"] is a list of length B; each item is a list of T vectors\n",
    "        all_dist_steps = []\n",
    "        for ex_dists in batch[\"dist_steps\"]:\n",
    "            # ex_dists is e.g. [Tensor(vocab_size), Tensor(vocab_size), ...] or lists\n",
    "            ex_tensors = []\n",
    "            for d in ex_dists:\n",
    "                # If it's already a Tensor, just .to(device); if it's a list, convert\n",
    "                if isinstance(d, torch.Tensor):\n",
    "                    ex_tensors.append(d.to(device))\n",
    "                else:\n",
    "                    ex_tensors.append(torch.tensor(d, dtype=torch.float, device=device))\n",
    "            # Stack them: shape => (T, vocab_size)\n",
    "            stacked = torch.stack(ex_tensors, dim=0)\n",
    "            all_dist_steps.append(stacked)\n",
    "        # Now we have a Python list of length B, each shape (T, vocab_size).\n",
    "        # If T is the same for every example, we can stack directly:\n",
    "        dist_steps = torch.stack(all_dist_steps, dim=0)  # (B, T, vocab_size)\n",
    "        # If T differs across examples, you'd need pad_sequence instead:\n",
    "        # dist_steps = nn.utils.rnn.pad_sequence(all_dist_steps, batch_first=True)\n",
    "\n",
    "        batch_size = prompt_ids.size(0)\n",
    "        batch_loss = torch.zeros((), device=device)\n",
    "\n",
    "        # One forward pass for the entire batch to get \"past_key_values\"\n",
    "        outputs = model(input_ids=prompt_ids, attention_mask=attention_mask, use_cache=True)\n",
    "        past_key_values = outputs.past_key_values\n",
    "\n",
    "        # We'll accumulate partial losses in a small tensor\n",
    "        partial_losses = torch.zeros(num_soft_steps + 1, device=device)\n",
    "\n",
    "        # Loop over each \"soft step\" in parallel for the entire batch\n",
    "        for step_idx in range(num_soft_steps):\n",
    "            # Get last logits for the batch\n",
    "            last_logits = outputs.logits[:, -1, :]  # (B, vocab_size)\n",
    "\n",
    "            # Cross entropy wrt. teacher distribution at this step\n",
    "            dist_vec = dist_steps[:, step_idx, :]    # (B, vocab_size)\n",
    "            step_ce_vals = cross_entropy_distribution_batch(last_logits, dist_vec)\n",
    "            step_loss = step_ce_vals.mean()\n",
    "            partial_losses[step_idx] = step_loss.detach()\n",
    "\n",
    "            if str(step_idx) in loss_steps_to_include:\n",
    "                batch_loss += step_loss\n",
    "\n",
    "            # Build \"soft embedding\" for the entire batch\n",
    "            if supervision_type == \"soft_teacher\":\n",
    "                token_dist_vec = F.softmax(last_logits, dim=-1)  # (B, vocab_size)\n",
    "                e_soft = token_dist_vec @ embedding_matrix       # (B, n_embd)\n",
    "            else:\n",
    "                # \"hard_teacher\"\n",
    "                e_soft = dist_vec @ embedding_matrix             # (B, n_embd)\n",
    "\n",
    "            # Feed that embedding as the next token for all B examples\n",
    "            out2 = model(\n",
    "                inputs_embeds=e_soft.unsqueeze(1),  # (B, 1, n_embd)\n",
    "                past_key_values=past_key_values,\n",
    "                use_cache=True\n",
    "            )\n",
    "            outputs = out2\n",
    "            past_key_values = out2.past_key_values\n",
    "\n",
    "        # Final \"hard\" step => CE with final_label\n",
    "        last_logits = outputs.logits[:, -1, :]  # (B, vocab_size)\n",
    "        final_loss = ce_loss_fn(last_logits, final_labels)\n",
    "        partial_losses[-1] = final_loss.detach()\n",
    "\n",
    "        if \"h\" in loss_steps_to_include:\n",
    "            batch_loss += final_loss\n",
    "\n",
    "        # Backprop once for the entire batch\n",
    "        optimizer.zero_grad()\n",
    "        batch_loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        total_loss += batch_loss.item()\n",
    "        steps += 1\n",
    "\n",
    "    avg_loss = total_loss / max(1, steps)\n",
    "    return avg_loss, train_loss_tokens / max(1, steps)"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "89d94f94bc8138f5"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "def cross_entropy_distribution(logits, target_dist):\n",
    "    \"\"\"\n",
    "    logits: shape (vocab_size,)\n",
    "    target_dist: shape (vocab_size,)\n",
    "    Returns scalar: cross-entropy = - sum_{k} p(k) log softmax(logits)[k]\n",
    "    \"\"\"\n",
    "    EPS = 10**(-8)\n",
    "    log_probs = F.log_softmax(logits, dim=-1)\n",
    "    return - (target_dist * log_probs).sum() + (target_dist*torch.log(target_dist + EPS)).sum()\n",
    "\n",
    "@torch.no_grad()\n",
    "def eval_soft_steps_acc(model, data_loader, device, seq_length):\n",
    "    model.eval()\n",
    "    total_loss = 0.0\n",
    "    total_correct = 0\n",
    "    total_samples = 0\n",
    "    steps = 0\n",
    "\n",
    "    ce_loss_fn = nn.CrossEntropyLoss()\n",
    "    embedding_matrix = model.transformer.wte.weight\n",
    "    val_loss_tokens = np.zeros((seq_length,))\n",
    "\n",
    "    def get_last_logits(o):\n",
    "        return o.logits[:, -1, :]\n",
    "\n",
    "    for batch in data_loader:\n",
    "        prompt_ids = batch[\"prompt_ids\"].to(device)\n",
    "        attention_mask = batch[\"attention_mask\"].to(device)\n",
    "        dist_steps_list = batch[\"dist_steps\"]\n",
    "        final_labels = batch[\"final_labels\"]\n",
    "        batch_size = prompt_ids.size(0)\n",
    "\n",
    "        batch_loss = 0.0  # python float is fine; each step_loss will track grad separately\n",
    "        val_loss_tokens_batch = np.zeros((seq_length,))\n",
    "\n",
    "        for i in range(batch_size):\n",
    "            # Forward the prompt in discrete form\n",
    "            pi = prompt_ids[i].unsqueeze(0)\n",
    "            am = attention_mask[i].unsqueeze(0)\n",
    "\n",
    "            outputs = model(pi, attention_mask=am, use_cache=True)\n",
    "            pkv = outputs.past_key_values\n",
    "\n",
    "            item_loss = 0.0\n",
    "\n",
    "            # Distribution steps (1..4)\n",
    "            for count, dist_vec in enumerate(dist_steps_list[i]):\n",
    "                if count == len(dist_steps_list[i]) - 1:\n",
    "                    break\n",
    "                last_logits = get_last_logits(outputs).squeeze(0)  # shape (vocab_size,)\n",
    "                step_loss = cross_entropy_distribution(last_logits, dist_vec.to(device))\n",
    "                item_loss += step_loss.item()\n",
    "                val_loss_tokens_batch[count] += step_loss.item()\n",
    "\n",
    "                # build \"soft\" embedding => e_soft = sum_v dist_vec[v]*embedding_matrix[v]\n",
    "                token_dist_vec = F.softmax(last_logits, dim=-1)\n",
    "                e_soft = torch.matmul(token_dist_vec, embedding_matrix)  # shape (n_embd,)\n",
    "                out2 = model(\n",
    "                    inputs_embeds=e_soft.unsqueeze(0).unsqueeze(1),  # (1,1,n_embd)\n",
    "                    past_key_values=pkv,\n",
    "                    use_cache=True\n",
    "                )\n",
    "                pkv = out2.past_key_values\n",
    "                outputs = out2\n",
    "\n",
    "            # final step => measure CE to final_label, also do discrete \"prediction\" for accuracy\n",
    "            last_logits = get_last_logits(outputs).squeeze(0)  # (vocab_size,)\n",
    "            final_label_id = final_labels[i]\n",
    "\n",
    "            # we can do cross-entropy with the final label\n",
    "            final_loss = ce_loss_fn(last_logits.unsqueeze(0), torch.tensor([final_label_id], device=device))\n",
    "            item_loss += final_loss.item()\n",
    "            val_loss_tokens_batch[-1] += final_loss.item()\n",
    "\n",
    "            # for accuracy, pick argmax\n",
    "            predicted_id = last_logits.argmax(dim=-1).item()\n",
    "            if predicted_id == final_label_id:\n",
    "                total_correct += 1\n",
    "\n",
    "            batch_loss += item_loss\n",
    "\n",
    "        # average over batch\n",
    "        batch_loss /= batch_size\n",
    "        val_loss_tokens_batch /= batch_size\n",
    "        val_loss_tokens += val_loss_tokens_batch\n",
    "        total_loss += batch_loss\n",
    "        total_samples += batch_size\n",
    "        steps += 1\n",
    "\n",
    "    avg_loss = total_loss / steps\n",
    "    val_loss_tokens /= steps\n",
    "    accuracy = total_correct / total_samples\n",
    "    return avg_loss, accuracy, val_loss_tokens"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "d67b301fe097b97"
  },
  {
   "cell_type": "markdown",
   "source": [
    "#### Train for 1000 epochs on MNNS task, we'll compare final validation accuracies."
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "e013e29fbea1661e"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "LOSS_STEPS = [\"0\", \"1\", \"2\", \"h\"]  # 3 soft steps + \"h\" for final hard label\n",
    "SUPERVISION = \"hard_teacher\"\n",
    "\n",
    "print(f\"Max Seq Len: {MAX_SEQ_LEN}, \"\n",
    "      f\"Embedding Dim: {EMBEDDING_DIM}, \"\n",
    "      f\"Digit Range: {DIGIT_RANGE}, \"\n",
    "      f\"Batch Size: {BATCH_SIZE}, \"\n",
    "      f\"Seq Length: {SEQ_LENGTH}, \"\n",
    "      f\"Loss Steps: {LOSS_STEPS}, \"\n",
    "      f\"Split Method: {SPLIT_METHOD}, \"\n",
    "      f\"Num Epochs: {NUM_EPOCHS}, \"\n",
    "      f\"Output Dir: {OUTPUT_DIR}, \"\n",
    "      f\"Supervision: {SUPERVISION}, \"\n",
    "      f\"Num Layers: {NUM_LAYERS}, \"\n",
    "      f\"Num Heads: {NUM_HEADS}\")\n",
    "\n",
    "# Build vocab & model\n",
    "vocab, token2id, id2token = generate_vocab()\n",
    "vocab_size = len(vocab)\n",
    "print(\"Vocab size =\", vocab_size)\n",
    "\n",
    "config = GPT2Config(\n",
    "    vocab_size=vocab_size,\n",
    "    n_positions=MAX_SEQ_LEN,\n",
    "    n_embd=EMBEDDING_DIM,\n",
    "    n_layer=NUM_LAYERS,\n",
    "    n_head=NUM_HEADS\n",
    ")\n",
    "model = GPT2LMHeadModel(config)\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "model.to(device)\n",
    "\n",
    "# Create the dataset => 4 digits as prompt + 4 distribution steps + final label\n",
    "dataset = FourDigitsSoftDataset(token2id=token2id, digit_range=DIGIT_RANGE, seq_length=SEQ_LENGTH)\n",
    "\n",
    "train_data, val_data = permutation_train_val_split_continuous(\n",
    "    dataset=dataset,\n",
    "    id2token=id2token,\n",
    "    seq_length=SEQ_LENGTH,\n",
    "    train_ratio=0.8\n",
    ")\n",
    "\n",
    "train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)\n",
    "val_loader = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)\n",
    "\n",
    "# Optimizer\n",
    "optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)\n",
    "\n",
    "# Training\n",
    "train_losses = []\n",
    "val_losses = []\n",
    "val_accuracies = []\n",
    "train_losses_tokens = []\n",
    "val_losses_tokens = []\n",
    "for epoch in range(NUM_EPOCHS):\n",
    "    train_loss, train_loss_tokens = train_soft_steps_batch(\n",
    "        model,\n",
    "        train_loader,\n",
    "        optimizer,\n",
    "        device,\n",
    "        LOSS_STEPS,\n",
    "        SUPERVISION,\n",
    "        SEQ_LENGTH,\n",
    "        num_soft_steps=SEQ_LENGTH - 1\n",
    "    )\n",
    "    val_loss, val_acc, val_loss_tokens = eval_soft_steps_acc(model, val_loader, device, SEQ_LENGTH)\n",
    "    train_losses.append(train_loss)\n",
    "    val_losses.append(val_loss)\n",
    "    val_accuracies.append(val_acc)\n",
    "    train_losses_tokens.append(train_loss_tokens)\n",
    "    val_losses_tokens.append(val_loss_tokens)\n",
    "    print(f\"Epoch {epoch + 1} | \"\n",
    "          f\"Train Loss: {train_loss:.4f} | \"\n",
    "          f\"Val Loss: {val_loss:.4f} | \"\n",
    "          f\"Val Accuracy: {val_acc:.2%} | \"\n",
    "          f\"Val Loss Tokens: \" + \"-\".join([f\"{val_loss_tokens[i]:.4f}\" for i in range(len(val_loss_tokens))]) + \" | \")\n",
    "\n",
    "# Plot losses & accuracies\n",
    "epochs_range = range(1, NUM_EPOCHS + 1)\n",
    "file_name = (\n",
    "    f\"Digit{DIGIT_RANGE.start}-{DIGIT_RANGE.stop}\" + \"_seq\" + str(SEQ_LENGTH) + \"_emb\" +\n",
    "    str(EMBEDDING_DIM) + \"_steps\" + \"\".join(LOSS_STEPS) + \"_split\" + SPLIT_METHOD +\n",
    "    \"_batch\" + str(BATCH_SIZE) + \"_epochs\" + str(NUM_EPOCHS)\n",
    ")\n",
    "\n",
    "# save the model\n",
    "torch.save(model.state_dict(), f\"models/{OUTPUT_DIR}continuous_model_{file_name}.pt\")"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "ff5a0b47eb90703b"
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false
   },
   "id": "470655ef0ea0acca"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
