{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import json\n",
    "import pickle\n",
    "from tqdm import tqdm\n",
    "import time\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.autograd import Variable\n",
    "\n",
    "import torchvision\n",
    "from torchvision import datasets\n",
    "from torchvision import transforms\n",
    "from torchvision.utils import save_image\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from PIL import Image\n",
    "\n",
    "from random import randint\n",
    "from collections import Counter, defaultdict\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics.pairwise import cosine_similarity\n",
    "from sklearn.manifold import TSNE\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "from sklearn.decomposition import PCA\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from dino_disentanglement_multi_clevr_dual import FeatureExtractorMulti, FeatureGeneratorMulti\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_images_from_folder(folder):\n",
    "    images = []\n",
    "    filenames = sorted([f for f in os.listdir(folder) if f.endswith('.png')])\n",
    "    for filename in filenames:\n",
    "        img_path = os.path.join(folder, filename)\n",
    "        img = Image.open(img_path)\n",
    "        img_array = np.array(img)\n",
    "        images.append(img_array)\n",
    "    return np.stack(images)\n",
    "\n",
    "def load_json(path):\n",
    "    # Open and load the JSON file\n",
    "    with open(path, 'r') as file:\n",
    "        data = json.load(file)\n",
    "\n",
    "    return data\n",
    "\n",
    "# Simulate loading with a progress bar\n",
    "def load_with_progress(filename, description):\n",
    "    with open(filename, 'rb') as f:\n",
    "        file_size = f.seek(0, 2)  # Move to the end of the file to get its size\n",
    "        f.seek(0)  # Move back to the start of the file\n",
    "        chunk_size = file_size // 100  # Define chunk size for progress update\n",
    "        progress_bar = tqdm(total=100, desc=description)\n",
    "        \n",
    "        data = []\n",
    "        while True:\n",
    "            chunk = f.read(chunk_size)\n",
    "            if not chunk:\n",
    "                break\n",
    "            data.append(chunk)\n",
    "            progress_bar.update(1)\n",
    "            time.sleep(0.01)  # Simulate time delay for loading\n",
    "        \n",
    "        progress_bar.close()\n",
    "        f.seek(0)  # Move back to the start of the file to load the actual data\n",
    "        return pickle.load(f)\n",
    "\n",
    "class Flatten(nn.Module):\n",
    "    def forward(self, input):\n",
    "        return input.reshape(input.size(0), -1)\n",
    "    \n",
    "class UnFlatten(nn.Module):\n",
    "    def forward(self, input, size=1024):\n",
    "        return input.reshape(input.size(0), size, 1, 1)\n",
    "\n",
    "class VAE(nn.Module):\n",
    "    def __init__(self, image_channels=3, h_dim=1024, z_dim=64):\n",
    "        super(VAE, self).__init__()\n",
    "        self.encoder = nn.Sequential(\n",
    "            nn.Conv2d(image_channels, 32, kernel_size=4, stride=2),\n",
    "            nn.ReLU(),\n",
    "            nn.Conv2d(32, 64, kernel_size=4, stride=2),\n",
    "            nn.ReLU(),\n",
    "            nn.Conv2d(64, 128, kernel_size=4, stride=2),\n",
    "            nn.ReLU(),\n",
    "            nn.Conv2d(128, 256, kernel_size=4, stride=2),\n",
    "            nn.ReLU(),\n",
    "            Flatten()\n",
    "        )\n",
    "        \n",
    "        self.fc1 = nn.Linear(h_dim, z_dim)\n",
    "        self.fc2 = nn.Linear(h_dim, z_dim)\n",
    "        self.fc3 = nn.Linear(z_dim, h_dim)\n",
    "        \n",
    "        self.decoder = nn.Sequential(\n",
    "            UnFlatten(),\n",
    "            nn.ConvTranspose2d(h_dim, 128, kernel_size=5, stride=2),\n",
    "            nn.ReLU(),\n",
    "            nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2),\n",
    "            nn.ReLU(),\n",
    "            nn.ConvTranspose2d(64, 32, kernel_size=6, stride=2),\n",
    "            nn.ReLU(),\n",
    "            nn.ConvTranspose2d(32, image_channels, kernel_size=6, stride=2),\n",
    "            nn.Sigmoid(),\n",
    "        )\n",
    "        \n",
    "    def reparameterize(self, mu, logvar):\n",
    "        std = logvar.mul(0.5).exp_().cuda()\n",
    "        # return torch.normal(mu, std)\n",
    "        esp = torch.randn(*mu.size()).cuda()\n",
    "        z = mu + std * esp\n",
    "        return z\n",
    "    \n",
    "    def bottleneck(self, h):\n",
    "        #print(h.shape)\n",
    "        mu, logvar = self.fc1(h), self.fc2(h)\n",
    "        z = self.reparameterize(mu, logvar)\n",
    "        return z, mu, logvar\n",
    "\n",
    "    def encode(self, x):\n",
    "        h = self.encoder(x)\n",
    "        z, mu, logvar = self.bottleneck(h)\n",
    "        return z, mu, logvar\n",
    "\n",
    "    def decode(self, z):\n",
    "        z = self.fc3(z)\n",
    "        z = self.decoder(z)\n",
    "        return z\n",
    "\n",
    "    def forward(self, x):\n",
    "        z, mu, logvar = self.encode(x)\n",
    "        z = self.decode(z)\n",
    "        return z, mu, logvar"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import requests\n",
    "import pickle\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "from torchvision import transforms\n",
    "\n",
    "#model = VAE(image_channels=3).to('cuda')\n",
    "#model.load_state_dict(torch.load('../model_outputs/clevrtex2/vae_64_clevrtex_100.torch', map_location='cuda'))\n",
    "\n",
    "dinov2_vitl14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14').cuda()\n",
    "\n",
    "transform = transforms.Compose([           \n",
    "                                transforms.Resize(256),                    \n",
    "                                transforms.CenterCrop(224),               \n",
    "                                transforms.Normalize(                      \n",
    "                                mean=[0.485, 0.456, 0.406],                \n",
    "                                std=[0.229, 0.224, 0.225]              \n",
    "                                )])\n",
    "\n",
    "\n",
    "transform1 = transforms.Compose([           \n",
    "                                transforms.Resize(520),\n",
    "                                transforms.CenterCrop(518), #should be multiple of model patch_size                 \n",
    "                                transforms.Normalize(mean=0.5, std=0.2)\n",
    "                                ])\n",
    "\n",
    "transform2 = transforms.Compose([           \n",
    "                                transforms.Resize(520),\n",
    "                                transforms.CenterCrop(518), #should be multiple of model patch_size                 \n",
    "                                ])\n",
    "\n",
    "transform_segmented = transforms.Compose([           \n",
    "                            transforms.Resize((64, 64)),              \n",
    "                    ])\n",
    "\n",
    "transform_1515 = transforms.Compose([           \n",
    "                            transforms.Resize((15, 15)),              \n",
    "                    ])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Load data\n",
    "full_image_data = load_with_progress('pickled_embeddings/multi6_imagef_clevrtex_data_500_384.pickle', 'Loading full_image_data')\n",
    "full_image_metadata = load_with_progress('pickled_embeddings/multi6_imagef_clevrtex_metadata_500_384.pickle', 'Loading full_image_metadata')\n",
    "single_image_data = load_with_progress('pickled_embeddings/single6_imagef_clevrtex_data_500_384.pickle', 'Loading single_image_data')\n",
    "single_image_metadata = load_with_progress('pickled_embeddings/single6_imagef_clevrtex_metadata_500_384.pickle', 'Loading single_image_metadata')\n",
    "\n",
    "# full_image_data = load_with_progress('/home/stefan/Documents/Dino_V2-main/pickled_embeddings/clevr15_50/multi_imagef_clevrtex_data_5000_sq1550_2_384.pickle', 'Loading full_image_data')\n",
    "# full_image_metadata = load_with_progress('/home/stefan/Documents/Dino_V2-main/pickled_embeddings/clevr15_50/multi_imagef_clevrtex_metadata_5000_sq1550_2_384.pickle', 'Loading full_image_metadata')\n",
    "# single_image_data = load_with_progress('/home/stefan/Documents/Dino_V2-main/pickled_embeddings/clevr15_50/single_imagef_clevrtex_data_5000_sq1550_2_384.pickle', 'Loading single_image_data')\n",
    "# single_image_metadata = load_with_progress('/home/stefan/Documents/Dino_V2-main/pickled_embeddings/clevr15_50/single_imagef_clevrtex_metadata_5000_sq1550_2_384.pickle', 'Loading single_image_metadata')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "m_size = 500\n",
    "s_size = 389\n",
    "\n",
    "full_images_m = full_image_data['full_images'][:m_size]\n",
    "large_images_m = full_image_data['large_images'][:m_size]\n",
    "square_regions_m = full_image_data['square_regions'][:m_size]\n",
    "dino_full_atts_m = full_image_data['dino_full_atts'][:m_size]\n",
    "dino_projection_atts_m = full_image_data['dino_projection_atts'][:m_size]\n",
    "projections_m = full_image_data['projections'][:m_size]\n",
    "\n",
    "full_images_s = single_image_data['full_images']\n",
    "large_images_s = single_image_data['large_images']\n",
    "square_regions_s = single_image_data['square_regions']\n",
    "dino_full_atts_s = single_image_data['dino_full_atts']\n",
    "dino_projection_atts_s = single_image_data['dino_projection_atts']\n",
    "projections_s = single_image_data['projections']\n",
    "\n",
    "dino_full_atts_m = np.stack(dino_full_atts_m)\n",
    "projections_m = np.asarray(projections_m)\n",
    "square_regions_m = np.asarray(square_regions_m)\n",
    "dino_projection_atts_m = np.asarray(dino_projection_atts_m)\n",
    "\n",
    "dino_full_atts_s = np.stack(dino_full_atts_s)\n",
    "projections_s = np.asarray(projections_s)\n",
    "square_regions_s = np.asarray(square_regions_s)\n",
    "dino_projection_atts_s = np.asarray(dino_projection_atts_s)\n",
    "\n",
    "full_images_m = np.stack(full_images_m)\n",
    "color_items_m = full_image_metadata['color'][:m_size]\n",
    "shape_items_m = full_image_metadata['shape'][:m_size]   \n",
    "material_items_m = full_image_metadata['material'][:m_size]\n",
    "size_items_m = full_image_metadata['size'][:m_size]\n",
    "threedcoords_items_m = full_image_metadata['3d_coords'][:m_size]\n",
    "numobjects_items_m = full_image_metadata['num_objects'][:m_size] \n",
    "\n",
    "full_images_s = np.stack(full_images_s)\n",
    "color_items_s = single_image_metadata['color'][:s_size]\n",
    "shape_items_s = single_image_metadata['shape'][:s_size]\n",
    "size_items_s = single_image_metadata['size'][:s_size]\n",
    "material_items_s = single_image_metadata['material'][:s_size]\n",
    "threedcoords_items_s = single_image_metadata['3d_coords'][:s_size]\n",
    "numobjects_items_s = single_image_metadata['num_objects'][:s_size]\n",
    "\n",
    "print(projections_m.shape)\n",
    "print(\"PS: \", projections_s[0].shape)\n",
    "print(\"PM: \", projections_m[0].shape)\n",
    "\n",
    "print(\"COLOR_ITEM_SHAPE: \", color_items_m.shape)\n",
    "\n",
    "print(\"DINO_FULL_ATTS_S: \", projections_m.shape)\n",
    "print(\"COLOR_ITEM_SHAPE_S: \", color_items_s.shape)\n",
    "\n",
    "print(color_items_m[1])\n",
    "\n",
    "\n",
    "print(color_items_s)\n",
    "print(full_images_m.shape)\n",
    "print(projections_s.shape)\n",
    "plt.imshow(full_images_s[1])\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "index = 300\n",
    "\n",
    "plt.imshow(full_images_s[index])\n",
    "plt.show()\n",
    "\n",
    "print(color_items_s[index])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 106,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Similarity functions\n",
    "\n",
    "def dino_comparison(ref_dino_full_atts, dino_full_atts_m):\n",
    "    similarities = []\n",
    "\n",
    "    for dino_full_att in dino_full_atts_m:\n",
    "        similarity = cosine_similarity(ref_dino_full_atts.reshape(1, -1), dino_full_att.reshape(1, -1))\n",
    "        similarities.append(similarity)\n",
    "    similarities = np.stack(similarities).squeeze((1, 2))\n",
    "\n",
    "    # Sort similarities descending\n",
    "    sorted_indices = np.argsort(similarities)[::-1]\n",
    "\n",
    "    return sorted_indices, similarities\n",
    "\n",
    "\n",
    "def vae_dino_comparison_gpu2(ref_projections_dino_concat, projections_m, dino_full_atts_m, vae_influence, topk=1):\n",
    "    # Convert input arrays to PyTorch tensors\n",
    "    ref_projections_dino_concat = torch.tensor(ref_projections_dino_concat, device='cuda')\n",
    "    avg_cosine_similarities = []\n",
    "\n",
    "    for projection_m, dino_att in zip(projections_m, dino_full_atts_m):\n",
    "        projection_m = torch.tensor(projection_m, device='cuda')\n",
    "        dino_att = torch.tensor(dino_att, device='cuda')\n",
    "        \n",
    "        repeated_dino_atts = dino_att.repeat(projection_m.shape[0], 1)\n",
    "        concat_projection = torch.cat((repeated_dino_atts, projection_m.squeeze(1) * vae_influence), dim=1)\n",
    "\n",
    "        # Compute cosine similarity\n",
    "        cosine_similarities = F.cosine_similarity(ref_projections_dino_concat.unsqueeze(1), concat_projection.unsqueeze(0), dim=2)\n",
    "\n",
    "        # Sort and get the top-k mean cosine similarities\n",
    "        sorted_cosines, _ = torch.sort(cosine_similarities, dim=1, descending=True)\n",
    "        best_cosine_similarities = torch.mean(sorted_cosines[:, :1], dim=1)\n",
    "\n",
    "        sorted_best_cosine_similarities, sorted_indices = torch.sort(best_cosine_similarities, descending=True)\n",
    "        average_cosine_similarity = torch.mean(sorted_best_cosine_similarities[:]).item()\n",
    "    \n",
    "        avg_cosine_similarities.append(average_cosine_similarity)\n",
    "\n",
    "    # Convert results back to numpy arrays for further processing\n",
    "    avg_cosine_similarities = np.array(avg_cosine_similarities)\n",
    "    sorted_indices = np.argsort(avg_cosine_similarities)[::-1]\n",
    "    sorted_similarities = avg_cosine_similarities[sorted_indices]\n",
    "\n",
    "    return sorted_indices, sorted_similarities\n",
    "\n",
    "\n",
    "def vae_comparison3_opt(ref_slot, slots):\n",
    "    ref_slot = torch.tensor(ref_slot).cuda()\n",
    "    similarities = []\n",
    "\n",
    "    for slot in slots:\n",
    "        slot = torch.tensor(slot).cuda()\n",
    "        cosine_similarities = torch.nn.functional.cosine_similarity(ref_slot.unsqueeze(1), slot.unsqueeze(0), dim=2)\n",
    "\n",
    "        sorted_cosines, _ = torch.sort(cosine_similarities, dim=1, descending=True)\n",
    "        best_cosine_similarities = torch.mean(sorted_cosines[:, :1], dim=1)\n",
    "        sorted_best_cosine_similarities, sorted_best_indices = torch.sort(best_cosine_similarities, descending=True)\n",
    "        average_cosine_similarity = torch.mean(sorted_best_cosine_similarities[:]).item()\n",
    "\n",
    "        similarities.append(average_cosine_similarity)\n",
    "    \n",
    "    similarities = np.array(similarities)\n",
    "    sorted_indices = np.argsort(similarities)[::-1]\n",
    "    sorted_similarities = similarities[sorted_indices]\n",
    "    \n",
    "    return sorted_indices, sorted_similarities"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Sample data for evaluation\n",
    "\n",
    "#Query data\n",
    "filtered_1_indeces = [i for i, color_item in enumerate(color_items_s) if len(color_item) == 1]\n",
    "filtered_1_images = full_images_s[filtered_1_indeces]\n",
    "filtered_1_dino_full_atts = dino_full_atts_s[filtered_1_indeces]\n",
    "filtered_1_projections = projections_s[filtered_1_indeces]\n",
    "filtered_1_color_items = color_items_s[filtered_1_indeces]\n",
    "filtered_1_shape_items = shape_items_s[filtered_1_indeces]\n",
    "filtered_1_size_items = size_items_s[filtered_1_indeces]\n",
    "filtered_1_material_items = material_items_s[filtered_1_indeces]\n",
    "filtered_1_threedcoords_items = threedcoords_items_s[filtered_1_indeces]\n",
    "filtered_1_numobjects_items = numobjects_items_s[filtered_1_indeces]\n",
    "\n",
    "# Candidate dataidate data\n",
    "filtered_1_remaining_indeces = [i for i, color_item in enumerate(color_items_m) if len(color_item) != 1]\n",
    "filtered_1_remaining_images = full_images_m[filtered_1_remaining_indeces]\n",
    "filtered_1_remaining_dino_full_atts = dino_full_atts_m[filtered_1_remaining_indeces]\n",
    "filtered_1_remaining_projections = projections_m[filtered_1_remaining_indeces]\n",
    "filtered_1_remaining_color_items = color_items_m[filtered_1_remaining_indeces]\n",
    "filtered_1_remaining_shape_items = shape_items_m[filtered_1_remaining_indeces]\n",
    "filtered_1_remaining_size_items = size_items_m[filtered_1_remaining_indeces]\n",
    "filtered_1_remaining_material_items = material_items_m[filtered_1_remaining_indeces]\n",
    "filtered_1_remaining_threedcoords_items = threedcoords_items_m[filtered_1_remaining_indeces]\n",
    "filtered_1_remaining_numobjects_items = numobjects_items_m[filtered_1_remaining_indeces]\n",
    "\n",
    "\n",
    "print(filtered_1_images.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 108,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the query batches\n",
    "def get_ref_data(data_in, k=3, n=100):\n",
    "    ref_data = []\n",
    "    filtered_indices = data_in[0]\n",
    "    filtered_images = data_in[1]\n",
    "    filtered_dino_full_atts = data_in[2]\n",
    "    filtered_projections = data_in[3]\n",
    "    filtered_color_items = data_in[4]\n",
    "    filtered_shape_items = data_in[5]\n",
    "    filtered_size_items = data_in[6]\n",
    "    filtered_material_items = data_in[7]\n",
    "    filtered_threedcoords_items = data_in[8]\n",
    "    filtered_numobjects_items = data_in[9]\n",
    "    for i in range(k):\n",
    "        np.random.seed(42 + i)\n",
    "        random_indices = np.random.choice(len(filtered_1_color_items), n, replace=False)    \n",
    "        # Get the reference samples\n",
    "        ref_filtered_indices = np.array(filtered_indices)[random_indices]\n",
    "        ref_filtered_images = np.array(filtered_images)[random_indices]\n",
    "        ref_filtered_dino_full_atts = filtered_dino_full_atts[random_indices]\n",
    "        ref_filtered_projections = filtered_projections[random_indices]\n",
    "        ref_filtered_color_items = filtered_color_items[random_indices]\n",
    "        ref_filtered_shape_items = filtered_shape_items[random_indices]\n",
    "        ref_filtered_size_items = filtered_size_items[random_indices]\n",
    "        ref_filtered_material_items = filtered_material_items[random_indices]\n",
    "        ref_filtered_threedcoords_items = filtered_threedcoords_items[random_indices]\n",
    "        ref_filtered_numobjects_items = filtered_numobjects_items[random_indices]\n",
    "\n",
    "        data = (ref_filtered_indices, ref_filtered_images, ref_filtered_dino_full_atts, ref_filtered_projections, ref_filtered_color_items, ref_filtered_shape_items, ref_filtered_size_items, ref_filtered_material_items, ref_filtered_threedcoords_items, ref_filtered_numobjects_items)\n",
    "        ref_data.append(data)\n",
    "\n",
    "    return ref_data\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 109,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_in = [filtered_1_indeces, filtered_1_images, filtered_1_dino_full_atts, filtered_1_projections, filtered_1_color_items, filtered_1_shape_items, filtered_1_size_items, filtered_1_material_items, filtered_1_threedcoords_items, filtered_1_numobjects_items]\n",
    "\n",
    "ref_data = get_ref_data(data_in, 7, 50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 110,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Extract candidates which match the query closely. Record number of matches for normalisation.\n",
    "\n",
    "def extract_examples_with_matches(ref_filtered_indices,\n",
    "                                ref_filtered_images,\n",
    "                                ref_filtered_dino_full_atts,\n",
    "                                ref_filtered_projections,\n",
    "                                ref_filtered_color_items,\n",
    "                                ref_filtered_shape_items,\n",
    "                                ref_filtered_size_items,\n",
    "                                ref_filtered_material_items,\n",
    "                                ref_filtered_threedcoords_items,\n",
    "                                ref_filtered_numobjects_items,\n",
    "                                filtered_remaining_color_items,\n",
    "                                filtered_remaining_shape_items,\n",
    "                                filtered_remaining_size_items,\n",
    "                                filtered_remaining_material_items,\n",
    "                                filtered_remaining_threedcoords_items,\n",
    "                                filtered_remaining_numobjects_items,\n",
    "                                filtered_remaining_images,\n",
    "                                matching_exp=\"1<=x<=5\"):\n",
    "\n",
    "    matches = 0\n",
    "    matched_items = defaultdict(list)\n",
    "    num_matching_items = defaultdict(list)\n",
    "    saved_matching_items = defaultdict(list)\n",
    "    for i, (color_item, shape_item, size_item, material_item) in enumerate(zip(ref_filtered_color_items, ref_filtered_shape_items, ref_filtered_size_items, ref_filtered_material_items)):\n",
    "        # Convert sublists to tuples to make them hashable\n",
    "        ref_objects = list(zip(color_item, shape_item, material_item, size_item))\n",
    "        ref_objects_tuples = [tuple(sublist) for sublist in ref_objects]\n",
    "        # Count the occurrences of each unique sublist\n",
    "        element_counts = Counter(ref_objects_tuples)\n",
    "        \n",
    "        ref_threedcoords_dict = {}\n",
    "        for k in range(len(ref_objects_tuples)):\n",
    "            ref_threedcoords_dict[ref_objects_tuples[k]] = ref_filtered_threedcoords_items[i][0]\n",
    "\n",
    "        for j, (color_item_r, shape_item_r, size_item_r, material_item_r) in enumerate(zip(filtered_remaining_color_items, filtered_remaining_shape_items, filtered_remaining_size_items, filtered_remaining_material_items)): # New scene\n",
    "            remaining_objects = list(zip(color_item_r, shape_item_r, material_item_r, size_item_r))\n",
    "            remaining_objects_tuples = [tuple(sublist) for sublist in remaining_objects]\n",
    "            remaining_element_counts = Counter(remaining_objects_tuples)\n",
    "\n",
    "            threedcoords_dict = {}\n",
    "            for k in range(len(remaining_objects_tuples)):\n",
    "                threedcoords_dict[remaining_objects_tuples[k]] = filtered_remaining_threedcoords_items[j][k]\n",
    "\n",
    "            \n",
    "            num_matches = 0\n",
    "            saved_elements = []\n",
    "\n",
    "            for ref_object in element_counts:\n",
    "                if ref_object in remaining_element_counts:\n",
    "                    \n",
    "                    distance = np.linalg.norm(np.array(ref_threedcoords_dict[ref_object]) - np.array(threedcoords_dict[ref_object]))\n",
    "                    print(distance)\n",
    "                    if element_counts[ref_object] == remaining_element_counts[ref_object] and distance < 1.5:\n",
    "                        print(\"REF_OBJECT: \", ref_object)\n",
    "                        print(\"ELEMENT_COUNTS: \", element_counts)\n",
    "                        print(\"REMAINING_EC: \", remaining_element_counts)\n",
    "                        print(\"THREED_REF: \", ref_threedcoords_dict[ref_object])\n",
    "                        print(\"THREED: \", threedcoords_dict[ref_object])\n",
    "                        num_matches += 1\n",
    "                        saved_elements.append(j)\n",
    "\n",
    "\n",
    "            if matching_exp == \"1<=x<=3\":\n",
    "                if num_matches >= 1 and num_matches <= 3:\n",
    "                    matched_items[i].append(saved_elements[0])\n",
    "                    if i not in num_matching_items:\n",
    "                        num_matching_items[i] = 0\n",
    "                    num_matching_items[i] += 1\n",
    "                    matches += 1\n",
    "            elif matching_exp == \"2<=x<=3\":\n",
    "                if num_matches >= 2 and num_matches <= 3:\n",
    "                    matched_items[i].append(saved_elements[0])\n",
    "                    if i not in num_matching_items:\n",
    "                        num_matching_items[i] = 0\n",
    "                    num_matching_items[i] += 1\n",
    "                    matches += 1\n",
    "            elif matching_exp == \"x==3\":\n",
    "                if num_matches == 3:\n",
    "                    matched_items[i].append(saved_elements[0])\n",
    "                    if i not in num_matching_items:\n",
    "                        num_matching_items[i] = 0\n",
    "                    num_matching_items[i] += 1\n",
    "                    matches += 1\n",
    "            elif matching_exp == \"1<=x<=5\":\n",
    "                if num_matches >= 1 and num_matches <= 5:\n",
    "                    matched_items[i].append(saved_elements[0])\n",
    "                    if i not in num_matching_items:\n",
    "                        num_matching_items[i] = 0\n",
    "                    num_matching_items[i] += 1\n",
    "                    matches += 1\n",
    "            elif matching_exp == \"2<=x<=5\":\n",
    "                if num_matches >= 2 and num_matches <= 5:\n",
    "                    matched_items[i].append(saved_elements[0])\n",
    "                    if i not in num_matching_items:\n",
    "                        num_matching_items[i] = 0\n",
    "                    num_matching_items[i] += 1\n",
    "                    matches += 1\n",
    "            elif matching_exp == \"3<=x<=5\":\n",
    "                if num_matches >= 3 and num_matches <= 5:\n",
    "                    matched_items[i].append(saved_elements[0])\n",
    "                    if i not in num_matching_items:\n",
    "                        num_matching_items[i] = 0\n",
    "                    num_matching_items[i] += 1\n",
    "                    matches += 1\n",
    "            elif matching_exp == \"4<=x<=5\":\n",
    "                if num_matches >= 4 and num_matches <= 5:\n",
    "                    matched_items[i].append(saved_elements[0])\n",
    "                    if i not in num_matching_items:\n",
    "                        num_matching_items[i] = 0\n",
    "                    num_matching_items[i] += 1\n",
    "                    matches += 1\n",
    "            elif matching_exp == \"x==5\":\n",
    "                if num_matches == 5:\n",
    "                    matched_items[i].append(saved_elements[0])\n",
    "                    if i not in num_matching_items:\n",
    "                        num_matching_items[i] = 0\n",
    "                    num_matching_items[i] += 1\n",
    "                    matches += 1\n",
    "\n",
    "    print(\"MATCHES: \", matches / len(ref_filtered_color_items))\n",
    "    print(\"MATCHED_ITEMS: \", matched_items)\n",
    "    print(\"LENGTH: \", len(matched_items))\n",
    "    print(\"KEYS: \", matched_items.keys())\n",
    "    print(num_matching_items)\n",
    "    print(list(num_matching_items.values()))\n",
    "\n",
    "    new_keys = list(matched_items.keys())\n",
    "\n",
    "    ref_filtered_indices = np.array(ref_filtered_indices)[new_keys]\n",
    "    ref_filtered_images = np.array(ref_filtered_images)[new_keys]\n",
    "    ref_filtered_dino_full_atts = ref_filtered_dino_full_atts[new_keys]\n",
    "    ref_filtered_projections = ref_filtered_projections[new_keys]\n",
    "    ref_filtered_color_items = ref_filtered_color_items[new_keys]\n",
    "    ref_filtered_shape_items = ref_filtered_shape_items[new_keys]\n",
    "    ref_filtered_size_items = ref_filtered_size_items[new_keys]\n",
    "    ref_filtered_material_items = ref_filtered_material_items[new_keys]\n",
    "    ref_filtered_threedcoords_items = ref_filtered_threedcoords_items[new_keys]\n",
    "    ref_filtered_numobjects_items = ref_filtered_numobjects_items[new_keys]\n",
    "\n",
    "\n",
    "    plt.imshow(ref_filtered_images[:1].squeeze(0))\n",
    "    plt.show()\n",
    "    plt.imshow(filtered_remaining_images[300])\n",
    "    plt.show()\n",
    "\n",
    "    return (ref_filtered_indices, \n",
    "            ref_filtered_images,\n",
    "            ref_filtered_dino_full_atts,\n",
    "            ref_filtered_projections,\n",
    "            ref_filtered_color_items,\n",
    "            ref_filtered_shape_items,\n",
    "            ref_filtered_size_items,\n",
    "            ref_filtered_material_items, \n",
    "            ref_filtered_threedcoords_items,\n",
    "            ref_filtered_numobjects_items), list(num_matching_items.values()), matched_items\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "matching_exp = \"1<=x<=3\"\n",
    "ref_data_matched = []\n",
    "samples_lists = []\n",
    "matched_items_lists = []\n",
    "for data in ref_data:\n",
    "    ref_filtered_indices = data[0]\n",
    "    ref_filtered_images = data[1]\n",
    "    ref_filtered_dino_full_atts = data[2]\n",
    "    ref_filtered_projections = data[3]\n",
    "    ref_filtered_color_items = data[4]\n",
    "    ref_filtered_shape_items = data[5]\n",
    "    ref_filtered_size_items = data[6]\n",
    "    ref_filtered_material_items = data[7]\n",
    "    ref_filtered_threedcoords_items = data[8]\n",
    "    ref_filtered_numobjects_items = data[9]\n",
    "\n",
    "    ref_data_out, sample_list, matched_items = extract_examples_with_matches(ref_filtered_indices,\n",
    "                                                                ref_filtered_images,\n",
    "                                                                ref_filtered_dino_full_atts,\n",
    "                                                                ref_filtered_projections,\n",
    "                                                                ref_filtered_color_items,\n",
    "                                                                ref_filtered_shape_items,\n",
    "                                                                ref_filtered_size_items,\n",
    "                                                                ref_filtered_material_items,\n",
    "                                                                ref_filtered_threedcoords_items,\n",
    "                                                                ref_filtered_numobjects_items,\n",
    "                                                                filtered_1_remaining_color_items,\n",
    "                                                                filtered_1_remaining_shape_items,\n",
    "                                                                filtered_1_remaining_size_items,\n",
    "                                                                filtered_1_remaining_material_items,\n",
    "                                                                filtered_1_remaining_threedcoords_items,\n",
    "                                                                filtered_1_remaining_numobjects_items,\n",
    "                                                                filtered_1_remaining_images,\n",
    "                                                                matching_exp=matching_exp)\n",
    "    \n",
    "    ref_data_matched.append(ref_data_out)\n",
    "    samples_lists.append(sample_list)\n",
    "    matched_items_lists.append(matched_items)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# If more than 10 matches are found, clip to 10\n",
    "samples_lists = [np.clip(sublist, None, 10) for sublist in samples_lists]\n",
    "print(samples_lists)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 113,
   "metadata": {},
   "outputs": [],
   "source": [
    "from skimage.metrics import structural_similarity as ssim\n",
    "\n",
    "def vae_dino_comparison_multi_abl(ref_object_indeces,\n",
    "                              ref_full_images_m,\n",
    "                              ref_dino_full_atts_m,\n",
    "                              ref_projections_m,\n",
    "                              remaining_images_m,\n",
    "                              remaining_projections_m, \n",
    "                              remaining_dino_full_atts_m, \n",
    "                              ref_color_items_m,\n",
    "                              ref_shape_items_m,\n",
    "                              ref_size_items_m,\n",
    "                              ref_material_items_m,\n",
    "                              ref_threedcoords_items_m,\n",
    "                              ref_numobjects_items_m,\n",
    "                              remaining_color_items_m, \n",
    "                              remaining_shape_items_m, \n",
    "                              remaining_size_items_m, \n",
    "                              remaining_material_items_m,\n",
    "                              remaining_threedcoords_items_m,\n",
    "                              remaining_numobjects_items_m,\n",
    "                              matching_exp=\"1<=x<=3\",\n",
    "                              sample_list=[10]):\n",
    "\n",
    "    cell_size = 15\n",
    "    vae_influence = 40\n",
    "    num_samples = 10\n",
    "\n",
    "\n",
    "    overall_precision = []\n",
    "    overall_w_precision = []\n",
    "    overall_precision_abl_dict = defaultdict(list)\n",
    "    overall_threedcoords_distance = []\n",
    "    matches_dict = defaultdict(int)\n",
    "\n",
    "    theoretical_max = 0\n",
    "    for i in range(10):\n",
    "        theoretical_max += 1 / (i+1)\n",
    "\n",
    "    avg_threedcoords_distance = 0\n",
    "    total_num_matches = 0\n",
    "\n",
    "\n",
    "    for image_index_m in range(len(ref_object_indeces)):\n",
    "        print(\"CURRENT IMAGE: \", image_index_m)\n",
    "        ref_dino_full_atts = ref_dino_full_atts_m[image_index_m]\n",
    "        ref_projections = ref_projections_m[image_index_m]\n",
    "\n",
    "        #print(ref_dino_full_atts_m.shape)\n",
    "        ref_dino_full_atts_repeated = np.repeat(np.expand_dims(ref_dino_full_atts, 0), ref_projections.shape[0], axis=0)\n",
    "        #print(ref_dino_full_atts_repeated.shape)\n",
    "        #print(ref_projections.shape)\n",
    "        ref_projection_dino_concat = np.concatenate((ref_dino_full_atts_repeated.squeeze(1), ref_projections * vae_influence), axis=1)\n",
    "\n",
    "        sorted_vd_indices, sorted_vd_sim = vae_dino_comparison_gpu2(ref_projection_dino_concat, remaining_projections_m, remaining_dino_full_atts_m, vae_influence, 1)\n",
    "        \n",
    "        ref_color_items = ref_color_items_m[image_index_m]\n",
    "        ref_shape_items = ref_shape_items_m[image_index_m]\n",
    "        ref_material_items = ref_material_items_m[image_index_m]\n",
    "        ref_size_items = ref_size_items_m[image_index_m]\n",
    "        ref_threedcoords_items = ref_threedcoords_items_m[image_index_m]\n",
    "\n",
    "        ref_objects = []\n",
    "        ref_threed_distances = []\n",
    "        ref_abl_dict = defaultdict(list)\n",
    "\n",
    "        for i in range(len(ref_color_items)):\n",
    "            ref_objects.append([ref_color_items[i], ref_shape_items[i], ref_material_items[i], ref_size_items[i]])\n",
    "\n",
    "\n",
    "            ref_abl_dict['S'].append([ref_size_items[i]])\n",
    "            ref_abl_dict['D'].append([ref_shape_items[i]])\n",
    "            ref_abl_dict['M'].append([ref_material_items[i]])\n",
    "            ref_abl_dict['C'].append([ref_color_items[i]])\n",
    "\n",
    "            ref_abl_dict['SD'].append([ref_size_items[i], ref_shape_items[i]])\n",
    "            ref_abl_dict['DS'].append([ref_shape_items[i], ref_size_items[i]])\n",
    "            ref_abl_dict['SM'].append([ref_size_items[i], ref_material_items[i]])\n",
    "            ref_abl_dict['MS'].append([ref_material_items[i], ref_size_items[i]])\n",
    "            ref_abl_dict['SC'].append([ref_size_items[i], ref_color_items[i]])\n",
    "            ref_abl_dict['CS'].append([ref_color_items[i], ref_size_items[i]])\n",
    "            ref_abl_dict['DM'].append([ref_shape_items[i], ref_material_items[i]])\n",
    "            ref_abl_dict['MD'].append([ref_material_items[i], ref_shape_items[i]])\n",
    "            ref_abl_dict['DC'].append([ref_shape_items[i], ref_color_items[i]])\n",
    "            ref_abl_dict['CD'].append([ref_color_items[i], ref_shape_items[i]])\n",
    "            ref_abl_dict['MC'].append([ref_material_items[i], ref_color_items[i]])\n",
    "            ref_abl_dict['CM'].append([ref_color_items[i], ref_material_items[i]])\n",
    "\n",
    "            ref_abl_dict['SDM'].append([ref_size_items[i], ref_shape_items[i], ref_material_items[i]])\n",
    "            ref_abl_dict['SMD'].append([ref_size_items[i], ref_material_items[i], ref_shape_items[i]])\n",
    "            ref_abl_dict['DSM'].append([ref_shape_items[i], ref_size_items[i], ref_material_items[i]])\n",
    "            ref_abl_dict['DMS'].append([ref_shape_items[i], ref_material_items[i], ref_size_items[i]])\n",
    "            ref_abl_dict['MSD'].append([ref_material_items[i], ref_size_items[i], ref_shape_items[i]])\n",
    "            ref_abl_dict['MDS'].append([ref_material_items[i], ref_shape_items[i], ref_size_items[i]])\n",
    "\n",
    "            ref_abl_dict['SDC'].append([ref_size_items[i], ref_shape_items[i], ref_color_items[i]])\n",
    "            ref_abl_dict['SCD'].append([ref_size_items[i], ref_color_items[i], ref_shape_items[i]])\n",
    "            ref_abl_dict['DSC'].append([ref_shape_items[i], ref_size_items[i], ref_color_items[i]])\n",
    "            ref_abl_dict['DCS'].append([ref_shape_items[i], ref_color_items[i], ref_size_items[i]])\n",
    "            ref_abl_dict['CSD'].append([ref_color_items[i], ref_size_items[i], ref_shape_items[i]])\n",
    "            ref_abl_dict['CDS'].append([ref_color_items[i], ref_shape_items[i], ref_size_items[i]])\n",
    "\n",
    "            ref_abl_dict['SMC'].append([ref_size_items[i], ref_material_items[i], ref_color_items[i]])\n",
    "            ref_abl_dict['SCM'].append([ref_size_items[i], ref_color_items[i], ref_material_items[i]])\n",
    "            ref_abl_dict['MSC'].append([ref_material_items[i], ref_size_items[i], ref_color_items[i]])\n",
    "            ref_abl_dict['MCS'].append([ref_material_items[i], ref_color_items[i], ref_size_items[i]])\n",
    "            ref_abl_dict['CSM'].append([ref_color_items[i], ref_size_items[i], ref_material_items[i]])\n",
    "            ref_abl_dict['CMS'].append([ref_color_items[i], ref_material_items[i], ref_size_items[i]])\n",
    "\n",
    "            ref_abl_dict['DMC'].append([ref_shape_items[i], ref_material_items[i], ref_color_items[i]])\n",
    "            ref_abl_dict['DCM'].append([ref_shape_items[i], ref_color_items[i], ref_material_items[i]])\n",
    "            ref_abl_dict['MDC'].append([ref_material_items[i], ref_shape_items[i], ref_color_items[i]])\n",
    "            ref_abl_dict['MCD'].append([ref_material_items[i], ref_color_items[i], ref_shape_items[i]])\n",
    "            ref_abl_dict['CMD'].append([ref_color_items[i], ref_material_items[i], ref_shape_items[i]])\n",
    "            ref_abl_dict['CDM'].append([ref_color_items[i], ref_shape_items[i], ref_material_items[i]])\n",
    "            ref_threed_distances.append([ref_threedcoords_items[i]])\n",
    "\n",
    "        print(\"TEST_REF: \", ref_color_items)\n",
    "        print(\"REF_OBJECTS: \", ref_objects)\n",
    "\n",
    "\n",
    "        avg_match_count = 0\n",
    "        avg_w_match_count = 0\n",
    "        avg_match_count_abl_dict = defaultdict(int)\n",
    "\n",
    "        for i, index in enumerate(sorted_vd_indices[0:num_samples+0]):\n",
    "            color_items = remaining_color_items_m[index]\n",
    "            shape_items = remaining_shape_items_m[index]\n",
    "            material_items = remaining_material_items_m[index]\n",
    "            size_items = remaining_size_items_m[index]\n",
    "            threedcoords_items = remaining_threedcoords_items_m[index]\n",
    "\n",
    "            num_matches = 0\n",
    "            num_matches_abl_dict = defaultdict(int)\n",
    "            objects = []\n",
    "            object_abl_dict = defaultdict(list)\n",
    "            threed_distances = []\n",
    "\n",
    "            for i in range(len(color_items)):\n",
    "                objects.append([color_items[i], shape_items[i], material_items[i], size_items[i]])\n",
    "\n",
    "                object_abl_dict['S'].append([size_items[i]])\n",
    "                object_abl_dict['D'].append([shape_items[i]])\n",
    "                object_abl_dict['M'].append([material_items[i]])\n",
    "                object_abl_dict['C'].append([color_items[i]])\n",
    "\n",
    "                object_abl_dict['SD'].append([size_items[i], shape_items[i]])\n",
    "                object_abl_dict['DS'].append([shape_items[i], size_items[i]])\n",
    "                object_abl_dict['SM'].append([size_items[i], material_items[i]])\n",
    "                object_abl_dict['MS'].append([material_items[i], size_items[i]])\n",
    "                object_abl_dict['SC'].append([size_items[i], color_items[i]])\n",
    "                object_abl_dict['CS'].append([color_items[i], size_items[i]])\n",
    "                object_abl_dict['DM'].append([shape_items[i], material_items[i]])\n",
    "                object_abl_dict['MD'].append([material_items[i], shape_items[i]])\n",
    "                object_abl_dict['DC'].append([shape_items[i], color_items[i]])\n",
    "                object_abl_dict['CD'].append([color_items[i], shape_items[i]])\n",
    "                object_abl_dict['MC'].append([material_items[i], color_items[i]])\n",
    "                object_abl_dict['CM'].append([color_items[i], material_items[i]])\n",
    "\n",
    "                object_abl_dict['SDM'].append([size_items[i], shape_items[i], material_items[i]])\n",
    "                object_abl_dict['SMD'].append([size_items[i], material_items[i], shape_items[i]])\n",
    "                object_abl_dict['DSM'].append([shape_items[i], size_items[i], material_items[i]])\n",
    "                object_abl_dict['DMS'].append([shape_items[i], material_items[i], size_items[i]])\n",
    "                object_abl_dict['MSD'].append([material_items[i], size_items[i], shape_items[i]])\n",
    "                object_abl_dict['MDS'].append([material_items[i], shape_items[i], size_items[i]])\n",
    "\n",
    "                object_abl_dict['SDC'].append([size_items[i], shape_items[i], color_items[i]])\n",
    "                object_abl_dict['SCD'].append([size_items[i], color_items[i], shape_items[i]])\n",
    "                object_abl_dict['DSC'].append([shape_items[i], size_items[i], color_items[i]])\n",
    "                object_abl_dict['DCS'].append([shape_items[i], color_items[i], size_items[i]])\n",
    "                object_abl_dict['CSD'].append([color_items[i], size_items[i], shape_items[i]])\n",
    "                object_abl_dict['CDS'].append([color_items[i], shape_items[i], size_items[i]])\n",
    "\n",
    "                object_abl_dict['SMC'].append([size_items[i], material_items[i], color_items[i]])\n",
    "                object_abl_dict['SCM'].append([size_items[i], color_items[i], material_items[i]])\n",
    "                object_abl_dict['MSC'].append([material_items[i], size_items[i], color_items[i]])\n",
    "                object_abl_dict['MCS'].append([material_items[i], color_items[i], size_items[i]])\n",
    "                object_abl_dict['CSM'].append([color_items[i], size_items[i], material_items[i]])\n",
    "                object_abl_dict['CMS'].append([color_items[i], material_items[i], size_items[i]])\n",
    "\n",
    "                object_abl_dict['DMC'].append([shape_items[i], material_items[i], color_items[i]])\n",
    "                object_abl_dict['DCM'].append([shape_items[i], color_items[i], material_items[i]])\n",
    "                object_abl_dict['MDC'].append([material_items[i], shape_items[i], color_items[i]])\n",
    "                object_abl_dict['MCD'].append([material_items[i], color_items[i], shape_items[i]])\n",
    "                object_abl_dict['CMD'].append([color_items[i], material_items[i], shape_items[i]])\n",
    "                object_abl_dict['CDM'].append([color_items[i], shape_items[i], material_items[i]])\n",
    "                threed_distances.append([threedcoords_items[i]])\n",
    "            \n",
    "            print(\"REF_OBJECTS: \", ref_objects)\n",
    "            print(\"OBJECTS: \", objects)\n",
    "\n",
    "            ref_objects_tuples_abl_dict = defaultdict(list)\n",
    "            element_counts_abl_dict = defaultdict(list)\n",
    "            object_tuples_abl_dict = defaultdict(list)\n",
    "            element_counts_objects_abl_dict = defaultdict(list)\n",
    "\n",
    "            # Query\n",
    "            # Convert sublists to tuples to make them hashable\n",
    "            ref_objects_tuples = [tuple(sublist) for sublist in ref_objects]\n",
    "            # Count the occurrences of each unique sublist\n",
    "            element_counts = Counter(ref_objects_tuples)\n",
    "            print(\"Element Counts:\", element_counts)\n",
    "\n",
    "            # Candidate\n",
    "            objects_tuples = [tuple(sublist) for sublist in objects]\n",
    "            element_counts_objects = Counter(objects_tuples)\n",
    "\n",
    "\n",
    "            for subset in [\"S\", \"D\", \"M\", \"C\", \"SD\", \"DS\",\"SM\", \"MS\",\"SC\", \"CS\",\"DM\", \"MD\",\"DC\", \"CD\",\"MC\", \"CM\", \"SDM\", \"SMD\", \"DSM\", \"DMS\", \"MSD\", \"MDS\", \"SDC\", \"SCD\", \"DSC\", \"DCS\", \"CSD\", \"CDS\", \"SMC\", \"SCM\", \"MSC\", \"MCS\", \"CSM\", \"CMS\", \"DMC\", \"DCM\", \"MDC\", \"MCD\", \"CMD\", \"CDM\"]:\n",
    "                ref_objects_tuples_abl_dict[subset] = [tuple(sublist) for sublist in ref_abl_dict[subset]]\n",
    "                element_counts_abl_dict[subset] = Counter(ref_objects_tuples_abl_dict[subset])\n",
    "                object_tuples_abl_dict[subset] = [tuple(sublist) for sublist in object_abl_dict[subset]]\n",
    "                element_counts_objects_abl_dict[subset] = Counter(object_tuples_abl_dict[subset])\n",
    "\n",
    "\n",
    "            for ref_key in element_counts:\n",
    "                if ref_key in element_counts_objects:\n",
    "                    print(ref_key)\n",
    "                    print(element_counts[ref_key])\n",
    "                    print(element_counts_objects[ref_key])\n",
    "                    if element_counts[ref_key] > 0 and element_counts_objects[ref_key] > 0:\n",
    "                        num_matches += element_counts[ref_key]\n",
    "                        avg_threedcoords_distance += np.linalg.norm(np.array(ref_threed_distances[ref_key == element_counts.keys()])[:2] - np.array(threed_distances[ref_key == element_counts_objects.keys()])[:2])\n",
    "\n",
    "\n",
    "            for subset in [\"S\", \"D\", \"M\", \"C\", \"SD\", \"DS\",\"SM\", \"MS\",\"SC\", \"CS\",\"DM\", \"MD\",\"DC\", \"CD\",\"MC\", \"CM\", \"SDM\", \"SMD\", \"DSM\", \"DMS\", \"MSD\", \"MDS\", \"SDC\", \"SCD\", \"DSC\", \"DCS\", \"CSD\", \"CDS\", \"SMC\", \"SCM\", \"MSC\", \"MCS\", \"CSM\", \"CMS\", \"DMC\", \"DCM\", \"MDC\", \"MCD\", \"CMD\", \"CDM\"]:\n",
    "                for ref_key in element_counts_abl_dict[subset]:\n",
    "                    if ref_key in element_counts_objects_abl_dict[subset]:\n",
    "                        if element_counts_abl_dict[subset][ref_key] > 0 and element_counts_objects_abl_dict[subset][ref_key] > 0:\n",
    "                            num_matches_abl_dict[subset] += element_counts_abl_dict[subset][ref_key]\n",
    "        \n",
    "\n",
    "\n",
    "            print(\"NUM_MATCHES: \", num_matches)    \n",
    "            matches_dict[num_matches] += 1\n",
    "            total_num_matches += num_matches\n",
    "\n",
    "            if matching_exp == \"1<=x<=3\":\n",
    "                if num_matches >= 1 and num_matches <= 3:\n",
    "                    avg_w_match_count += 1 / (i+1)\n",
    "                    avg_match_count += 1\n",
    "                    #print(\"MATCH 1<=x<=3\")\n",
    "            elif matching_exp == \"2<=x<=3\":\n",
    "                if num_matches >= 2 and num_matches <= 3:\n",
    "                    avg_w_match_count += 1 / (i+1)\n",
    "                    avg_match_count += 1\n",
    "                    #print(\"MATCH 2<=x<=3\")\n",
    "            elif matching_exp == \"x==3\":\n",
    "                if num_matches == 3:\n",
    "                    avg_w_match_count += 1 / (i+1)\n",
    "                    avg_match_count += 1\n",
    "             \n",
    "               \n",
    "            for subset in [\"S\", \"D\", \"M\", \"C\", \"SD\", \"DS\",\"SM\", \"MS\",\"SC\", \"CS\",\"DM\", \"MD\",\"DC\", \"CD\",\"MC\", \"CM\", \"SDM\", \"SMD\", \"DSM\", \"DMS\", \"MSD\", \"MDS\", \"SDC\", \"SCD\", \"DSC\", \"DCS\", \"CSD\", \"CDS\", \"SMC\", \"SCM\", \"MSC\", \"MCS\", \"CSM\", \"CMS\", \"DMC\", \"DCM\", \"MDC\", \"MCD\", \"CMD\", \"CDM\"]:\n",
    "                    if matching_exp == \"1<=x<=3\":\n",
    "                        if num_matches_abl_dict[subset] >= 1 and num_matches_abl_dict[subset] <= 3:\n",
    "                            avg_match_count_abl_dict[subset] += 1\n",
    "                    elif matching_exp == \"2<=x<=3\": \n",
    "                        if num_matches_abl_dict[subset] >= 2 and num_matches_abl_dict[subset] <= 3:\n",
    "                            avg_match_count_abl_dict[subset] += 1\n",
    "\n",
    "                    elif matching_exp == \"x==3\":\n",
    "                        if num_matches_abl_dict[subset] == 3:\n",
    "                            avg_match_count_abl_dict[subset] += 1\n",
    "\n",
    "        print(\"AVG_PRECISION: \", (avg_match_count / num_samples))\n",
    "        print(\"AVG_W_PRECISION: \", (avg_w_match_count / theoretical_max))\n",
    "        overall_w_precision.append(( avg_w_match_count / theoretical_max))\n",
    "        overall_precision.append((avg_match_count / num_samples))\n",
    "        for subset in [\"S\", \"D\", \"M\", \"C\", \"SD\", \"DS\",\"SM\", \"MS\",\"SC\", \"CS\",\"DM\", \"MD\",\"DC\", \"CD\",\"MC\", \"CM\", \"SDM\", \"SMD\", \"DSM\", \"DMS\", \"MSD\", \"MDS\", \"SDC\", \"SCD\", \"DSC\", \"DCS\", \"CSD\", \"CDS\", \"SMC\", \"SCM\", \"MSC\", \"MCS\", \"CSM\", \"CMS\", \"DMC\", \"DCM\", \"MDC\", \"MCD\", \"CMD\", \"CDM\"]:\n",
    "            overall_precision_abl_dict[subset].append(avg_match_count_abl_dict[subset] / num_samples)\n",
    "        \n",
    "        print(\"OVERALL_ABL: \", overall_precision_abl_dict)\n",
    "\n",
    "    overall_threedcoords_distance.append((avg_threedcoords_distance / total_num_matches))\n",
    "\n",
    "    return overall_precision, overall_w_precision, overall_precision_abl_dict, overall_threedcoords_distance, Counter(overall_precision), matches_dict\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_runs = 3\n",
    "overall_precision_vae_dino = []\n",
    "overall_w_precision_vae_dino = []\n",
    "overall_precision_ssm_vae_dino = []\n",
    "threedcoords_distance_vae_dino = []\n",
    "matches_dict_vae_dino_arr_vae_dino = []\n",
    "hist_arr_vae_dino = []\n",
    "\n",
    "# Create array with the value 10 repeated 50 times\n",
    "#samples_lists = [[10] * 50] * 5\n",
    "\n",
    "for i in range(len(ref_data_matched)):\n",
    "    # Set seed\n",
    "    #np.random.seed(i)\n",
    "    #filtered_multi_object_indeces = np.random.choice(range(len(full_images_m)), 200, replace=False)\n",
    "    #print(filtered_multi_object_indeces)\n",
    "    ref_filtered_3_indices = ref_data_matched[i][0]\n",
    "    ref_filtered_3_images = ref_data_matched[i][1]\n",
    "    ref_filtered_3_dino_full_atts = ref_data_matched[i][2]\n",
    "    ref_filtered_3_projections = ref_data_matched[i][3]\n",
    "    ref_filtered_3_color_items = ref_data_matched[i][4]\n",
    "    ref_filtered_3_shape_items = ref_data_matched[i][5]\n",
    "    ref_filtered_3_size_items = ref_data_matched[i][6]\n",
    "    ref_filtered_3_material_items = ref_data_matched[i][7]\n",
    "    ref_filtered_3_threedcoords_items = ref_data_matched[i][8]\n",
    "    ref_filtered_3_numobjects_items = ref_data_matched[i][9]\n",
    "\n",
    "    overall_precision, overall_w_precision, overall_precision_ssm, threedcoords_distance, matches_dict, hist_dict  = vae_dino_comparison_multi_abl(ref_filtered_3_indices,\n",
    "                                                                                                                                                 ref_filtered_3_images,\n",
    "                                                                                                                                                 ref_filtered_3_dino_full_atts,\n",
    "                                                                                                                                                 ref_filtered_3_projections,\n",
    "                                                                                                                                                 filtered_1_remaining_images,\n",
    "                                                                                                                                                 filtered_1_remaining_projections,\n",
    "                                                                                                                                                 filtered_1_remaining_dino_full_atts,\n",
    "                                                                                                                                                 ref_filtered_3_color_items,\n",
    "                                                                                                                                                 ref_filtered_3_shape_items,\n",
    "                                                                                                                                                 ref_filtered_3_size_items,\n",
    "                                                                                                                                                 ref_filtered_3_material_items,\n",
    "                                                                                                                                                 ref_filtered_3_threedcoords_items,\n",
    "                                                                                                                                                 ref_filtered_3_numobjects_items,\n",
    "                                                                                                                                                 filtered_1_remaining_color_items,\n",
    "                                                                                                                                                 filtered_1_remaining_shape_items,\n",
    "                                                                                                                                                 filtered_1_remaining_size_items,\n",
    "                                                                                                                                                 filtered_1_remaining_material_items,\n",
    "                                                                                                                                                 filtered_1_remaining_threedcoords_items,\n",
    "                                                                                                                                                 filtered_1_remaining_numobjects_items,\n",
    "                                                                                                                                                 matching_exp=matching_exp,\n",
    "                                                                                                                                                 sample_list=samples_lists[i])  \n",
    "    overall_precision_vae_dino.append(overall_precision)\n",
    "    overall_w_precision_vae_dino.append(overall_w_precision)\n",
    "    overall_precision_ssm_vae_dino.append(overall_precision_ssm)\n",
    "    threedcoords_distance_vae_dino.append(threedcoords_distance)\n",
    "    matches_dict_vae_dino_arr_vae_dino.append(matches_dict)\n",
    "    hist_arr_vae_dino.append(hist_dict)\n",
    "    \n",
    "print(\"OVERALL_PRECISION: \", overall_precision_vae_dino)\n",
    "print(\"OVERALL_W_PRECISION: \", overall_w_precision_vae_dino)\n",
    "print(\"OVERALL_PRECISION_SSM: \", overall_precision_ssm_vae_dino)\n",
    "print(\"THREEDCOORDS_DISTANCE: \", threedcoords_distance_vae_dino)\n",
    "print(\"MATCHES_DICT: \", matches_dict_vae_dino_arr_vae_dino)\n",
    "print(\"HIST_DICT: \", hist_arr_vae_dino) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get ablation results for individual scores\n",
    "from collections import defaultdict\n",
    "import numpy as np\n",
    "\n",
    "# Assuming `dict_list` is your list of defaultdicts\n",
    "means_list = []\n",
    "\n",
    "for d in overall_precision_ssm_vae_dino:  # Replace `dict_list` with your actual variable\n",
    "    mean_dict = {}\n",
    "    for key, values in d.items():\n",
    "        mean_dict[key] = np.mean(values)\n",
    "    means_list.append(mean_dict)\n",
    "\n",
    "# If you want to compute the overall mean across all dicts for each key:\n",
    "overall_means = defaultdict(list)\n",
    "for mean_dict in means_list:\n",
    "    for key, value in mean_dict.items():\n",
    "        overall_means[key].append(value)\n",
    "\n",
    "final_means = {k: np.mean(v) for k, v in overall_means.items()}\n",
    "\n",
    "print(final_means)\n",
    "\n",
    "length_1_mean = 0\n",
    "length_1_num_elements = 0\n",
    "\n",
    "length_2_mean = 0\n",
    "length_2_num_elements = 0\n",
    "length_3_mean = 0\n",
    "length_3_num_elements = 0\n",
    "for key in final_means:\n",
    "    if len(key) == 1:\n",
    "        length_1_mean += final_means[key]\n",
    "        length_1_num_elements += 1\n",
    "    elif len(key) == 2:\n",
    "        length_2_mean += final_means[key]\n",
    "        length_2_num_elements += 1\n",
    "\n",
    "    elif len(key) == 3:\n",
    "        length_3_mean += final_means[key]\n",
    "        length_3_num_elements += 1\n",
    "\n",
    "\n",
    "print(length_1_mean / length_1_num_elements)\n",
    "print(length_2_mean / length_2_num_elements)\n",
    "print(length_3_mean / length_3_num_elements)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get attribute ablation\n",
    "\n",
    "def get_filtered(final_means_filtered):\n",
    "    length_1_mean = 0\n",
    "    length_1_num_elements = 0\n",
    "\n",
    "    length_2_mean = 0\n",
    "    length_2_num_elements = 0\n",
    "    length_3_mean = 0\n",
    "    length_3_num_elements = 0\n",
    "    for key in final_means_filtered:\n",
    "        if len(key) == 1:\n",
    "            length_1_mean += final_means_filtered[key]\n",
    "            length_1_num_elements += 1\n",
    "        elif len(key) == 2:\n",
    "            length_2_mean += final_means_filtered[key]\n",
    "            length_2_num_elements += 1\n",
    "\n",
    "        elif len(key) == 3:\n",
    "            length_3_mean += final_means_filtered[key]\n",
    "            length_3_num_elements += 1\n",
    "\n",
    "\n",
    "    print(length_1_mean / length_1_num_elements)\n",
    "    print(length_2_mean / length_2_num_elements)\n",
    "    print(length_3_mean / length_3_num_elements)\n",
    "\n",
    "\n",
    "print(\"COLOR ABLATION\")\n",
    "final_means_filtered = {}\n",
    "for key in final_means:\n",
    "    if 'C' not in key:\n",
    "        final_means_filtered[key] = final_means[key]\n",
    "\n",
    "get_filtered(final_means_filtered)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get overall score\n",
    "\n",
    "print(\"OVERALL_PRECISION: \", np.mean(overall_precision_vae_dino))\n",
    "print(\"OVERALL_W_PRECISION: \", np.mean(overall_w_precision_vae_dino))\n",
    "print(\"THREEDCOORDS_DISTANCE: \", np.mean(threedcoords_distance_vae_dino))\n",
    "print(\"MATCHES_DICT: \", matches_dict_vae_dino_arr_vae_dino)\n",
    "print(\"HIST_DICT: \", hist_arr_vae_dino)\n",
    "\n",
    "error_rate = 0\n",
    "error_rates = []\n",
    "for i in range(7):\n",
    "        error_rate += matches_dict_vae_dino_arr_vae_dino[i][0.0]\n",
    "        error_rates.append(matches_dict_vae_dino_arr_vae_dino[i][0.0])\n",
    "\n",
    "print(error_rates)\n",
    "\n",
    "print(error_rate / 7)\n",
    "print(np.mean(error_rates))\n",
    "print(np.std(error_rates))\n",
    "\n",
    "print(np.std(np.mean(np.array(overall_precision_vae_dino), axis=1)))\n",
    "print(np.std(np.mean(np.array(overall_w_precision_vae_dino), axis=1)))\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "from skimage.metrics import structural_similarity as ssim\n",
    "\n",
    "def dino_test_single_abl(ref_full_images_m,\n",
    "                    ref_object_indeces,\n",
    "                    ref_dino_full_atts_m,\n",
    "                    remaining_full_images_m,\n",
    "                    remaining_dino_full_atts_m,\n",
    "                    ref_color_items_m,\n",
    "                    ref_shape_items_m,\n",
    "                    ref_size_items_m,\n",
    "                    ref_material_items_m,\n",
    "                    ref_threedcoords_items_m,\n",
    "                    ref_numobjects_items_m,\n",
    "                    remaining_color_items_m,\n",
    "                    remaining_shape_items_m,\n",
    "                    remaining_size_items_m,\n",
    "                    remaining_material_items_m,\n",
    "                    remaining_threedcoords_items_m,\n",
    "                    remaining_numobjects_items_m,\n",
    "                    matching_exp=\"1<=x<=3\",\n",
    "                    sample_list=[10]):\n",
    "\n",
    "    num_samples = 10\n",
    "\n",
    "    overall_precision = []\n",
    "    overall_w_precision = []\n",
    "    overall_precision_abl_dict = defaultdict(list)\n",
    "    overall_threedcoords_distance = []\n",
    "    matches_dict = defaultdict(int)\n",
    "\n",
    "    theoretical_max = 0\n",
    "    for i in range(10):\n",
    "        theoretical_max += 1 / (i+1)\n",
    "\n",
    "    avg_threedcoords_distance = 0\n",
    "    total_num_matches = 0\n",
    "\n",
    "\n",
    "    for image_index_m in range(len(ref_object_indeces)):\n",
    "        print(\"CURRENT IMAGE: \", image_index_m)\n",
    "\n",
    "        ref_dino_full_atts = ref_dino_full_atts_m[image_index_m]\n",
    "        sorted_vd_indices, sorted_d_sim = dino_comparison(ref_dino_full_atts, remaining_dino_full_atts_m)\n",
    "\n",
    "\n",
    "        ref_color_items = ref_color_items_m[image_index_m]\n",
    "        ref_shape_items = ref_shape_items_m[image_index_m]\n",
    "        ref_material_items = ref_material_items_m[image_index_m]\n",
    "        ref_size_items = ref_size_items_m[image_index_m]\n",
    "        ref_threedcoords_items = ref_threedcoords_items_m[image_index_m]\n",
    "\n",
    "        ref_objects = []\n",
    "        ref_threed_distances = []\n",
    "\n",
    "        ref_abl_dict = defaultdict(list)\n",
    "\n",
    "        for i in range(len(ref_color_items)):\n",
    "            ref_objects.append([ref_color_items[i], ref_shape_items[i], ref_material_items[i], ref_size_items[i]])\n",
    "\n",
    "            ref_abl_dict['S'].append([ref_size_items[i]])\n",
    "            ref_abl_dict['D'].append([ref_shape_items[i]])\n",
    "            ref_abl_dict['M'].append([ref_material_items[i]])\n",
    "            ref_abl_dict['C'].append([ref_color_items[i]])\n",
    "\n",
    "            ref_abl_dict['SD'].append([ref_size_items[i], ref_shape_items[i]])\n",
    "            ref_abl_dict['DS'].append([ref_shape_items[i], ref_size_items[i]])\n",
    "            ref_abl_dict['SM'].append([ref_size_items[i], ref_material_items[i]])\n",
    "            ref_abl_dict['MS'].append([ref_material_items[i], ref_size_items[i]])\n",
    "            ref_abl_dict['SC'].append([ref_size_items[i], ref_color_items[i]])\n",
    "            ref_abl_dict['CS'].append([ref_color_items[i], ref_size_items[i]])\n",
    "            ref_abl_dict['DM'].append([ref_shape_items[i], ref_material_items[i]])\n",
    "            ref_abl_dict['MD'].append([ref_material_items[i], ref_shape_items[i]])\n",
    "            ref_abl_dict['DC'].append([ref_shape_items[i], ref_color_items[i]])\n",
    "            ref_abl_dict['CD'].append([ref_color_items[i], ref_shape_items[i]])\n",
    "            ref_abl_dict['MC'].append([ref_material_items[i], ref_color_items[i]])\n",
    "            ref_abl_dict['CM'].append([ref_color_items[i], ref_material_items[i]])\n",
    "\n",
    "            ref_abl_dict['SDM'].append([ref_size_items[i], ref_shape_items[i], ref_material_items[i]])\n",
    "            ref_abl_dict['SMD'].append([ref_size_items[i], ref_material_items[i], ref_shape_items[i]])\n",
    "            ref_abl_dict['DSM'].append([ref_shape_items[i], ref_size_items[i], ref_material_items[i]])\n",
    "            ref_abl_dict['DMS'].append([ref_shape_items[i], ref_material_items[i], ref_size_items[i]])\n",
    "            ref_abl_dict['MSD'].append([ref_material_items[i], ref_size_items[i], ref_shape_items[i]])\n",
    "            ref_abl_dict['MDS'].append([ref_material_items[i], ref_shape_items[i], ref_size_items[i]])\n",
    "\n",
    "            ref_abl_dict['SDC'].append([ref_size_items[i], ref_shape_items[i], ref_color_items[i]])\n",
    "            ref_abl_dict['SCD'].append([ref_size_items[i], ref_color_items[i], ref_shape_items[i]])\n",
    "            ref_abl_dict['DSC'].append([ref_shape_items[i], ref_size_items[i], ref_color_items[i]])\n",
    "            ref_abl_dict['DCS'].append([ref_shape_items[i], ref_color_items[i], ref_size_items[i]])\n",
    "            ref_abl_dict['CSD'].append([ref_color_items[i], ref_size_items[i], ref_shape_items[i]])\n",
    "            ref_abl_dict['CDS'].append([ref_color_items[i], ref_shape_items[i], ref_size_items[i]])\n",
    "\n",
    "            ref_abl_dict['SMC'].append([ref_size_items[i], ref_material_items[i], ref_color_items[i]])\n",
    "            ref_abl_dict['SCM'].append([ref_size_items[i], ref_color_items[i], ref_material_items[i]])\n",
    "            ref_abl_dict['MSC'].append([ref_material_items[i], ref_size_items[i], ref_color_items[i]])\n",
    "            ref_abl_dict['MCS'].append([ref_material_items[i], ref_color_items[i], ref_size_items[i]])\n",
    "            ref_abl_dict['CSM'].append([ref_color_items[i], ref_size_items[i], ref_material_items[i]])\n",
    "            ref_abl_dict['CMS'].append([ref_color_items[i], ref_material_items[i], ref_size_items[i]])\n",
    "\n",
    "            ref_abl_dict['DMC'].append([ref_shape_items[i], ref_material_items[i], ref_color_items[i]])\n",
    "            ref_abl_dict['DCM'].append([ref_shape_items[i], ref_color_items[i], ref_material_items[i]])\n",
    "            ref_abl_dict['MDC'].append([ref_material_items[i], ref_shape_items[i], ref_color_items[i]])\n",
    "            ref_abl_dict['MCD'].append([ref_material_items[i], ref_color_items[i], ref_shape_items[i]])\n",
    "            ref_abl_dict['CMD'].append([ref_color_items[i], ref_material_items[i], ref_shape_items[i]])\n",
    "            ref_abl_dict['CDM'].append([ref_color_items[i], ref_shape_items[i], ref_material_items[i]])\n",
    "\n",
    "            ref_threed_distances.append([ref_threedcoords_items[i]])\n",
    "\n",
    "        print(\"TEST_REF: \", ref_color_items)\n",
    "        print(\"REF_OBJECTS: \", ref_objects)\n",
    "\n",
    "\n",
    "        avg_match_count = 0\n",
    "        avg_w_match_count = 0\n",
    "        avg_match_count_abl_dict = defaultdict(int)\n",
    "\n",
    "        for i, index in enumerate(sorted_vd_indices[0:num_samples]):\n",
    "            color_items = remaining_color_items_m[index]\n",
    "            shape_items = remaining_shape_items_m[index]\n",
    "            material_items = remaining_material_items_m[index]\n",
    "            size_items = remaining_size_items_m[index]\n",
    "            threedcoords_items = remaining_threedcoords_items_m[index]\n",
    "\n",
    "            num_matches = 0\n",
    "            num_matches_abl_dict = defaultdict(int)\n",
    "\n",
    "            objects = []\n",
    "            objects_ssm = []\n",
    "            objects_ss = []\n",
    "            objects_s = []\n",
    "            object_abl_dict = defaultdict(list)\n",
    "\n",
    "            threed_distances = []\n",
    "\n",
    "            for i in range(len(color_items)):\n",
    "                objects.append([color_items[i], shape_items[i], material_items[i], size_items[i]])\n",
    "                objects_ssm.append([shape_items[i], material_items[i], size_items[i]])\n",
    "                objects_ss.append([shape_items[i], size_items[i]])\n",
    "                objects_s.append([shape_items[i]])\n",
    "\n",
    "                object_abl_dict['S'].append([size_items[i]])\n",
    "                object_abl_dict['D'].append([shape_items[i]])\n",
    "                object_abl_dict['M'].append([material_items[i]])\n",
    "                object_abl_dict['C'].append([color_items[i]])\n",
    "\n",
    "                object_abl_dict['SD'].append([size_items[i], shape_items[i]])\n",
    "                object_abl_dict['DS'].append([shape_items[i], size_items[i]])\n",
    "                object_abl_dict['SM'].append([size_items[i], material_items[i]])\n",
    "                object_abl_dict['MS'].append([material_items[i], size_items[i]])\n",
    "                object_abl_dict['SC'].append([size_items[i], color_items[i]])\n",
    "                object_abl_dict['CS'].append([color_items[i], size_items[i]])\n",
    "                object_abl_dict['DM'].append([shape_items[i], material_items[i]])\n",
    "                object_abl_dict['MD'].append([material_items[i], shape_items[i]])\n",
    "                object_abl_dict['DC'].append([shape_items[i], color_items[i]])\n",
    "                object_abl_dict['CD'].append([color_items[i], shape_items[i]])\n",
    "                object_abl_dict['MC'].append([material_items[i], color_items[i]])\n",
    "                object_abl_dict['CM'].append([color_items[i], material_items[i]])\n",
    "\n",
    "                object_abl_dict['SDM'].append([size_items[i], shape_items[i], material_items[i]])\n",
    "                object_abl_dict['SMD'].append([size_items[i], material_items[i], shape_items[i]])\n",
    "                object_abl_dict['DSM'].append([shape_items[i], size_items[i], material_items[i]])\n",
    "                object_abl_dict['DMS'].append([shape_items[i], material_items[i], size_items[i]])\n",
    "                object_abl_dict['MSD'].append([material_items[i], size_items[i], shape_items[i]])\n",
    "                object_abl_dict['MDS'].append([material_items[i], shape_items[i], size_items[i]])\n",
    "\n",
    "                object_abl_dict['SDC'].append([size_items[i], shape_items[i], color_items[i]])\n",
    "                object_abl_dict['SCD'].append([size_items[i], color_items[i], shape_items[i]])\n",
    "                object_abl_dict['DSC'].append([shape_items[i], size_items[i], color_items[i]])\n",
    "                object_abl_dict['DCS'].append([shape_items[i], color_items[i], size_items[i]])\n",
    "                object_abl_dict['CSD'].append([color_items[i], size_items[i], shape_items[i]])\n",
    "                object_abl_dict['CDS'].append([color_items[i], shape_items[i], size_items[i]])\n",
    "\n",
    "                object_abl_dict['SMC'].append([size_items[i], material_items[i], color_items[i]])\n",
    "                object_abl_dict['SCM'].append([size_items[i], color_items[i], material_items[i]])\n",
    "                object_abl_dict['MSC'].append([material_items[i], size_items[i], color_items[i]])\n",
    "                object_abl_dict['MCS'].append([material_items[i], color_items[i], size_items[i]])\n",
    "                object_abl_dict['CSM'].append([color_items[i], size_items[i], material_items[i]])\n",
    "                object_abl_dict['CMS'].append([color_items[i], material_items[i], size_items[i]])\n",
    "\n",
    "                object_abl_dict['DMC'].append([shape_items[i], material_items[i], color_items[i]])\n",
    "                object_abl_dict['DCM'].append([shape_items[i], color_items[i], material_items[i]])\n",
    "                object_abl_dict['MDC'].append([material_items[i], shape_items[i], color_items[i]])\n",
    "                object_abl_dict['MCD'].append([material_items[i], color_items[i], shape_items[i]])\n",
    "                object_abl_dict['CMD'].append([color_items[i], material_items[i], shape_items[i]])\n",
    "                object_abl_dict['CDM'].append([color_items[i], shape_items[i], material_items[i]])\n",
    "                \n",
    "                threed_distances.append([threedcoords_items[i]])\n",
    "\n",
    "            print(\"REF_OBJECTS: \", ref_objects)\n",
    "            print(\"OBJECTS: \", objects)\n",
    "\n",
    "            ref_objects_tuples_abl_dict = defaultdict(list)\n",
    "            element_counts_abl_dict = defaultdict(list)\n",
    "            object_tuples_abl_dict = defaultdict(list)\n",
    "            element_counts_objects_abl_dict = defaultdict(list)\n",
    "            \n",
    "            ref_objects_tuples = [tuple(sublist) for sublist in ref_objects]\n",
    "            element_counts = Counter(ref_objects_tuples)\n",
    "            objects_tuples = [tuple(sublist) for sublist in objects]\n",
    "            element_counts_objects = Counter(objects_tuples)\n",
    "\n",
    "            for subset in [\"S\", \"D\", \"M\", \"C\", \"SD\", \"DS\",\"SM\", \"MS\",\"SC\", \"CS\",\"DM\", \"MD\",\"DC\", \"CD\",\"MC\", \"CM\", \"SDM\", \"SMD\", \"DSM\", \"DMS\", \"MSD\", \"MDS\", \"SDC\", \"SCD\", \"DSC\", \"DCS\", \"CSD\", \"CDS\", \"SMC\", \"SCM\", \"MSC\", \"MCS\", \"CSM\", \"CMS\", \"DMC\", \"DCM\", \"MDC\", \"MCD\", \"CMD\", \"CDM\"]:\n",
    "\n",
    "                ref_objects_tuples_abl_dict[subset] = [tuple(sublist) for sublist in ref_abl_dict[subset]]\n",
    "\n",
    "                element_counts_abl_dict[subset] = Counter(ref_objects_tuples_abl_dict[subset])\n",
    "\n",
    "                print(\"Element Counts:\", element_counts)\n",
    "\n",
    "               \n",
    "                object_tuples_abl_dict[subset] = [tuple(sublist) for sublist in object_abl_dict[subset]]\n",
    "                element_counts_objects_abl_dict[subset] = Counter(object_tuples_abl_dict[subset])\n",
    "\n",
    "                print(\"Element Counts Objects:\", element_counts_objects)\n",
    "                print(\"INDEX: \", index)\n",
    "\n",
    "                print(element_counts_objects.keys())\n",
    "\n",
    "            for ref_key in element_counts:\n",
    "                if ref_key in element_counts_objects:\n",
    "                    print(ref_key)\n",
    "                    print(element_counts[ref_key])\n",
    "                    print(element_counts_objects[ref_key])\n",
    "                    if element_counts[ref_key] > 0 and element_counts_objects[ref_key] > 0:\n",
    "                        num_matches += element_counts[ref_key]\n",
    "                        avg_threedcoords_distance += np.linalg.norm(np.array(ref_threed_distances[ref_key == element_counts.keys()])[:2] - np.array(threed_distances[ref_key == element_counts_objects.keys()])[:2])\n",
    "\n",
    "\n",
    "            for subset in [\"S\", \"D\", \"M\", \"C\", \"SD\", \"DS\",\"SM\", \"MS\",\"SC\", \"CS\",\"DM\", \"MD\",\"DC\", \"CD\",\"MC\", \"CM\", \"SDM\", \"SMD\", \"DSM\", \"DMS\", \"MSD\", \"MDS\", \"SDC\", \"SCD\", \"DSC\", \"DCS\", \"CSD\", \"CDS\", \"SMC\", \"SCM\", \"MSC\", \"MCS\", \"CSM\", \"CMS\", \"DMC\", \"DCM\", \"MDC\", \"MCD\", \"CMD\", \"CDM\"]:\n",
    "                for ref_key in element_counts_abl_dict[subset]:\n",
    "                    if ref_key in element_counts_objects_abl_dict[subset]:\n",
    "                        if element_counts_abl_dict[subset][ref_key] > 0 and element_counts_objects_abl_dict[subset][ref_key] > 0:\n",
    "                            num_matches_abl_dict[subset] += element_counts_abl_dict[subset][ref_key]\n",
    "        \n",
    "            print(num_matches_abl_dict)\n",
    "            print(\"NUM_MATCHES: \", num_matches)    \n",
    "            matches_dict[num_matches] += 1\n",
    "            total_num_matches += num_matches\n",
    "\n",
    "\n",
    "            if matching_exp == \"1<=x<=3\":\n",
    "                if num_matches >= 1 and num_matches <= 3:\n",
    "                    avg_w_match_count += 1 / (i+1)\n",
    "                    avg_match_count += 1\n",
    "            elif matching_exp == \"2<=x<=3\": # For multi-object\n",
    "                if num_matches >= 2 and num_matches <= 3:\n",
    "                    avg_w_match_count += 1 / (i+1)\n",
    "                    avg_match_count += 1            \n",
    "            elif matching_exp == \"x==3\":\n",
    "                if num_matches == 3:\n",
    "                    avg_w_match_count += 1 / (i+1)\n",
    "                    avg_match_count += 1\n",
    "               \n",
    "\n",
    "            for subset in [\"S\", \"D\", \"M\", \"C\", \"SD\", \"DS\",\"SM\", \"MS\",\"SC\", \"CS\",\"DM\", \"MD\",\"DC\", \"CD\",\"MC\", \"CM\", \"SDM\", \"SMD\", \"DSM\", \"DMS\", \"MSD\", \"MDS\", \"SDC\", \"SCD\", \"DSC\", \"DCS\", \"CSD\", \"CDS\", \"SMC\", \"SCM\", \"MSC\", \"MCS\", \"CSM\", \"CMS\", \"DMC\", \"DCM\", \"MDC\", \"MCD\", \"CMD\", \"CDM\"]:\n",
    "                if matching_exp == \"1<=x<=3\":\n",
    "                    if num_matches_abl_dict[subset] >= 1 and num_matches_abl_dict[subset] <= 3:\n",
    "                        avg_match_count_abl_dict[subset] += 1\n",
    "                elif matching_exp == \"2<=x<=3\": \n",
    "                    if num_matches_abl_dict[subset] >= 2 and num_matches_abl_dict[subset] <= 3:\n",
    "                        avg_match_count_abl_dict[subset] += 1\n",
    "\n",
    "                elif matching_exp == \"x==3\":\n",
    "                    if num_matches_abl_dict[subset] == 3:\n",
    "                        avg_match_count_abl_dict[subset] += 1\n",
    "\n",
    "\n",
    "\n",
    "        print(\"AVG_PRECISION: \", (avg_match_count / num_samples))\n",
    "        print(\"AVG_W_PRECISION: \", (avg_w_match_count / theoretical_max))\n",
    "\n",
    "        overall_w_precision.append(( avg_w_match_count / theoretical_max))\n",
    "        overall_precision.append((avg_match_count / num_samples))\n",
    "\n",
    "        for subset in [\"S\", \"D\", \"M\", \"C\", \"SD\", \"DS\",\"SM\", \"MS\",\"SC\", \"CS\",\"DM\", \"MD\",\"DC\", \"CD\",\"MC\", \"CM\", \"SDM\", \"SMD\", \"DSM\", \"DMS\", \"MSD\", \"MDS\", \"SDC\", \"SCD\", \"DSC\", \"DCS\", \"CSD\", \"CDS\", \"SMC\", \"SCM\", \"MSC\", \"MCS\", \"CSM\", \"CMS\", \"DMC\", \"DCM\", \"MDC\", \"MCD\", \"CMD\", \"CDM\"]:\n",
    "            overall_precision_abl_dict[subset].append(avg_match_count_abl_dict[subset] / num_samples)\n",
    "        \n",
    "        print(\"OVERALL_ABL: \", overall_precision_abl_dict)\n",
    "\n",
    "    overall_threedcoords_distance.append((avg_threedcoords_distance / total_num_matches))\n",
    "\n",
    "    return overall_precision, overall_w_precision, overall_precision_abl_dict, overall_threedcoords_distance, Counter(overall_precision), matches_dict\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_runs = 3\n",
    "overall_precision_dino = []\n",
    "overall_w_precision_dino = []\n",
    "overall_precision_ssm_dino = []\n",
    "threedcoords_distance_dino = []\n",
    "matches_dict_dino_arr_dino = []\n",
    "hist_arr_dino = []\n",
    "#samples_lists = [[10] * 50] * 5\n",
    "\n",
    "for i in range(len(ref_data_matched)):\n",
    "\n",
    "    ref_filtered_indices = ref_data_matched[i][0]\n",
    "    ref_filtered_images = ref_data_matched[i][1]\n",
    "    ref_filtered_dino_full_atts = ref_data_matched[i][2]\n",
    "    ref_filtered_projections = ref_data_matched[i][3]\n",
    "    ref_filtered_color_items = ref_data_matched[i][4]\n",
    "    ref_filtered_shape_items = ref_data_matched[i][5]\n",
    "    ref_filtered_size_items = ref_data_matched[i][6]\n",
    "    ref_filtered_material_items = ref_data_matched[i][7]\n",
    "    ref_filtered_threedcoords_items = ref_data_matched[i][8]\n",
    "    ref_filtered_numobjects_items = ref_data_matched[i][9]\n",
    "\n",
    "    overall_precision, overall_w_precision, overall_precision_ssm, threedcoords_distance, matches_dict, hist_dict = dino_test_single_abl(ref_filtered_images, \n",
    "                                                                                                                                        ref_filtered_indices,\n",
    "                                                                                                                                        ref_filtered_dino_full_atts,\n",
    "                                                                                                                                        filtered_1_remaining_images,\n",
    "                                                                                                                                        filtered_1_remaining_dino_full_atts,\n",
    "                                                                                                                                        ref_filtered_color_items,\n",
    "                                                                                                                                        ref_filtered_shape_items,\n",
    "                                                                                                                                        ref_filtered_size_items,\n",
    "                                                                                                                                        ref_filtered_material_items,\n",
    "                                                                                                                                        ref_filtered_threedcoords_items,\n",
    "                                                                                                                                        ref_filtered_numobjects_items,\n",
    "                                                                                                                                        filtered_1_remaining_color_items,\n",
    "                                                                                                                                        filtered_1_remaining_shape_items,\n",
    "                                                                                                                                        filtered_1_remaining_size_items,\n",
    "                                                                                                                                        filtered_1_remaining_material_items,\n",
    "                                                                                                                                        filtered_1_remaining_threedcoords_items,\n",
    "                                                                                                                                        filtered_1_remaining_numobjects_items,\n",
    "                                                                                                                                        matching_exp=matching_exp,\n",
    "                                                                                                                                        sample_list=samples_lists[i])                                                                                                                                                                        \n",
    "    overall_precision_dino.append(overall_precision)\n",
    "    overall_w_precision_dino.append(overall_w_precision)\n",
    "    overall_precision_ssm_dino.append(overall_precision_ssm)\n",
    "    threedcoords_distance_dino.append(threedcoords_distance)\n",
    "    matches_dict_dino_arr_dino.append(matches_dict)\n",
    "    hist_arr_dino.append(hist_dict)\n",
    "    \n",
    "\n",
    "    \n",
    "print(\"OVERALL_PRECISION: \", overall_precision_dino)\n",
    "print(\"OVERALL_W_PRECISION: \", overall_w_precision_dino)\n",
    "print(\"OVERALL_PRECISION_SSM: \", overall_precision_ssm_dino)\n",
    "print(\"THREEDCOORDS_DISTANCE: \", threedcoords_distance_dino)\n",
    "print(\"MATCHES_DICT: \", matches_dict_dino_arr_dino)\n",
    "print(\"HIST_DICT: \", hist_arr_dino)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "import numpy as np\n",
    "\n",
    "# Assuming `dict_list` is your list of defaultdicts\n",
    "means_list = []\n",
    "\n",
    "for d in overall_precision_ssm_dino:  # Replace `dict_list` with your actual variable\n",
    "    mean_dict = {}\n",
    "    for key, values in d.items():\n",
    "        mean_dict[key] = np.mean(values)\n",
    "    means_list.append(mean_dict)\n",
    "\n",
    "# If you want to compute the overall mean across all dicts for each key:\n",
    "overall_means = defaultdict(list)\n",
    "for mean_dict in means_list:\n",
    "    for key, value in mean_dict.items():\n",
    "        overall_means[key].append(value)\n",
    "\n",
    "final_means = {k: np.mean(v) for k, v in overall_means.items()}\n",
    "\n",
    "print(final_means)\n",
    "\n",
    "length_1_mean = 0\n",
    "length_1_num_elements = 0\n",
    "length_2_mean = 0\n",
    "length_2_num_elements = 0\n",
    "length_3_mean = 0\n",
    "length_3_num_elements = 0\n",
    "for key in final_means:\n",
    "    if len(key) == 1:\n",
    "        length_1_mean += final_means[key]\n",
    "        length_1_num_elements += 1\n",
    "    elif len(key) == 2:\n",
    "        length_2_mean += final_means[key]\n",
    "        length_2_num_elements += 1\n",
    "\n",
    "    elif len(key) == 3:\n",
    "        length_3_mean += final_means[key]\n",
    "        length_3_num_elements += 1\n",
    "\n",
    "\n",
    "print(length_1_mean / length_1_num_elements)\n",
    "print(length_2_mean / length_2_num_elements)\n",
    "print(length_3_mean / length_3_num_elements)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_filtered(final_means_filtered):\n",
    "    length_1_mean = 0\n",
    "    length_1_num_elements = 0\n",
    "\n",
    "    length_2_mean = 0\n",
    "    length_2_num_elements = 0\n",
    "    length_3_mean = 0\n",
    "    length_3_num_elements = 0\n",
    "    for key in final_means_filtered:\n",
    "        if len(key) == 1:\n",
    "            length_1_mean += final_means_filtered[key]\n",
    "            length_1_num_elements += 1\n",
    "        elif len(key) == 2:\n",
    "            length_2_mean += final_means_filtered[key]\n",
    "            length_2_num_elements += 1\n",
    "\n",
    "        elif len(key) == 3:\n",
    "            length_3_mean += final_means_filtered[key]\n",
    "            length_3_num_elements += 1\n",
    "\n",
    "\n",
    "    print(length_1_mean / length_1_num_elements)\n",
    "    print(length_2_mean / length_2_num_elements)\n",
    "    print(length_3_mean / length_3_num_elements)\n",
    "\n",
    "    return length_1_mean / length_1_num_elements, length_2_mean / length_2_num_elements, length_3_mean / length_3_num_elements\n",
    "\n",
    "\n",
    "print(\"COLOR ABLATION\")\n",
    "final_means_filtered = {}\n",
    "for key in final_means:\n",
    "    if 'C' not in key:\n",
    "        final_means_filtered[key] = final_means[key]\n",
    "\n",
    "l1c, l2c, l3c = get_filtered(final_means_filtered)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"OVERALL_PRECISION: \", np.mean(overall_precision_dino))\n",
    "print(\"OVERALL_W_PRECISION: \", np.mean(overall_w_precision_dino))\n",
    "print(\"THREEDCOORDS_DISTANCE: \", np.mean(threedcoords_distance_dino))\n",
    "print(\"MATCHES_DICT: \", matches_dict_dino_arr_dino)\n",
    "print(\"HIST_DICT: \", hist_arr_dino)\n",
    "\n",
    "error_rate = 0\n",
    "error_rates = []\n",
    "for i in range(7):\n",
    "        error_rate += matches_dict_dino_arr_dino[i][0.0]\n",
    "        error_rates.append(matches_dict_dino_arr_dino[i][0.0])\n",
    "\n",
    "print(error_rate / 7)\n",
    "print(np.mean(error_rates))\n",
    "print(np.std(error_rates))\n",
    "\n",
    "print(np.std(np.mean(np.array(overall_precision_dino), axis=1)))\n",
    "print(np.std(np.mean(np.array(overall_w_precision_dino), axis=1)))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "from skimage.metrics import structural_similarity as ssim\n",
    "\n",
    "def vae_vae_comparison_multi_abl(ref_object_indeces,\n",
    "                              ref_full_images_m,\n",
    "                              ref_dino_full_atts_m,\n",
    "                              ref_projections_m,\n",
    "                              remaining_images_m,\n",
    "                              remaining_projections_m, \n",
    "                              remaining_dino_full_atts_m, \n",
    "                              ref_color_items_m,\n",
    "                              ref_shape_items_m,\n",
    "                              ref_size_items_m,\n",
    "                              ref_material_items_m,\n",
    "                              ref_threedcoords_items_m,\n",
    "                              ref_numobjects_items_m,\n",
    "                              remaining_color_items_m, \n",
    "                              remaining_shape_items_m, \n",
    "                              remaining_size_items_m, \n",
    "                              remaining_material_items_m,\n",
    "                              remaining_threedcoords_items_m,\n",
    "                              remaining_numobjects_items_m,\n",
    "                              matching_exp=\"1<=x<=3\",\n",
    "                              sample_list=[10]):\n",
    "\n",
    "    num_samples = 10\n",
    "\n",
    "    overall_precision = []\n",
    "    overall_w_precision = []\n",
    "\n",
    "    overall_precision_abl_dict = defaultdict(list)\n",
    "    overall_threedcoords_distance = []\n",
    "    matches_dict = defaultdict(int)\n",
    "\n",
    "    theoretical_max = 0\n",
    "    for i in range(num_samples):\n",
    "        theoretical_max += 1 / (i+1)\n",
    "\n",
    "    avg_threedcoords_distance = 0\n",
    "    total_num_matches = 0\n",
    "\n",
    "\n",
    "    for image_index_m in range(len(ref_object_indeces)):\n",
    "        print(\"CURRENT IMAGE: \", image_index_m)\n",
    "        ref_dino_full_atts = ref_dino_full_atts_m[image_index_m]\n",
    "        ref_projections = ref_projections_m[image_index_m]\n",
    "\n",
    "        sorted_vd_indices, sorted_vd_sim = vae_comparison3_opt(ref_projections, remaining_projections_m)\n",
    "        \n",
    "        ref_color_items = ref_color_items_m[image_index_m]\n",
    "        ref_shape_items = ref_shape_items_m[image_index_m]\n",
    "        ref_material_items = ref_material_items_m[image_index_m]\n",
    "        ref_size_items = ref_size_items_m[image_index_m]\n",
    "        ref_threedcoords_items = ref_threedcoords_items_m[image_index_m]\n",
    "\n",
    "        ref_objects = []\n",
    "        ref_threed_distances = []\n",
    "        ref_abl_dict = defaultdict(list)\n",
    "\n",
    "\n",
    "        for i in range(len(ref_color_items)):\n",
    "            ref_objects.append([ref_color_items[i], ref_shape_items[i], ref_material_items[i], ref_size_items[i]])\n",
    "\n",
    "            ref_abl_dict['S'].append([ref_size_items[i]])\n",
    "            ref_abl_dict['D'].append([ref_shape_items[i]])\n",
    "            ref_abl_dict['M'].append([ref_material_items[i]])\n",
    "            ref_abl_dict['C'].append([ref_color_items[i]])\n",
    "\n",
    "            ref_abl_dict['SD'].append([ref_size_items[i], ref_shape_items[i]])\n",
    "            ref_abl_dict['DS'].append([ref_shape_items[i], ref_size_items[i]])\n",
    "            ref_abl_dict['SM'].append([ref_size_items[i], ref_material_items[i]])\n",
    "            ref_abl_dict['MS'].append([ref_material_items[i], ref_size_items[i]])\n",
    "            ref_abl_dict['SC'].append([ref_size_items[i], ref_color_items[i]])\n",
    "            ref_abl_dict['CS'].append([ref_color_items[i], ref_size_items[i]])\n",
    "            ref_abl_dict['DM'].append([ref_shape_items[i], ref_material_items[i]])\n",
    "            ref_abl_dict['MD'].append([ref_material_items[i], ref_shape_items[i]])\n",
    "            ref_abl_dict['DC'].append([ref_shape_items[i], ref_color_items[i]])\n",
    "            ref_abl_dict['CD'].append([ref_color_items[i], ref_shape_items[i]])\n",
    "            ref_abl_dict['MC'].append([ref_material_items[i], ref_color_items[i]])\n",
    "            ref_abl_dict['CM'].append([ref_color_items[i], ref_material_items[i]])\n",
    "\n",
    "            ref_abl_dict['SDM'].append([ref_size_items[i], ref_shape_items[i], ref_material_items[i]])\n",
    "            ref_abl_dict['SMD'].append([ref_size_items[i], ref_material_items[i], ref_shape_items[i]])\n",
    "            ref_abl_dict['DSM'].append([ref_shape_items[i], ref_size_items[i], ref_material_items[i]])\n",
    "            ref_abl_dict['DMS'].append([ref_shape_items[i], ref_material_items[i], ref_size_items[i]])\n",
    "            ref_abl_dict['MSD'].append([ref_material_items[i], ref_size_items[i], ref_shape_items[i]])\n",
    "            ref_abl_dict['MDS'].append([ref_material_items[i], ref_shape_items[i], ref_size_items[i]])\n",
    "\n",
    "            ref_abl_dict['SDC'].append([ref_size_items[i], ref_shape_items[i], ref_color_items[i]])\n",
    "            ref_abl_dict['SCD'].append([ref_size_items[i], ref_color_items[i], ref_shape_items[i]])\n",
    "            ref_abl_dict['DSC'].append([ref_shape_items[i], ref_size_items[i], ref_color_items[i]])\n",
    "            ref_abl_dict['DCS'].append([ref_shape_items[i], ref_color_items[i], ref_size_items[i]])\n",
    "            ref_abl_dict['CSD'].append([ref_color_items[i], ref_size_items[i], ref_shape_items[i]])\n",
    "            ref_abl_dict['CDS'].append([ref_color_items[i], ref_shape_items[i], ref_size_items[i]])\n",
    "\n",
    "            ref_abl_dict['SMC'].append([ref_size_items[i], ref_material_items[i], ref_color_items[i]])\n",
    "            ref_abl_dict['SCM'].append([ref_size_items[i], ref_color_items[i], ref_material_items[i]])\n",
    "            ref_abl_dict['MSC'].append([ref_material_items[i], ref_size_items[i], ref_color_items[i]])\n",
    "            ref_abl_dict['MCS'].append([ref_material_items[i], ref_color_items[i], ref_size_items[i]])\n",
    "            ref_abl_dict['CSM'].append([ref_color_items[i], ref_size_items[i], ref_material_items[i]])\n",
    "            ref_abl_dict['CMS'].append([ref_color_items[i], ref_material_items[i], ref_size_items[i]])\n",
    "\n",
    "            ref_abl_dict['DMC'].append([ref_shape_items[i], ref_material_items[i], ref_color_items[i]])\n",
    "            ref_abl_dict['DCM'].append([ref_shape_items[i], ref_color_items[i], ref_material_items[i]])\n",
    "            ref_abl_dict['MDC'].append([ref_material_items[i], ref_shape_items[i], ref_color_items[i]])\n",
    "            ref_abl_dict['MCD'].append([ref_material_items[i], ref_color_items[i], ref_shape_items[i]])\n",
    "            ref_abl_dict['CMD'].append([ref_color_items[i], ref_material_items[i], ref_shape_items[i]])\n",
    "            ref_abl_dict['CDM'].append([ref_color_items[i], ref_shape_items[i], ref_material_items[i]])\n",
    "            ref_threed_distances.append([ref_threedcoords_items[i]])\n",
    "\n",
    "        print(\"TEST_REF: \", ref_color_items)\n",
    "        print(\"REF_OBJECTS: \", ref_objects)\n",
    "\n",
    "\n",
    "        avg_match_count = 0\n",
    "        avg_w_match_count = 0\n",
    "        avg_match_count_abl_dict = defaultdict(int)\n",
    "\n",
    "\n",
    "        for i, index in enumerate(sorted_vd_indices[0:num_samples+0]):\n",
    "            color_items = remaining_color_items_m[index]\n",
    "            shape_items = remaining_shape_items_m[index]\n",
    "            material_items = remaining_material_items_m[index]\n",
    "            size_items = remaining_size_items_m[index]\n",
    "            threedcoords_items = remaining_threedcoords_items_m[index]\n",
    "\n",
    "            num_matches = 0\n",
    "            num_matches_abl_dict = defaultdict(int)\n",
    "            objects = []\n",
    "            object_abl_dict = defaultdict(list)\n",
    "            \n",
    "            threed_distances = []\n",
    "\n",
    "            for i in range(len(color_items)):\n",
    "                objects.append([color_items[i], shape_items[i], material_items[i], size_items[i]])\n",
    "                object_abl_dict['S'].append([size_items[i]])\n",
    "                object_abl_dict['D'].append([shape_items[i]])\n",
    "                object_abl_dict['M'].append([material_items[i]])\n",
    "                object_abl_dict['C'].append([color_items[i]])\n",
    "\n",
    "                object_abl_dict['SD'].append([size_items[i], shape_items[i]])\n",
    "                object_abl_dict['DS'].append([shape_items[i], size_items[i]])\n",
    "                object_abl_dict['SM'].append([size_items[i], material_items[i]])\n",
    "                object_abl_dict['MS'].append([material_items[i], size_items[i]])\n",
    "                object_abl_dict['SC'].append([size_items[i], color_items[i]])\n",
    "                object_abl_dict['CS'].append([color_items[i], size_items[i]])\n",
    "                object_abl_dict['DM'].append([shape_items[i], material_items[i]])\n",
    "                object_abl_dict['MD'].append([material_items[i], shape_items[i]])\n",
    "                object_abl_dict['DC'].append([shape_items[i], color_items[i]])\n",
    "                object_abl_dict['CD'].append([color_items[i], shape_items[i]])\n",
    "                object_abl_dict['MC'].append([material_items[i], color_items[i]])\n",
    "                object_abl_dict['CM'].append([color_items[i], material_items[i]])\n",
    "\n",
    "                object_abl_dict['SDM'].append([size_items[i], shape_items[i], material_items[i]])\n",
    "                object_abl_dict['SMD'].append([size_items[i], material_items[i], shape_items[i]])\n",
    "                object_abl_dict['DSM'].append([shape_items[i], size_items[i], material_items[i]])\n",
    "                object_abl_dict['DMS'].append([shape_items[i], material_items[i], size_items[i]])\n",
    "                object_abl_dict['MSD'].append([material_items[i], size_items[i], shape_items[i]])\n",
    "                object_abl_dict['MDS'].append([material_items[i], shape_items[i], size_items[i]])\n",
    "\n",
    "                object_abl_dict['SDC'].append([size_items[i], shape_items[i], color_items[i]])\n",
    "                object_abl_dict['SCD'].append([size_items[i], color_items[i], shape_items[i]])\n",
    "                object_abl_dict['DSC'].append([shape_items[i], size_items[i], color_items[i]])\n",
    "                object_abl_dict['DCS'].append([shape_items[i], color_items[i], size_items[i]])\n",
    "                object_abl_dict['CSD'].append([color_items[i], size_items[i], shape_items[i]])\n",
    "                object_abl_dict['CDS'].append([color_items[i], shape_items[i], size_items[i]])\n",
    "\n",
    "                object_abl_dict['SMC'].append([size_items[i], material_items[i], color_items[i]])\n",
    "                object_abl_dict['SCM'].append([size_items[i], color_items[i], material_items[i]])\n",
    "                object_abl_dict['MSC'].append([material_items[i], size_items[i], color_items[i]])\n",
    "                object_abl_dict['MCS'].append([material_items[i], color_items[i], size_items[i]])\n",
    "                object_abl_dict['CSM'].append([color_items[i], size_items[i], material_items[i]])\n",
    "                object_abl_dict['CMS'].append([color_items[i], material_items[i], size_items[i]])\n",
    "\n",
    "                object_abl_dict['DMC'].append([shape_items[i], material_items[i], color_items[i]])\n",
    "                object_abl_dict['DCM'].append([shape_items[i], color_items[i], material_items[i]])\n",
    "                object_abl_dict['MDC'].append([material_items[i], shape_items[i], color_items[i]])\n",
    "                object_abl_dict['MCD'].append([material_items[i], color_items[i], shape_items[i]])\n",
    "                object_abl_dict['CMD'].append([color_items[i], material_items[i], shape_items[i]])\n",
    "                object_abl_dict['CDM'].append([color_items[i], shape_items[i], material_items[i]])\n",
    "                threed_distances.append([threedcoords_items[i]])\n",
    "            \n",
    "            print(\"REF_OBJECTS: \", ref_objects)\n",
    "            print(\"OBJECTS: \", objects)\n",
    "\n",
    "            ref_objects_tuples_abl_dict = defaultdict(list)\n",
    "            element_counts_abl_dict = defaultdict(list)\n",
    "            object_tuples_abl_dict = defaultdict(list)\n",
    "            element_counts_objects_abl_dict = defaultdict(list)\n",
    "\n",
    "           # Convert sublists to tuples to make them hashable\n",
    "            ref_objects_tuples = [tuple(sublist) for sublist in ref_objects]\n",
    "            \n",
    "            # Count the occurrences of each unique sublist\n",
    "            element_counts = Counter(ref_objects_tuples)\n",
    "           \n",
    "\n",
    "            print(\"Element Counts:\", element_counts)\n",
    "\n",
    "            objects_tuples = [tuple(sublist) for sublist in objects]\n",
    "           \n",
    "            element_counts_objects = Counter(objects_tuples)\n",
    "           \n",
    "            for subset in [\"S\", \"D\", \"M\", \"C\", \"SD\", \"DS\",\"SM\", \"MS\",\"SC\", \"CS\",\"DM\", \"MD\",\"DC\", \"CD\",\"MC\", \"CM\", \"SDM\", \"SMD\", \"DSM\", \"DMS\", \"MSD\", \"MDS\", \"SDC\", \"SCD\", \"DSC\", \"DCS\", \"CSD\", \"CDS\", \"SMC\", \"SCM\", \"MSC\", \"MCS\", \"CSM\", \"CMS\", \"DMC\", \"DCM\", \"MDC\", \"MCD\", \"CMD\", \"CDM\"]:\n",
    "\n",
    "                \n",
    "                ref_objects_tuples_abl_dict[subset] = [tuple(sublist) for sublist in ref_abl_dict[subset]]\n",
    "\n",
    "               \n",
    "                element_counts_abl_dict[subset] = Counter(ref_objects_tuples_abl_dict[subset])\n",
    "\n",
    "                print(\"Element Counts:\", element_counts)\n",
    "\n",
    "             \n",
    "                object_tuples_abl_dict[subset] = [tuple(sublist) for sublist in object_abl_dict[subset]]\n",
    "\n",
    "\n",
    "               \n",
    "                element_counts_objects_abl_dict[subset] = Counter(object_tuples_abl_dict[subset])\n",
    "\n",
    "\n",
    "            print(\"Element Counts Objects:\", element_counts_objects)\n",
    "            print(\"INDEX: \", index)\n",
    "\n",
    "            print(element_counts_objects.keys())\n",
    "\n",
    "            for ref_key in element_counts:\n",
    "                if ref_key in element_counts_objects:\n",
    "                    print(ref_key)\n",
    "                    print(element_counts[ref_key])\n",
    "                    print(element_counts_objects[ref_key])\n",
    "                    if element_counts[ref_key] > 0 and element_counts_objects[ref_key] > 0:\n",
    "                        num_matches += element_counts[ref_key]\n",
    "                        avg_threedcoords_distance += np.linalg.norm(np.array(ref_threed_distances[ref_key == element_counts.keys()])[:2] - np.array(threed_distances[ref_key == element_counts_objects.keys()])[:2])\n",
    "\n",
    "\n",
    "            for subset in [\"S\", \"D\", \"M\", \"C\", \"SD\", \"DS\",\"SM\", \"MS\",\"SC\", \"CS\",\"DM\", \"MD\",\"DC\", \"CD\",\"MC\", \"CM\", \"SDM\", \"SMD\", \"DSM\", \"DMS\", \"MSD\", \"MDS\", \"SDC\", \"SCD\", \"DSC\", \"DCS\", \"CSD\", \"CDS\", \"SMC\", \"SCM\", \"MSC\", \"MCS\", \"CSM\", \"CMS\", \"DMC\", \"DCM\", \"MDC\", \"MCD\", \"CMD\", \"CDM\"]:\n",
    "                for ref_key in element_counts_abl_dict[subset]:\n",
    "                    if ref_key in element_counts_objects_abl_dict[subset]:\n",
    "                        if element_counts_abl_dict[subset][ref_key] > 0 and element_counts_objects_abl_dict[subset][ref_key] > 0:\n",
    "                            num_matches_abl_dict[subset] += element_counts_abl_dict[subset][ref_key]\n",
    "           \n",
    "\n",
    "            print(\"NUM_MATCHES: \", num_matches)    \n",
    "            matches_dict[num_matches] += 1\n",
    "            total_num_matches += num_matches\n",
    "\n",
    "            if matching_exp == \"1<=x<=3\":\n",
    "                if num_matches >= 1 and num_matches <= 3:\n",
    "                    avg_w_match_count += 1 / (i+1)\n",
    "                    avg_match_count += 1\n",
    "               \n",
    "            elif matching_exp == \"2<=x<=3\":\n",
    "                if num_matches >= 2 and num_matches <= 3:\n",
    "                    avg_w_match_count += 1 / (i+1)\n",
    "                    avg_match_count += 1\n",
    "               \n",
    "            elif matching_exp == \"x==3\":\n",
    "                if num_matches == 3:\n",
    "                    avg_w_match_count += 1 / (i+1)\n",
    "                    avg_match_count += 1\n",
    "               \n",
    "            for subset in [\"S\", \"D\", \"M\", \"C\", \"SD\", \"DS\",\"SM\", \"MS\",\"SC\", \"CS\",\"DM\", \"MD\",\"DC\", \"CD\",\"MC\", \"CM\", \"SDM\", \"SMD\", \"DSM\", \"DMS\", \"MSD\", \"MDS\", \"SDC\", \"SCD\", \"DSC\", \"DCS\", \"CSD\", \"CDS\", \"SMC\", \"SCM\", \"MSC\", \"MCS\", \"CSM\", \"CMS\", \"DMC\", \"DCM\", \"MDC\", \"MCD\", \"CMD\", \"CDM\"]:\n",
    "                    if matching_exp == \"1<=x<=3\":\n",
    "                        if num_matches_abl_dict[subset] >= 1 and num_matches_abl_dict[subset] <= 3:\n",
    "                            avg_match_count_abl_dict[subset] += 1\n",
    "\n",
    "\n",
    "                    elif matching_exp == \"2<=x<=3\": \n",
    "                        if num_matches_abl_dict[subset] >= 2 and num_matches_abl_dict[subset] <= 3:\n",
    "                            avg_match_count_abl_dict[subset] += 1\n",
    "\n",
    "                    elif matching_exp == \"x==3\":\n",
    "                        if num_matches_abl_dict[subset] == 3:\n",
    "                            avg_match_count_abl_dict[subset] += 1\n",
    "\n",
    "\n",
    "\n",
    "        print(\"AVG_PRECISION: \", (avg_match_count / num_samples))\n",
    "        print(\"AVG_W_PRECISION: \", (avg_w_match_count / theoretical_max))\n",
    "\n",
    "        overall_w_precision.append(( avg_w_match_count / theoretical_max))\n",
    "        overall_precision.append((avg_match_count / num_samples))\n",
    "        \n",
    "        for subset in [\"S\", \"D\", \"M\", \"C\", \"SD\", \"DS\",\"SM\", \"MS\",\"SC\", \"CS\",\"DM\", \"MD\",\"DC\", \"CD\",\"MC\", \"CM\", \"SDM\", \"SMD\", \"DSM\", \"DMS\", \"MSD\", \"MDS\", \"SDC\", \"SCD\", \"DSC\", \"DCS\", \"CSD\", \"CDS\", \"SMC\", \"SCM\", \"MSC\", \"MCS\", \"CSM\", \"CMS\", \"DMC\", \"DCM\", \"MDC\", \"MCD\", \"CMD\", \"CDM\"]:\n",
    "            overall_precision_abl_dict[subset].append(avg_match_count_abl_dict[subset] / num_samples)\n",
    "        \n",
    "        print(\"OVERALL_ABL: \", overall_precision_abl_dict)\n",
    "\n",
    "    overall_threedcoords_distance.append((avg_threedcoords_distance / total_num_matches))\n",
    "\n",
    "    return overall_precision, overall_w_precision, overall_precision_abl_dict, overall_threedcoords_distance, Counter(overall_precision), matches_dict\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_runs = 3\n",
    "overall_precision_vae_vae = []\n",
    "overall_w_precision_vae_vae = []\n",
    "overall_precision_ssm_vae_vae = []\n",
    "overal_precision_ss_vae_vae = []\n",
    "overall_precision_s_vae_vae = []\n",
    "threedcoords_distance_vae_vae = []\n",
    "matches_dict_vae_vae_arr_vae_vae = []\n",
    "hist_arr_vae_vae = []\n",
    "\n",
    "\n",
    "for i in range(len(ref_data_matched)):\n",
    "\n",
    "    ref_filtered_3_indices = ref_data_matched[i][0]\n",
    "    ref_filtered_3_images = ref_data_matched[i][1]\n",
    "    ref_filtered_3_dino_full_atts = ref_data_matched[i][2]\n",
    "    ref_filtered_3_projections = ref_data_matched[i][3]\n",
    "    ref_filtered_3_color_items = ref_data_matched[i][4]\n",
    "    ref_filtered_3_shape_items = ref_data_matched[i][5]\n",
    "    ref_filtered_3_size_items = ref_data_matched[i][6]\n",
    "    ref_filtered_3_material_items = ref_data_matched[i][7]\n",
    "    ref_filtered_3_threedcoords_items = ref_data_matched[i][8]\n",
    "    ref_filtered_3_numobjects_items = ref_data_matched[i][9]\n",
    "\n",
    "    overall_precision, overall_w_precision, overall_precision_ssm, threedcoords_distance, matches_dict, hist_dict  = vae_vae_comparison_multi_abl(ref_filtered_3_indices,\n",
    "                                                                                                                                                 ref_filtered_3_images,\n",
    "                                                                                                                                                 ref_filtered_3_dino_full_atts,\n",
    "                                                                                                                                                 ref_filtered_3_projections,\n",
    "                                                                                                                                                 filtered_1_remaining_images,\n",
    "                                                                                                                                                 filtered_1_remaining_projections,\n",
    "                                                                                                                                                 filtered_1_remaining_dino_full_atts,\n",
    "                                                                                                                                                 ref_filtered_3_color_items,\n",
    "                                                                                                                                                 ref_filtered_3_shape_items,\n",
    "                                                                                                                                                 ref_filtered_3_size_items,\n",
    "                                                                                                                                                 ref_filtered_3_material_items,\n",
    "                                                                                                                                                 ref_filtered_3_threedcoords_items,\n",
    "                                                                                                                                                 ref_filtered_3_numobjects_items,\n",
    "                                                                                                                                                 filtered_1_remaining_color_items,\n",
    "                                                                                                                                                 filtered_1_remaining_shape_items,\n",
    "                                                                                                                                                 filtered_1_remaining_size_items,\n",
    "                                                                                                                                                 filtered_1_remaining_material_items,\n",
    "                                                                                                                                                 filtered_1_remaining_threedcoords_items,\n",
    "                                                                                                                                                 filtered_1_remaining_numobjects_items,\n",
    "                                                                                                                                                 matching_exp=matching_exp,\n",
    "                                                                                                                                                 sample_list=samples_lists[i])  \n",
    "    overall_precision_vae_vae.append(overall_precision)\n",
    "    overall_w_precision_vae_vae.append(overall_w_precision)\n",
    "    overall_precision_ssm_vae_vae.append(overall_precision_ssm)\n",
    "    threedcoords_distance_vae_vae.append(threedcoords_distance)\n",
    "    matches_dict_vae_vae_arr_vae_vae.append(matches_dict)\n",
    "    hist_arr_vae_vae.append(hist_dict)\n",
    "    \n",
    "print(\"OVERALL_PRECISION: \", overall_precision_vae_vae)\n",
    "print(\"OVERALL_W_PRECISION: \", overall_w_precision_vae_vae)\n",
    "print(\"OVERALL_PRECISION_SSM: \", overall_precision_ssm_vae_vae)\n",
    "print(\"THREEDCOORDS_DISTANCE: \", threedcoords_distance_vae_vae)\n",
    "print(\"MATCHES_DICT: \", matches_dict_vae_vae_arr_vae_vae)\n",
    "print(\"HIST_DICT: \", hist_arr_vae_vae) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"OVERALL_PRECISION: \", np.mean(overall_precision_vae_vae))\n",
    "print(\"OVERALL_W_PRECISION: \", np.mean(overall_w_precision_vae_vae))\n",
    "print(\"THREEDCOORDS_DISTANCE: \", np.mean(threedcoords_distance_vae_vae))\n",
    "#print(\"MATCHES_DICT: \", matches_dict_dino_arr_dino)\n",
    "print(\"HIST_DICT: \", hist_arr_vae_vae)\n",
    "\n",
    "error_rate = 0\n",
    "error_rates = []\n",
    "for i in range(7):\n",
    "        error_rate += matches_dict_vae_vae_arr_vae_vae[i][0.0]\n",
    "        error_rates.append(matches_dict_vae_vae_arr_vae_vae[i][0.0])\n",
    "\n",
    "print(error_rate / 7)\n",
    "print(np.mean(error_rates) / 50)\n",
    "print(np.std(error_rates))\n",
    "\n",
    "\n",
    "print(np.std(np.mean(np.array(overall_precision_vae_vae), axis=1)))\n",
    "print(np.std(np.mean(np.array(overall_w_precision_vae_vae), axis=1)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "import numpy as np\n",
    "\n",
    "# Assuming `dict_list` is your list of defaultdicts\n",
    "means_list = []\n",
    "\n",
    "for d in overall_precision_ssm_vae_vae:  # Replace `dict_list` with your actual variable\n",
    "    mean_dict = {}\n",
    "    for key, values in d.items():\n",
    "        mean_dict[key] = np.mean(values)\n",
    "    means_list.append(mean_dict)\n",
    "\n",
    "# If you want to compute the overall mean across all dicts for each key:\n",
    "overall_means = defaultdict(list)\n",
    "for mean_dict in means_list:\n",
    "    for key, value in mean_dict.items():\n",
    "        overall_means[key].append(value)\n",
    "\n",
    "final_means = {k: np.mean(v) for k, v in overall_means.items()}\n",
    "final_std = {k: np.std(v) for k, v in overall_means.items()}\n",
    "\n",
    "\n",
    "print(final_means)\n",
    "print(final_std)\n",
    "\n",
    "length_1_mean = 0\n",
    "length_1_num_elements = 0\n",
    "\n",
    "length_2_mean = 0\n",
    "length_2_num_elements = 0\n",
    "length_3_mean = 0\n",
    "length_3_num_elements = 0\n",
    "for key in final_means:\n",
    "    if len(key) == 1:\n",
    "        length_1_mean += final_means[key]\n",
    "        length_1_num_elements += 1\n",
    "    elif len(key) == 2:\n",
    "        length_2_mean += final_means[key]\n",
    "        length_2_num_elements += 1\n",
    "\n",
    "    elif len(key) == 3:\n",
    "        length_3_mean += final_means[key]\n",
    "        length_3_num_elements += 1\n",
    "\n",
    "\n",
    "print(length_1_mean / length_1_num_elements)\n",
    "print(length_2_mean / length_2_num_elements)\n",
    "print(length_3_mean / length_3_num_elements)\n",
    "\n",
    "# D is shape\n",
    "# S is size\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_filtered(final_means_filtered):\n",
    "    length_1_mean = 0\n",
    "    length_1_num_elements = 0\n",
    "\n",
    "    length_2_mean = 0\n",
    "    length_2_num_elements = 0\n",
    "    length_3_mean = 0\n",
    "    length_3_num_elements = 0\n",
    "\n",
    "    length_1_elements = []\n",
    "    length_2_elements = []\n",
    "    length_3_elements = []\n",
    "\n",
    "    for key in final_means_filtered:\n",
    "        if len(key) == 1:\n",
    "            length_1_mean += final_means_filtered[key]\n",
    "            length_1_num_elements += 1\n",
    "            length_1_elements.append(final_std[key])\n",
    "        elif len(key) == 2:\n",
    "            length_2_mean += final_means_filtered[key]\n",
    "            length_2_num_elements += 1\n",
    "            length_2_elements.append(final_std[key])\n",
    "\n",
    "\n",
    "        elif len(key) == 3:\n",
    "            length_3_mean += final_means_filtered[key]\n",
    "            length_3_num_elements += 1\n",
    "\n",
    "\n",
    "    print(length_1_mean / length_1_num_elements)\n",
    "    print(length_2_mean / length_2_num_elements)\n",
    "    print(length_3_mean / length_3_num_elements)\n",
    "\n",
    "    print(np.std(length_1_elements))\n",
    "    print(np.std(length_2_elements))\n",
    "\n",
    "    return length_1_mean / length_1_num_elements, length_2_mean / length_2_num_elements, length_3_mean / length_3_num_elements\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "print(\"COLOUR ABLATION\")\n",
    "final_means_filtered = {}\n",
    "for key in final_means:\n",
    "    if 'C' not in key:\n",
    "        final_means_filtered[key] = final_means[key]\n",
    "\n",
    "l1m, l2m, l3m = get_filtered(final_means_filtered)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dmlab",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
