{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Libraries & Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../')\n",
    "\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm import tqdm\n",
    "from pathlib import Path\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import TensorDataset, DataLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from layers import CompleteLayer\n",
    "from inits import Size, Like\n",
    "from inits import (\n",
    "    RandomNormal,\n",
    "    RandomUniform,\n",
    "    Ones,\n",
    "    Zeros,\n",
    "    Triu\n",
    ")\n",
    "from pruning import PruneEnsemble\n",
    "from pruning import (\n",
    "    NoPrune,\n",
    "    RandomPrune,\n",
    "    TopKPrune,\n",
    "    DynamicTopK,\n",
    "    ThresholdPrune,\n",
    "    TrilPrune,\n",
    "    TrilDamp,\n",
    "    DynamicTrilDamp\n",
    ")\n",
    "import data\n",
    "import losses\n",
    "import experiments\n",
    "from training import train\n",
    "from evals import (\n",
    "    NullVisualiser,\n",
    "    LineVisualiser,\n",
    "    BoxVisualiser,\n",
    "    WeightVisualiser,\n",
    "    OrderednessVisualiser\n",
    ")\n",
    "from utils import permute, brute_force_orderedness"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if torch.cuda.is_available():\n",
    "    try:\n",
    "        _ = torch.tensor([0], device='cuda')\n",
    "        device = torch.device('cuda')\n",
    "    except:\n",
    "        device = torch.device('cpu')\n",
    "else:\n",
    "    device = torch.device('cpu')\n",
    "\n",
    "print(f'Using device: {device}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generate Tiny Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "TASK_NAME = 'xor' # 'xor', 'sine', 'none'\n",
    "OUT_DIR = Path('../media')\n",
    "FNAME = f'{TASK_NAME}.pdf'\n",
    "\n",
    "print(f'Running task: {TASK_NAME}.')\n",
    "\n",
    "generator = data.TaskGenerator(TASK_NAME, device)\n",
    "dataloader = generator.dataloader\n",
    "params = generator.params"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Simple Baseline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_TRIES = 10  # number of times to run each experiment for reliability\n",
    "SEED = 3141592  # random seed for reproducibility"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def baseline_setup():\n",
    "    baseline = generator.get_mlp_baseline()\n",
    "    criterion = losses.MSELoss()\n",
    "    optim = torch.optim.Adam(\n",
    "        baseline.parameters(),\n",
    "        lr=params['baseline_lr']\n",
    "    )\n",
    "    return {\n",
    "        'model': baseline,\n",
    "        'train_criterion': criterion,\n",
    "        'optimiser': optim\n",
    "    }\n",
    "\n",
    "visualisers = experiments.run(\n",
    "    tries=NUM_TRIES,\n",
    "    seed=SEED,\n",
    "    track_orderedness=False,\n",
    "    setup_fn=baseline_setup,\n",
    "    visualisers={\n",
    "        'train': LineVisualiser(\n",
    "            lambda r: r['train_losses'],\n",
    "            xlabel='Step',\n",
    "            ylabel='Train Loss',\n",
    "            only_values=True,\n",
    "            fname=OUT_DIR/'dynamics'/f'{TASK_NAME}-mlp.pdf'\n",
    "        )\n",
    "    },\n",
    "    n_epochs=params['baseline_epochs'],\n",
    "    train_dataloader=dataloader,\n",
    "    early_stop=3e-3,\n",
    "    trainable=TASK_NAME != 'none'\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Complete Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def setup():\n",
    "    complete = CompleteLayer(\n",
    "        input_size=params['input_size'],\n",
    "        hidden_size=params['hidden_size'],\n",
    "        output_size=params['output_size'],\n",
    "        values_init=(RandomNormal(), True),\n",
    "        weights_init=(RandomNormal(), True),\n",
    "        activation=F.sigmoid,\n",
    "        use_bias=False\n",
    "    ).to(device)\n",
    "    optim = torch.optim.Adam(\n",
    "        complete.parameters(),\n",
    "        lr=params['complete_lr']\n",
    "    )\n",
    "    return {\n",
    "        'model': complete,\n",
    "        'optimiser': optim\n",
    "    }\n",
    "\n",
    "TRACK_ORDEREDNESS = TASK_NAME in ('none', 'xor')\n",
    "visualisers, result = experiments.run(\n",
    "    pruner=PruneEnsemble({\n",
    "        'values': NoPrune(),\n",
    "        'weights': DynamicTopK(k=0.5),\n",
    "    }),\n",
    "    visualisers={\n",
    "        'orderedness-absolute': OrderednessVisualiser(\n",
    "            lambda r: r['model'].weights,\n",
    "            name   = 'weights',\n",
    "            graphs = False\n",
    "        ),\n",
    "        'orderedness-change': BoxVisualiser(\n",
    "            lambda r: r['delta_final'],\n",
    "            name='Change in orderedness',\n",
    "            ylabel='Change in orderedness',\n",
    "            fname=OUT_DIR/'change'/f'{TASK_NAME}-box.pdf'\n",
    "        ),\n",
    "        'orderedness-steps': LineVisualiser(\n",
    "            lambda r: r['delta_steps'],\n",
    "            xlabel='Step',\n",
    "            ylabel='Change in orderedness',\n",
    "            only_values=True,\n",
    "            fname=OUT_DIR/'change'/f'{TASK_NAME}-curve.pdf'\n",
    "        ) if TRACK_ORDEREDNESS else NullVisualiser(),\n",
    "        'train': LineVisualiser(\n",
    "            lambda r: r['train_losses'],\n",
    "            xlabel='Step',\n",
    "            ylabel='Train Loss',\n",
    "            only_values=True,\n",
    "            fname=OUT_DIR/'dynamics'/f'{TASK_NAME}-clp.pdf'\n",
    "        ),\n",
    "        'weights': WeightVisualiser(\n",
    "            lambda r: r['model'].weights,\n",
    "            name='weights',\n",
    "            show=['sample']\n",
    "        )\n",
    "    },\n",
    "    seed=SEED,\n",
    "    tries=NUM_TRIES,\n",
    "    its=params['its'],\n",
    "    track_orderedness=TRACK_ORDEREDNESS,\n",
    "    n_epochs=params['complete_epochs'],\n",
    "    setup_fn=setup,\n",
    "    train_dataloader=dataloader,\n",
    "    train_criterion=losses.MSELoss(),\n",
    "    trainable=TASK_NAME != 'none'\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample = visualisers['weights'].all_weights[0]\n",
    "square = sample[:, :-params['input_size']]\n",
    "orderedness, perm = brute_force_orderedness(\n",
    "    square, fixed_size=params['output_size']\n",
    ")\n",
    "print(f'Orderedness: {orderedness:.3f}')\n",
    "\n",
    "sns.heatmap(permute(\n",
    "    sample, perm, perm+[len(perm)+i for i in range(params['input_size'])]\n",
    "), annot=True, cmap='viridis')\n",
    "plt.title(f'Orderedness of weights: {orderedness:.3f}')\n",
    "plt.tight_layout()\n",
    "plt.savefig(OUT_DIR/'weights'/FNAME, bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Hidden unit - iteration plots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "its = list(range(1, 11))\n",
    "units = list(range(1, 9, 2))\n",
    "\n",
    "ui_means = np.zeros((len(units), len(its)))\n",
    "ui_stds = np.zeros((len(units), len(its)))\n",
    "\n",
    "pbar = tqdm(total=len(units) * len(its))\n",
    "for u_idx, u in enumerate(units):\n",
    "    for i_idx, i in enumerate(its):\n",
    "\n",
    "        def _setup():\n",
    "            complete = CompleteLayer(\n",
    "                input_size=params['input_size'],\n",
    "                hidden_size=u,\n",
    "                output_size=params['output_size'],\n",
    "                values_init=(RandomNormal(), True),\n",
    "                weights_init=(RandomNormal(), True),\n",
    "                activation=F.sigmoid,\n",
    "                use_bias=False\n",
    "            ).to(device)\n",
    "            optim = torch.optim.Adam(\n",
    "                complete.parameters(),\n",
    "                lr=params['complete_lr']\n",
    "            )\n",
    "            return {\n",
    "                'model': complete,\n",
    "                'optimiser': optim\n",
    "            }\n",
    "\n",
    "        visualisers, result = experiments.run(\n",
    "            its=i,\n",
    "            pruner=PruneEnsemble({\n",
    "                'values': NoPrune(),\n",
    "                'weights': DynamicTopK(k=0.5),\n",
    "            }),\n",
    "            visualisers={\n",
    "                'delta-final': BoxVisualiser(\n",
    "                    lambda r: r['delta_final'],\n",
    "                    show=False\n",
    "                ),\n",
    "            },\n",
    "            seed=SEED,\n",
    "            tries=NUM_TRIES,\n",
    "            track_orderedness=False,\n",
    "            n_epochs=params['complete_epochs'],\n",
    "            setup_fn=_setup,\n",
    "            train_dataloader=dataloader,\n",
    "            train_criterion=losses.MSELoss(),\n",
    "            early_stop=0,\n",
    "            trainable=TASK_NAME != 'none'\n",
    "        )\n",
    "\n",
    "        ui_means[u_idx, i_idx] = visualisers['delta-final'].mean_x\n",
    "        ui_stds[u_idx, i_idx] = visualisers['delta-final'].std_x\n",
    "        pbar.update(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.heatmap(\n",
    "    ui_means[::-1],\n",
    "    annot=True,\n",
    "    yticklabels=units[::-1],\n",
    "    xticklabels=its,\n",
    "    cmap='crest'\n",
    ")\n",
    "\n",
    "print(f'Change in Orderedness - Mean ({TASK_NAME})')\n",
    "plt.xlabel('Number of iterations')\n",
    "plt.ylabel('Number of hidden units')\n",
    "plt.tight_layout()\n",
    "plt.savefig(OUT_DIR/'hi'/f'{TASK_NAME}-mean.pdf', bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.heatmap(\n",
    "    ui_stds[::-1],\n",
    "    annot=True,\n",
    "    yticklabels=units[::-1],\n",
    "    xticklabels=its,\n",
    "    cmap='crest'\n",
    ")\n",
    "\n",
    "print(f'Change in Orderedness - Stddev ({TASK_NAME})')\n",
    "plt.xlabel('Number of iterations')\n",
    "plt.ylabel('Number of hidden units')\n",
    "plt.tight_layout()\n",
    "plt.savefig(OUT_DIR/'hi'/f'{TASK_NAME}-std.pdf', bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Sparsity - orderedness plots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sparsity_vals = np.arange(0, 1, 0.1).tolist()\n",
    "\n",
    "pruners = {\n",
    "    'Random': (lambda s: RandomPrune(p=s), 'p'),\n",
    "    'Top-K': (lambda s: TopKPrune(k=1-s), '1-k'),\n",
    "    'Dyn. Top-K': (lambda s: DynamicTopK(k=1-s), '1-k'),\n",
    "    'Tril-damp': (lambda s: TrilDamp(f=s), 'f'),\n",
    "    'Dyn. Tril-damp': (lambda s: DynamicTrilDamp(f=s), 'f')\n",
    "}\n",
    "\n",
    "so_means = {\n",
    "    p: []\n",
    "    for p in pruners\n",
    "}\n",
    "\n",
    "so_stds = {\n",
    "    p: []\n",
    "    for p in pruners\n",
    "}\n",
    "\n",
    "so_maxs = {\n",
    "    p: []\n",
    "    for p in pruners\n",
    "}\n",
    "\n",
    "so_mins = {\n",
    "    p: []\n",
    "    for p in pruners\n",
    "}\n",
    "\n",
    "pbar = tqdm(total=len(sparsity_vals) * len(pruners))\n",
    "for pn in pruners:\n",
    "    for v_idx, v in enumerate(sparsity_vals):\n",
    "\n",
    "        def _setup():\n",
    "            complete = CompleteLayer(\n",
    "                input_size=params['input_size'],\n",
    "                hidden_size=params['hidden_size'],\n",
    "                output_size=params['output_size'],\n",
    "                values_init=(RandomNormal(), True),\n",
    "                weights_init=(RandomNormal(), True),\n",
    "                activation=F.sigmoid,\n",
    "                use_bias=False\n",
    "            ).to(device)\n",
    "            optim = torch.optim.Adam(\n",
    "                complete.parameters(),\n",
    "                lr=params['complete_lr']\n",
    "            )\n",
    "            return {\n",
    "                'model': complete,\n",
    "                'optimiser': optim\n",
    "            }\n",
    "\n",
    "        visualisers, result = experiments.run(\n",
    "            pruner=PruneEnsemble({\n",
    "                'values': NoPrune(),\n",
    "                'weights': pruners[pn][0](v),\n",
    "            }),\n",
    "            visualisers={\n",
    "                'final': BoxVisualiser(\n",
    "                    lambda r: r['final_orderedness'],\n",
    "                    show=False\n",
    "                ),\n",
    "            },\n",
    "            seed=SEED,\n",
    "            tries=NUM_TRIES,\n",
    "            its=params['its'],\n",
    "            track_orderedness=False,\n",
    "            n_epochs=params['complete_epochs'],\n",
    "            setup_fn=_setup,\n",
    "            train_dataloader=dataloader,\n",
    "            train_criterion=losses.MSELoss(),\n",
    "            early_stop=0,\n",
    "            trainable=TASK_NAME != 'none'\n",
    "        )\n",
    "\n",
    "        so_means[pn].append(visualisers['final'].mean_x)\n",
    "        so_stds[pn].append(visualisers['final'].std_x)\n",
    "        so_maxs[pn].append(visualisers['final'].max_x)\n",
    "        so_mins[pn].append(visualisers['final'].min_x)\n",
    "        pbar.update(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(12, 7))\n",
    "\n",
    "for pn in pruners:\n",
    "    plt.plot(sparsity_vals, so_means[pn], label=pn + f' ({pruners[pn][1]})')\n",
    "    plt.fill_between(sparsity_vals, so_maxs[pn], so_mins[pn], alpha=0.2)\n",
    "\n",
    "print(f'Relationship Between Pruning Sparsity and Orderedness ({TASK_NAME})')\n",
    "plt.legend()\n",
    "plt.xlabel('Pruning sparsity')\n",
    "plt.ylabel('Orderedness')\n",
    "plt.tight_layout()\n",
    "plt.savefig(OUT_DIR/'so'/FNAME, bbox_inches='tight')\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ml13",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
