{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "freegroup_dimension = 4\n",
    "\n",
    "%env TOKENIZERS_PARALLELISM=false\n",
    "%env CUDA_VISIBLE_DEVICES=1\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from random import randint, random\n",
    "\n",
    "def random_commutator(leaves, proba_commutator = 0.8):\n",
    "    if len(leaves) == 0:\n",
    "        raise ValueError\n",
    "    if len(leaves) == 1:\n",
    "        return leaves[0]\n",
    "    if len(leaves) >= 2:\n",
    "        split_idx = randint(1, len(leaves) - 1)\n",
    "        if split_idx >= len(leaves) // 2: split_idx = min(len(leaves) - 1, int(split_idx * 1.5))\n",
    "        else: split_idx = max(1, int(split_idx * 0.5))\n",
    "    return (random_commutator(leaves[:split_idx], proba_commutator), random_commutator(leaves[split_idx:], proba_commutator)) \\\n",
    "            if random() < proba_commutator else \\\n",
    "            [random_commutator(leaves[:split_idx], proba_commutator), random_commutator(leaves[split_idx:], proba_commutator)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "from freegroup.tools import Visitor\n",
    "\n",
    "class CalculateDepth(Visitor):\n",
    "    def visit_generator(self, word):\n",
    "        return 0\n",
    "\n",
    "    def visit_commutator(self, commutator):\n",
    "        return max(map(self, commutator)) + 1\n",
    "\n",
    "    def visit_multiplication(self, mult):\n",
    "        return max(map(self, mult))\n",
    "\n",
    "calculate_depth = CalculateDepth()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## DATASET"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from freegroup.sampling import free_group_bounded, random_length\n",
    "from freegroup.tools import flatten, normalize, to_string\n",
    "\n",
    "from itertools import repeat, islice\n",
    "from tqdm import tqdm\n",
    "from random import randint, sample, random, shuffle\n",
    "\n",
    "from numpy.random import geometric\n",
    "\n",
    "\n",
    "\n",
    "from dataclasses import dataclass\n",
    "\n",
    "@dataclass\n",
    "class DatasetConfig:\n",
    "    freegroup_dimension: int = freegroup_dimension\n",
    "    max_commutee_length: int = 3\n",
    "    max_commutator_depth: int = 7\n",
    "    proba_commutator: float = 0.85\n",
    "    max_multipliers_number: int = 3\n",
    "    min_total_length: int = 0\n",
    "    max_total_length: int = 250\n",
    "\n",
    "data_config = DatasetConfig()\n",
    "\n",
    "def data_iterator(config: DatasetConfig):\n",
    "    commutee = free_group_bounded(config.freegroup_dimension, random_length_method = lambda : geometric(0.5))\n",
    "\n",
    "    def random_commutator(config: DatasetConfig, depth = None):\n",
    "\n",
    "        depth = random_length(config.max_commutator_depth, method=\"almost_uniform\") if depth is None else depth\n",
    "\n",
    "        if depth <= 0:\n",
    "            return next(commutee)\n",
    "\n",
    "        coin = random()\n",
    "        if depth == 1 or coin < config.proba_commutator:\n",
    "            depths = [randint(1, depth), depth]\n",
    "            shuffle(depths)\n",
    "            return tuple([random_commutator(config, d - 1) for d in depths])\n",
    "        coin -= config.proba_commutator\n",
    "\n",
    "        depths = [randint(2, depth) for _ in range(randint(2, config.max_multipliers_number) - 1)] +\\\n",
    "            [depth]\n",
    "        shuffle(depths)\n",
    "        return list([random_commutator(config, d - 1) for d in depths])\n",
    "\n",
    "\n",
    "    def commutators(config):\n",
    "        while True: yield random_commutator(config)\n",
    "\n",
    "    generator = commutators(config)\n",
    "\n",
    "    generator = map(lambda x: (flatten(x), x), generator)\n",
    "    generator = map(lambda p: tuple(map(normalize, p)), generator)\n",
    "\n",
    "    generator = filter(lambda p: not isinstance(p[1], list), generator)\n",
    "    generator = filter(lambda p: config.min_total_length < len(p[0]) < config.max_total_length, generator)\n",
    "    \n",
    "    def to_dict(entry):\n",
    "        word, commutator = entry\n",
    "        \n",
    "        return {\n",
    "            'word': to_string(word, method = 'integer'),\n",
    "            'commutator': to_string(commutator, method = 'integer'),\n",
    "        }\n",
    "    \n",
    "    return map(to_dict, generator)\n",
    "\n",
    "visualize = list(tqdm(islice(data_iterator(DatasetConfig()), 1000), total = 1000))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from freegroup.tools import Visitor\n",
    "\n",
    "class FindMultiplication(Visitor):\n",
    "    def visit_generator(self, generator):\n",
    "        return False\n",
    "\n",
    "    def visit_commutator(self, commutator):\n",
    "        return any(map(self, commutator))\n",
    "\n",
    "    def visit_multiplication(self, mult):\n",
    "        n_commutators = 0\n",
    "        for m in mult:\n",
    "            n_commutators += isinstance(m, tuple)\n",
    "        return n_commutators > 0\n",
    "                \n",
    "find_multiplication = FindMultiplication()\n",
    "        \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from freegroup.tools import from_string, flatten\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "from itertools import islice\n",
    "\n",
    "from numpy import arange\n",
    "\n",
    "for c in map(lambda x: x['commutator'], visualize):\n",
    "    try: from_string(c, method = 'integer')\n",
    "    except BaseException: print(c)\n",
    "\n",
    "words = list(map(lambda x: from_string(x['word'], method = 'integer'), visualize))\n",
    "commutators = list(map(lambda x: from_string(x['commutator'], method = 'integer'), visualize))\n",
    "\n",
    "print('HISTOGRAM \"THE LENGTH OF A WORD\"')\n",
    "\n",
    "lens = list(map(len, words))\n",
    "plt.hist(lens)\n",
    "plt.show()\n",
    "\n",
    "print('THE MAXIMUM COMMUTATOR DEPTH')\n",
    "\n",
    "depths = list(map(calculate_depth, commutators))\n",
    "plt.hist(depths)\n",
    "plt.show()\n",
    "\n",
    "print('REDUCTION SIZE OF LENGTH')\n",
    "\n",
    "diffs = list(map(lambda x: len(flatten(x[1])) - len(x[0]), zip(words, commutators)))\n",
    "\n",
    "plt.hist(diffs, arange(-0.5, 20))\n",
    "plt.show()\n",
    "\n",
    "\n",
    "\n",
    "print(f'THE RATE OF WORDS CONTAINING MULTIPLICATION: {sum(map(find_multiplication, commutators)) / len(commutators)}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TRAINING"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import PreTrainedTokenizerFast\n",
    "\n",
    "tokenizer = PreTrainedTokenizerFast.from_pretrained(f'tokenizer/word-level-tokenizer-{freegroup_dimension}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import BertConfig, GPT2Config, EncoderDecoderConfig, EncoderDecoderModel\n",
    "\n",
    "encoder_config = BertConfig(\n",
    "    vocab_size              = len(tokenizer),\n",
    "    hidden_size             = 128,\n",
    "    max_position_embeddings = 1024,\n",
    "    num_hidden_layers       = 12,\n",
    "    num_attention_heads     = 8,\n",
    "    intermediate_size       = 4 * 128,\n",
    "    pad_token_id            = tokenizer.pad_token_id,\n",
    ")\n",
    "\n",
    "decoder_config = BertConfig(\n",
    "    vocab_size              = len(tokenizer),\n",
    "    hidden_size             = 128,\n",
    "    max_position_embeddings = 1024,\n",
    "    num_hidden_layers       = 12,\n",
    "    num_attention_heads     = 8,\n",
    "    intermediate_size       = 4 * 128,\n",
    "    pad_token_id            = tokenizer.pad_token_id,\n",
    ")\n",
    "\n",
    "model_config = EncoderDecoderConfig.from_encoder_decoder_configs(\n",
    "    encoder_config = encoder_config,\n",
    "    decoder_config = decoder_config,\n",
    ")\n",
    "\n",
    "model = EncoderDecoderModel(config=model_config)\n",
    "model.config.decoder_start_token_id = tokenizer.bos_token_id\n",
    "model.config.bos_token_id           = tokenizer.bos_token_id\n",
    "model.config.eos_token_id           = tokenizer.eos_token_id\n",
    "model.config.pad_token_id           = tokenizer.pad_token_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from evaluate import load, combine\n",
    "\n",
    "metrics = combine([load(\"bleu\"), load(\"rouge\")])\n",
    "\n",
    "def compute_metrics(eval_pred):\n",
    "    logits, labels = eval_pred\n",
    "\n",
    "    labels[labels == -100] = tokenizer.convert_tokens_to_ids('<pad>')\n",
    "\n",
    "    predictions = tokenizer.batch_decode(logits, skip_special_tokens=True)\n",
    "    references = tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
    "\n",
    "    print(predictions[0])\n",
    "    print(references[0])\n",
    "\n",
    "    return metrics.compute(references = references, predictions = predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import refueable_dataset\n",
    "\n",
    "def preprocess(input):\n",
    "    model_inputs = tokenizer(\n",
    "        input['word'],\n",
    "        return_token_type_ids   = False,\n",
    "        return_attention_mask   = False,\n",
    "    )\n",
    "    with tokenizer.as_target_tokenizer():\n",
    "        model_targets = tokenizer(\n",
    "            input['commutator'],\n",
    "        )\n",
    "    \n",
    "    model_inputs['input_ids'] = model_inputs['input_ids']\n",
    "    model_inputs['labels']    = model_targets['input_ids']\n",
    "\n",
    "    return model_inputs\n",
    "\n",
    "with refueable_dataset(\n",
    "    generator_fn = data_iterator,\n",
    "    preprocess_fn = preprocess,\n",
    "    config = data_config,\n",
    "    batch_size = 5,\n",
    "    queue_size = 5,\n",
    ") as dataset:\n",
    "    for d in dataset: break\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "from typing import List\n",
    "import torch\n",
    "\n",
    "def data_collator(batch: List):\n",
    "    def max_length_pad(batch: List):\n",
    "        max_length = max(map(len, batch))\n",
    "        batch = map(lambda x: x + [tokenizer.pad_token_id] * (max_length - len(x)), batch)\n",
    "        batch = map(lambda x: torch.tensor(x, dtype=int), batch)\n",
    "        batch = torch.stack(list(batch))\n",
    "        \n",
    "        attention_mask = torch.ones_like(batch)\n",
    "        attention_mask.masked_fill_(batch == tokenizer.pad_token_id, 0.) \n",
    "        return batch, attention_mask\n",
    "    \n",
    "    input_ids, attention_mask = max_length_pad([x.input_ids for x in batch])\n",
    "    labels, decoder_attention_mask = max_length_pad([x.labels for x in batch])\n",
    "    decoder_input_ids = labels[:, :-1].clone()\n",
    "    labels[labels == tokenizer.pad_token_id] = -100\n",
    "    \n",
    "    return {\n",
    "        'input_ids': input_ids,\n",
    "        'attention_mask': attention_mask,\n",
    "        'decoder_input_ids': decoder_input_ids,\n",
    "        'labels': labels[:, 1:],\n",
    "        'decoder_attention_mask': decoder_attention_mask[:, 1:],\n",
    "    }\n",
    "\n",
    "\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "with refueable_dataset(\n",
    "    generator_fn = data_iterator,\n",
    "    preprocess_fn = preprocess,\n",
    "    config = data_config,\n",
    "    batch_size = 5,\n",
    "    queue_size = 5,\n",
    ") as dataset:\n",
    "    for batch in DataLoader(dataset, batch_size = 2, collate_fn=data_collator):\n",
    "        break\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer\n",
    "\n",
    "prefix = 'bert-bert-512'\n",
    "\n",
    "args = Seq2SeqTrainingArguments(\n",
    "    output_dir                  = str(prefix),\n",
    "    overwrite_output_dir        = True,\n",
    "\n",
    "    predict_with_generate       = True,\n",
    "    evaluation_strategy         = 'steps',\n",
    "    eval_steps                  = 200,\n",
    "\n",
    "    per_device_train_batch_size = 64,\n",
    "    per_device_eval_batch_size  = 64,\n",
    "    \n",
    "    logging_steps               = 1000,\n",
    "    save_steps                  = 1000,\n",
    "    \n",
    "    max_steps                   = 500 * 1000,\n",
    "\n",
    "    learning_rate               = 8e-5,\n",
    "\n",
    "    generation_max_length       = 200,\n",
    "    \n",
    "    save_total_limit            = 100,\n",
    ")\n",
    "\n",
    "eval_dataset = list(map(preprocess, islice(data_iterator(data_config), 1000)))\n",
    "\n",
    "\n",
    "with refueable_dataset(\n",
    "    generator_fn    = data_iterator,\n",
    "    preprocess_fn   = preprocess,\n",
    "    config          = data_config,\n",
    "    queue_size      = 3,\n",
    "    batch_size      = 500,\n",
    "    max_consecutive_calls = 3,\n",
    ") as train_dataset:\n",
    "\n",
    "    trainer = Seq2SeqTrainer(\n",
    "        model           = model,\n",
    "        args            = args,\n",
    "        train_dataset   = train_dataset,\n",
    "        eval_dataset    = eval_dataset,\n",
    "        data_collator   = data_collator,\n",
    "        compute_metrics = compute_metrics,\n",
    "    )\n",
    "\n",
    "    trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# INFERENCE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from freegroup.tools import (\n",
    "    from_string, to_string, normalize, flatten\n",
    ")\n",
    "from itertools import repeat\n",
    "from torch import tensor\n",
    "\n",
    "def translation(word, num_tries = None, generation_config = dict()):\n",
    "    input_ids = tokenizer(to_string(word, method = 'tokenizer')).input_ids\n",
    "    input_ids = tensor(input_ids, dtype = int, device = model.device)\n",
    "    for _ in repeat(None, num_tries) if not num_tries is None else repeat(None):\n",
    "        outputs = model.generate(input_ids.unsqueeze(0), **generation_config)\n",
    "        outputs = tokenizer.batch_decode(outputs, skip_special_tokens = True)\n",
    "\n",
    "        for output in outputs:\n",
    "            try:\n",
    "                output = from_string(output, method = 'tokenizer')\n",
    "            except BaseException:\n",
    "                continue\n",
    "            if normalize(flatten(output)) == normalize(flatten(word)):\n",
    "                return output\n",
    "\n",
    "generation_config = dict(\n",
    "    eos_token_id = tokenizer.eos_token_id,\n",
    "    pad_token_id = tokenizer.pad_token_id,\n",
    ")\n",
    "\n",
    "translation(normalize(flatten(from_string('[x, y]', method='lu'))), generation_config = generation_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import tensor\n",
    "from freegroup import tools, sampling as smp\n",
    "from freegroup.commutators import (\n",
    "    to_tokenizer, from_tokenizer\n",
    ")\n",
    "from itertools import repeat\n",
    "from re import finditer\n",
    "from itertools import islice\n",
    "\n",
    "\n",
    "def sampler(word, **generation_kwargs):\n",
    "    def sample():\n",
    "        input_ids = tokenizer(\n",
    "            to_tokenizer(word),\n",
    "        ).input_ids\n",
    "        input_ids = tensor(input_ids, dtype=int, device = model.device)\n",
    "        outputs = model.generate(input_ids.unsqueeze(0), **generation_kwargs)\n",
    "        return tokenizer.batch_decode(outputs)\n",
    "\n",
    "    return sample\n",
    "\n",
    "\n",
    "from re import finditer\n",
    "\n",
    "def trim_eos_tokens(string):\n",
    "    positions = finditer(tokenizer.eos_token, string)\n",
    "    begin = next(positions).end()\n",
    "    \n",
    "    try:                    end = next(positions).start()\n",
    "    except StopIteration:   end = len(string)\n",
    "\n",
    "    return string[begin:end]\n",
    "\n",
    "\n",
    "def try_or(callable, default, *args, **kwargs):\n",
    "    try:\n",
    "        return callable(*args, **kwargs)\n",
    "    except BaseException:\n",
    "        return default\n",
    "\n",
    "def translation(\n",
    "    word,\n",
    "    num_tries = None,\n",
    "    **generation_kwargs,\n",
    "):\n",
    "    g = smp.iterable_from_batches(sampler(\n",
    "        word, **generation_kwargs \n",
    "    ), num_tries)\n",
    "    g = map(trim_eos_tokens, g)\n",
    "    g = map(lambda x: try_or(from_tokenizer, [], x), g)\n",
    "    \n",
    "    return filter(lambda x: normalize(to_freegroup(x)) == word, g)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from freegroup.commutators import (\n",
    "    from_lu, normalize, to_freegroup, to_lu\n",
    ")\n",
    "\n",
    "generation_config = dict(\n",
    "    max_length = 40,\n",
    "    num_beams = 100,\n",
    "    num_return_sequences = 100,\n",
    ")\n",
    "\n",
    "word = normalize(to_freegroup(from_lu('[[x, y], [x, yz]]')))\n",
    "\n",
    "try:\n",
    "    c = next(translation(word, num_tries = 5, **generation_config))\n",
    "    print(to_lu(c))\n",
    "except StopIteration:\n",
    "    print('Failed!')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "vscode": {
   "interpreter": {
    "hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
