{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Libraries & Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../')\n",
    "\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm import tqdm\n",
    "from pathlib import Path\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import TensorDataset, DataLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from layers import CompleteLayer\n",
    "from inits import Size, Like\n",
    "from inits import (\n",
    "    RandomNormal,\n",
    "    RandomUniform,\n",
    "    Ones,\n",
    "    Zeros,\n",
    "    Triu\n",
    ")\n",
    "from pruning import PruneEnsemble\n",
    "from pruning import (\n",
    "    NoPrune,\n",
    "    RandomPrune,\n",
    "    TopKPrune,\n",
    "    DynamicTopK,\n",
    "    ThresholdPrune,\n",
    "    TrilPrune,\n",
    "    TrilDamp,\n",
    "    DynamicTrilDamp\n",
    ")\n",
    "import data\n",
    "import losses\n",
    "import experiments\n",
    "from training import train\n",
    "from evals import (\n",
    "    EvalVisualiser,\n",
    "    LineVisualiser,\n",
    "    BoxVisualiser,\n",
    "    WeightVisualiser,\n",
    "    OrderednessVisualiser\n",
    ")\n",
    "from utils import permute, brute_force_orderedness"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if torch.cuda.is_available():\n",
    "    try:\n",
    "        _ = torch.tensor([0], device='cuda')\n",
    "        device = torch.device('cuda')\n",
    "    except:\n",
    "        device = torch.device('cpu')\n",
    "else:\n",
    "    device = torch.device('cpu')\n",
    "\n",
    "print(f'Using device: {device}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Complete Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_TRIES = 10  # number of times to run each experiment for reliability\n",
    "SEED = 3141592  # random seed for reproducibility"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "WEIGHTS_INIT = RandomNormal()\n",
    "VALUES_INIT = RandomNormal()\n",
    "WEIGHTS_PRUNE = lambda : RandomPrune(p=0.5)\n",
    "\n",
    "for task in ('none', 'xor', 'sine'):\n",
    "\n",
    "    generator = data.TaskGenerator(task, device)\n",
    "    dataloader = generator.dataloader\n",
    "    params = generator.params\n",
    "    trainable = task != 'none'\n",
    "\n",
    "    def setup():\n",
    "        complete = CompleteLayer(\n",
    "            input_size=params['input_size'],\n",
    "            hidden_size=params['hidden_size'],\n",
    "            output_size=params['output_size'],\n",
    "            values_init=(VALUES_INIT, True),\n",
    "            weights_init=(WEIGHTS_INIT, True),\n",
    "            activation=F.sigmoid,\n",
    "            use_bias=False\n",
    "        ).to(device)\n",
    "        optim = torch.optim.Adam(\n",
    "            complete.parameters(),\n",
    "            lr=params['complete_lr']\n",
    "        )\n",
    "        return {\n",
    "            'model': complete,\n",
    "            'optimiser': optim\n",
    "        }\n",
    "\n",
    "    visualisers, result = experiments.run(\n",
    "        its=params['its'],\n",
    "        track_orderedness=False,\n",
    "        pruner=PruneEnsemble({\n",
    "            'values': NoPrune(),\n",
    "            'weights': WEIGHTS_PRUNE(),\n",
    "        }),\n",
    "        visualisers={\n",
    "            'start': BoxVisualiser(\n",
    "                lambda r: r['start_orderedness'],\n",
    "                show=False\n",
    "            ),\n",
    "            'delta': BoxVisualiser(\n",
    "                lambda r: r['delta_final'],\n",
    "                show=False\n",
    "            ),\n",
    "        },\n",
    "        seed=SEED,\n",
    "        tries=NUM_TRIES,\n",
    "        n_epochs=params['complete_epochs'],\n",
    "        setup_fn=setup,\n",
    "        train_dataloader=dataloader,\n",
    "        train_criterion=losses.MSELoss(),\n",
    "        trainable=trainable\n",
    "    )\n",
    "    if task == 'none':\n",
    "        print(f'Init: {visualisers['start'].mean_x:.3f} $\\\\pm$ {visualisers['start'].std_x:.3f}')\n",
    "    print(f'{task}: {visualisers['delta'].mean_x:.3f} $\\\\pm$ {visualisers['delta'].std_x:.3f}')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ml13",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
