{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3afba055",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e2fde5f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "from pathlib import Path\n",
    "from rae import PROJECT_ROOT\n",
    "import numpy as np\n",
    "import torch.nn.functional as F\n",
    "from pytorch_lightning import seed_everything\n",
    "from torch import nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bddf3583",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "from tqdm import tqdm\n",
    "from transformers import AutoModel, AutoTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a60d9f72",
   "metadata": {},
   "outputs": [],
   "source": [
    "from rae.modules.attention import RelativeAttention, AttentionOutput"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37f2829d",
   "metadata": {},
   "outputs": [],
   "source": [
    "device: str = \"cuda\"\n",
    "fine_grained: bool = False\n",
    "target_key: str = f\"{'fine' if fine_grained else 'coarse'}_label\"\n",
    "data_key: str = \"image\"\n",
    "dataset_name: str = \"cifar100\"\n",
    "num_anchors: int = 768\n",
    "train_perc: float = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54fdaa84",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset, ClassLabel\n",
    "\n",
    "\n",
    "def get_dataset(split: str, perc: float):\n",
    "    seed_everything(42)\n",
    "    assert 0 < perc <= 1\n",
    "    dataset = load_dataset(dataset_name)[split]\n",
    "\n",
    "    # Select a random subset\n",
    "    if perc != 1:\n",
    "        indices = list(range(len(dataset)))\n",
    "        random.shuffle(indices)\n",
    "        indices = indices[: int(len(indices) * perc)]\n",
    "        dataset = dataset.select(indices)\n",
    "\n",
    "    def clean_sample(sample):\n",
    "        return sample\n",
    "\n",
    "    #     dataset = dataset.map(clean_sample)\n",
    "\n",
    "    return dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b49e4af9",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = get_dataset(split=\"train\", perc=train_perc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a41a58f",
   "metadata": {},
   "outputs": [],
   "source": [
    "class2idx = train_dataset.features[target_key].str2int\n",
    "num_classes = train_dataset.features[target_key].num_classes\n",
    "num_classes, len(train_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "981b6fa5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import timm\n",
    "from timm.data import resolve_data_config\n",
    "from timm.data.transforms_factory import create_transform\n",
    "from transformers import AutoFeatureExtractor, AutoModelForImageClassification, AutoModel\n",
    "\n",
    "\n",
    "def load_transformer(transformer_name):\n",
    "    transformer = timm.create_model(transformer_name, pretrained=True, num_classes=0)\n",
    "    return transformer.requires_grad_(False).eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1a0b9f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_dataset = get_dataset(split=\"test\", perc=train_perc)\n",
    "len(test_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6beba308",
   "metadata": {},
   "outputs": [],
   "source": [
    "@torch.no_grad()\n",
    "def call_transformer(batch, transformer):\n",
    "    #     batch[\"encoding\"] = batch[\"encoding\"].to(device)\n",
    "    sample_encodings = transformer(batch[\"encoding\"].to(device))\n",
    "    #     hidden = sample_encodings[\"hidden_states\"][-1]\n",
    "    #     assert hidden.size(-1) == hidden.size(-2), hidden.size()\n",
    "    #     print(sample_encodings.shape)\n",
    "    return {\"hidden\": sample_encodings}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3649cedf",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import *\n",
    "\n",
    "\n",
    "def get_anchors(num_anchors: int, seed: int):\n",
    "    seed_everything(seed)\n",
    "    assert num_anchors <= len(train_dataset)\n",
    "    anchor_idxs = list(range(len(train_dataset)))\n",
    "    random.shuffle(anchor_idxs)\n",
    "    anchor_idxs = anchor_idxs[:num_anchors]\n",
    "    return anchor_idxs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37f4b15a",
   "metadata": {},
   "outputs": [],
   "source": [
    "transformer_name = \"vit_small_patch16_224\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e42242d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "transformer = load_transformer(transformer_name=transformer_name)\n",
    "config = resolve_data_config({}, model=transformer)\n",
    "transform = create_transform(**config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a88f2425",
   "metadata": {},
   "outputs": [],
   "source": [
    "def relative_projection(x, anchors):\n",
    "    x = F.normalize(x, p=2, dim=-1)\n",
    "    anchors = F.normalize(anchors, p=2, dim=-1)\n",
    "    return torch.einsum(\"bm, am -> ba\", x, anchors)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "397f6d06",
   "metadata": {},
   "outputs": [],
   "source": [
    "def collate_fn(batch, feature_extractor, transform):\n",
    "    return {\"encoding\": torch.stack([transform(sample[\"img\"].convert(\"RGB\")) for sample in batch], dim=0)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a3d7735",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_latents(dataloader, anchors, split: str, transformer) -> Dict[str, torch.Tensor]:\n",
    "    absolute_latents: List = []\n",
    "    relative_latents: List = []\n",
    "    #     logits_latents: List = []\n",
    "\n",
    "    transformer = transformer.to(device)\n",
    "    for batch in tqdm(dataloader, desc=f\"[{split}] Computing latents\"):\n",
    "        with torch.no_grad():\n",
    "            transformer_out = call_transformer(batch=batch, transformer=transformer)\n",
    "\n",
    "            #             logits_latents.append(transformer_out[\"logits\"].cpu())\n",
    "            absolute_latents.append(transformer_out[\"hidden\"].cpu())\n",
    "\n",
    "            if anchors is not None:\n",
    "                batch_rel_latents = relative_projection(x=transformer_out[\"hidden\"], anchors=anchors)\n",
    "                relative_latents.append(batch_rel_latents.cpu())\n",
    "\n",
    "    absolute_latents: torch.Tensor = torch.cat(absolute_latents, dim=0).cpu()\n",
    "    #     logits_latents: torch.Tensor = torch.cat(logits_latents, dim=0).cpu()\n",
    "    relative_latents: torch.Tensor = (\n",
    "        torch.cat(relative_latents, dim=0).cpu() if len(relative_latents) > 0 else relative_latents\n",
    "    )\n",
    "\n",
    "    transformer = transformer.cpu()\n",
    "    return {\n",
    "        \"absolute\": absolute_latents,\n",
    "        \"relative\": relative_latents,\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "679eb256",
   "metadata": {},
   "outputs": [],
   "source": [
    "absolute_latents = {\n",
    "    split: get_latents(\n",
    "        dataloader=DataLoader(\n",
    "            train_dataset if split == \"train\" else test_dataset,\n",
    "            num_workers=4,\n",
    "            pin_memory=True,\n",
    "            collate_fn=partial(collate_fn, feature_extractor=None, transform=transform),\n",
    "            batch_size=32,\n",
    "        ),\n",
    "        split=f\"{split}/{transformer_name}\",\n",
    "        anchors=None,\n",
    "        transformer=transformer,\n",
    "    )[\"absolute\"]\n",
    "    for split in (\"train\", \"test\")\n",
    "}\n",
    "absolute_latents"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "abd28026",
   "metadata": {},
   "outputs": [],
   "source": [
    "absolute_latents[\"train\"].shape, absolute_latents[\"test\"].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b9c2141",
   "metadata": {},
   "outputs": [],
   "source": [
    "encoding_dim: int = absolute_latents[\"train\"].size(-1)\n",
    "encoding_dim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58529383",
   "metadata": {},
   "outputs": [],
   "source": [
    "sweep = {\n",
    "    \"seed\": list(range(3)),\n",
    "    \"num_anchors\": list(range(1, 50, 2)) + list(range(50, 100, 5)) + list(range(100, 1000, 20))\n",
    "    # \"seed_index\": [0],\n",
    "    #     \"num_epochs\": [10, 30, 50],\n",
    "    #     \"in_channels\": [num_anchors],\n",
    "    # \"out_channels\": [10, 32, 64],\n",
    "    #     \"out_channels\": [num_anchors],\n",
    "    #     \"num_layers\": [64, 32],\n",
    "    #     \"dropout\": [0.1, 0.5],\n",
    "    # \"hidden_fn\": [torch.relu, torch.tanh, torch.sigmoid],\n",
    "    # \"conv_fn\": [torch.relu, torch.tanh, torch.sigmoid],\n",
    "    #     \"hidden_fn\": [torch.nn.ReLU(), torch.nn.Tanh()],\n",
    "    #     \"conv_fn\": [torch.nn.ReLU(), torch.nn.Tanh()],\n",
    "    #     \"optimizer\": [torch.optim.Adam, torch.optim.SGD],\n",
    "    #     \"lr\": [0.01, 0.02],\n",
    "    #     \"encoder\": [\n",
    "    #         (\n",
    "    #             \"GCN2Conv\",\n",
    "    #             functools.partial(\n",
    "    #                 encoder_factory,\n",
    "    #                 encoder_type=\"GCN2Conv\",\n",
    "    #                 **dict(alpha=0.1, theta=0.5, shared_weights=True, normalize=False),\n",
    "    #             ),\n",
    "    #         ),\n",
    "    #         # (\"GCNConv\", functools.partial(encoder_factory, encoder_type=\"GCNConv\")),\n",
    "    #         # (\"GATConv\", functools.partial(encoder_factory, encoder_type=\"GATConv\")),\n",
    "    #         (\"GINConv\", functools.partial(encoder_factory, encoder_type=\"GINConv\")),\n",
    "    #     ],\n",
    "}\n",
    "from sklearn.model_selection import ParameterGrid\n",
    "\n",
    "experiments = ParameterGrid(sweep)\n",
    "experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd41a8b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "latent_normalize: bool = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74f27c1d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import ParameterSampler, ParameterGrid\n",
    "import logging\n",
    "from tqdm import tqdm\n",
    "from pprint import pprint\n",
    "from rae.utils.utils import to_device\n",
    "from pytorch_lightning.utilities.seed import log as seed_log\n",
    "from sklearn.metrics import precision_recall_fscore_support, accuracy_score\n",
    "import pandas as pd\n",
    "\n",
    "\n",
    "class Lambda(nn.Module):\n",
    "    def __init__(self, func):\n",
    "        super().__init__()\n",
    "        self.func = func\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.func(x)\n",
    "\n",
    "\n",
    "stats = {x: [] for x in (\"experiment\", \"epoch\", \"loss\", \"val_fscore\", \"val_acc\", \"num_anchors\")}\n",
    "\n",
    "# for i, experiment in enumerate(pbar := tqdm(ParameterSampler(sweep, n_iter=50, random_state=42), desc=\"Experiment\")):\n",
    "for i, experiment in enumerate(pbar := tqdm(ParameterGrid(sweep), desc=\"Experiment\", leave=True)):\n",
    "    seed: int = experiment[\"seed\"]\n",
    "    temp_log_level = seed_log.getEffectiveLevel()\n",
    "    seed_log.setLevel(logging.ERROR)\n",
    "    seed_everything(seed)\n",
    "    seed_log.setLevel(temp_log_level)\n",
    "\n",
    "    num_anchors: int = experiment[\"num_anchors\"]\n",
    "    anchor_idxs = get_anchors(num_anchors=num_anchors, seed=seed)\n",
    "    anchors = absolute_latents[\"train\"][anchor_idxs].to(device)\n",
    "\n",
    "    train_data = absolute_latents[\"train\"].to(device)\n",
    "    train_data = relative_projection(x=train_data, anchors=anchors)\n",
    "\n",
    "    test_data = absolute_latents[\"test\"].to(device)\n",
    "    test_data = relative_projection(x=test_data, anchors=anchors)\n",
    "    if latent_normalize:\n",
    "        train_data = F.normalize(train_data, p=2, dim=-1)\n",
    "        test_data = F.normalize(test_data, p=2, dim=-1)\n",
    "\n",
    "    #     tensor_train_dataset = TensorDataset(train_data, torch.as_tensor(train_dataset[target_key]))\n",
    "    #     train_loader = DataLoader(tensor_train_dataset, batch_size=32, pin_memory=True, shuffle=True, num_workers=4)\n",
    "    train_y = torch.as_tensor(train_dataset[target_key]).to(device)\n",
    "\n",
    "    test_y = np.asarray(test_dataset[target_key])\n",
    "\n",
    "    model = nn.Sequential(\n",
    "        Lambda(lambda x: x.permute(1, 0)),\n",
    "        nn.InstanceNorm1d(num_features=num_anchors),\n",
    "        Lambda(lambda x: x.permute(1, 0)),\n",
    "        nn.Linear(in_features=num_anchors, out_features=encoding_dim),\n",
    "        nn.Tanh(),\n",
    "        Lambda(lambda x: x.permute(1, 0)),\n",
    "        nn.InstanceNorm1d(num_features=encoding_dim),\n",
    "        Lambda(lambda x: x.permute(1, 0)),\n",
    "        nn.Linear(in_features=encoding_dim, out_features=train_dataset.features[target_key].num_classes),\n",
    "        #         nn.Tanh(),\n",
    "        #         Lambda(lambda x: x.permute(1, 0)),\n",
    "        #         nn.InstanceNorm1d(num_features=num_anchors),\n",
    "        #         Lambda(lambda x: x.permute(1, 0)),\n",
    "        #         nn.Linear(in_features=num_anchors, out_features=train_dataset.features[target_key].num_classes),\n",
    "        #         nn.ReLU(),\n",
    "    ).to(device)\n",
    "\n",
    "    opt = Adam(model.parameters(), lr=1e-3)\n",
    "    loss_fn = CrossEntropyLoss()\n",
    "\n",
    "    for epoch in tqdm(range(10), leave=False, desc=\"epoch\"):\n",
    "        model.train()\n",
    "\n",
    "        pred_y = model(train_data)\n",
    "        loss = loss_fn(pred_y, train_y)\n",
    "        loss.backward()\n",
    "        opt.step()\n",
    "        opt.zero_grad()\n",
    "\n",
    "        model.eval()\n",
    "        with torch.no_grad():\n",
    "            test_preds = model(test_data).softmax(-1).argmax(-1).cpu().numpy()\n",
    "\n",
    "        precision, recall, fscore, _ = precision_recall_fscore_support(test_y, test_preds, average=\"weighted\")\n",
    "        acc = accuracy_score(test_y, test_preds)\n",
    "\n",
    "        loss = loss.detach().cpu().numpy()\n",
    "\n",
    "        stats[\"experiment\"].append(i)\n",
    "        stats[\"epoch\"].append(epoch)\n",
    "        stats[\"loss\"].append(loss)\n",
    "        stats[\"val_fscore\"].append(fscore)\n",
    "        stats[\"val_acc\"].append(acc)\n",
    "        stats[\"num_anchors\"].append(num_anchors)\n",
    "\n",
    "        pbar.set_description(f\"Epoch: {epoch}, Loss: {loss:.4f}\" f\"Val F1: {fscore:.4f} num_anchors: {num_anchors}\")\n",
    "\n",
    "    model = model.cpu().eval()\n",
    "\n",
    "stats = pd.DataFrame(stats)\n",
    "stats.to_csv(\n",
    "    PROJECT_ROOT / \"experiments\" / \"sec:anchor-analysis\" / f\"{dataset_name}_data_manifold_stats_anchors_analysis.tsv\",\n",
    "    sep=\"\\t\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3d832e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "transformer_name2hidden_dim = {\n",
    "    transformer_name: latents[\"absolute\"][0].shape[0] for transformer_name, latents in transformer2train_latents.items()\n",
    "}\n",
    "transformer_name2hidden_dim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7951c369",
   "metadata": {},
   "outputs": [],
   "source": [
    "SEEDS = list(range(3))\n",
    "train_classifiers = {\n",
    "    seed: {\n",
    "        embedding_type: {\n",
    "            transformer_name: fit(\n",
    "                train_latents[embedding_type],\n",
    "                train_dataset[target_key],\n",
    "                seed=seed,\n",
    "                normalize=latent_normalize,\n",
    "                hidden_dim=num_anchors\n",
    "                if embedding_type == \"relative\"\n",
    "                else transformer_name2hidden_dim[transformer_name],\n",
    "            )\n",
    "            for transformer_name, train_latents in tqdm(\n",
    "                transformer2train_latents.items(), leave=False, desc=\"transformer\"\n",
    "            )\n",
    "        }\n",
    "        for embedding_type in tqdm([\"absolute\", \"relative\"], leave=False, desc=\"embedding_type\")\n",
    "    }\n",
    "    for seed in tqdm(SEEDS, leave=False, desc=\"seed\")\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc23b282",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_classifiers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca059a8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import precision_recall_fscore_support, mean_absolute_error\n",
    "import itertools\n",
    "\n",
    "numeric_results = {\n",
    "    \"seed\": [],\n",
    "    \"embed_type\": [],\n",
    "    \"train_model\": [],\n",
    "    \"test_model\": [],\n",
    "    \"precision\": [],\n",
    "    \"recall\": [],\n",
    "    \"fscore\": [],\n",
    "    \"stitched\": [],\n",
    "}\n",
    "for seed, embed_type2transformer2classifier in train_classifiers.items():\n",
    "    for embed_type, transformer2classifier in embed_type2transformer2classifier.items():\n",
    "        for (transformer_name1, classifier1), (transformer_name2, classifier2) in itertools.product(\n",
    "            transformer2classifier.items(), repeat=2\n",
    "        ):\n",
    "            if embed_type == \"absolute\" and (\n",
    "                transformer_name2hidden_dim[transformer_name1] != transformer_name2hidden_dim[transformer_name2]\n",
    "            ):\n",
    "                precision = recall = fscore = np.nan\n",
    "            else:\n",
    "                test_latents = transformer2test_latents[transformer_name1][embed_type]\n",
    "                if latent_normalize:\n",
    "                    test_latents = F.normalize(test_latents, p=2, dim=-1)\n",
    "                preds = classifier2(test_latents)\n",
    "                test_y = np.array(test_dataset[target_key])\n",
    "\n",
    "                precision, recall, fscore, _ = precision_recall_fscore_support(test_y, preds, average=\"weighted\")\n",
    "            numeric_results[\"embed_type\"].append(embed_type)\n",
    "            numeric_results[\"train_model\"].append(transformer_name1)\n",
    "            numeric_results[\"test_model\"].append(transformer_name2)\n",
    "            numeric_results[\"precision\"].append(precision)\n",
    "            numeric_results[\"recall\"].append(recall)\n",
    "            numeric_results[\"fscore\"].append(fscore)\n",
    "            numeric_results[\"stitched\"].append(transformer_name1 != transformer_name2)\n",
    "            numeric_results[\"seed\"].append(seed)\n",
    "\n",
    "\n",
    "import pandas as pd\n",
    "\n",
    "pd.options.display.max_columns = None\n",
    "pd.options.display.max_rows = None\n",
    "df = pd.DataFrame(numeric_results)\n",
    "df.to_csv(\n",
    "    f\"vision_transformer-stitching-{dataset_name}-{'fine' if fine_grained else 'coarse'}-{train_perc}.tsv\",\n",
    "    sep=\"\\t\",\n",
    ")\n",
    "df = df[df.train_model != \"regnetx_002\"][df.test_model != \"regnetx_002\"][df.train_model != \"rexnet_100\"][\n",
    "    df.test_model != \"rexnet_100\"\n",
    "]\n",
    "df = df.groupby(\n",
    "    [\n",
    "        \"embed_type\",\n",
    "        \"stitched\",\n",
    "        \"train_model\",\n",
    "        \"test_model\",\n",
    "    ]\n",
    ").agg([np.mean])\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "372200db",
   "metadata": {},
   "outputs": [],
   "source": [
    "full_df = pd.read_csv(\n",
    "    f\"vision_transformer-stitching-{dataset_name}-{'fine' if fine_grained else 'coarse'}-{train_perc}.tsv\",\n",
    "    sep=\"\\t\",\n",
    "    index_col=0,\n",
    ")\n",
    "# full_df = full_df[full_df.train_model != \"regnetx_002\"][full_df.test_model != \"regnetx_002\"][\n",
    "#     full_df.train_model != \"rexnet_100\"\n",
    "# ][full_df.test_model != \"rexnet_100\"]\n",
    "\n",
    "df = full_df.groupby(\n",
    "    [\n",
    "        \"embed_type\",\n",
    "        \"stitched\",\n",
    "        \"train_model\",\n",
    "        \"test_model\",\n",
    "    ]\n",
    ").agg([np.mean, \"count\"])\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12d3e07f",
   "metadata": {},
   "outputs": [],
   "source": [
    "full_df.drop(columns=[\"stitched\", \"seed\", \"precision\", \"recall\"]).groupby(\n",
    "    [\"embed_type\", \"train_model\", \"test_model\"]\n",
    ").agg([np.mean]).round(3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8591d01b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
