{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import copy\n",
    "import logging\n",
    "from pathlib import Path\n",
    "from typing import Dict\n",
    "import math\n",
    "\n",
    "import hydra\n",
    "import matplotlib\n",
    "import matplotlib.colors as colors\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import omegaconf\n",
    "import seaborn as sns\n",
    "import torch  # noqa\n",
    "import wandb\n",
    "from hydra.utils import instantiate\n",
    "from matplotlib import tri\n",
    "from matplotlib.offsetbox import AnnotationBbox, OffsetImage\n",
    "from omegaconf import DictConfig\n",
    "from pytorch_lightning import LightningModule\n",
    "from scipy.stats import qmc\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm import tqdm\n",
    "from ccmm.matching.utils import perm_indices_to_perm_matrix\n",
    "from ccmm.utils.utils import normalize_unit_norm, project_onto\n",
    "\n",
    "from nn_core.callbacks import NNTemplateCore\n",
    "from nn_core.common import PROJECT_ROOT\n",
    "from nn_core.common.utils import seed_index_everything\n",
    "from nn_core.model_logging import NNLogger\n",
    "\n",
    "from torch.utils.data import DataLoader, Subset, SubsetRandomSampler\n",
    "\n",
    "import ccmm  # noqa\n",
    "from ccmm.matching.utils import (\n",
    "    apply_permutation_to_statedict,\n",
    "    get_all_symbols_combinations,\n",
    "    plot_permutation_history_animation,\n",
    "    restore_original_weights,\n",
    "    load_permutations,\n",
    ")\n",
    "from ccmm.utils.utils import (\n",
    "    linear_interpolate_state_dicts,\n",
    "    load_model_from_info,\n",
    "    map_model_seed_to_symbol,\n",
    "    save_factored_permutations,\n",
    ")\n",
    "\n",
    "from ccmm.utils.utils import vector_to_state_dict\n",
    "import pytorch_lightning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "matplotlib.rcParams[\"font.family\"] = \"serif\"\n",
    "sns.set_context(\"talk\")\n",
    "matplotlib.rcParams[\"text.usetex\"] = True\n",
    "cmap_name = \"coolwarm_r\"\n",
    "\n",
    "logging.getLogger(\"lightning.pytorch\").setLevel(logging.WARNING)\n",
    "logging.getLogger(\"torch\").setLevel(logging.WARNING)\n",
    "logging.getLogger(\"pytorch_lightning.accelerators.cuda\").setLevel(logging.WARNING)\n",
    "pylogger = logging.getLogger(__name__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import hydra\n",
    "from hydra import initialize, compose\n",
    "from typing import Dict, List\n",
    "\n",
    "hydra.core.global_hydra.GlobalHydra.instance().clear()\n",
    "initialize(version_base=None, config_path=str(\"../conf\"), job_name=\"matching_n_models\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cfg = compose(config_name=\"matching_n_models\", overrides=[])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "core_cfg = cfg  # NOQA\n",
    "cfg = cfg.matching\n",
    "\n",
    "seed_index_everything(cfg)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Hyperparameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_sampled_points = 500  # 2048\n",
    "num_test_samples = 500"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transform = instantiate(core_cfg.dataset.test.transform)\n",
    "\n",
    "train_dataset = instantiate(core_cfg.dataset.train, transform=transform)\n",
    "test_dataset = instantiate(core_cfg.dataset.test, transform=transform)\n",
    "\n",
    "train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, num_workers=cfg.num_workers)\n",
    "\n",
    "test_subset = Subset(test_dataset, list(range(num_test_samples)))\n",
    "\n",
    "test_loader = DataLoader(test_subset, batch_size=1000, num_workers=cfg.num_workers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = instantiate(cfg.trainer, enable_progress_bar=False, enable_model_summary=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.utils.utils import load_model_from_artifact\n",
    "\n",
    "run = wandb.init(project=core_cfg.core.project_name, entity=core_cfg.core.entity, job_type=\"matching\")\n",
    "\n",
    "# {a: 1, b: 2, c: 3, ..}\n",
    "symbols_to_seed: Dict[int, str] = {map_model_seed_to_symbol(seed): seed for seed in cfg.model_seeds}\n",
    "\n",
    "artifact_path = (\n",
    "    lambda seed: f\"{core_cfg.core.entity}/{core_cfg.core.project_name}/{core_cfg.model.model_identifier}_{seed}:v0\"\n",
    ")\n",
    "\n",
    "# {a: model_a, b: model_b, c: model_c, ..}\n",
    "models: Dict[str, LightningModule] = {\n",
    "    map_model_seed_to_symbol(seed): load_model_from_artifact(run, artifact_path(seed)) for seed in cfg.model_seeds\n",
    "}\n",
    "\n",
    "num_models = len(models)\n",
    "\n",
    "pylogger.info(f\"Using {num_models} models with architecture {core_cfg.model.model_identifier}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# always permute the model having larger character order, i.e. c -> b, b -> a and so on ...\n",
    "symbols = set(symbols_to_seed.keys())\n",
    "sorted_symbols = sorted(symbols, reverse=False)\n",
    "\n",
    "# (a, b), (a, c), (b, c), ...\n",
    "all_combinations = get_all_symbols_combinations(symbols)\n",
    "# combinations of the form (a, b), (a, c), (b, c), .. and not (b, a), (c, a) etc\n",
    "canonical_combinations = [(source, target) for (source, target) in all_combinations if source < target]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Matching"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pylogger.info(f\"Matching the following model pairs: {canonical_combinations}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load permutation specification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "permutation_spec_builder = instantiate(core_cfg.model.permutation_spec_builder)\n",
    "permutation_spec = permutation_spec_builder.create_permutation_spec()\n",
    "\n",
    "ref_model = list(models.values())[0]\n",
    "assert set(permutation_spec.layer_and_axes_to_perm.keys()) == set(ref_model.model.state_dict().keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "matcher = instantiate(cfg.matcher, permutation_spec=permutation_spec)\n",
    "pylogger.info(f\"Matcher: {matcher.name}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "permutations, perm_history = matcher(models, symbols=sorted_symbols, combinations=canonical_combinations)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "models = {symb: model.to(\"cpu\") for symb, model in models.items()}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Permute models to universe"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.matching.utils import perm_matrix_to_perm_indices\n",
    "\n",
    "\n",
    "models_permuted_to_universe = {symbol: copy.deepcopy(model) for symbol, model in models.items()}\n",
    "\n",
    "for symbol, model in models_permuted_to_universe.items():\n",
    "    perms_to_universe = {}\n",
    "\n",
    "    for perm_name, perm in permutations[symbol].items():\n",
    "        perm = perm_indices_to_perm_matrix(perm)\n",
    "        perm_to_universe = perm.T\n",
    "        perm_to_universe = perm_matrix_to_perm_indices(perm_to_universe)\n",
    "        perms_to_universe[perm_name] = perm_to_universe\n",
    "\n",
    "    permuted_params = apply_permutation_to_statedict(permutation_spec, perms_to_universe, model.model.state_dict())\n",
    "    models_permuted_to_universe[symbol].model.load_state_dict(permuted_params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Permute models pairwise"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.matching.utils import unfactor_permutations\n",
    "\n",
    "models_permuted_pairwise = {\n",
    "    symbol: {\n",
    "        other_symb: copy.deepcopy(model)\n",
    "        for symbol, model in models.items()\n",
    "        for other_symb in set(symbols).difference(symbol)\n",
    "    }\n",
    "    for symbol in symbols\n",
    "}\n",
    "pairwise_permutations = unfactor_permutations(permutations)\n",
    "\n",
    "for fixed, permutee in canonical_combinations:\n",
    "    permuted_params = apply_permutation_to_statedict(\n",
    "        permutation_spec, pairwise_permutations[fixed][permutee], models[permutee].model.state_dict()\n",
    "    )\n",
    "    models_permuted_pairwise[fixed][permutee].model.load_state_dict(permuted_params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Check performance of models before and after permutation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for symbol, model in models_permuted_to_universe.items():\n",
    "    trainer.test(models_permuted_to_universe[symbol], test_loader)\n",
    "    trainer.test(models[symbol], test_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Analyze models as vectors"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Flatten models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "flat_models = {symbol: torch.nn.utils.parameters_to_vector(model.parameters()) for symbol, model in models.items()}\n",
    "flat_models_permuted_to_universe = {\n",
    "    symbol: torch.nn.utils.parameters_to_vector(model.parameters())\n",
    "    for symbol, model in models_permuted_to_universe.items()\n",
    "}\n",
    "\n",
    "flat_models_permuted_pairwise = {\n",
    "    symbol: {\n",
    "        other_symb: torch.nn.utils.parameters_to_vector(model.parameters()) for other_symb, model in models.items()\n",
    "    }\n",
    "    for symbol, models in models_permuted_pairwise.items()\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Analyze the norms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.utils.utils import to_np\n",
    "\n",
    "\n",
    "norms = {\"model\": [], \"permuted\": [], \"diff\": []}\n",
    "for symbol, model in models.items():\n",
    "    norm = torch.norm(flat_models[symbol])\n",
    "    norm_permuted = torch.norm(flat_models_permuted_to_universe[symbol])\n",
    "    norm_diff = torch.norm(flat_models[symbol] - flat_models_permuted_to_universe[symbol])\n",
    "\n",
    "    norms[\"model\"].append(to_np(norm))\n",
    "    norms[\"permuted\"].append(to_np(norm_permuted))\n",
    "    norms[\"diff\"].append(to_np(norm_diff))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "df = pd.DataFrame(norms, index=models.keys())\n",
    "\n",
    "df = df.apply(pd.to_numeric)\n",
    "\n",
    "plt.figure(figsize=(5, 5))\n",
    "sns.heatmap(df, annot=True, cmap=\"viridis\")\n",
    "plt.title(\"Model Norms Comparison\")\n",
    "plt.ylabel(\"Model Symbol\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# matrix of the cosine products\n",
    "cosine_matrix = np.zeros((len(models), len(models)))\n",
    "\n",
    "for i, (symbol_i, model_i) in enumerate(models.items()):\n",
    "    for j, (symbol_j, model_j) in enumerate(models.items()):\n",
    "        cosine_matrix[i, j] = flat_models_permuted_to_universe[symbol_i].dot(flat_models[symbol_j]) / (\n",
    "            torch.norm(flat_models_permuted_to_universe[symbol_i]) * torch.norm(flat_models[symbol_j])\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot the matrix\n",
    "\n",
    "plt.figure(figsize=(5, 5))\n",
    "sns.heatmap(cosine_matrix, annot=True, cmap=\"viridis\")\n",
    "plt.title(\"Cosine Similarity Matrix\")\n",
    "plt.ylabel(\"Model Symbol\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = to_np(flat_models[\"a\"])\n",
    "plt.style.use(\"seaborn\")\n",
    "fig, ax = plt.subplots(figsize=(5, 5))\n",
    "\n",
    "ax.hist(x, bins=500)\n",
    "print(\"?\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Experiment: sparsify models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sparsify_models = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if sparsify_models:\n",
    "    sparsified_models = {}\n",
    "\n",
    "    # set to zero all the values far than 1std from the mean\n",
    "    sparsified_models = {\n",
    "        symbol: torch.where(torch.abs(flat_model) > torch.std(flat_model), torch.zeros_like(flat_model), flat_model)\n",
    "        for symbol, flat_model in flat_models.items()\n",
    "    }\n",
    "    x = flat_models[\"a\"]\n",
    "    print(x.shape)\n",
    "    # count how many values < 1e-4\n",
    "    torch.sum(torch.abs(x) < 1e-2)\n",
    "\n",
    "    flat_models_sparse = {}\n",
    "    flat_models_perm_sparse = {}\n",
    "    for symb, model in flat_models.items():\n",
    "        flat_models_sparse[symb] = torch.clone(model)\n",
    "        flat_models_sparse[symb][torch.abs(model) < 1e-3] = 0.0\n",
    "\n",
    "        flat_models_perm_sparse[symb] = torch.clone(flat_models_permuted_to_universe[symb])\n",
    "        flat_models_perm_sparse[symb][torch.abs(flat_models_permuted_to_universe[symb]) < 1e-3] = 0.0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Going from 2D to high dimensions and back"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Sampling the 2D plane"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "random_points_plane = qmc.scale(\n",
    "    qmc.Sobol(d=2, scramble=True, seed=cfg.seed_index).random(num_sampled_points),\n",
    "    [-0.5, -0.5],\n",
    "    [0.5, 0.5],\n",
    ")\n",
    "\n",
    "pylogger.info(random_points_plane[:10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "boundaries = [[-0.5], [0.5]]\n",
    "lower_bounds = np.array([boundaries[0][0], boundaries[0][0]])\n",
    "upper_bounds = np.array([boundaries[1][0], boundaries[1][0]])\n",
    "\n",
    "pylogger.info(f\"Lower bounds: {lower_bounds}\")\n",
    "pylogger.info(f\"Upper bounds: {upper_bounds}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Method 1: Barycentric coordinates"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def represent_2D_point_pentagon_barycentric_coordinates(x):\n",
    "    \"\"\"\n",
    "    x: point in the plane (2, )\n",
    "    \"\"\"\n",
    "    origins = get_pentagon_vertices(0.0, 0.0, 0.5)\n",
    "\n",
    "    # (2, num_models)\n",
    "    A = origins.transpose(1, 0)\n",
    "\n",
    "    # (3, num_models)\n",
    "    A = np.vstack([A, np.ones(5)])\n",
    "\n",
    "    # (3, )\n",
    "    x = np.append(x, 1)\n",
    "\n",
    "    z, residuals, rank, s = np.linalg.lstsq(A, x, rcond=None)\n",
    "\n",
    "    A = torch.from_numpy(A)\n",
    "    z = torch.from_numpy(z)\n",
    "    x = torch.from_numpy(x)\n",
    "    # assert torch.allclose(A @ z, x)\n",
    "\n",
    "    assert torch.allclose(torch.sum(z).float(), torch.tensor(1.0).float())\n",
    "\n",
    "    return z.float()\n",
    "\n",
    "\n",
    "def get_pentagon_vertices(center_x, center_y, radius):\n",
    "    \"\"\"\n",
    "    Get the vertices of a pentagon centered at (center_x, center_y) with the given radius.\n",
    "    \"\"\"\n",
    "    pentagon_vertices = []\n",
    "\n",
    "    for i in range(5):\n",
    "        angle_deg = 72 * i  # 72 degrees between each point\n",
    "        angle_rad = math.radians(angle_deg)  # Convert to radians\n",
    "\n",
    "        x = radius * math.cos(angle_rad) + center_x\n",
    "        y = radius * math.sin(angle_rad) + center_y\n",
    "\n",
    "        pentagon_vertices.append((x, y))\n",
    "\n",
    "    return np.array(pentagon_vertices)\n",
    "\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "def create_regular_polygon_vertices(sides, radius=1):\n",
    "    \"\"\"Create vertices for a regular polygon centered at the origin\"\"\"\n",
    "    return np.array(\n",
    "        [[radius * np.cos(2 * np.pi * i / sides), radius * np.sin(2 * np.pi * i / sides)] for i in range(sides)]\n",
    "    )\n",
    "\n",
    "\n",
    "def wachspress_coordinates(vertices, p):\n",
    "    \"\"\"Calculate Wachspress barycentric coordinates for a point inside a polygon\"\"\"\n",
    "    n = len(vertices)\n",
    "    alphas = np.zeros(n)\n",
    "    for i in range(n):\n",
    "        v1, v2, v3 = vertices[i - 1], vertices[i], vertices[(i + 1) % n]\n",
    "        area = area_of_triangle(v1, v2, v3) + 1e-6\n",
    "        d1 = distance_to_edge(p, v1, v2) + 1e-6\n",
    "        d2 = distance_to_edge(p, v2, v3) + 1e-6\n",
    "        alphas[i] = area / (d1 * d2) if d1 * d2 > 1e-5 else 0\n",
    "\n",
    "    return alphas / np.sum(alphas)\n",
    "\n",
    "\n",
    "def area_of_triangle(v1, v2, v3):\n",
    "    \"\"\"Calculate the area of a triangle given its vertices\"\"\"\n",
    "    return 0.5 * np.linalg.norm(np.cross(v2 - v1, v3 - v1)) + 1e-6\n",
    "\n",
    "\n",
    "def distance_to_edge(p, v1, v2):\n",
    "    \"\"\"Calculate the distance from point p to the edge (v1, v2)\"\"\"\n",
    "    if np.all(v1 == v2):\n",
    "        return np.linalg.norm(p - v1 + 1e-6)\n",
    "    return np.linalg.norm(np.cross(v2 - v1, v1 - p)) / np.linalg.norm(v2 - v1) + 1e-6\n",
    "\n",
    "\n",
    "def wachspress_coordinates(vertices, p):\n",
    "    \"\"\"Calculate Wachspress barycentric coordinates for a point inside a polygon with a fix for vertices\"\"\"\n",
    "    n = len(vertices)\n",
    "    alphas = np.zeros(n)\n",
    "\n",
    "    # Check if the point is exactly on one of the vertices\n",
    "    for i, vertex in enumerate(vertices):\n",
    "        if np.all(p == vertex):\n",
    "            coords = np.zeros(n)\n",
    "            coords[i] = 1\n",
    "            return coords\n",
    "\n",
    "    # Calculate Wachspress weights if the point is not one of the vertices\n",
    "    for i in range(n):\n",
    "        v1, v2, v3 = vertices[i - 1], vertices[i], vertices[(i + 1) % n]\n",
    "        area = area_of_triangle(v1, v2, v3)\n",
    "        d1 = distance_to_edge(p, v1, v2)\n",
    "        d2 = distance_to_edge(p, v2, v3)\n",
    "        alphas[i] = area / (d1 * d2) if d1 * d2 != 0 else 0\n",
    "\n",
    "    return alphas / np.sum(alphas)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vertices = create_regular_polygon_vertices(10, 0.5)\n",
    "\n",
    "plt.scatter(vertices[:, 0], vertices[:, 1], c=\"black\", s=100)\n",
    "\n",
    "circle = plt.Circle((0, 0), 0.5, color=\"black\", fill=False)\n",
    "\n",
    "plt.gca().add_patch(circle)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Represent permuted models as barycentric coordinates wrt the n models\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def represent_wrt_models_barycentric(model_to_repr, flat_models):\n",
    "    # (num_params_per_model, num_models)\n",
    "    A = torch.stack(list(flat_models), dim=1)\n",
    "\n",
    "    scaling = 1\n",
    "    # Augment A with an additional row for the sum-to-one constraint\n",
    "    ones_row = torch.ones(1, A.shape[1]) * scaling\n",
    "\n",
    "    # (num_params_per_model + 1, num_models)\n",
    "\n",
    "    A_augmented = torch.cat([A, ones_row], dim=0)\n",
    "\n",
    "    # Augment the target model with an additional element for the sum-to-one constraint\n",
    "    # (num_params_per_model + 1,)\n",
    "    target_augmented = torch.cat([model_to_repr, torch.tensor([scaling])])\n",
    "\n",
    "    # Solve the linear system (least squares)\n",
    "    # want z such that Az = x\n",
    "    # x is the target model\n",
    "    barycentric_coords = torch.linalg.lstsq(A_augmented, target_augmented.unsqueeze(1)).solution\n",
    "    # barycentric_coords = torch.linalg.lstsq(A, model_to_repr.unsqueeze(1)).solution\n",
    "\n",
    "    pylogger.info(barycentric_coords)\n",
    "\n",
    "    return barycentric_coords.cpu().detach().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# USE WHEN TRYING TO REPRESENT PERMUTED MODELS WRT ORIGIN MODELS\n",
    "def get_model_2D_coordinates_barycentric(flat_models):\n",
    "    vertices = create_regular_polygon_vertices(num_models, 0.5)\n",
    "\n",
    "    model_2D_repr = {symbol: None for symbol in symbols_to_seed.keys()}\n",
    "    universe_model_2D_repr = {symbol: None for symbol in symbols_to_seed.keys()}\n",
    "\n",
    "    for model_num, (symbol, perm_model) in enumerate(flat_models_permuted_to_universe.items()):\n",
    "        model_baryc_coordinates = represent_wrt_models_barycentric(perm_model, flat_models)\n",
    "\n",
    "        pylogger.info(f\"{symbol} baryc coords: {model_baryc_coordinates.sum()}\")\n",
    "\n",
    "        model_2D_repr[symbol] = vertices[model_num]\n",
    "        universe_model_2D_repr[symbol] = (model_baryc_coordinates * vertices).sum(axis=0)\n",
    "\n",
    "    return model_2D_repr, universe_model_2D_repr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# USE WHEN CONSIDERING ALL THE 2*N MODELS (PERM AND ORIGINS) TO REPRESENT EACH MODEL AS A VERTEX OF THE 2*N REGULAR POLYGON\n",
    "def get_model_2D_coordinates_barycentric_all_models(all_flat_models, num_models, num_pairwise_perms=0):\n",
    "    model_letters = [\"A\", \"B\", \"C\", \"D\", \"E\"][:num_models]\n",
    "\n",
    "    symbol_names = [r\"\\Theta_\" + i for i in model_letters]\n",
    "    perm_symbol_names = [r\"\\pi(\\Theta_\" + i + \")\" for i in model_letters]\n",
    "\n",
    "    pairwise_perm_symbol_names = []\n",
    "\n",
    "    if num_pairwise_perms > 0:\n",
    "        for ind, letter in enumerate(model_letters[1:]):\n",
    "            pairwise_perm_symbol_names.append(\n",
    "                r\"\\pi_{\" + f\"{letter}->{model_letters[ind]}\" + r\"}(\\Theta_\" + letter + \")\"\n",
    "            )\n",
    "\n",
    "    all_symbol_names = symbol_names + perm_symbol_names + pairwise_perm_symbol_names\n",
    "\n",
    "    origins = create_regular_polygon_vertices(2 * num_models + num_pairwise_perms, 0.45)\n",
    "\n",
    "    model_2D_repr = {symbol_name: None for symbol_name in all_symbol_names}\n",
    "    # universe_model_2D_repr = {symbol: None for symbol in symbols_to_seed.keys()}\n",
    "\n",
    "    for model_num, flat_model in enumerate(all_flat_models):\n",
    "        model_baryc_coordinates = represent_wrt_models_barycentric(flat_model, all_flat_models)\n",
    "\n",
    "        pylogger.info(f\"{all_symbol_names[model_num]} baryc coords: {model_baryc_coordinates.sum()}\")\n",
    "\n",
    "        model_2D_repr[all_symbol_names[model_num]] = origins[model_num]\n",
    "\n",
    "    return model_2D_repr\n",
    "\n",
    "\n",
    "all_flat_models = [*flat_models.values(), *flat_models_permuted_to_universe.values()]\n",
    "# all_flat_models.append(flat_models_permuted_pairwise[\"b\"][\"a\"])\n",
    "\n",
    "model_2D_repr = get_model_2D_coordinates_barycentric_all_models(all_flat_models, num_models, num_pairwise_perms=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Collect the test loss for random samples in the 2D plane "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "\n",
    "def evaluate_model_interp_on_point(point, flat_models, model, trainer, test_loader):\n",
    "    # (num_models, )\n",
    "    num_vertices = len(flat_models)\n",
    "    origins = create_regular_polygon_vertices(num_vertices, radius=0.45)\n",
    "\n",
    "    baryc_coords = torch.tensor(wachspress_coordinates(origins, point)).unsqueeze(1)\n",
    "\n",
    "    # (num_models, num_params_per_model)\n",
    "    flat_models = torch.stack(flat_models)\n",
    "\n",
    "    new_flat_params = (flat_models * baryc_coords).sum(dim=0)\n",
    "\n",
    "    new_params = vector_to_state_dict(new_flat_params, model.model)\n",
    "\n",
    "    model.model.load_state_dict(new_params)\n",
    "\n",
    "    results = trainer.test(model, test_loader, verbose=False)\n",
    "\n",
    "    return results[0][\"loss/test\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = copy.deepcopy(models[\"a\"])\n",
    "\n",
    "test_loss_for_random_pts_plane = np.array(\n",
    "    [\n",
    "        evaluate_model_interp_on_point(point, all_flat_models, model, trainer, test_loader)\n",
    "        for point in tqdm(random_points_plane)\n",
    "    ]\n",
    ")\n",
    "\n",
    "test_loss_for_random_pts_plane[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from functools import partial\n",
    "# from multiprocessing import Pool\n",
    "\n",
    "# pool = Pool() #defaults to number of available CPU's\n",
    "\n",
    "# eval_func = partial(evaluate_model_interp_on_point, all_flat_models, model, trainer, test_loader)\n",
    "\n",
    "# results = np.zeros(len(random_points_plane))\n",
    "# for ind, res in enumerate(tqdm(pool.imap(eval_func, iter(random_points_plane)), total=len(random_points_plane))):\n",
    "#     results[ind] = res"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Method 2: Reference models as basis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_basis_vectors(origin_model, basis_model_1, basis_model_2):\n",
    "    basis1 = basis_model_1 - origin_model\n",
    "    scale1 = norm(basis1)\n",
    "    basis1_normed = normalize_unit_norm(basis1)\n",
    "\n",
    "    basis2 = basis_model_2 - origin_model\n",
    "    scale2 = norm(basis2)\n",
    "    basis2 = basis2 - project_onto(basis2, basis1_normed)\n",
    "    basis2_normed = normalize_unit_norm(basis2)\n",
    "\n",
    "    return basis1_normed, basis2_normed, scale1, scale2\n",
    "\n",
    "\n",
    "# basis_model_1, basis_model_2, scale_1, scale_2 = get_basis_vectors(origin_model=flat_models['a'], basis_model_1=flat_models['b'], basis_model_2=flat_models['c'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from multiprocessing import Pool\n",
    "\n",
    "\n",
    "def evaluate_model_interp_on_point(\n",
    "    point, basis_model_1, basis_model_2, origin_model, ref_model, trainer, test_loader, scale_1, scale_2\n",
    "):\n",
    "    # (num_models, )\n",
    "    new_flat_params = origin_model + (scale_1 * basis_model_1 * point[0] + scale_2 * basis_model_2 * point[1])\n",
    "\n",
    "    new_params = vector_to_state_dict(new_flat_params, ref_model.model)\n",
    "\n",
    "    ref_model.model.load_state_dict(new_params)\n",
    "\n",
    "    eval_results = trainer.test(ref_model, test_loader, verbose=False)\n",
    "\n",
    "    return eval_results[0][\"loss/test\"]\n",
    "\n",
    "\n",
    "# ref_model = copy.deepcopy(models['a'])\n",
    "# origin_model = flat_models['a']\n",
    "\n",
    "# eval_results = np.array([evaluate_model_interp_on_point(point, scale_1=scale_1, scale_2=scale_2, basis_model_1=basis_model_1, basis_model_2=basis_model_2, origin_model=origin_model, ref_model=ref_model, trainer=trainer, test_loader=test_loader) for point in tqdm(random_points_plane)])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Represent models 2D"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def represent_wrt_models(model_to_repr, origin_model, basis1, basis2, scale_1, scale_2):\n",
    "    x_coord = torch.dot(model_to_repr - origin_model, basis1) / scale_1\n",
    "    y_coord = torch.dot(model_to_repr - origin_model, basis2) / scale_2\n",
    "\n",
    "    return torch.stack([x_coord, y_coord]).detach().cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_model_2D_coordinates(flat_models, flat_perm_models, basis_model_1, basis_model_2, scale_1, scale_2):\n",
    "    model_2D_repr = {symbol: None for symbol in symbols_to_seed.keys()}\n",
    "    universe_model_2D_repr = {symbol: None for symbol in symbols_to_seed.keys()}\n",
    "\n",
    "    for model_num, (symbol, perm_model) in enumerate(flat_perm_models.items()):\n",
    "        model_2D = represent_wrt_models(\n",
    "            flat_models[symbol],\n",
    "            origin_model=flat_models[\"a\"],\n",
    "            basis1=basis_model_1,\n",
    "            basis2=basis_model_2,\n",
    "            scale_1=scale_1,\n",
    "            scale_2=scale_2,\n",
    "        )\n",
    "        model_2D_perm = represent_wrt_models(\n",
    "            perm_model,\n",
    "            origin_model=flat_models[\"a\"],\n",
    "            basis1=basis_model_1,\n",
    "            basis2=basis_model_2,\n",
    "            scale_1=scale_1,\n",
    "            scale_2=scale_2,\n",
    "        )\n",
    "\n",
    "        model_2D_repr[symbol] = model_2D\n",
    "        universe_model_2D_repr[symbol] = model_2D_perm\n",
    "\n",
    "    return model_2D_repr, universe_model_2D_repr\n",
    "\n",
    "\n",
    "# model_2D_repr, universe_model_2D_repr = get_model_2D_coordinates(flat_models, flat_models_permuted_to_universe, basis_model_1, basis_model_2, scale_1, scale_2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# pylogger.info(model_2D_repr)\n",
    "# pylogger.info(universe_model_2D_repr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create the 2D grid of points and their corresponding losses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "xi = np.linspace(boundaries[0][0], boundaries[1][0])\n",
    "yi = np.linspace(boundaries[0][0], boundaries[1][0])\n",
    "\n",
    "# Linearly interpolate the data (x, y) on a grid defined by (xi, yi).\n",
    "triang = tri.Triangulation(random_points_plane[:, 0], random_points_plane[:, 1])\n",
    "\n",
    "# We need to cap the maximum loss value so that the contouring is not completely saturated by wildly large losses\n",
    "interpolator = tri.LinearTriInterpolator(triang, np.clip(test_loss_for_random_pts_plane, None, 5))\n",
    "\n",
    "# interpolator = tri.LinearTriInterpolator(triang, jnp.log(jnp.minimum(1.5, eval_results[:, 0])))\n",
    "zi = interpolator(*np.meshgrid(xi, yi))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "offsets = {\n",
    "    r\"\\Theta_A\": {\"x\": -0.05, \"y\": 0.05},\n",
    "    r\"\\Theta_B\": {\"x\": 0.07, \"y\": 0.06},\n",
    "    r\"\\Theta_C\": {\"x\": -0.05, \"y\": -0.0},\n",
    "    \"\\pi(\\Theta_A)\": {\"x\": +0.15, \"y\": 0.08},\n",
    "    \"\\pi(\\Theta_B)\": {\"x\": 0.05, \"y\": 0.1},\n",
    "    \"\\pi(\\Theta_C)\": {\"x\": 0.15, \"y\": -0.02},\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(model_2D_repr.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10, 10))\n",
    "num_levels = 13\n",
    "\n",
    "plt.contour(xi, yi, zi, levels=num_levels, linewidths=0.25, colors=\"grey\", alpha=0.5)\n",
    "\n",
    "# cmap = truncate_colormap(plt.get_cmap(cmap_name), 0.0, 1)\n",
    "\n",
    "plt.contourf(xi, yi, zi, levels=num_levels, cmap=plt.get_cmap(cmap_name), extend=\"both\")\n",
    "\n",
    "label_bboxes = dict(facecolor=\"tab:grey\", boxstyle=\"round\", edgecolor=\"none\", alpha=0.5)\n",
    "\n",
    "print(model_2D_repr)\n",
    "for symbol, point in model_2D_repr.items():\n",
    "    plt.scatter(point[0], point[1], marker=\"x\", color=\"white\", zorder=10)\n",
    "\n",
    "    plt.text(\n",
    "        point[0] + offsets[symbol][\"x\"],\n",
    "        point[1] + offsets[symbol][\"y\"],\n",
    "        r\"${\\bf \" + symbol + r\"}$\",\n",
    "        color=\"white\",\n",
    "        fontsize=24,\n",
    "        bbox=label_bboxes,\n",
    "        horizontalalignment=\"right\",\n",
    "        verticalalignment=\"top\",\n",
    "    )\n",
    "\n",
    "\n",
    "box_x = 0\n",
    "box_y = 0.5\n",
    "title_text = r\"$C^2M^3$\"\n",
    "\n",
    "# Draw box only\n",
    "plt.text(\n",
    "    box_x,\n",
    "    box_y,\n",
    "    title_text,\n",
    "    color=(0.0, 0.0, 0.0, 0.0),\n",
    "    fontsize=24,\n",
    "    horizontalalignment=\"center\",\n",
    "    verticalalignment=\"center\",\n",
    "    bbox=dict(boxstyle=\"round\", fc=(1, 1, 1, 1), ec=\"black\", pad=0.4),\n",
    ")\n",
    "# Draw text only\n",
    "plt.text(\n",
    "    box_x,\n",
    "    box_y - 0.0115,\n",
    "    title_text,\n",
    "    color=(0.0, 0.0, 0.0, 1.0),\n",
    "    fontsize=24,\n",
    "    horizontalalignment=\"center\",\n",
    "    verticalalignment=\"center\",\n",
    ")\n",
    "\n",
    "\n",
    "# plt.colorbar()\n",
    "plt.xlim(-0.5, 0.5)\n",
    "plt.ylim(-0.5, 0.5)\n",
    "#   plt.xlim(-0.9, 1.9)\n",
    "#   plt.ylim(-0.9, 1.9)\n",
    "plt.xticks([])\n",
    "plt.yticks([])\n",
    "plt.axis(\"equal\")\n",
    "# plt.tight_layout()\n",
    "plt.savefig(\"figures/MLP_cifar_loss_contour.pdf\", format=\"pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def linear_interpolation(model_a, model_b, lamb):\n",
    "    return (1 - lamb) * model_a + lamb * model_b\n",
    "\n",
    "\n",
    "def get_interp_loss_curve(lambdas, model_a, model_b, ref_model):\n",
    "    interp_losses = []\n",
    "\n",
    "    for lamb in lambdas:\n",
    "        interp_params = linear_interpolation(model_a, model_b, lamb)\n",
    "        interp_params = vector_to_state_dict(interp_params, ref_model.model)\n",
    "        ref_model.model.load_state_dict(interp_params)\n",
    "        eval_results = trainer.test(ref_model, test_loader, verbose=False)\n",
    "        interp_losses.append(eval_results[0][\"loss/test\"])\n",
    "\n",
    "    return interp_losses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ref_model = copy.deepcopy(models[\"a\"])\n",
    "\n",
    "lambdas = np.linspace(0, 1, 25)\n",
    "# A, B\n",
    "interp_ab = get_interp_loss_curve(lambdas, flat_models[\"a\"], flat_models[\"b\"], ref_model)\n",
    "\n",
    "# A, P_AB(B)\n",
    "interp_a_bperm_to_a = get_interp_loss_curve(\n",
    "    lambdas, flat_models[\"a\"], flat_models_permuted_pairwise[\"a\"][\"b\"], ref_model\n",
    ")\n",
    "\n",
    "# P(B), P_AB(B)\n",
    "interp_b_uni_bperm_to_a = get_interp_loss_curve(\n",
    "    lambdas, flat_models_permuted_to_universe[\"b\"], flat_models_permuted_pairwise[\"a\"][\"b\"], ref_model\n",
    ")\n",
    "\n",
    "# P(A), P(B)\n",
    "interp_b_uni_a_uni = get_interp_loss_curve(\n",
    "    lambdas, flat_models_permuted_to_universe[\"b\"], flat_models_permuted_to_universe[\"a\"], ref_model\n",
    ")\n",
    "\n",
    "# B, P(B)\n",
    "interp_b_b_uni = get_interp_loss_curve(lambdas, flat_models[\"b\"], flat_models_permuted_to_universe[\"b\"], ref_model)\n",
    "\n",
    "# A, P(A)\n",
    "interp_a_a_uni = get_interp_loss_curve(lambdas, flat_models[\"a\"], flat_models_permuted_to_universe[\"a\"], ref_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# B, P_BA(A)\n",
    "interp_b_aperm_to_b = get_interp_loss_curve(\n",
    "    lambdas, flat_models[\"b\"], flat_models_permuted_pairwise[\"b\"][\"a\"], ref_model\n",
    ")\n",
    "\n",
    "# B->A, P_U(A)\n",
    "interp_bperm_to_a_a_uni = get_interp_loss_curve(\n",
    "    lambdas, flat_models_permuted_pairwise[\"b\"][\"a\"], flat_models_permuted_to_universe[\"a\"], ref_model\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure()\n",
    "plt.plot(lambdas, interp_ab, marker=\"x\", label=\"A, B\")\n",
    "plt.plot(lambdas, interp_a_bperm_to_a, marker=\"x\", label=r\"$A, P_{A}P_{B}^\\top(B)$\")\n",
    "plt.plot(lambdas, interp_b_aperm_to_b, marker=\"x\", label=r\"$B, P_{B}P_{A}^\\top(A)$\")\n",
    "plt.plot(lambdas, interp_b_uni_bperm_to_a, marker=\"x\", label=r\"$P_{B}^\\top(B), P_{A}P_{B}^\\top(B)$\")\n",
    "plt.plot(lambdas, interp_bperm_to_a_a_uni, marker=\"x\", label=r\"$P_{B}P_{A}^\\top(A), P_{A}^\\top(A)$\", color=\"black\")\n",
    "plt.plot(lambdas, interp_b_uni_a_uni, marker=\"x\", label=r\"$P_{B}^\\top(B), P_{A}^\\top(A)$\")\n",
    "plt.plot(lambdas, interp_b_b_uni, marker=\"x\", label=r\"$B, P_{B}^\\top(B)$\")\n",
    "\n",
    "plt.plot(\n",
    "    lambdas,\n",
    "    interp_a_a_uni,\n",
    "    marker=\"x\",\n",
    "    label=r\"$A, P_{A}^\\top(A)$\",\n",
    ")\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for perm in permutations[\"a\"].values():\n",
    "    perm = perm_indices_to_perm_matrix(perm)\n",
    "    assert torch.all(perm @ perm.T == torch.eye(perm.shape[0]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Hic sunt leones"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# lambdas = np.linspace(0, 1, 25)\n",
    "\n",
    "# interp_results = []\n",
    "# interp_params = []\n",
    "\n",
    "# model_a_2D = represent_wrt_models(\n",
    "#     flat_models[\"a\"],\n",
    "#     origin_model=flat_models[\"a\"],\n",
    "#     basis1=basis_model_1,\n",
    "#     basis2=basis_model_2,\n",
    "#     scale_1=scale_1,\n",
    "#     scale_2=scale_2,\n",
    "# )\n",
    "# model_e_2D = represent_wrt_models(\n",
    "#     flat_models[\"e\"],\n",
    "#     origin_model=flat_models[\"a\"],\n",
    "#     basis1=basis_model_1,\n",
    "#     basis2=basis_model_2,\n",
    "#     scale_1=scale_1,\n",
    "#     scale_2=scale_2,\n",
    "# )\n",
    "\n",
    "# norms = []\n",
    "# results = {\"2D_interp\": [], \"N_interp\": []}\n",
    "# for lamb in lambdas:\n",
    "#     interp_model = flat_models[\"a\"] * lamb + flat_models[\"e\"] * (1 - lamb)\n",
    "#     interp_params.append(interp_model)\n",
    "\n",
    "#     new_params = vector_to_state_dict(interp_model, ref_model.model)\n",
    "#     ref_model.model.load_state_dict(new_params)\n",
    "#     res = trainer.test(ref_model, test_loader, verbose=False)[0][\"loss/test\"]\n",
    "\n",
    "#     results[\"2D_interp\"].append(res)\n",
    "#     interp_point = model_a_2D * lamb + model_e_2D * (1 - lamb)\n",
    "#     new_params_reconstructed = origin_model + (\n",
    "#         scale_1 * basis_model_1 * interp_point[0] + scale_2 * basis_model_2 * interp_point[1]\n",
    "#     )\n",
    "\n",
    "#     ref_model.model.load_state_dict(vector_to_state_dict(new_params_reconstructed, ref_model.model))\n",
    "#     res = trainer.test(ref_model, test_loader, verbose=False)[0][\"loss/test\"]\n",
    "#     results[\"N_interp\"].append(res)\n",
    "\n",
    "#     norms.append(torch.norm(new_params_reconstructed - interp_model).detach().cpu().numpy())\n",
    "\n",
    "# plt.figure()\n",
    "# plt.plot(lambdas, norms, marker=\"o\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# results[\"2D_interp\"] = np.array(results[\"2D_interp\"])\n",
    "# results[\"N_interp\"] = np.array(results[\"N_interp\"])\n",
    "\n",
    "# plt.figure()\n",
    "# plt.plot(lambdas, results[\"2D_interp\"], marker=\"o\")\n",
    "# plt.plot(lambdas, results[\"N_interp\"], marker=\"x\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# (flat_models[\"c\"] / norm(flat_models[\"c\"])) @ (flat_models[\"b\"] / norm(flat_models[\"b\"]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# pylogger.info(model_a_2D)\n",
    "# pylogger.info(model_c_2D)\n",
    "\n",
    "# interps_on_plane = []\n",
    "# for lamb in lambdas:\n",
    "#     interps_on_plane.append(model_a_2D * lamb + model_c_2D * (1 - lamb))\n",
    "\n",
    "#     model_interp = flat_models[\"a\"] * (1 - lamb) + flat_models[\"c\"] * lamb\n",
    "\n",
    "\n",
    "# interps_on_plane"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# for ind, (par, lambd) in zip(interp_params, lambdas):\n",
    "\n",
    "#     interp_par_2D = represent_wrt_models(\n",
    "#         par, origin_model=flat_models[\"a\"], basis1=basis_model_1, basis2=basis_model_2, scale_1=scale_1, scale_2=scale_2\n",
    "#     )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# new_params = vector_to_state_dict(mean_abc, ref_model.model)\n",
    "\n",
    "# ref_model.model.load_state_dict(new_params)\n",
    "\n",
    "# test_loss_for_random_pts_plane = trainer.test(ref_model, test_loader, verbose=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ccmm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
