{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "654243ca-4d5d-4c6e-b071-020bb76941f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "from tqdm.notebook import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "07a63596-52c8-4169-b2bf-73afac81970b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from greedy import (\n",
    "    initialize_and_load_model, \n",
    "    extract_parameters, \n",
    "    find_best_permutation,  \n",
    "    SignFlipper,\n",
    "    GlobalPermutator\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a1c8aaf-a859-4885-b1a9-3bee9cc4cf9a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from model import SoftTreeEnsemble, TreeWiseSoftTreeEnsemble"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f1a531b4-0e89-45d8-9b69-b83af6fb833e",
   "metadata": {},
   "source": [
    "## 全体動作確認用"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "621f5ebd-cf11-420d-8062-91071aacbc27",
   "metadata": {},
   "outputs": [],
   "source": [
    "input_dim, output_dim, depth, alpha, beta, n_tree, asym = (\n",
    "    28 * 28,\n",
    "    10,\n",
    "    3,\n",
    "    2.0,\n",
    "    1.0,\n",
    "    1024,\n",
    "    False,\n",
    ")\n",
    "model_a = initialize_and_load_model(\n",
    "    \"./mnist_mlp_1.pt\", input_dim, output_dim, depth, alpha, beta, n_tree, asym\n",
    ")\n",
    "model_b = initialize_and_load_model(\n",
    "    \"./mnist_mlp_2.pt\", input_dim, output_dim, depth, alpha, beta, n_tree, asym\n",
    ")\n",
    "model_a_clone = copy.deepcopy(model_a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0b91896-3a46-4020-bfce-6e365a502512",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_a_flipped = SignFlipper.signflip(model_a_clone, model_b)\n",
    "model_a_permuted = GlobalPermutator.global_permutation(model_a_flipped, model_b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ffcbd5f8-ccc1-4668-b4e3-db1e435e28e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "params_a = extract_parameters(model_a_permuted)\n",
    "params_b = extract_parameters(model_b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d42960f4-8f1a-40bf-b7c7-4ca2021e99b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "tree_index = 130\n",
    "plt.figure(figsize=(10,4))\n",
    "plt.subplot(1,2,1)\n",
    "plt.imshow(params_a[\"weight\"][0, tree_index, :].reshape(28, 28))\n",
    "plt.subplot(1,2,2)\n",
    "plt.imshow(params_b[\"weight\"][0, tree_index, :].reshape(28, 28))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "87974829-09cb-43d3-9aa6-1286f051955c",
   "metadata": {},
   "source": [
    "## SignFlipperの動作確認用"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30df18c0-3cc0-4e3e-a62a-260b8c22713a",
   "metadata": {},
   "outputs": [],
   "source": [
    "input_dim, output_dim, depth, alpha, beta, n_tree, asym = (\n",
    "    28 * 28,\n",
    "    10,\n",
    "    3,\n",
    "    2.0,\n",
    "    1.0,\n",
    "    1024,\n",
    "    False,\n",
    ")\n",
    "model_a = initialize_and_load_model(\n",
    "    \"./mnist_mlp_1.pt\", input_dim, output_dim, depth, alpha, beta, n_tree, asym\n",
    ")\n",
    "model_b = initialize_and_load_model(\n",
    "    \"./mnist_mlp_2.pt\", input_dim, output_dim, depth, alpha, beta, n_tree, asym\n",
    ")\n",
    "model_a_clone = copy.deepcopy(model_a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3c58c50-42bf-4675-8403-5b88027365a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "params_a = extract_parameters(model_a_clone)\n",
    "params_b = extract_parameters(model_b)\n",
    "\n",
    "model_a_treewise = TreeWiseSoftTreeEnsemble(\n",
    "    model_a.config[\"input_dim\"],\n",
    "    model_a.config[\"output_dim\"],\n",
    "    model_a.config[\"depth\"],\n",
    "    model_a.config[\"alpha\"],\n",
    "    model_a.config[\"beta\"],\n",
    "    model_a.config[\"n_tree\"],\n",
    "    model_a.config[\"asym\"],\n",
    ")\n",
    "model_b_treewise = TreeWiseSoftTreeEnsemble(\n",
    "    model_a.config[\"input_dim\"],\n",
    "    model_a.config[\"output_dim\"],\n",
    "    model_a.config[\"depth\"],\n",
    "    model_a.config[\"alpha\"],\n",
    "    model_a.config[\"beta\"],\n",
    "    model_a.config[\"n_tree\"],\n",
    "    model_a.config[\"asym\"],\n",
    ")\n",
    "model_a_treewise.copy_parameters(model_a_clone)\n",
    "model_b_treewise.copy_parameters(model_b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83e07191-2b25-4276-aab6-bd9612f7dcb4",
   "metadata": {},
   "outputs": [],
   "source": [
    "tree_index = 130\n",
    "plt.subplot(1,2,1)\n",
    "plt.imshow(model_a_treewise.trees[tree_index].root.fc.weight.data.reshape(28,28))\n",
    "plt.subplot(1,2,2)\n",
    "plt.imshow(model_b_treewise.trees[tree_index].root.fc.weight.data.reshape(28,28))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a756b12-9580-4092-b669-f45d43b48efa",
   "metadata": {},
   "outputs": [],
   "source": [
    "node_index=0\n",
    "\n",
    "distance_matrix = SignFlipper._compute_local_distance_matrix(\n",
    "    params_a, params_b, node_index\n",
    ")\n",
    "perm = find_best_permutation(distance_matrix)\n",
    "signs = SignFlipper._check_signflip(params_a, params_b, perm, node_index)\n",
    "\n",
    "tree_index = 130\n",
    "plt.subplot(2,2,1)\n",
    "plt.imshow(model_a_treewise.trees[perm[tree_index]].root.fc.weight.data.reshape(28,28))\n",
    "plt.subplot(2,2,2)\n",
    "plt.imshow(model_b_treewise.trees[tree_index].root.fc.weight.data.reshape(28,28))\n",
    "\n",
    "model_a_treewise.flip_children(signs, node_index)\n",
    "\n",
    "plt.subplot(2,2,3)\n",
    "plt.imshow(model_a_treewise.trees[perm[tree_index]].root.fc.weight.data.reshape(28,28))\n",
    "plt.subplot(2,2,4)\n",
    "plt.imshow(model_b_treewise.trees[tree_index].root.fc.weight.data.reshape(28,28))\n",
    "plt.suptitle(signs[tree_index].tolist())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "30d2f111-66df-43fc-9ff0-f9e1f669a800",
   "metadata": {},
   "source": [
    "## GlobalPermutatorの動作確認用"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f1fb232-4d97-420d-8cc1-e5dcc97a7e02",
   "metadata": {},
   "outputs": [],
   "source": [
    "GlobalPermutator._calc_weight(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5394d69d-1322-4108-a50d-31c3ad130744",
   "metadata": {},
   "outputs": [],
   "source": [
    "GlobalPermutator._calc_weight(4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e10ec82d-b503-4e48-8c06-604d06488594",
   "metadata": {},
   "outputs": [],
   "source": [
    "input_dim, output_dim, depth, alpha, beta, n_tree, asym = (\n",
    "    28 * 28,\n",
    "    10,\n",
    "    3,\n",
    "    2.0,\n",
    "    1.0,\n",
    "    1024,\n",
    "    False,\n",
    ")\n",
    "model_a = initialize_and_load_model(\n",
    "    \"./mnist_mlp_1.pt\", input_dim, output_dim, depth, alpha, beta, n_tree, asym\n",
    ")\n",
    "model_b = initialize_and_load_model(\n",
    "    \"./mnist_mlp_2.pt\", input_dim, output_dim, depth, alpha, beta, n_tree, asym\n",
    ")\n",
    "model_a_clone = copy.deepcopy(model_a)\n",
    "\n",
    "model_a_flipped = SignFlipper.signflip(model_a_clone, model_b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65782814-e33b-4104-b7d6-a75306177d42",
   "metadata": {},
   "outputs": [],
   "source": [
    "params_a_flipped = extract_parameters(model_a_flipped)\n",
    "params_b = extract_parameters(model_b)\n",
    "\n",
    "global_distance_matrix = GlobalPermutator._compute_global_distance_matrix(\n",
    "    params_a_flipped, params_b, depth=model_a_permuted.config[\"depth\"]\n",
    ")\n",
    "perm = find_best_permutation(global_distance_matrix)\n",
    "model_a_permuted = GlobalPermutator._apply_permutation(perm, model_a_flipped)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d7f975d-ce13-46ad-b699-50f24cef947c",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(global_distance_matrix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9760281f-5fc9-4898-b977-17824c249cbc",
   "metadata": {},
   "outputs": [],
   "source": [
    "tree_index = 130\n",
    "plt.subplot(1,2,1)\n",
    "plt.imshow(model_a_permuted.root.fc.weight.data[tree_index].reshape(28,28))\n",
    "plt.subplot(1,2,2)\n",
    "plt.imshow(model_b.root.fc.weight.data[tree_index].reshape(28,28))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a744eac5-a90e-4773-b568-323af7b8549d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 木構造の可視化関数\n",
    "def plot_tree(tree):\n",
    "    depth = 3  # 木の深さ\n",
    "    plt.figure(figsize=(20, 10))\n",
    "    plot_node(tree, max_depth=depth, current_depth=0, position=0)\n",
    "\n",
    "def plot_node(node, max_depth, current_depth, position):\n",
    "    def plot_weight(node):\n",
    "        plt.imshow(node.fc.weight.detach().reshape(28,28), cmap=\"coolwarm\")\n",
    "        plt.xticks([])\n",
    "        plt.yticks([])\n",
    "        \n",
    "    def plot_leaf(node):\n",
    "        plt.barh(range(10), node.left.param.squeeze().cpu().detach().numpy(), alpha=0.5)\n",
    "        plt.barh(range(10), node.right.param.squeeze().cpu().detach().numpy(), alpha=0.5)\n",
    "        ax = plt.gca()\n",
    "\n",
    "        ax.spines['top'].set_visible(False) \n",
    "        ax.spines['right'].set_visible(False)\n",
    "        ax.spines['bottom'].set_visible(False)\n",
    "        ax.spines['left'].set_visible(False) \n",
    "\n",
    "        plt.tick_params(\n",
    "            bottom=False,  \n",
    "            left=False, \n",
    "            labelbottom=False, \n",
    "            labelleft=False\n",
    "        )\n",
    "\n",
    "    cols = 2 ** (max_depth - 1)  # 全体の列数\n",
    "    rows = max_depth + 1  # 全体の行数\n",
    "    index = current_depth * cols + position + 1\n",
    "    plt.subplot(rows, cols, index)\n",
    "\n",
    "    plot_weight(node)\n",
    "    if node.left.leaf:\n",
    "        plt.subplot(rows, cols, index+cols)\n",
    "        plot_leaf(node)\n",
    "    else:\n",
    "        plot_node(node.left, max_depth, current_depth + 1, position * 2)\n",
    "\n",
    "    if node.right.leaf:\n",
    "        pass # leftの方で処理は行なってしまう\n",
    "    else:\n",
    "        plot_node(node.right, max_depth, current_depth + 1, position * 2 + 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7344ff33-dab1-4937-8bc4-d5795e8c5c67",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_tree(model_a_treewise.trees[0].root)\n",
    "plot_tree(model_b_treewise.trees[0].root)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97f3feb8-275d-4879-8049-a3b3c7d61ede",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
