{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%env TOKENIZERS_PARALLELISM=false"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from freegroup.tools import from_string, flatten, normalize\n",
    "\n",
    "fdim = 4\n",
    "max_length = 400\n",
    "\n",
    "representatives = {\n",
    "    3: '[[x, y], [x, yz]]',\n",
    "    4: '[[[x, y], [x, yz]], [[x, y], [x, yzp]]]'\n",
    "}\n",
    "\n",
    "singles = [\n",
    "    int(''.join(['0' if i == idx else '1' for i in range(fdim + 1)]), 2)\n",
    "    for idx in range(fdim + 1)\n",
    "]\n",
    "\n",
    "total_steps = int(5e5)\n",
    "sched_steps = 200\n",
    "\n",
    "config = {\n",
    "    'freegroup_dimension': fdim,\n",
    "    'seed': 42,\n",
    "    'notebook_name': 'train.ipynb',\n",
    "    'device': 'cuda:1',\n",
    "    \n",
    "    'method': 'mask',\n",
    "    \n",
    "    'tokenizer': {\n",
    "        'pretrained_model_name_or_path': f'tokenizer/word-level-tokenizer-{fdim}',\n",
    "    },\n",
    "    \n",
    "    'model': {\n",
    "        'model_type': 'gpt2',\n",
    "        'n_positions': 1024,\n",
    "        'n_embd': 10 * 12,\n",
    "        'n_layer': 12,\n",
    "        'n_head': 10,\n",
    "    },\n",
    "    \n",
    "    'train': {\n",
    "        'batch_size': 16,\n",
    "        \n",
    "        'steps': total_steps,\n",
    "        'log_steps': 400,\n",
    "        'save_steps': 25000,\n",
    "        \n",
    "        'optimizer': {\n",
    "            'name': 'AdamW',\n",
    "            'args': {\n",
    "                'lr': 1e-5,\n",
    "            },\n",
    "        },\n",
    "        \n",
    "        'scheduler': {\n",
    "            'name': 'Linear',\n",
    "            'args': {\n",
    "                'start_factor': 1.,\n",
    "                'end_factor': 0.5,\n",
    "                'total_iters': total_steps // sched_steps,\n",
    "            },\n",
    "            'steps': sched_steps,\n",
    "        },\n",
    "        \n",
    "        'dataset': {\n",
    "            'dist': {\n",
    "                'name': 'incomplete-intersection',\n",
    "                'args': {\n",
    "                    'freegroup_dimension': fdim,\n",
    "                    'zero_closure_parameters': {\n",
    "                        'method': 'brackets',\n",
    "                        'depth_method': 'u',\n",
    "                        'depth_parameters': {'radius': 10},\n",
    "                    },\n",
    "                    'non_zero_closure_parameters': {\n",
    "                        'method': 'brackets',\n",
    "                        'depth_method': 'u',\n",
    "                        'depth_parameters': {'radius': 30},\n",
    "                    },\n",
    "                    'max_freegroups': 1,\n",
    "                    'freegroup_parameters': {\n",
    "                        'length_method': 'u',\n",
    "                        'length_parameters': {'radius': 10},\n",
    "                    },\n",
    "                    'total_max_length': max_length,\n",
    "                    'probas': {k: 1 for k in singles},\n",
    "                },\n",
    "            },\n",
    "            'type': {\n",
    "                'name': 'refillable',\n",
    "                'args': {\n",
    "                    'size': 64 * 100,\n",
    "                    'max_calls': 3,\n",
    "                },\n",
    "            },\n",
    "        },\n",
    "    },\n",
    "    \n",
    "    'eval': {\n",
    "        'steps': 2000,\n",
    "        'batch_size': 16,\n",
    "        'dataset': {\n",
    "            'validation': {\n",
    "                'dist': {\n",
    "                    'name': 'incomplete-intersection',\n",
    "                    'args': {\n",
    "                        'freegroup_dimension': fdim,\n",
    "                        'zero_closure_parameters': {\n",
    "                            'method': 'brackets',\n",
    "                            'depth_method': 'u',\n",
    "                            'depth_parameters': {'radius': 10},\n",
    "                        },\n",
    "                        'non_zero_closure_parameters': {\n",
    "                            'method': 'brackets',\n",
    "                            'depth_method': 'u',\n",
    "                            'depth_parameters': {'radius': 30},\n",
    "                        },\n",
    "                        'max_freegroups': 1,\n",
    "                        'freegroup_parameters': {\n",
    "                            'length_method': 'u',\n",
    "                            'length_parameters': {'radius': 10},\n",
    "                        },\n",
    "                        'total_max_length': max_length,\n",
    "                        'probas': {k: 1 for k in singles},\n",
    "                    },\n",
    "                },\n",
    "                'type': {\n",
    "                    'name': 'fixed',\n",
    "                    'args': {'size': 64 * 10},\n",
    "                },\n",
    "            },\n",
    "            'trivial': {\n",
    "                'dist': {\n",
    "                    'name': 'complete-intersection',\n",
    "                    'args': {\n",
    "                        'freegroup_dimension': fdim,\n",
    "                        'zero_closure_parameters': {\n",
    "                            'method': 'brackets',\n",
    "                            'depth_method': 'u',\n",
    "                            'depth_parameters': {'radius': 8},\n",
    "                        },\n",
    "                        'non_zero_closure_parameters': {\n",
    "                            'method': 'brackets',\n",
    "                            'depth_method': 'u',\n",
    "                            'depth_parameters': {'radius': 15},\n",
    "                        },\n",
    "                        'max_freegroups': 1,\n",
    "                        'freegroup_parameters': {\n",
    "                            'length_method': 'u',\n",
    "                            'length_parameters': {'radius': 10},\n",
    "                        },\n",
    "                        'total_max_length': max_length,\n",
    "                        'max_multipliers': 1,\n",
    "                    },\n",
    "                },\n",
    "                'type': {\n",
    "                    'name': 'fixed',\n",
    "                    'args': {'size': 64 * 10},\n",
    "                },\n",
    "            },\n",
    "            'non-trivial': {\n",
    "                'dist': {\n",
    "                    'name': 'generator-permutations',\n",
    "                    'args': {\n",
    "                        'freegroup_dimension': fdim, \n",
    "                        'word': normalize(flatten(from_string(representatives[fdim], method = 'lu')))\n",
    "                    },\n",
    "                },\n",
    "                'type': {\n",
    "                    'name': 'fixed',\n",
    "                    'args': {'size': fdim},\n",
    "                },\n",
    "            },\n",
    "        },\n",
    "    },\n",
    "    \n",
    "    'gen': {\n",
    "        'steps': 2500,\n",
    "        'batch_size': 16,\n",
    "        'dataset': {\n",
    "            'prefix_5': {\n",
    "                'dist': {\n",
    "                    'name': 'freegroup',\n",
    "                    'args': {\n",
    "                        'freegroup_dimension': fdim,\n",
    "                        'length_method': 'c',\n",
    "                        'length_parameters': {'radius': 5},\n",
    "                    },\n",
    "                },\n",
    "                'type': {\n",
    "                    'name': 'refillable',\n",
    "                    'args': {\n",
    "                        'size': 64 * 5,\n",
    "                        'max_calls': 1,\n",
    "                    },\n",
    "                },\n",
    "            },\n",
    "            'prefix_7': {\n",
    "                'dist': {\n",
    "                    'name': 'freegroup',\n",
    "                    'args': {\n",
    "                        'freegroup_dimension': fdim,\n",
    "                         'length_method': 'c',\n",
    "                        'length_parameters': {'radius': 7},\n",
    "                    },\n",
    "                },\n",
    "                'type': {\n",
    "                    'name': 'refillable',\n",
    "                    'args': {\n",
    "                        'size': 64 * 5,\n",
    "                        'max_calls': 1,\n",
    "                    },\n",
    "                },\n",
    "            },\n",
    "            'prefix_10': {\n",
    "                'dist': {\n",
    "                    'name': 'freegroup',\n",
    "                    'args': {\n",
    "                        'freegroup_dimension': fdim,\n",
    "                        'length_method': 'c',\n",
    "                        'length_parameters': {'radius': 10},\n",
    "                    },\n",
    "                },\n",
    "                'type': {\n",
    "                    'name': 'refillable',\n",
    "                    'args': {\n",
    "                        'size': 64 * 5,\n",
    "                        'max_calls': 1,\n",
    "                    },\n",
    "                },\n",
    "            },\n",
    "        },\n",
    "        'methods': {\n",
    "            'beam': {\n",
    "                'num_beams': 5,\n",
    "                'num_return_sequences': 5,\n",
    "                'max_length': max_length,\n",
    "                'repetition_penalty': 1.2,\n",
    "            },\n",
    "            'sample': {\n",
    "                'do_sample': True,\n",
    "                'num_return_sequences': 5,\n",
    "                'max_length': max_length,\n",
    "                'top_p': 0.9,\n",
    "            },\n",
    "        },\n",
    "    },\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# DATASET"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from freegroup.tools import (\n",
    "    is_from_singleton_normal_closure, wu_closure,\n",
    "    flatten, normalize, Mult\n",
    ")\n",
    "from freegroup.sampling import (\n",
    "    normal_closure, freegroup,\n",
    "    random_tree\n",
    ")\n",
    "from iteration_utilities import repeatfunc, unique_everseen\n",
    "from itertools import islice\n",
    "from tqdm.notebook import tqdm\n",
    "from copy import deepcopy\n",
    "from numpy.random import choice, shuffle, randint\n",
    "\n",
    "from contextlib import contextmanager\n",
    "from multiprocess import Pool\n",
    "from torch.utils.data import Dataset\n",
    "\n",
    "def compute_multi_label(word, freegroup_dimension):\n",
    "    return [is_from_singleton_normal_closure(word, wu_closure(freegroup_dimension, idx)) for idx in range(0, freegroup_dimension + 1)]\n",
    "\n",
    "def sample_dataset(\n",
    "    dist, size, dist_kwargs = {},\n",
    "    unique = True, unique_key = None,\n",
    "    progress = True, tqdm_kwargs = {}\n",
    "):\n",
    "    iterator = dist(**dist_kwargs)\n",
    "    \n",
    "    key = lambda x: tuple(x['word']) if unique_key is None else unique_key\n",
    "    if unique: iterator = unique_everseen(iterator, key = key)\n",
    "        \n",
    "    iterator = islice(iterator, size)\n",
    "    \n",
    "    if progress: iterator = tqdm(iterator, total = size, **tqdm_kwargs)\n",
    "    \n",
    "    return list(iterator)\n",
    "\n",
    "class RefillableDataset(Dataset):\n",
    "    def __init__(self, dist, dist_kwargs, size, max_calls):\n",
    "        self.size = size\n",
    "        self.calls, self.max_calls = 0, max_calls\n",
    "        \n",
    "        self.pool = Pool(1)\n",
    "        \n",
    "        self.next_batch_fn = lambda: self.pool.apply_async(\n",
    "            sample_dataset,\n",
    "            kwds = {\n",
    "                'dist': dist, 'dist_kwargs': dist_kwargs,\n",
    "                'size': size, 'progress': False,\n",
    "            })\n",
    "        \n",
    "        self.curr_batch = None\n",
    "        self.next_batch = self.next_batch_fn()\n",
    "        \n",
    "    def __len__(self): return self.max_calls * self.size\n",
    "    \n",
    "    def __getitem__(self, idx):\n",
    "        self.calls += 1\n",
    "        if self.curr_batch is None or self.calls >= self.max_calls * len(self.curr_batch):\n",
    "            self.curr_batch = self.next_batch.get()\n",
    "            self.next_batch = self.next_batch_fn()\n",
    "            self.calls = 0\n",
    "            \n",
    "        return self.curr_batch[idx % self.size]\n",
    "    \n",
    "    def stop(self): self.pool.terminate()\n",
    "        \n",
    "@contextmanager\n",
    "def refillable_dataset(dist, dist_kwargs, size, max_calls):\n",
    "    dataset = RefillableDataset(dataset_fn, preprocess_fn, dataset_config, size, max_calls)\n",
    "    yield dataset\n",
    "    dataset.stop()\n",
    "    \n",
    "def incomplete_intersection_dist(\n",
    "    freegroup_dimension = 3,\n",
    "    zero_closure_parameters = {'method': 'brackets', 'depth_method': 'c', 'depth_parameters': {'radius': 5}},\n",
    "    non_zero_closure_parameters = {'method': 'brackets', 'depth_method': 'c', 'depth_parameters': {'radius': 30}},\n",
    "    max_freegroups = 1,\n",
    "    freegroup_parameters = {'length_method': 'c', 'length_parameters': {'radius': 5}},\n",
    "    max_multipliers = 1,\n",
    "    probas = {k: 1 for k in range(1, 2 ** 4 - 1)},\n",
    "    total_max_length = 200,\n",
    "):\n",
    "    \n",
    "    probas = deepcopy(probas)\n",
    "    for k in range(1, 2 ** freegroup_dimension - 1):\n",
    "        if not k in probas: probas[k] = 0.\n",
    "    _sum = sum(probas.values())\n",
    "    for k, v in probas.items():\n",
    "        probas[k] /= _sum\n",
    "    \n",
    "    multi_labels = list(probas.keys())\n",
    "    probas = list(probas.values())\n",
    "    \n",
    "    def _init():\n",
    "        leaves = []\n",
    "        multi_label = choice(multi_labels, p = probas)\n",
    "        for idx in range(freegroup_dimension + 1):\n",
    "            if multi_label & (1 << idx) > 0:\n",
    "                leaves.append(normal_closure(\n",
    "                    freegroup_dimension = freegroup_dimension,\n",
    "                    closure = wu_closure(freegroup_dimension, idx),\n",
    "                    **(zero_closure_parameters if idx == 0 else non_zero_closure_parameters),\n",
    "                ))\n",
    "        for _ in range(randint(low = 0, high = max_freegroups)):\n",
    "            leaves.append(freegroup(\n",
    "                freegroup_dimension = freegroup_dimension,\n",
    "                **freegroups_parameters,\n",
    "            ))\n",
    "            \n",
    "        shuffle(leaves)\n",
    "        return random_tree(leaves)\n",
    "    \n",
    "    def init(): return normalize(flatten(Mult([_init() for _ in range(randint(low = 1, high = max_multipliers + 1))])))\n",
    "    \n",
    "    def features(word):\n",
    "        return {\n",
    "            'word': word[::],\n",
    "            'multi_label': compute_multi_label(word, freegroup_dimension)\n",
    "        }\n",
    "        \n",
    "    def condition(entry):\n",
    "        if 0 == len(entry['word']) or len(entry['word']) >= total_max_length:\n",
    "            return False\n",
    "        \n",
    "        if sum(entry['multi_label']) >= freegroup_dimension + 1:\n",
    "            return False\n",
    "        return True\n",
    "\n",
    "    return filter(condition, map(features, repeatfunc(init)))\n",
    "\n",
    "\n",
    "\n",
    "def complete_intersection_dist(\n",
    "    freegroup_dimension = 3,\n",
    "    zero_closure_parameters = {'method': 'brackets', 'depth_method': 'c', 'depth_parameters': {'radius': 5}},\n",
    "    non_zero_closure_parameters = {'method': 'brackets', 'depth_method': 'c', 'depth_parameters': {'radius': 30}},\n",
    "    max_freegroups = 1,\n",
    "    freegroup_parameters = {'length_method': 'c', 'length_parameters': {'radius': 5}},\n",
    "    max_multipliers = 1,\n",
    "    total_max_length = 200,\n",
    "):\n",
    "    def _init():\n",
    "        leaves = []\n",
    "        for idx in range(freegroup_dimension + 1):\n",
    "            leaves.append(normal_closure(\n",
    "                freegroup_dimension = freegroup_dimension,\n",
    "                closure = wu_closure(freegroup_dimension, idx),\n",
    "                **(zero_closure_parameters if idx == 0 else non_zero_closure_parameters),\n",
    "            ))\n",
    "        for _ in range(randint(low = 0, high = max_freegroups)):\n",
    "            leaves.append(freegroup(\n",
    "                freegroup_dimension = freegroup_dimension,\n",
    "                **freegroups_parameters,\n",
    "            ))\n",
    "            \n",
    "        shuffle(leaves)\n",
    "        return random_tree(leaves)\n",
    "    \n",
    "    def init(): return normalize(flatten(Mult([_init() for _ in range(randint(low = 1, high = max_multipliers + 1))])))\n",
    "    \n",
    "    def features(word):\n",
    "        return {\n",
    "            'word': word[::],\n",
    "            'multi_label': compute_multi_label(word, freegroup_dimension)\n",
    "        }\n",
    "        \n",
    "    def condition(entry):\n",
    "        if 0 == len(entry['word']) or len(entry['word']) >= total_max_length:\n",
    "            return False\n",
    "        \n",
    "        if sum(entry['multi_label']) != freegroup_dimension + 1:\n",
    "            return False\n",
    "        return True\n",
    "\n",
    "    return filter(condition, map(features, repeatfunc(init)))\n",
    "\n",
    "\n",
    "def generator_permutations_dist(freegroup_dimension, word):\n",
    "    for s in range(freegroup_dimension):\n",
    "        yield {\n",
    "            'word': [-1 if f < 0 else 1 * (1 + (abs(f) - 1 + s) % freegroup_dimension) for f in word],\n",
    "            'multi_label': [1] * (freegroup_dimension + 1),\n",
    "        }\n",
    "        \n",
    "def wrapper_dist(freegroup_dimension, word_generator, word_generator_kwargs):\n",
    "    return map(lambda x: {'word': x, 'multi_label': compute_multi_label(x, freegroup_dimension)}, word_generator(**word_generator_kwargs))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "\n",
    "visualize = sample_dataset(\n",
    "    incomplete_intersection_dist, size = 1000,\n",
    "    dist_kwargs = config['train']['dataset']['dist']['args']\n",
    ")\n",
    "\n",
    "words = list(map(lambda x: x['word'], visualize))\n",
    "masks = list(map(lambda x: x['multi_label'], visualize))\n",
    "\n",
    "plt.hist(list(map(len, words)))\n",
    "plt.show()\n",
    "\n",
    "visualize = sample_dataset(\n",
    "    complete_intersection_dist, size = 1000,\n",
    "    dist_kwargs = config['eval']['dataset']['trivial']['dist']['args'],\n",
    ")\n",
    "\n",
    "words = list(map(lambda x: x['word'], visualize))\n",
    "masks = list(map(lambda x: x['multi_label'], visualize))\n",
    "\n",
    "plt.hist(list(map(len, words)))\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TRAIN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(**config['tokenizer'])\n",
    "\n",
    "model = AutoModelForCausalLM.from_config(AutoConfig.for_model(\n",
    "        bos_token_id = tokenizer.bos_token_id,\n",
    "        eos_token_id = tokenizer.eos_token_id,\n",
    "        pad_token_id = tokenizer.pad_token_id,\n",
    "        **config['model']\n",
    "))\n",
    "\n",
    "sum(p.numel() for p in model.parameters() if p.requires_grad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import List, Dict, Any\n",
    "from freegroup.tools import batch_to_string\n",
    "from torch import tensor\n",
    "\n",
    "def data_collator(\n",
    "    model,\n",
    "    tokenizer,\n",
    "    freegroup_dimension,\n",
    "    method = 'ignore', # choice ['ignore', 'prompt', 'mask']\n",
    "    mode = 'predict', # choice: ['predict', 'generate']\n",
    "):\n",
    "    y, n, sep = tokenizer.additional_special_tokens\n",
    "    \n",
    "    def prompt(multi_label):\n",
    "        return ' '.join([y if f else n for f in multi_label] + [sep])\n",
    "    \n",
    "    def predict_collate_fn(batch: List[Dict[str, Any]]):\n",
    "        words = batch_to_string(list(map(lambda x: x['word'], batch)))\n",
    "        multi_labels = list(map(lambda x: x['multi_label'], batch))\n",
    "        \n",
    "        if method == 'prompt':\n",
    "            words = [prompt(label) + w for w, label in zip(words, multi_labels)]\n",
    "                        \n",
    "        inputs = tokenizer(words, padding = True, return_tensors = 'pt')\n",
    "        \n",
    "        input_ids = inputs.input_ids.clone()\n",
    "        input_ids = input_ids[:, :-1]\n",
    "        \n",
    "        attention_mask = inputs.attention_mask.clone()\n",
    "        attention_mask = attention_mask[:, :-1]\n",
    "        \n",
    "        labels = inputs.input_ids.clone()\n",
    "        labels[labels == tokenizer.pad_token_id] = -100\n",
    "        labels = labels[:, 1:]\n",
    "        \n",
    "        if method in ['ignore', 'prompt'] :\n",
    "            return {\n",
    "                'input_ids': input_ids,\n",
    "                'labels': labels,\n",
    "                'attention_mask': attention_mask,\n",
    "            }\n",
    "                \n",
    "        if method == 'mask':\n",
    "            head_mask = tensor(multi_labels)\n",
    "            # must be [num_layers, batch, num_heads, seq_len, seq_len]\n",
    "            head_mask = head_mask.unsqueeze(0)\\\n",
    "                    .unsqueeze(-1).unsqueeze(-1)\\\n",
    "                    .repeat(1, 1, model.config.n_head // (freegroup_dimension + 1), 1, 1)\\\n",
    "                    .expand(model.config.n_layer, -1, -1, -1, -1)\\\n",
    "                    .clone()\n",
    "                        \n",
    "            return {\n",
    "                'input_ids': input_ids,\n",
    "                'attention_mask': attention_mask,\n",
    "                'labels': labels,\n",
    "                'head_mask': head_mask,\n",
    "            }\n",
    "    \n",
    "    \n",
    "    def generate_collate_fn(batch: List[Dict[str, Any]]):\n",
    "        words = list(map(lambda x: x['word'], batch))\n",
    "        words = batch_to_string(words)\n",
    "        \n",
    "        if method == 'prompt':\n",
    "            words = [prompt([1] * (freegroup_dimension + 1)) + w for w, label in zip(words, multi_labels)]\n",
    "        \n",
    "        inputs = tokenizer(words, padding = True, return_tensors = 'pt')\n",
    "        \n",
    "        return {\n",
    "            'inputs': inputs.input_ids[:, :-1] # exclude `eos` token\n",
    "        }\n",
    "    \n",
    "    return predict_collate_fn if mode == 'predict' else generate_collate_fn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Dict\n",
    "from transformers import LogitsProcessor\n",
    "from torch import tensor, inf\n",
    "\n",
    "class NoLastTokenReductionProcessor(LogitsProcessor):\n",
    "    def __init__(self, reciprocal_tokens: Dict[int, int]):\n",
    "        self.reciprocal_tokens = reciprocal_tokens\n",
    "\n",
    "    @staticmethod\n",
    "    def from_tokenizer(freegroup_dimension, tokenizer):\n",
    "        from itertools import chain\n",
    "        \n",
    "        reciprocal_tokens = dict()\n",
    "        for x in chain(range(-freegroup_dimension, 0), range(1, freegroup_dimension + 1)):\n",
    "            idx, _idx = tokenizer.convert_tokens_to_ids([str(x), str(-x)])\n",
    "            reciprocal_tokens[idx] = _idx\n",
    "\n",
    "        return NoLastTokenReductionProcessor(reciprocal_tokens)\n",
    "\n",
    "    def __call__(self, input_ids, scores):\n",
    "        \n",
    "        id_with_reciprocals = [\n",
    "            (i, self.reciprocal_tokens[token_idx.item()]) if token_idx.item() in self.reciprocal_tokens else None\n",
    "            for i, token_idx in enumerate(input_ids[:, -1])\n",
    "        ]\n",
    "        id_with_reciprocals = filter(lambda x: not x is None, id_with_reciprocals)\n",
    "\n",
    "\n",
    "        try:    \n",
    "            batch_ids, last_reciprocal_ids = map(\n",
    "                lambda x: tensor(list(x), dtype = int, device = scores.device),\n",
    "                zip(*id_with_reciprocals)\n",
    "            )\n",
    "            scores[batch_ids, last_reciprocal_ids] = -inf\n",
    "        except ValueError:\n",
    "            pass\n",
    "        \n",
    "        return scores    \n",
    "    \n",
    "# https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/generation/logits_process.py#L892\n",
    "class SuppressTokensLogitsProcessor(LogitsProcessor):\n",
    "    def __init__(self, tokens):\n",
    "        self.tokens = list(tokens)\n",
    "        \n",
    "    def __call__(self, input_ids, scores):\n",
    "        scores[:, self.tokens] = -float('inf')\n",
    "        return scores\n",
    "    \n",
    "\n",
    "from freegroup.tools import (\n",
    "    batch_normalize, batch_is_from_singleton_normal_closure, wu_closure,\n",
    "    batch_reduce_modulo_singleton_normal_closure, batch_to_string,\n",
    "    batch_from_string,\n",
    ")\n",
    "\n",
    "from freegroup.sampling import (\n",
    "    freegroup_generator\n",
    ")\n",
    "from copy import deepcopy\n",
    "from itertools import islice\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "def completion_ratio(outputs, references, freegroup_dimension):\n",
    "    per_reference = len(outputs) // len(references)\n",
    "    is_from_closures = [batch_is_from_singleton_normal_closure(outputs, wu_closure(freegroup_dimension, idx)) for idx in range(freegroup_dimension + 1)]\n",
    "    is_from_closures = np.array(is_from_closures)\n",
    "    is_from_closures = np.transpose(is_from_closures)\n",
    "    is_from_closures = np.all(is_from_closures, axis=-1)\n",
    "    is_from_closures = is_from_closures & (np.array(list(map(len, outputs))) > 0)\n",
    "    is_from_closures = is_from_closures.reshape(len(references), -1)\n",
    "    is_from_closures = is_from_closures.any(axis = -1)\n",
    "    return {'completion_ratio': is_from_closures.mean()}\n",
    "\n",
    "def reduction_ratio(outputs, references, freegroup_dimension):\n",
    "    closures = \\\n",
    "        [\n",
    "            batch_reduce_modulo_singleton_normal_closure(outputs, wu_closure(freegroup_dimension, idx))\n",
    "            for idx in range(freegroup_dimension + 1)\n",
    "        ]\n",
    "    \n",
    "    closures_lenghts = np.array(list(map(lambda cls: list(map(len, cls)), closures)))\n",
    "    lengths = np.array(list(map(len, outputs)))\n",
    "    \n",
    "    reduced = (lengths[None, :] - closures_lenghts) / lengths[None, :]\n",
    "    reduced = np.nan_to_num(reduced)\n",
    "    reduced = reduced.mean(axis=1)\n",
    "    metrics = dict()\n",
    "    for idx in range(freegroup_dimension + 1):\n",
    "        metrics[f'reduction_ratio_{idx}'] = reduced[idx]\n",
    "    metrics[f'reduction_ratio'] = reduced.mean()\n",
    "      \n",
    "    return metrics\n",
    "        \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import display, Javascript\n",
    "display(Javascript('IPython.notebook.save_checkpoint();'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "import wandb\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from os.path import exists\n",
    "from os import makedirs\n",
    "from shutil import rmtree\n",
    "\n",
    "from freegroup.sampling import freegroup_generator\n",
    "from freegroup.tools import batch_from_string\n",
    "\n",
    "from transformers import TopPLogitsWarper, RepetitionPenaltyLogitsProcessor, LogitsProcessorList\n",
    "from torch.optim import AdamW\n",
    "from torch.optim.lr_scheduler import LinearLR\n",
    "from torch.utils.data import DataLoader\n",
    "from torch import no_grad, save\n",
    "from transformers import set_seed\n",
    "\n",
    "set_seed(config['seed'])\n",
    "\n",
    "metrics, report = {}, {}\n",
    "def add_metric(name, value):\n",
    "    if not name in metrics:\n",
    "        metrics[name] = []\n",
    "    metrics[name].append(value)\n",
    "    \n",
    "def report_metric(name):\n",
    "    report[f'{name}_mean'] = np.mean(metrics[name])\n",
    "    report[f'{name}_max'] = np.max(metrics[name])\n",
    "    report[f'{name}_min'] = np.min(metrics[name])\n",
    "    metrics[name] = []\n",
    "    \n",
    "def construct_dataset(conf):\n",
    "    dist_kwargs = conf['dist']['args']\n",
    "    if conf['dist']['name'] == 'incomplete-intersection':\n",
    "        dist = incomplete_intersection_dist\n",
    "    elif conf['dist']['name'] == 'complete-intersection':\n",
    "        dist = complete_intersection_dist\n",
    "    elif conf['dist']['name'] == 'generator-permutations':\n",
    "        dist = generator_permutations_dist\n",
    "    elif conf['dist']['name'] == 'freegroup':\n",
    "        dist = wrapper_dist\n",
    "        dist_kwargs = {\n",
    "            'freegroup_dimension': config['freegroup_dimension'],\n",
    "            'word_generator': freegroup_generator,\n",
    "            'word_generator_kwargs': conf['dist']['args'],\n",
    "        }\n",
    "    \n",
    "    type_kwargs = conf['type']['args']\n",
    "    if conf['type']['name'] == 'refillable':\n",
    "        return RefillableDataset(dist, dist_kwargs, **type_kwargs)\n",
    "    elif conf['type']['name'] == 'fixed':\n",
    "        return sample_dataset(dist, dist_kwargs = dist_kwargs, **type_kwargs)\n",
    "        \n",
    "train_dataset = construct_dataset(config['train']['dataset'])\n",
    "eval_dataset = {k: construct_dataset(v) for k, v in config['eval']['dataset'].items()}\n",
    "gen_dataset = {k: construct_dataset(v) for k, v in config['gen']['dataset'].items()}\n",
    "\n",
    "gen_methods = {}\n",
    "for k, v in config['gen']['methods'].items():\n",
    "    logits_processors = []\n",
    "    \n",
    "    if 'top_p' in v:\n",
    "        logits_processors.append(TopPLogitsWarper(v.pop('top_p')))\n",
    "    if 'repetition_penalty' in v:\n",
    "        logits_processors.append(RepetitionPenaltyLogitsProcessor(v.pop('repetition_penalty')))\n",
    "    \n",
    "    logits_processors.append(\n",
    "        SuppressTokensLogitsProcessor(tokenizer.convert_tokens_to_ids(['[', ']', ',', 'y', 'n', ':']))\n",
    "    )\n",
    "    logits_processors.append(\n",
    "        NoLastTokenReductionProcessor.from_tokenizer(config['freegroup_dimension'], tokenizer)\n",
    "    )\n",
    "    \n",
    "    gen_methods[k] = {'logits_processor': LogitsProcessorList(logits_processors), **v}\n",
    "    \n",
    "\n",
    "with    wandb.init(notes = '', config = config) as run,\\\n",
    "        tqdm(total = config['train']['steps']) as progress:\n",
    "    \n",
    "    model.to(config['device'])\n",
    "    \n",
    "    artifact = wandb.Artifact(f'notebook', type='notebook')\n",
    "    artifact.add_file(config['notebook_name'])\n",
    "    run.log_artifact(artifact)\n",
    "    \n",
    "    if config['train']['optimizer']['name'] == 'AdamW':\n",
    "        optimizer = AdamW(params = model.parameters(), **config['train']['optimizer']['args'])\n",
    "        \n",
    "    if config['train']['scheduler']['name'] == 'Linear':\n",
    "        scheduler = LinearLR(optimizer,**config['train']['scheduler']['args'])\n",
    "        \n",
    "    train_dataloader_fn = lambda: DataLoader(\n",
    "        train_dataset,\n",
    "        batch_size = config['train']['batch_size'],\n",
    "        collate_fn = data_collator(\n",
    "            model, tokenizer, config['freegroup_dimension'],\n",
    "            method = config['method'], mode = 'predict',\n",
    "        )\n",
    "    )\n",
    "    \n",
    "    dataloader = [].__iter__() # empty iterator\n",
    "    \n",
    "    while progress.n < config['train']['steps']:\n",
    "        progress.update()\n",
    "        model.train()\n",
    "        \n",
    "        try:\n",
    "            batch = next(dataloader)\n",
    "        except StopIteration:\n",
    "            dataloader = train_dataloader_fn().__iter__()\n",
    "            batch = next(dataloader)\n",
    "            \n",
    "        for k, v in batch.items():\n",
    "            batch[k] = v.to(model.device)\n",
    "            \n",
    "        outputs = model(**batch)\n",
    "        optimizer.zero_grad()\n",
    "        outputs[0].backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        add_metric('train/loss', outputs[0].item())\n",
    "        \n",
    "        if progress.n % config['train']['scheduler']['steps'] == 0:\n",
    "            scheduler.step()\n",
    "            \n",
    "        if progress.n % config['eval']['steps'] == 0:\n",
    "            model.eval()\n",
    "            for key, dataset in eval_dataset.items():\n",
    "                for batch in DataLoader(\n",
    "                    dataset, shuffle=True,\n",
    "                    batch_size = config['eval']['batch_size'],\n",
    "                    collate_fn = data_collator(\n",
    "                        model, tokenizer, config['freegroup_dimension'],\n",
    "                        method = config['method'], mode = 'predict',\n",
    "                    )\n",
    "                ):\n",
    "                    for k, v in batch.items():\n",
    "                        batch[k] = v.to(model.device)\n",
    "                    with no_grad():\n",
    "                        outputs = model(**batch)\n",
    "                    add_metric(f'eval/{key}_loss', outputs[0].item())\n",
    "                report_metric(f'eval/{key}_loss')\n",
    "                \n",
    "        if progress.n % config['gen']['steps'] == 0:\n",
    "            model.eval()\n",
    "            \n",
    "            gen_metric_names = ['completion_ratio'] + ['reduction_ratio'] +\\\n",
    "                [f'reduction_ratio_{idx}' for idx in range(config['freegroup_dimension'] + 1)]\n",
    "            \n",
    "            for dataset_key, dataset in gen_dataset.items():\n",
    "                for method_key, method in gen_methods.items():\n",
    "                    \n",
    "                    for batch in DataLoader(\n",
    "                        dataset, shuffle = True,\n",
    "                        batch_size = config['gen']['batch_size'],\n",
    "                        collate_fn = data_collator(\n",
    "                            model, tokenizer, config['freegroup_dimension'],\n",
    "                            method = config['method'], mode = 'generate',\n",
    "                        )\n",
    "                    ):\n",
    "                        for k, v in batch.items():\n",
    "                            batch[k] = v.to(model.device)\n",
    "                            \n",
    "                        outputs = model.generate(**batch, **method)\n",
    "                        outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)\n",
    "                        outputs = batch_from_string(outputs)\n",
    "                        \n",
    "                        for k, v in completion_ratio(\n",
    "                            outputs, batch['inputs'],\n",
    "                            freegroup_dimension = config['freegroup_dimension']\n",
    "                        ).items():\n",
    "                            add_metric(f'gen/{method_key}_{dataset_key}/{k}', v)\n",
    "                            \n",
    "                        for k, v in reduction_ratio(\n",
    "                            outputs, batch['inputs'],\n",
    "                            freegroup_dimension = config['freegroup_dimension']\n",
    "                        ).items():\n",
    "                            add_metric(f'gen/{method_key}_{dataset_key}/{k}', v)\n",
    "                            \n",
    "                    for name in gen_metric_names:\n",
    "                        report_metric(f'gen/{method_key}_{dataset_key}/{name}')\n",
    "                        \n",
    "        if progress.n % config['train']['log_steps'] == 0:\n",
    "            report_metric('train/loss')\n",
    "        \n",
    "        if progress.n % config['train']['save_steps'] == 0:\n",
    "            model_artifact = wandb.Artifact(run.id, type='model')\n",
    "            checkpoint = f'{run.dir}/checkpoint'\n",
    "            if exists(checkpoint): rmtree(checkpoint)\n",
    "            makedirs(checkpoint)\n",
    "            model.save_pretrained(checkpoint)\n",
    "            tokenizer.save_pretrained(checkpoint)\n",
    "            save(optimizer.state_dict(), f'{checkpoint}/optimizer.pt')\n",
    "            save(scheduler.state_dict(), f'{checkpoint}/scheduler.pt')\n",
    "            model_artifact.add_dir(checkpoint)\n",
    "            run.log_artifact(model_artifact)\n",
    "            \n",
    "        if report: wandb.log(report, commit = True, step = progress.n)\n",
    "            \n",
    "for dataset in [train_dataset, *eval_dataset.values(), *gen_dataset.values()]:\n",
    "    if isinstance(dataset, RefillableDataset): dataset.stop()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "vscode": {
   "interpreter": {
    "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
