{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# read all the permutations files in 'output/CIFAR10/frank_wolfe/history'\n",
    "\n",
    "import os\n",
    "import json\n",
    "from ccmm.matching.utils import load_permutations\n",
    "from nn_core.common import PROJECT_ROOT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "permutations_dir = os.path.join(PROJECT_ROOT, \"output\", \"CIFAR10\", \"frank_wolfe\", \"history\")\n",
    "chosen_perm = \"P_bg0\"\n",
    "\n",
    "all_permutations = []\n",
    "for filename in sorted(os.listdir(permutations_dir)):\n",
    "    if filename.endswith(\".json\"):\n",
    "        file = os.path.join(permutations_dir, filename)\n",
    "        permutation = load_permutations(file, matrix_format=True)\n",
    "        all_permutations.append(permutation[\"a\"][\"b\"][chosen_perm])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from ccmm.matching.utils import perm_indices_to_perm_matrix\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "# Assuming `permutation_matrices` is your list of torch tensors\n",
    "\n",
    "# Set the number of rows and columns for subplot\n",
    "num_matrices = len(all_permutations)\n",
    "ncols = 3  # for example, can change as needed\n",
    "nrows = (num_matrices + ncols - 1) // ncols\n",
    "\n",
    "plt.figure(figsize=(ncols * 4, nrows * 4))\n",
    "\n",
    "for i, perm in enumerate(all_permutations):\n",
    "    plt.subplot(nrows, ncols, i + 1)\n",
    "    plt.imshow(perm.numpy(), cmap=\"hot\", interpolation=\"nearest\")\n",
    "    plt.title(f\"Matrix {i+1}\")\n",
    "    plt.colorbar()\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import plotly.graph_objects as go\n",
    "from plotly.subplots import make_subplots\n",
    "import torch\n",
    "import plotly.io as pio\n",
    "\n",
    "num_matrices = 6  # len(all_permutations)\n",
    "ncols = 3\n",
    "nrows = (num_matrices + ncols - 1) // ncols\n",
    "\n",
    "fig = make_subplots(\n",
    "    rows=nrows,\n",
    "    cols=ncols,\n",
    "    subplot_titles=[\"Initialization\"] + [f\"Step {i+1}\" for i in range(1, num_matrices)],\n",
    "    vertical_spacing=0.2,\n",
    ")  # , horizontal_spacing=spacing, vertical_spacing=spacing)  # Adjust spacing as needed\n",
    "\n",
    "fig.update_layout(\n",
    "    height=300 * nrows,\n",
    "    width=300 * ncols,\n",
    "    plot_bgcolor=\"rgba(255,255,255,255)\",  # Transparent background\n",
    "    margin=dict(l=25, r=25, t=25, b=25),  # Margin around the whole figure\n",
    "    paper_bgcolor=\"rgba(255,255,255,1)\",  # White background for the paper\n",
    "    font=dict(size=25, color=\"black\"),  # Font for titles and labels\n",
    ")\n",
    "\n",
    "for i in range(1, nrows + 1):\n",
    "    for j in range(1, ncols + 1):\n",
    "        fig.update_xaxes(showticklabels=False, showgrid=True, row=i, col=j)  # linecolor='black', linewidth=1)\n",
    "        fig.update_yaxes(showticklabels=False, showgrid=True, row=i, col=j)  # linecolor='black', linewidth=1)\n",
    "\n",
    "colorscale = [[0, \"cornsilk\"], [1e-8, \"red\"], [1, \"blue\"]]\n",
    "\n",
    "for i, perm in enumerate(all_permutations[:num_matrices]):\n",
    "    row = i // ncols + 1\n",
    "    col = i % ncols + 1\n",
    "    fig.add_trace(go.Heatmap(z=perm.numpy(), colorscale=colorscale, showscale=False), row=row, col=col)\n",
    "\n",
    "fig.update_layout(height=300 * nrows, width=300 * ncols)\n",
    "fig.update_annotations(font_size=25)\n",
    "\n",
    "# add colorbar\n",
    "fig.add_trace(go.Heatmap(z=[[0, 1]], colorscale=colorscale, showscale=True), row=nrows, col=ncols)\n",
    "\n",
    "fig.show()\n",
    "\n",
    "pio.write_image(fig, \"figures/permutation_matrices.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ncols = 3  # for example, can change as needed\n",
    "nrows = (num_matrices + ncols - 1) // ncols\n",
    "\n",
    "\n",
    "margin = 0.02  # Margin around each subplot\n",
    "spacing = 0.02  # Spacing between subplots\n",
    "\n",
    "fig = make_subplots(\n",
    "    rows=nrows, cols=ncols, subplot_titles=[\"Initialization\"] + [f\"Step {i+1}\" for i in range(1, num_matrices)]\n",
    ")  # , horizontal_spacing=spacing, vertical_spacing=spacing)  # Adjust spacing as needed\n",
    "\n",
    "for i, matrix in enumerate(all_permutations[:num_matrices]):\n",
    "    row = i // ncols + 1\n",
    "    col = i % ncols + 1\n",
    "\n",
    "    # Convert tensor to numpy array\n",
    "    matrix_np = matrix.numpy()\n",
    "\n",
    "    # Define a custom colorscale\n",
    "    colorscale = [[0, \"white\"], [1e-8, \"red\"], [1, \"blue\"]]\n",
    "\n",
    "    fig.add_trace(go.Heatmap(z=matrix_np, colorscale=colorscale, showscale=False), row=row, col=col)\n",
    "\n",
    "\n",
    "# Update layout with a border and hide axes\n",
    "fig.update_layout(\n",
    "    height=300 * nrows,\n",
    "    width=300 * ncols,\n",
    "    plot_bgcolor=\"rgba(255,255,255,255)\",  # Transparent background\n",
    "    # margin=dict(l=20, r=20, t=20, b=20),  # Margin around the whole figure\n",
    "    paper_bgcolor=\"rgba(255,255,255,1)\",  # White background for the paper\n",
    "    font=dict(size=25, color=\"black\"),  # Font for titles and labels\n",
    ")\n",
    "\n",
    "# Calculate and draw borders\n",
    "for i in range(num_matrices):\n",
    "    row = i // ncols\n",
    "    col = i % ncols\n",
    "\n",
    "    x0 = (col / ncols) + margin\n",
    "    y0 = 1 - ((row + 1) / nrows) + margin\n",
    "    x1 = ((col + 1) / ncols) - margin - spacing\n",
    "    y1 = 1 - (row / nrows) - margin - spacing\n",
    "\n",
    "    fig.add_shape(\n",
    "        type=\"rect\", xref=\"paper\", yref=\"paper\", x0=x0, y0=y0, x1=x1, y1=y1, line=dict(color=\"Black\", width=2)\n",
    "    )\n",
    "\n",
    "# Update layout\n",
    "\n",
    "# Hide x and y axes for all subplots\n",
    "for i in range(1, nrows + 1):\n",
    "    for j in range(1, ncols + 1):\n",
    "        fig.update_xaxes(showticklabels=False, showgrid=True, row=i, col=j)  # linecolor='black', linewidth=1)\n",
    "        fig.update_yaxes(showticklabels=False, showgrid=True, row=i, col=j)  # linecolor='black', linewidth=1)\n",
    "\n",
    "\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ccmm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
