{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b19a3e56-6c87-43c7-b6b9-626583af1fe0",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96132a58-74c9-4f6d-a8ff-dd265c4631d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import glob\n",
    "import os\n",
    "import shutil\n",
    "import torch\n",
    "import numpy as np\n",
    "from torch.utils.data import DataLoader\n",
    "import torch.optim as optim\n",
    "import metrics\n",
    "from scipy.stats import pearsonr\n",
    "\n",
    "def check_modular(points, num_thetas=500):\n",
    "    # expects points to be : dim X N\n",
    "    # Given a set of points generate all the lines and test whether the inequalities are satisfied\n",
    "    points_demeaned = points - np.mean(points, axis=1, keepdims=True)\n",
    "    extreme_points = np.min(np.abs(np.array([np.min(points_demeaned, axis=1), np.max(points_demeaned, axis=1)])),\n",
    "                            axis=0)\n",
    "    corr = np.mean(np.multiply(points_demeaned[0, :], points_demeaned[1, :]))\n",
    "    S = np.array([[extreme_points[0] ** 2, -corr], [-corr, extreme_points[1] ** 2]])\n",
    "    flag = 0\n",
    "\n",
    "    thetas = np.arange(0, np.pi * 2, 2 * np.pi / num_thetas) + 0.01\n",
    "    diffs = []\n",
    "    for theta in thetas:\n",
    "        w = np.array([np.cos(theta), np.sin(theta)])\n",
    "        crit_value = w.T @ S @ w\n",
    "        diff = np.min(w @ points_demeaned) ** 2 - crit_value\n",
    "        diffs.append(diff)\n",
    "        if diff < 0:\n",
    "            flag = 1\n",
    "    return {'mixed': flag,\n",
    "            'S': S,\n",
    "            'ds_demean': points_demeaned,\n",
    "            'diffs': diffs,\n",
    "            }\n",
    "\n",
    "N = 32\n",
    "d = 2\n",
    "n_repeats = 100\n",
    "n_epochs = 1000000\n",
    "batch_size = 32\n",
    "latent_dim = 16\n",
    "w_reg = 1e-4\n",
    "z_reg = 1e-3\n",
    "z_nn = 5e-1\n",
    "\n",
    "\n",
    "class AE(torch.nn.Module):\n",
    "    def __init__(self, input_dim, latent_dim):\n",
    "        super(AE, self).__init__()\n",
    "\n",
    "        self.linear1 = torch.nn.Linear(input_dim, latent_dim)\n",
    "        # self.activation = torch.nn.ReLU()\n",
    "        self.linear2 = torch.nn.Linear(latent_dim, input_dim)\n",
    "\n",
    "    def forward(self, x):\n",
    "        z = self.linear1(x)\n",
    "        x_hat = self.linear2(z)\n",
    "        return x_hat, z\n",
    "\n",
    "\n",
    "def most_mixed_neuron(model):\n",
    "    a, b = torch.abs(model.linear1.weight).detach().numpy().T\n",
    "    keep = a + b > np.max(a + b) / 10\n",
    "    mixed = (np.minimum(a, b) / (a + b))[keep]\n",
    "    # return angle (from 0 of 90) in 1st quadrant:\n",
    "    angle = np.arctan(a / b)[keep]\n",
    "    angle = np.minimum(angle, np.pi / 2 - angle)\n",
    "\n",
    "    return {'av_mixed': np.mean(mixed),\n",
    "            'most_mixed': np.max(mixed),\n",
    "            'most_angle': np.max(angle),\n",
    "            'av_angle': np.mean(angle),\n",
    "            }\n",
    "\n",
    "# create your datset\n",
    "results = {'mod':\n",
    "               {'av_mixed': [],\n",
    "                'most_mixed': [],\n",
    "                'most_angle': [],\n",
    "                'av_angle': [],\n",
    "                'diffs': [],\n",
    "                'multiinfo': [],\n",
    "                'lcinfom': [],\n",
    "                'pred_loss': [],\n",
    "                'z_loss': [],\n",
    "                'nn_loss': [],\n",
    "                'weight_loss': []\n",
    "                },\n",
    "           'mix':\n",
    "               {'av_mixed': [],\n",
    "                'most_mixed': [],\n",
    "                'most_angle': [],\n",
    "                'av_angle': [],\n",
    "                'diffs': [],\n",
    "                'multiinfo': [],\n",
    "                'lcinfom': [],\n",
    "                'pred_loss': [],\n",
    "                'z_loss': [],\n",
    "                'nn_loss': [],\n",
    "                'weight_loss': []\n",
    "                },\n",
    "           }\n",
    "\n",
    "for repeat in range(n_repeats): # you will want to parralelise this...\n",
    "    if repeat % 2 == 0:\n",
    "        mixed = 1\n",
    "        while mixed == 1:\n",
    "            dataset = np.random.rand(N, d).astype(np.float32)\n",
    "            mod_res = check_modular(dataset.T)\n",
    "            mixed = mod_res['mixed']\n",
    "    else:\n",
    "        mixed = 0\n",
    "        while mixed == 0:\n",
    "            dataset = np.random.rand(N, d).astype(np.float32)\n",
    "            mod_res = check_modular(dataset.T)\n",
    "            mixed = mod_res['mixed']\n",
    "\n",
    "    sources = dataset\n",
    "    corr_h = pearsonr(sources[:, 0], sources[:, 1])[0]\n",
    "    sources = metrics.discretize_binning(sources, bins='auto')\n",
    "    mi_h = metrics.normalized_multiinformation(sources)\n",
    "\n",
    "    msg = \"repeat={:.2f}, mixed={:.2f}, multi_info={:.2f}\".format(repeat, mixed, mi_h)\n",
    "    print(msg)\n",
    "\n",
    "    results_run = {'av_mixed': [],\n",
    "                   'most_mixed': [],\n",
    "                   'most_angle': [],\n",
    "                   'av_angle': [],\n",
    "                   'diffs': mod_res['diffs'],\n",
    "                   'multiinfo': mi_h,\n",
    "                   'lcinfom': [],\n",
    "                   'pred_loss': [],\n",
    "                   'z_loss': [],\n",
    "                   'nn_loss': [],\n",
    "                   'weight_loss': []\n",
    "                   }\n",
    "\n",
    "    my_dataset = DataLoader(dataset, batch_size=batch_size, shuffle=True)\n",
    "    loss = torch.nn.MSELoss()\n",
    "    model = AE(d, latent_dim)\n",
    "    optimizer = optim.Adam(model.parameters(), lr=1e-3)\n",
    "    for epoch in range(n_epochs):\n",
    "        for X_batch in my_dataset:\n",
    "            for param in model.parameters():\n",
    "                param.grad = None\n",
    "            x_hat, z = model(X_batch)\n",
    "            pred_loss = loss(X_batch, x_hat)\n",
    "            z_loss = 0.5 * torch.mean(torch.sum(z ** 2, dim=1))\n",
    "            nn_loss = torch.mean(torch.sum(torch.nn.ReLU()(-z), dim=1))\n",
    "            weight_l2 = 0.0\n",
    "            for name, p in model.named_parameters():\n",
    "                if 'weight' in name:\n",
    "                    weight_l2 += 0.5 * (p ** 2).sum()\n",
    "            loss_tot = pred_loss + z_reg * z_loss + z_nn * nn_loss + w_reg * weight_l2\n",
    "\n",
    "            loss_tot.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            losses = {'pred_loss': pred_loss.detach().numpy(),\n",
    "                      'z_loss': z_loss.detach().numpy(),\n",
    "                      'nn_loss': nn_loss.detach().numpy(),\n",
    "                      'weight_loss': weight_l2.detach().numpy()}\n",
    "\n",
    "        if epoch % 1000 == 0:\n",
    "            res = most_mixed_neuron(model)\n",
    "            for key, val in res.items():\n",
    "                results_run[key].append(val)\n",
    "            for key, val in losses.items():\n",
    "                results_run[key].append(val)\n",
    "        if epoch % 10000 == 0:\n",
    "            print('', end='.')\n",
    "            # this is slow to computre\n",
    "            latents = z.detach().numpy()\n",
    "            sources = X_batch.detach().numpy()\n",
    "            lcinfom = metrics.compute_linear_metrics(sources, latents, 'continuous', 'continuous')\n",
    "            results_run['lcinfom'].append(lcinfom['linear_cinfom'])\n",
    "\n",
    "    for key, val in results_run.items():\n",
    "        results['mix' if mixed == 1 else 'mod'][key].append(val)\n",
    "\n",
    "    msg = \"repeat={:.2f}, \".format(repeat) + ''.join(\n",
    "        f'{key}={str(val[-1])[:7]}, ' for key, val in results_run.items() if key not in ['diffs', 'multiinfo'])\n",
    "    print(msg)\n",
    "\n",
    "save_path = '.' # choose path to save to\n",
    "np.save(save_path + '/results_all' + '.npy', results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e9c0c55-f6bb-4020-8aea-269b98e22d53",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle\n",
    "import torch\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from skimage.transform import resize\n",
    "import itertools\n",
    "from itertools import repeat \n",
    "from disentangled_rnn_utils import DotDict as Dd\n",
    "from scipy import stats\n",
    "\n",
    "import seaborn\n",
    "seaborn.set_style(style='white')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e995c84b-e7c1-43bd-9e53-e70215698402",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "save_path = '.'\n",
    "results = np.load(save_path + '/results_all.npy', allow_pickle=True).item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0246fd0c-95ff-4806-bab7-7627a7eb3a54",
   "metadata": {},
   "outputs": [],
   "source": [
    "# MOD VS MIXED : VS DIFF / MULTINFO\n",
    "no_axes = False\n",
    "num_plots = 4\n",
    "s = 10\n",
    "figsize = (2.5,2)\n",
    "for key in results['mix'].keys():\n",
    "    if key not in ['most_angle']:\n",
    "        continue\n",
    "    if key in ['diffs', 'multiinfo']:\n",
    "        continue\n",
    "    plt.figure(figsize=figsize)\n",
    "\n",
    "    x = results['mod']['multiinfo']\n",
    "    y = np.array(results['mod'][key])[:,-1]\n",
    "    plt.scatter(x, y, label='mod', s=s)\n",
    "    plt.plot(np.unique(x), np.poly1d(np.polyfit(x, y, 1))(np.unique(x)))\n",
    "\n",
    "    x = results['mix']['multiinfo']\n",
    "    y = np.array(results['mix'][key])[:,-1]\n",
    "    plt.scatter(x, y, label='mix', s=s)\n",
    "    plt.plot(np.unique(x), np.poly1d(np.polyfit(x, y, 1))(np.unique(x)))\n",
    "    if no_axes:\n",
    "        plt.gca().spines['top'].set_visible(False)\n",
    "        plt.gca().spines['right'].set_visible(False)\n",
    "        plt.gca().spines['left'].set_visible(False)\n",
    "        plt.gca().spines['bottom'].set_visible(False)\n",
    "        plt.tick_params(left=False, right=False, labelleft=False, labelright=False, bottom=False, top=False)\n",
    "    plt.xlabel('Normalised Source Multiinformation')\n",
    "    plt.ylabel(\"Most Mixed Neuron's Angle\")\n",
    "    plt.title('Correlation: ' + str(np.round(stats.pearsonr(x, y).statistic, 3)) + ' , ' + 'p=' + str(np.round(stats.pearsonr(x, y).pvalue, 3)))\n",
    "\n",
    "    plt.savefig('nsmi.png', bbox_inches='tight', dpi=300)\n",
    "\n",
    "    plt.figure(figsize=figsize)\n",
    "\n",
    "    x = -np.min(np.array(results['mod']['diffs']), axis=1)\n",
    "    y = np.array(results['mod'][key])[:,-1]\n",
    "    plt.scatter(x, y, label='mod', s=s)\n",
    "    plt.plot(np.unique(x), np.poly1d(np.polyfit(x, y, 1))(np.unique(x)))\n",
    "\n",
    "    x = -np.min(np.array(results['mix']['diffs']), axis=1)\n",
    "    y = np.array(results['mix'][key])[:,-1]\n",
    "    plt.scatter(x, y, label='mix', s=s)\n",
    "    plt.plot(np.unique(x), np.poly1d(np.polyfit(x, y, 1))(np.unique(x)))\n",
    "\n",
    "    if no_axes:\n",
    "        plt.gca().spines['top'].set_visible(False)\n",
    "        plt.gca().spines['right'].set_visible(False)\n",
    "        plt.gca().spines['left'].set_visible(False)\n",
    "        plt.gca().spines['bottom'].set_visible(False)\n",
    "        plt.tick_params(left=False, right=False, labelleft=False, labelright=False, bottom=False, top=False)\n",
    "    plt.xlabel('Mixing Energy Gain from Theory')\n",
    "    plt.ylabel(\"Most Mixed Neuron's Angle\")\n",
    "    plt.title('Correlation: ' + str(np.round(stats.pearsonr(x, y).statistic, 3)) + ' , ' + 'p=' + str(np.round(stats.pearsonr(x, y).pvalue, 3)))\n",
    "\n",
    "    plt.savefig('energy.png', bbox_inches='tight', dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "936aed19-7418-454d-b852-fa1d660f6c17",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
