{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "02022097",
   "metadata": {},
   "source": [
    "# Greedy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a85d9f27",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import List\n",
    "from functools import reduce\n",
    "from random import sample\n",
    "\n",
    "Word = List[int]\n",
    "\n",
    "def remove_complementaries(word: Word):\n",
    "    if len(word) >= 2 and word[-1] == -word[-2]:\n",
    "        return word[:-2]\n",
    "    return word\n",
    "\n",
    "def remove_rotations(word: Word, *substrs: Word):\n",
    "    for substr in substrs:\n",
    "        if len(word) < len(substr): continue\n",
    "        doubled = substr * 2\n",
    "        for idx in range(len(doubled)):\n",
    "            if word[-len(substr):] == doubled[idx:idx + len(substr)]:\n",
    "                return word[:-len(substr)]\n",
    "    return word\n",
    "\n",
    "def combine(x, y):\n",
    "    x.append(y)\n",
    "    return x\n",
    "\n",
    "def generate_from_intersection(\n",
    "    prefix: Word,\n",
    "    num_generators: int,\n",
    "    max_iterations: int,\n",
    "):\n",
    "    generated = prefix[::]\n",
    "\n",
    "    stacks = [[] for _ in range(num_generators + 1)]\n",
    "    bases = [[i] for i in range(1, num_generators + 1)] + [list(range(1, num_generators + 1))]\n",
    "    ibases = [[-f for f in base[::-1]] for base in bases]\n",
    "\n",
    "    for i, (base, ibase) in enumerate(zip(bases, ibases)):\n",
    "        stacks[i] = reduce(lambda x, y: remove_complementaries(remove_rotations(combine(x, y), base, ibase)), prefix, [])\n",
    "\n",
    "    iteration = 0\n",
    "    while not all(map(lambda v: len(v) == 0, stacks)) and iteration < max_iterations:\n",
    "        iteration += 1\n",
    "\n",
    "        votes = {}\n",
    "        for stack, base, ibase in zip(stacks, bases, ibases):\n",
    "            if len(stack) == 0: continue\n",
    "\n",
    "            last = stack[-1]\n",
    "            if last == generated[-1]:\n",
    "                if not last in base and not last in ibase:\n",
    "                    for key in [base[0], ibase[0]]:\n",
    "                        votes[key] = votes.get(key, 0) + 1\n",
    "                elif last in base:\n",
    "                    idx = base.index(last)\n",
    "                    key = base[(idx + 1) % len(base)]\n",
    "                    votes[key] = votes.get(key, 0) + 1\n",
    "                elif last in ibase:\n",
    "                    idx = ibase.index(last)\n",
    "                    key = ibase[(idx + 1) % len(ibase)]\n",
    "                    votes[key] = votes.get(key, 0) + 1\n",
    "            else:\n",
    "                votes[-last] = votes.get(-last, 0) + 1\n",
    "\n",
    "        max_val  = max(votes.values())\n",
    "        max_key  = sample([key for key in votes.keys() if votes[key] == max_val], k = 1)[0]\n",
    "\n",
    "        generated.append(max_key)\n",
    "        \n",
    "        for i, (base, ibase) in enumerate(zip(bases, ibases)):\n",
    "            stacks[i] = remove_complementaries(remove_rotations(combine(stacks[i], max_key), base, ibase))\n",
    "    \n",
    "    \n",
    "    if not all(map(lambda v: len(v) == 0, stacks)):\n",
    "        return None\n",
    "    \n",
    "    return generated     "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a86c7a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from numpy.random import randint\n",
    "from tqdm import trange\n",
    "\n",
    "number_generators = 4\n",
    "max_prefix_length = 10\n",
    "\n",
    "tries = 200\n",
    "max_iterations = 100\n",
    "\n",
    "num_samples = 20\n",
    "\n",
    "unique = set()\n",
    "\n",
    "for _ in trange(num_samples):\n",
    "    \n",
    "    generated = None\n",
    "    while generated is None or tuple(generated) in unique:\n",
    "        prefix = None\n",
    "        while prefix is None or len(prefix) == 0:\n",
    "            prefix = randint(-number_generators, number_generators + 1, size = randint(max_prefix_length)).tolist()\n",
    "            prefix = [f for f in prefix if f != 0]\n",
    "            prefix = reduce(lambda x, y: remove_complementaries(combine(x, y)), prefix, [])\n",
    "        \n",
    "        iteration = 0\n",
    "        while generated is None and iteration < tries:\n",
    "            generated = generate_from_intersection(prefix, number_generators, max_iterations)\n",
    "            iteration += 1\n",
    "    unique.add(tuple(generated))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7464d0f8",
   "metadata": {},
   "source": [
    "# Random Search"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21d323c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataclasses import dataclass\n",
    "from os import environ\n",
    "from typing import Callable\n",
    "\n",
    "@dataclass\n",
    "class TrainDatasetConfig:\n",
    "    freegroup_dimension: int = 4\n",
    "    max_zero_closure_depth: int = 10\n",
    "    max_non_zero_closure_depth: int = 30\n",
    "    max_freegroup_commutees: int = 1\n",
    "    max_freegroup_commutee_length: int = 10\n",
    "    max_total_length: int = 200\n",
    "    min_total_length: int = 0\n",
    "    method: str = 'brackets'\n",
    "\n",
    "train_dataset_config = TrainDatasetConfig()\n",
    "\n",
    "from freegroup.sampling import normal_closure, freegroup, random_tree\n",
    "from freegroup.tools import flatten, normalize, is_from_singleton_normal_closure, wu_closure, to_string\n",
    "from random import sample, randint, shuffle, choice\n",
    "from iteration_utilities import unique_everseen, repeatfunc\n",
    "from itertools import islice\n",
    "from utils import compute_multi_label\n",
    "\n",
    "def train_dataset_fn(config: TrainDatasetConfig):\n",
    "    \n",
    "    max_depths = [config.max_zero_closure_depth] +\\\n",
    "        [config.max_non_zero_closure_depth for _ in range(1, config.freegroup_dimension + 1)]\n",
    "    \n",
    "    def initial():\n",
    "        exclude_idx = randint(0, config.freegroup_dimension)\n",
    "        words, multi_label = [], []\n",
    "        \n",
    "        for idx in range(config.freegroup_dimension + 1):\n",
    "            if idx != exclude_idx:\n",
    "                words.append(normal_closure(\n",
    "                    config.method,\n",
    "                    wu_closure(config.freegroup_dimension, idx),\n",
    "                    config.freegroup_dimension,\n",
    "                    'uniform',\n",
    "                    {'radius': max_depths[idx]},\n",
    "                ))\n",
    "                \n",
    "        for _ in range(randint(0, config.max_freegroup_commutees)):\n",
    "            words.append(freegroup(\n",
    "                freegroup_dimension = config.freegroup_dimension,\n",
    "                length_method = 'uniform', \n",
    "                length_parameters = {'radius': config.max_freegroup_commutee_length},\n",
    "            ))\n",
    "\n",
    "        shuffle(words)\n",
    "        return {\n",
    "            'word': normalize(flatten(random_tree(words))),\n",
    "        }\n",
    "    \n",
    "    def add_multi_label(entry):\n",
    "        entry['multi_label'] = compute_multi_label(entry['word'], freegroup_dimension = config.freegroup_dimension)\n",
    "        return entry\n",
    "        \n",
    "    def condition(entry):\n",
    "        word, multi_label = entry['word'], entry['multi_label']\n",
    "        \n",
    "        if config.min_total_length >= len(word) or len(word) >= config.max_total_length:\n",
    "            return False\n",
    "        \n",
    "        if sum(multi_label) != config.freegroup_dimension:\n",
    "            return False\n",
    "        \n",
    "        return True        \n",
    "    \n",
    "    return unique_everseen(filter(condition, map(add_multi_label, repeatfunc(initial))), key=lambda x: to_string(x['word']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9043e14",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import sample_dataset\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "from tqdm import tqdm\n",
    "from itertools import islice\n",
    "from scipy.stats import ks_2samp\n",
    "from utils import sample_dataset, completion_ratio\n",
    "from tqdm.notebook import trange, tqdm\n",
    "from numpy import mean, std\n",
    "from random import seed as rseed\n",
    "\n",
    "experiments = []\n",
    "\n",
    "for method, fdim, max_depth_0, max_depth_1, max_length in [\n",
    "    ('brackets', 3, 10, 30, 200),\n",
    "    ('brackets', 4, 10, 30, 400),\n",
    "    ('brackets', 5, 7, 30, 600)\n",
    "]:\n",
    "    experiments.append(TrainDatasetConfig(\n",
    "        method = method,\n",
    "        freegroup_dimension = fdim,\n",
    "        max_zero_closure_depth = max_depth_0,\n",
    "        max_non_zero_closure_depth = max_depth_1,\n",
    "        max_total_length = max_length,\n",
    "    ))\n",
    "    \n",
    "    \n",
    "results = {\n",
    "    k: [] for k in ['fdim', 'dataset']\n",
    "}\n",
    "    \n",
    "for config in tqdm(experiments):\n",
    "    results['fdim'].append(config.freegroup_dimension)\n",
    "    results['dataset'].append(sample_dataset(train_dataset_fn, size = 5000, config = config))\n",
    "    \n",
    "        "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ab48f5ee",
   "metadata": {},
   "source": [
    "# Evolutionary Method"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90609ee4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import random\n",
    "import warnings\n",
    "\n",
    "from freegroup.tools import (\n",
    "    reduce_modulo_singleton_normal_closure, normalize, to_string, is_from_singleton_normal_closure,\n",
    "    wu_closure\n",
    ")\n",
    "from freegroup.sampling import (\n",
    "    freegroup_generator,\n",
    "    normal_closure_generator,\n",
    ")\n",
    "from tqdm.notebook import tqdm, trange\n",
    "\n",
    "\n",
    "def distance_to_singleton_normal_closure(word, closure, approximation=\"reduction\"):\n",
    "    if approximation == \"reduction\":\n",
    "        return len(reduce_modulo_singleton_normal_closure(word, closure))\n",
    "    else:\n",
    "        raise NotImplementedError('unknown `approximation`')\n",
    "\n",
    "\n",
    "def better_base(word, fdim, previous_function, first=None):\n",
    "    if first:\n",
    "        current_function = 0\n",
    "        for idx in range(1, first + 1):\n",
    "            current_function += distance_to_singleton_normal_closure(word, wu_closure(fdim, idx))\n",
    "            if current_function >= previous_function:\n",
    "                return False\n",
    "        return True\n",
    "\n",
    "    current_function = 0\n",
    "    for idx in range(1, fdim + 1):\n",
    "        current_function += distance_to_singleton_normal_closure(word, wu_closure(fdim, idx))\n",
    "        if current_function >= previous_function:\n",
    "            return False\n",
    "    current_function += distance_to_singleton_normal_closure(word, wu_closure(fdim, 0))\n",
    "    if current_function >= previous_function:\n",
    "        return False\n",
    "    return True\n",
    "\n",
    "\n",
    "def dist_base(word, fdim, first=None):\n",
    "    if first:\n",
    "        return sum(distance_to_singleton_normal_closure(word, wu_closure(fdim, idx)) for idx in range(1, first + 1))\n",
    "    return sum(distance_to_singleton_normal_closure(word, wu_closure(fdim, idx)) for idx in range(0, fdim + 1))\n",
    "\n",
    "\n",
    "def optimize(\n",
    "    word, dist, better, mutation_rate=0.1, generators_number=2, \n",
    "    max_iters=10, method='gemmate', fixed_size=False, verbose=True):\n",
    "\n",
    "    # https://arxiv.org/pdf/1703.03334.pdf 3.2 (1 + 1) EA\n",
    "    # https://arxiv.org/pdf/1812.11061.pdf 2.2 (\\mu + \\lambda) EA\n",
    "\n",
    "    if method == 'gemmate' and fixed_size:\n",
    "        warnings.warn('gemmate mutation method is not compatible with `fixed_size` set to True')\n",
    "\n",
    "    generators = set(range(1, generators_number + 1)) | set(range(-generators_number, 0))\n",
    "    def mutate(word, method='gemmation'):\n",
    "        mutated_word = word.copy()\n",
    "        if method == 'gemmate':\n",
    "            i = random.randint(0, len(word))\n",
    "            if random.random() < mutation_rate:\n",
    "                mutated_word.insert(i, random.sample(\n",
    "                    generators - set([mutated_word[i-1]]) if i == len(word)\n",
    "                    else generators - set([mutated_word[i], mutated_word[i-1]])\n",
    "                , 1)[0])\n",
    "            else:\n",
    "                mutated_word.pop(min(i, len(word)-1))\n",
    "        elif method == 'edit':\n",
    "            for i in range(len(mutated_word)):\n",
    "                if random.random() < mutation_rate:\n",
    "                    mutated_word[i] = random.sample(generators - set([mutated_word[i]]), 1)[0]\n",
    "        else:\n",
    "            raise NotImplementedError('unknown `method`')\n",
    "        return mutated_word\n",
    "\n",
    "    current_function = dist(word)\n",
    "\n",
    "    if verbose:\n",
    "        print('INFO: optimization started')\n",
    "\n",
    "    for _ in range(max_iters):\n",
    "        new_word = mutate(word, method)\n",
    "        normalized = normalize(new_word)\n",
    "\n",
    "        if len(normalized) == 0:\n",
    "            continue\n",
    "\n",
    "        if better(normalized, current_function):\n",
    "            word = (new_word if fixed_size else normalized).copy()\n",
    "            current_function = dist(normalized)\n",
    "\n",
    "            if verbose:\n",
    "                print(f'INFO: f value = {current_function}')\n",
    "                print(to_string(normalized, method = 'su'))\n",
    "\n",
    "            if current_function == 0:\n",
    "                break\n",
    "\n",
    "    if verbose:\n",
    "        print(\n",
    "            f'INFO: optimization finished,', \n",
    "            'reached intersection' if current_function == 0 else 'reached max_iters', '\\n'\n",
    "            )\n",
    "\n",
    "    return normalize(word), current_function == 0\n",
    "\n",
    "\n",
    "class EvolutionarySampler:\n",
    "    def __init__(\n",
    "        self, generators_number=2, max_length=10, \n",
    "        exploration_rate=None, baseline=\"free\", first=None, **kwargs):\n",
    "\n",
    "        self.generators_number = generators_number\n",
    "        self.max_length = max_length\n",
    "        self.exploration_rate = exploration_rate\n",
    "\n",
    "        if baseline == \"free\":\n",
    "            self.baseline_group = freegroup_generator(\n",
    "                generators_number, 'uh', {'radius': max_length})\n",
    "        elif baseline == \"joint\":\n",
    "            self.baseline_group = normal_closure_generator('conjugation', wu_closure(generators_number, 0),\n",
    "            generators_number, 'uh', {'radius': max_length}, 'uh', {'radius': 5})\n",
    "        elif baseline == \"singleton\":\n",
    "            self.baseline_group = normal_closure_generator('conjugation', wu_closure(generators_number, 1), \n",
    "            generators_number, 'uh', {'radius': max_length}, 'uh', {'radius': 5})\n",
    "        else:\n",
    "            raise NotImplementedError('unknown `baseline`')\n",
    "\n",
    "        if baseline in [\"free\", \"joint\", \"singleton\"]:\n",
    "            self.dist = lambda word: dist_base(word, generators_number, first=first)\n",
    "            self.better = lambda word, previous_function: better_base(word, generators_number, previous_function, first=first)\n",
    "            if not first:\n",
    "                self.condition = lambda word: all(\n",
    "                    is_from_singleton_normal_closure(word, wu_closure(generators_number, idx)) \n",
    "                    for idx in range(0, generators_number + 1))\n",
    "            else:\n",
    "                self.condition = lambda word: all(\n",
    "                    is_from_singleton_normal_closure(word, wu_closure(generators_number, idx)) \n",
    "                    for idx in range(1, first + 1))\n",
    "        else:\n",
    "            raise NotImplementedError()\n",
    "\n",
    "        self.kwargs = kwargs\n",
    "\n",
    "    def __iter__(self):\n",
    "        return self\n",
    "\n",
    "    def __next__(self):\n",
    "        success = False\n",
    "        while not success:\n",
    "            word = next(self.baseline_group)\n",
    "            if self.condition(word):\n",
    "                return word\n",
    "            if random.random() > self.exploration_rate:\n",
    "                continue\n",
    "            word, success = optimize(\n",
    "                word, self.dist, self.better, \n",
    "                generators_number=self.generators_number, **self.kwargs)\n",
    "\n",
    "        return word\n",
    "    \n",
    "    \n",
    "\n",
    "\n",
    "\n",
    "from copy import deepcopy\n",
    "default = {\n",
    "    'num_tries': 5,\n",
    "    'num_samples': 1000,\n",
    "    'seed': 0,\n",
    "    'sampler': {\n",
    "        'baseline': 'singleton',\n",
    "        'generators_number': 3,\n",
    "        'max_length': 60,\n",
    "        'max_iters': 400,\n",
    "        'mutation_rate': 0.8,\n",
    "        'method': 'gemmate',\n",
    "        'fixed_size': False,\n",
    "        'first': None,\n",
    "        'verbose': False,\n",
    "    }\n",
    "}\n",
    "\n",
    "experiments = []\n",
    "\n",
    "for fdim, max_length, num_tries in [\n",
    "    (3, 60, 5),\n",
    "    (3, 60, 10),\n",
    "    (3, 60, 20),\n",
    "    (4, 100, 5),\n",
    "    (4, 100, 10),\n",
    "    (4, 100, 20),\n",
    "    (5, 200, 5),\n",
    "    (5, 200, 10),\n",
    "    (5, 200, 20),    \n",
    "]:\n",
    "    for seed in [1, 10, 100]:\n",
    "        kwargs = deepcopy(default)\n",
    "        kwargs['num_tries'] = num_tries\n",
    "        kwargs['sampler']['generators_number'] = fdim\n",
    "        kwargs['sampler']['max_length'] = max_length\n",
    "        kwargs['seed'] = seed\n",
    "        experiments.append(kwargs)\n",
    "        \n",
    "def run_experiment(config):\n",
    "    sampler = EvolutionarySampler(**config['sampler'])\n",
    "    ratio = 0\n",
    "    for _ in range(config['num_samples']):\n",
    "        word = next(sampler.baseline_group)\n",
    "        for _ in range(config['num_tries']):\n",
    "            word, success = optimize(\n",
    "                word, sampler.dist, sampler.better, \n",
    "                generators_number=sampler.generators_number, **sampler.kwargs)\n",
    "            if success: ratio += 1; break\n",
    "    result = deepcopy(config)\n",
    "    result['completion_ratio'] = ratio / config['num_samples']\n",
    "    return result\n",
    "\n",
    "from multiprocess import Pool\n",
    "\n",
    "results = []\n",
    "\n",
    "with Pool(5) as p:\n",
    "    for r in tqdm(p.imap_unordered(run_experiment, experiments), total = len(experiments)):\n",
    "        results.append(r)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
