{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "634a79dd-59e8-4edb-a271-221143c044fc",
   "metadata": {},
   "source": [
    "# To set up environment I have exported my working env to a .yml file."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "24c43505-638e-4f03-838c-0c0dd022592d",
   "metadata": {},
   "source": [
    "Just run `conda env create -f environment.yml` to install my environment.\n",
    "\n",
    "if you would rather you could make the environment from scratch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "a1d9daad-121a-460e-9347-37e8cea6a2bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import json\n",
    "import random\n",
    "from collections import defaultdict\n",
    "\n",
    "# Set project root once\n",
    "PROJECT_ROOT = \"./\"\n",
    "os.chdir(PROJECT_ROOT)\n",
    "sys.path.append(PROJECT_ROOT)\n",
    "\n",
    "import yaml\n",
    "import torch\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.colors as mcolors\n",
    "from matplotlib import cm\n",
    "from tqdm.auto import tqdm\n",
    "from sklearn.metrics import roc_curve, auc, log_loss, accuracy_score\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5b298dc9-a617-4719-bb86-00c89d726aea",
   "metadata": {},
   "source": [
    "# Dataset Download\n",
    "\n",
    "first you need to download the dataset used in the paper from this anonymus dropbox link (it may take a minute)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "bbe50cd1-6305-4d52-bf5a-6270bf556c60",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracted to /u/li19/moss_workshop_code\n"
     ]
    }
   ],
   "source": [
    "import requests, zipfile, io\n",
    "import os\n",
    "\n",
    "def download_and_unzip_dropbox(dropbox_url, extract_to=\".\"):\n",
    "    # Transform to direct download link\n",
    "    if \"?dl=0\" in dropbox_url:\n",
    "        download_url = dropbox_url.replace(\"?dl=0\", \"?dl=1\")\n",
    "    elif \"?dl=1\" not in dropbox_url:\n",
    "        download_url = dropbox_url + \"?dl=1\"\n",
    "    else:\n",
    "        download_url = dropbox_url\n",
    "\n",
    "    # Download and unzip\n",
    "    response = requests.get(download_url)\n",
    "    response.raise_for_status()\n",
    "\n",
    "    with zipfile.ZipFile(io.BytesIO(response.content)) as zip_ref:\n",
    "        zip_ref.extractall(extract_to)\n",
    "        print(f\"Extracted to {os.path.abspath(extract_to)}\")\n",
    "\n",
    "dropbox_url = \"https://dl.dropboxusercontent.com/scl/fi/bk0gcvfy2ytni5eeu4gah/dataset.zip?rlkey=da701q1fc3xas3s1et76h5f7a&st=ft5099v2&dl=0\"\n",
    "download_and_unzip_dropbox(dropbox_url)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "26bed7e3-3000-49f6-8bf8-0563cb56fa07",
   "metadata": {},
   "source": [
    "# Set up"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "60829567-3103-45d9-889b-e01ec1aa5a6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from definition import *\n",
    "from model.device_check import *\n",
    "import tool.dynamic as dynamic\n",
    "from engine.tweet import FixTensorFusion"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "4d155af5-d5d9-4ccc-bb1a-fe9129cfd1ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "# lock all random seed to make the experiment replicable\n",
    "seed = 1\n",
    "\n",
    "torch.manual_seed(seed)\n",
    "torch.cuda.manual_seed(seed)\n",
    "torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.\n",
    "np.random.seed(seed)  # Numpy module.\n",
    "random.seed(seed)  # Python random module.\n",
    "torch.manual_seed(seed)\n",
    "torch.backends.cudnn.benchmark = False\n",
    "torch.backends.cudnn.deterministic = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "1aaab4bb-0eaf-47b3-bf00-b9adb4de5618",
   "metadata": {},
   "outputs": [],
   "source": [
    "RESULTS_ROOT = os.path.join(PROJECT_ROOT, \"results\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "db865f91-7f5c-4463-8ef9-8aa28b76e775",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5b00117a-3aca-4980-8ca6-0416f20a8880",
   "metadata": {},
   "source": [
    "# Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "3c6648d8-0dd1-4c44-860d-92a347a5b546",
   "metadata": {},
   "outputs": [],
   "source": [
    "def main_train(config, exp_version=None, reset=False, nTrials=1, verbose=True, synth=False, shorten=False):\n",
    "    # Load configuration file\n",
    "    with open(os.path.join(PROJECT_ROOT, \"config\", \"tweet\", config), 'r') as f:\n",
    "        cfg = yaml.load(f, Loader=yaml.FullLoader)\n",
    "    print(cfg)\n",
    "    cfg[\"data\"][\"synthetic\"] = synth\n",
    "    exp_name = cfg[\"exp_name\"]\n",
    "    base_name = exp_name\n",
    "    for n in range(nTrials):\n",
    "        exp_name = base_name\n",
    "        if nTrials != 1:\n",
    "            exp_name += \"_t{}\".format(n)\n",
    "\n",
    "        if exp_version is not None:\n",
    "            exp_name += \"_e_{}\".format(exp_version)\n",
    "\n",
    "        if not os.path.exists(os.path.join(RESULTS_ROOT, exp_name)):\n",
    "            os.makedirs(os.path.join(RESULTS_ROOT, exp_name))\n",
    "\n",
    "        # Path to save output files, like losses, scores, figures etc.\n",
    "        report_path = os.path.join(RESULTS_ROOT, exp_name)\n",
    "\n",
    "        # Initialize test performances\n",
    "        test_performances = []\n",
    "\n",
    "        # Progress tracking file path\n",
    "        progress_file = os.path.join(report_path, \"latest_subset.json\")\n",
    "\n",
    "        if shorten:\n",
    "            j = 12\n",
    "        else:\n",
    "            j = 4\n",
    "        while j < 15:\n",
    "            j += 1\n",
    "            i = (2 ** j)\n",
    "            cur_cfg = cfg.copy()\n",
    "            cur_cfg[\"data\"][\"batch_size\"] = 16\n",
    "            cur_exp_name = \"subset{}\".format(i)\n",
    "            cur_subset_num = i + 1\n",
    "            cur_cfg[\"exp_name\"] = cur_exp_name\n",
    "            cur_cfg[\"data\"][\"num_subsets\"] = i\n",
    "\n",
    "            # Create directory for the current subset if it doesn't exist\n",
    "            if not os.path.exists(os.path.join(report_path, cur_exp_name)):\n",
    "                os.makedirs(os.path.join(report_path, cur_exp_name))\n",
    "\n",
    "            # Path to save output files for the current subset\n",
    "            cur_report_path = os.path.join(report_path, cur_exp_name)\n",
    "            \n",
    "            if verbose:\n",
    "                print(f\"Current exp: {cur_exp_name}\")\n",
    "            \n",
    "            # Initialize and run training\n",
    "            p = dynamic.import_string(cur_cfg[\"engine\"])(cur_cfg, cur_report_path)\n",
    "            p.verbose = True\n",
    "            p.train(cfg[\"train\"])\n",
    "            test_loss, test_acc = p.test()\n",
    "            print(\"Subset size {}, test accuracy {}\".format((i + 1) * 20, test_acc))\n",
    "            test_performances.append(test_acc)\n",
    "\n",
    "            # Save test performances\n",
    "            with open(os.path.join(report_path, \"test_performances.json\"), \"w\") as outfile:\n",
    "                outfile.write(json.dumps(test_performances))\n",
    "\n",
    "            # Update and save progress to the latest_subset file\n",
    "            with open(progress_file, \"w\") as outfile:\n",
    "                json.dump({\"latest_subset\": i + 1}, outfile)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bb1426c7-ad2e-4545-a520-0ad163f58bbf",
   "metadata": {},
   "source": [
    "# Set fusion method and data source here\n",
    "\n",
    "to train all models listed in paper you must run the training code using all configs with and without `synth == True`\n",
    "\n",
    "Not all settings must be run to plot results, for easiest testing do a single train and plot results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "607679ee-ee25-4196-abd6-33b0deff720d",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "all_cfg = [\"concat_tweet.yaml\", \"tensorfusion_tweet.yaml\", \"product_tweet.yaml\"]\n",
    "cfg = all_cfg[0]\n",
    "synth = False\n",
    "\n",
    "if synth:\n",
    "    exp = \"synth\"\n",
    "else:\n",
    "    exp = \"basic\"\n",
    "\n",
    "main_train(\n",
    "    config = cfg,\n",
    "    exp_version = exp, \n",
    "    reset = False, \n",
    "    nTrials = 1, \n",
    "    synth = synth,\n",
    "    shorten=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7f9947cd-cccd-47b1-9bce-cd95540dec43",
   "metadata": {},
   "source": [
    "# Accuracy Plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f62c78d-abb5-43c5-9625-8f324d34169d",
   "metadata": {},
   "outputs": [],
   "source": [
    "folders = [f for f in os.listdir(RESULTS_ROOT) if \"tweet\" in f]\n",
    "all_data = {}\n",
    "for f in [fold for fold in folders if \"test_performances.json\" in os.listdir(os.path.join(RESULTS_ROOT, fold))]:\n",
    "    log_file = os.path.join(RESULTS_ROOT, f, \"test_performances.json\")\n",
    "    with open(log_file, 'r') as file:\n",
    "        data = json.load(file)\n",
    "    all_data[f] = data\n",
    "    \n",
    "window_size = 2  # Smoothing window size\n",
    "\n",
    "colors = sorted(\n",
    "            mcolors.BASE_COLORS, key=lambda c: tuple(mcolors.rgb_to_hsv(mcolors.to_rgb(c))))\n",
    "\n",
    "concat_base = []\n",
    "concat_synth = []\n",
    "\n",
    "prod_base = []\n",
    "prod_synth = []\n",
    "\n",
    "tfuse_base = []\n",
    "tfuse_synth = []\n",
    "\n",
    "all_sets = [concat_base,concat_synth,prod_base,prod_synth,tfuse_base,tfuse_synth]\n",
    "labels = [\"concat_base\",\"concat_synth\",\"prod_base\",\"prod_synth\",\"tfuse_base\",\"tfuse_synth\"]\n",
    "for t in list(all_data.keys()):\n",
    "    idx = 0\n",
    "    if \"prod\" in t:\n",
    "        idx += 2\n",
    "    elif \"tensorfusion\" in t:\n",
    "        idx += 4\n",
    "        \n",
    "    if \"synth\" in t:\n",
    "        idx += 1\n",
    "    # if len(all_data[t]) == 11:\n",
    "    all_sets[idx].append(all_data[t])\n",
    "    \n",
    "all_sets = [np.array(subset) for subset in all_sets]\n",
    "\n",
    "all_means = []\n",
    "all_stds = []\n",
    "\n",
    "for model_data_pair in all_sets:\n",
    "    all_means.append(model_data_pair.mean(axis=0))\n",
    "    all_stds.append(model_data_pair.std(axis=0))\n",
    "x = np.arange(len(all_means[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3e78945-2ce4-40f3-886e-0962a3df3dac",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(10,5))    \n",
    "\n",
    "for mean, std, label in zip(all_means, all_stds, labels):\n",
    "    if \"synth\" in label:\n",
    "        line_style = 'dashed'\n",
    "    else:\n",
    "        line_style = 'solid'\n",
    "    if \"concat\" in label:\n",
    "        color = \"blue\"\n",
    "    elif \"tfuse\" in label:\n",
    "        color = 'red'\n",
    "    else:\n",
    "        color = 'green'\n",
    "\n",
    "    show_label = label.replace(\"_\", \" \").replace(\"synth\", \"synthetic\").replace(\"tfuse\", \"tensorfusion\").replace(\"prod\", \"product\").replace(\"base\", \"real\")\n",
    "    if mean.size != 0 and not np.isnan(mean).all():\n",
    "        print(mean)\n",
    "        plt.errorbar(\n",
    "            x, \n",
    "            mean, \n",
    "            label = show_label, \n",
    "            yerr=std, \n",
    "            capsize=2, \n",
    "            elinewidth=2, \n",
    "            markeredgewidth=1,\n",
    "            linestyle=line_style,\n",
    "            color=color\n",
    "        )\n",
    "\n",
    "plt.xlabel(\"Dataset Size in Chunks of 2^(x+5)\")\n",
    "plt.ylabel(\"Validation Accuracy\")\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "89ac7fbe-f442-4264-afec-8fae6799e3b5",
   "metadata": {},
   "source": [
    "# ROC Plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d0b6cf3-cc11-4f6e-a20e-1e2944e8cb03",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define base directory\n",
    "base_dir = RESULTS_ROOT\n",
    "\n",
    "# Define groups based on naming conventions\n",
    "fusion_methods = [\"concat\", \"prodconcat\", \"tensorfusion\"]\n",
    "environments = [\n",
    "    \"e_demo\", \n",
    "    \"e_basic\",\n",
    "    \"e_synth\"\n",
    "]\n",
    "\n",
    "# Create a dictionary to hold group data\n",
    "group_data = {f\"{fusion}_{env}\": [] for fusion in fusion_methods for env in environments}\n",
    "\n",
    "# Identify all trials\n",
    "trial_folders = [folder for folder in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, folder))]\n",
    "\n",
    "# Group trials by identifiers\n",
    "for folder in trial_folders:\n",
    "    for fusion in fusion_methods:\n",
    "        for env in environments:\n",
    "            if fusion in folder and env in folder:\n",
    "                if fusion == \"concat\":\n",
    "                    if \"prodconcat\" not in folder:\n",
    "                        group_data[f\"{fusion}_{env}\"].append(folder)\n",
    "                else:\n",
    "                    group_data[f\"{fusion}_{env}\"].append(folder)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "faf8f7ec-118a-49b8-9668-2f752876a072",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# List of subset folders and their sizes\n",
    "subset_folders = [\n",
    "    \"subset32\", \"subset64\", \"subset128\", \"subset256\", \"subset512\",\n",
    "    \"subset1024\", \"subset2048\", \"subset4096\", \"subset8192\", \"subset16384\", \"subset32768\"\n",
    "]\n",
    "sizes = [int(folder.replace(\"subset\", \"\")) for folder in subset_folders]\n",
    "normalized_sizes = (np.array(sizes) - min(sizes)) / (max(sizes) - min(sizes))\n",
    "colors = cm.viridis(normalized_sizes)  # Gradient colors for subsets\n",
    "\n",
    "# Function to compute ROC curve for subsets within a group\n",
    "def group_subset_roc(group_folders):\n",
    "    subset_results = defaultdict(list)\n",
    "    for folder in group_folders:\n",
    "        for subset in subset_folders:\n",
    "            test_results_path = os.path.join(base_dir, folder, subset, f\"{subset}_test_results.json\")\n",
    "            if not os.path.exists(test_results_path):\n",
    "                print(f\"Test results file not found: {test_results_path}\")\n",
    "                continue\n",
    "\n",
    "            with open(test_results_path, 'r') as f:\n",
    "                data = json.load(f)\n",
    "\n",
    "            ground_truth = np.array(data['ground_truth'])\n",
    "            predictions = np.array(data['predict'])\n",
    "\n",
    "            # Extract true labels and predicted scores\n",
    "            ground_truth_labels = ground_truth[:, 1]\n",
    "            predicted_scores = predictions[:, 1]\n",
    "\n",
    "            # Compute ROC curve\n",
    "            fpr, tpr, _ = roc_curve(ground_truth_labels, predicted_scores)\n",
    "            subset_results[subset].append((fpr, tpr))\n",
    "\n",
    "    # Average TPRs across trials for each subset\n",
    "    mean_results = {}\n",
    "    for subset, curves in subset_results.items():\n",
    "        all_tprs = [np.interp(np.linspace(0, 1, 100), fpr, tpr) for fpr, tpr in curves]\n",
    "        mean_tpr = np.mean(all_tprs, axis=0)\n",
    "        mean_fpr = np.linspace(0, 1, 100)\n",
    "        mean_results[subset] = (mean_fpr, mean_tpr)\n",
    "\n",
    "    return mean_results\n",
    "\n",
    "\n",
    "all_means = {key: group_subset_roc(group_data[key]) for key in group_data.keys()}\n",
    "# Plotting\n",
    "fig, axes = plt.subplots(2, 3, figsize=(12, 6))\n",
    "axes = axes.ravel()\n",
    "\n",
    "groups = [\"_\".join(t.split(\"_\")[1:]) for t in trial_folders]\n",
    "\n",
    "for idx, group_name in enumerate(groups):\n",
    "    group_folder = group_data[group_name]\n",
    "    mean_results = all_means[group_name]\n",
    "    ax = axes[idx]\n",
    "    i = 0\n",
    "    for subset, color in zip(subset_folders, colors):\n",
    "        i += 1\n",
    "        if subset not in mean_results:\n",
    "            continue\n",
    "        mean_fpr, mean_tpr = mean_results[subset]\n",
    "        roc_auc = auc(mean_fpr, mean_tpr)\n",
    "        \n",
    "        \n",
    "        if \"concat\" in group_name:\n",
    "            color = \"blue\"\n",
    "        if \"tensorfusion\" in group_name:\n",
    "            color = 'red'\n",
    "        if \"prodconcat\" in group_name:\n",
    "            color = 'green'\n",
    "        \n",
    "        alpha = i / 15\n",
    "        \n",
    "        ax.plot(mean_fpr, mean_tpr, color=color, alpha=alpha, lw=2, \n",
    "                label=f\"N={subset[6:]}\")\n",
    "\n",
    "    # Plot random guess line and formatting\n",
    "    ax.plot([0, 1], [0, 1], color=\"gray\", linestyle=\"--\", lw=1)\n",
    "    ax.set_title(group_name.replace(\"_\", \" \").replace(\" e \", \" \").replace(\"synth\", \"synthetic\").replace(\"tfuse\", \"tensorfusion\").replace(\"prodconcat\", \"product concat\").replace(\"basic\", \"real\"), fontsize=14)\n",
    "    ax.set_xlim([0.0, 1.0])\n",
    "    ax.set_ylim([0.0, 1.05])\n",
    "    ax.set_xlabel(\"False Positive Rate\")\n",
    "    ax.set_ylabel(\"True Positive Rate\")\n",
    "    ax.legend(loc=\"lower right\", fontsize=8)\n",
    "    ax.grid(alpha=0.3)\n",
    "\n",
    "# Adjust layout\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a60f743e-9dbf-4628-bfde-eec672573718",
   "metadata": {},
   "source": [
    "# Robustness Plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7283496d-a5ce-44a9-bfe3-dbb70ebadc8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the order of subsets by dataset size\n",
    "subsets_order = [\n",
    "    \"subset32\",\n",
    "    \"subset64\",\n",
    "    \"subset128\",\n",
    "    \"subset256\",\n",
    "    \"subset512\",\n",
    "    \"subset1024\",\n",
    "    \"subset2048\",\n",
    "    \"subset4096\",\n",
    "    \"subset8192\",\n",
    "    \"subset16384\",\n",
    "    \"subset32768\"\n",
    "]\n",
    "\n",
    "def parse_model_name(model_name):\n",
    "    \"\"\"\n",
    "    Parse model directory name of the form:\n",
    "    tweet_<FUSION_TYPE>_t<NUMBER>_e_<DATA_SOURCE>\n",
    "    Returns (fusion_type, data_source).\n",
    "    \"\"\"\n",
    "    splits = model_name.split(\"_\")\n",
    "    if len(splits) < 3:\n",
    "        return None, None\n",
    "    else:\n",
    "        return splits[1], splits[3]\n",
    "\n",
    "def find_test_performance(test_perf_path,trials):\n",
    "    \"\"\"\n",
    "    Given a path to test_performances.json, return both the maximum performance score\n",
    "    and the subset that gave that best performance.\n",
    "    \"\"\"\n",
    "    with open(test_perf_path, 'r') as f:\n",
    "        perf_data = json.load(f)\n",
    "        \n",
    "    # perf_data is a list of scores corresponding to subsets in ascending order.\n",
    "    max_score = max(perf_data)\n",
    "    best_index = perf_data.index(max_score)\n",
    "    best_subset = trials[best_index] if best_index < len(trials) else \"Unknown subset\"\n",
    "    \n",
    "    return max_score, best_subset\n",
    "\n",
    "def get_best_models(models_root=\"models\"):\n",
    "    # Dictionary to keep track of the best model per (fusion_type, data_source)\n",
    "    # Format: best_models[(fusion_type, data_source)] = (best_score, model_path, best_subset)\n",
    "    best_models = {}\n",
    "\n",
    "    # Iterate over all items in the models_root directory\n",
    "    for model_dir in os.listdir(models_root):\n",
    "        model_path = os.path.join(models_root, model_dir)\n",
    "        if not os.path.isdir(model_path):\n",
    "            continue\n",
    "        \n",
    "        # Parse model name\n",
    "        fusion_type, data_source = parse_model_name(model_dir)\n",
    "        if not fusion_type or not data_source:\n",
    "            # Not matching our pattern, skip\n",
    "            continue\n",
    "        \n",
    "        # test_performances.json file path\n",
    "        test_perf_file = os.path.join(model_path, \"test_performances.json\")\n",
    "        if not os.path.exists(test_perf_file):\n",
    "            # No performance file found, skip\n",
    "            continue\n",
    "        \n",
    "        # Get performance (score and best subset)\n",
    "        trials = [f for f in os.listdir(model_path) if \"subset\" in f and \".json\" not in f]\n",
    "        score, best_subset = find_test_performance(test_perf_file,trials)\n",
    "        \n",
    "        # Check if this is the best model so far for (fusion_type, data_source)\n",
    "        key = (fusion_type, data_source)\n",
    "        if key not in best_models or score > best_models[key][0]:\n",
    "            best_models[key] = (score, model_path, best_subset)\n",
    "        print(model_dir)\n",
    "    # Print out the best models\n",
    "    for (fusion_type, data_source), (score, path, subset) in best_models.items():\n",
    "        print(f\"For (fusion_type={fusion_type}, data_source={data_source}), best model: {path} with score {score} on {subset}\")\n",
    "    return best_models\n",
    "    \n",
    "best_models = get_best_models(RESULTS_ROOT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71fd5d05-bbad-4c2b-a370-d9c1e9f4742f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_config_for(fusion_type, data_source,conf_root=\"/u/li19/superfuse/config/tweet\"):\n",
    "    # Map each fusion_type to the corresponding configuration file\n",
    "    config_map = {\n",
    "        \"concat\": \"concat_tweet.yaml\",\n",
    "        \"prodconcat\": \"product_tweet.yaml\",\n",
    "        \"tensorfusion\": \"tensorfusion_tweet.yaml\"\n",
    "    }\n",
    "\n",
    "    if fusion_type not in config_map:\n",
    "        raise ValueError(f\"Unknown fusion_type '{fusion_type}'. Supported types: {list(config_map.keys())}\")\n",
    "    \n",
    "    config_file = os.path.join(conf_root, config_map[fusion_type])\n",
    "\n",
    "    if not os.path.exists(config_file):\n",
    "        raise FileNotFoundError(f\"Configuration file {config_file} not found.\")\n",
    "\n",
    "    with open(config_file, 'r') as f:\n",
    "        config = yaml.safe_load(f)\n",
    "\n",
    "    # Extract the data and train configurations\n",
    "    # Assumes these keys exist in the YAML.\n",
    "    cfg_data = config.get('data', {})\n",
    "    cfg_train = config.get('train', {})\n",
    "\n",
    "    return cfg_data, cfg_train, config\n",
    "\n",
    "\n",
    "def add_noise_to_batch(input_tuple, noise_level=0.1, apply_to='text'):\n",
    "    \"\"\"\n",
    "    Adds Gaussian noise to one of the modalities in the input_tuple.\n",
    "    input_tuple is expected to be (Vv, Tt, Y, V).\n",
    "    noise_level: standard deviation of the Gaussian noise.\n",
    "    apply_to: which modality to apply noise to ('text' or 'image').\n",
    "    \"\"\"\n",
    "    Vv, Tt, Y, V = input_tuple\n",
    "\n",
    "    # Convert to float tensor for consistent operations\n",
    "    Vv = Vv.float()\n",
    "    Tt = Tt.float()\n",
    "    \n",
    "    v_params = torch.tensor([0.0333, 6.1028])\n",
    "    t_params = torch.tensor([0.0185, 0.3924])\n",
    "\n",
    "\n",
    "    if apply_to == 'text':\n",
    "        # Add noise to text embeddings\n",
    "        noise = torch.randn_like(Tt) * v_params[0] + v_params[1]\n",
    "        Tt_noisy = Tt + (noise * noise_level)\n",
    "        return (Vv, Tt_noisy, Y, V)\n",
    "    elif apply_to == 'image':\n",
    "        # Add noise to image embeddings\n",
    "        noise = torch.randn_like(Vv) * v_params[0] + v_params[1]\n",
    "        Vv_noisy = Vv + (noise * noise_level)\n",
    "        return (Vv_noisy, Tt, Y, V)\n",
    "    elif apply_to == 'both':\n",
    "        # Add noise to image embeddings\n",
    "        v_noise = torch.randn_like(Vv) * v_params[0] + v_params[1]\n",
    "        t_noise = torch.randn_like(Tt) * t_params[0] + t_params[1]\n",
    "        Vv_noisy = Vv + (v_noise * noise_level)\n",
    "        Tt_noisy = Tt + (t_noise * noise_level)\n",
    "        return (Vv_noisy, Tt_noisy, Y, V)\n",
    "    else:\n",
    "        # No noise if apply_to is invalid\n",
    "        return (Vv, Tt, Y, V)\n",
    "\n",
    "\n",
    "def evaluate_with_noise(engine, loader, noise_levels=[0.0, 0.05, 0.1, 0.2], apply_to='text'):\n",
    "    \"\"\"\n",
    "    Evaluate model performance at various noise levels.\n",
    "    \"\"\"\n",
    "    results = {}\n",
    "    engine.set_eval()\n",
    "\n",
    "    for nl in noise_levels:\n",
    "        ground_truth = []\n",
    "        predict = []\n",
    "\n",
    "        with torch.no_grad():\n",
    "            for batch_idx, input_tuple in enumerate(loader):\n",
    "                noisy_tuple = add_noise_to_batch(input_tuple, noise_level=nl, apply_to=apply_to)\n",
    "                output, labels = engine.forward_pass(noisy_tuple)\n",
    "\n",
    "                ground_truth.append(labels.long().cpu().data.numpy())\n",
    "                predict.append(output.cpu().data.numpy())\n",
    "\n",
    "        # Concatenate arrays\n",
    "        ground_truth = np.concatenate(ground_truth, axis=0)\n",
    "        predict = np.concatenate(predict, axis=0)\n",
    "\n",
    "        ce_loss = log_loss(ground_truth, predict)\n",
    "        predict_labels = np.argmax(predict, axis=1)\n",
    "        ground_truth_labels = ground_truth\n",
    "        \n",
    "        # Convert ground_truth to indices if it's one-hot\n",
    "        ground_truth_labels = np.argmax(ground_truth, axis=1)\n",
    "        predict_labels = np.argmax(predict, axis=1)\n",
    "        \n",
    "        accuracy = accuracy_score(ground_truth_labels, predict_labels)\n",
    "\n",
    "        results[nl] = {\n",
    "            'ce_loss': ce_loss,\n",
    "            'accuracy': accuracy\n",
    "        }\n",
    "\n",
    "    return results\n",
    "\n",
    "\n",
    "def trial(noise_levels = [0.0, 0.05, 0.1, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95], apply_noise_to = 'text'):\n",
    "    experiment_results = {}\n",
    "    print(f\"Running with noise on {apply_noise_to}\")\n",
    "    for (fusion_type, data_source), (best_score, model_path, best_subset) in tqdm(best_models.items()):\n",
    "        cfg_data, cfg_train, cfg = get_config_for(fusion_type, data_source)  # You need to define this\n",
    "\n",
    "        synth = False\n",
    "\n",
    "        if \"synth\" in data_source:\n",
    "            synth = True\n",
    "\n",
    "        if \"synthetic\" not in cfg[\"data\"].keys():\n",
    "            cfg[\"data\"][\"synthetic\"] = synth\n",
    "            cfg_data[\"synthetic\"] = synth\n",
    "\n",
    "        result_path = \"out\"\n",
    "\n",
    "        engine = FixTensorFusion(cfg, result_path)\n",
    "        engine.init_dataset(cfg_data)\n",
    "        engine.init_models(cfg_train)\n",
    "\n",
    "        subset_path = os.path.join(model_path, best_subset, \"model.pth.tar\")\n",
    "        # Load the best model weights\n",
    "        if not os.path.exists(subset_path):\n",
    "            print(\"Skipping: \", subset_path)\n",
    "            continue\n",
    "        checkpoint = torch.load(subset_path, map_location=device,weights_only=False)\n",
    "\n",
    "        for m in engine.trained_models:\n",
    "            if m != \"fusion\":\n",
    "                getattr(engine, m).load_state_dict(checkpoint[m])\n",
    "\n",
    "        engine.set_eval()\n",
    "\n",
    "        # Create dataloader for evaluation\n",
    "        # Decide if you want to evaluate on validation or test set\n",
    "        val_loader = torch.utils.data.DataLoader(\n",
    "            engine.val_set,\n",
    "            batch_size=cfg_data[\"val_batch_size\"],\n",
    "            shuffle=False\n",
    "        )\n",
    "\n",
    "        # Evaluate model performance under increasing noise levels\n",
    "        results = evaluate_with_noise(engine, val_loader, noise_levels=noise_levels, apply_to=apply_noise_to)\n",
    "        print(results)\n",
    "        experiment_results[(fusion_type, data_source)] = results\n",
    "    return experiment_results\n",
    "\n",
    "\n",
    "\n",
    "t_data = trial(apply_noise_to=\"text\")\n",
    "i_data = trial(apply_noise_to=\"image\")\n",
    "b_data = trial(apply_noise_to=\"both\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8073d2e0-a83b-4d55-85d0-270d81a05b2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams.update({\n",
    "    'font.size': 14,\n",
    "    'axes.labelsize': 14,\n",
    "    'axes.titlesize': 16,\n",
    "    'legend.fontsize': 12,\n",
    "    'xtick.labelsize': 12,\n",
    "    'ytick.labelsize': 12,\n",
    "    'font.family': 'sans-serif',\n",
    "    'font.sans-serif': ['DejaVu Sans'],\n",
    "    'figure.figsize': (12, 6)\n",
    "})\n",
    "\n",
    "fig, ax_acc = plt.subplots(1,3)\n",
    "\n",
    "# Define color for each fusion type\n",
    "fusion_colors = {\n",
    "    'prodconcat': 'green',\n",
    "    'concat': 'blue',\n",
    "    'tensorfusion': 'red'\n",
    "}\n",
    "\n",
    "# Define line style for each data source\n",
    "source_styles = {\n",
    "    'synth': '--',\n",
    "    'basic': '-',\n",
    "    'demo': '-'\n",
    "}\n",
    "\n",
    "source_markers = {\n",
    "    'synth': 'o',\n",
    "    'basic': 's',\n",
    "    'demo': 's'\n",
    "}\n",
    "\n",
    "all_res = [t_data, i_data, b_data]\n",
    "labels = [\"Textual Noise\", \"Image Noise\", \"Both Noised\"]\n",
    "noise_levels = [0.0, 0.05, 0.1, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95]\n",
    "\n",
    "for i, ax in enumerate(ax_acc.flatten()):\n",
    "    # Plot accuracy\n",
    "    results = all_res[i]\n",
    "    for key, noise_data in results.items():\n",
    "        fusion_type, data_source = key\n",
    "        # Extract accuracies for each noise level\n",
    "        accuracies = [noise_data[n]['accuracy'] for n in noise_levels]\n",
    "\n",
    "        color = fusion_colors[fusion_type]\n",
    "        linestyle = source_styles[data_source]\n",
    "        marker = source_markers[data_source]\n",
    "\n",
    "        ax.plot(\n",
    "            noise_levels, accuracies, \n",
    "            linestyle=linestyle, \n",
    "            color=color, \n",
    "            # marker=marker,\n",
    "            markersize=6,\n",
    "            linewidth=2,\n",
    "            label=f\"{fusion_type}-{data_source}\".replace(\"-\", \" \").replace(\"basic\", \"Real\").replace(\"prodconcat\", \"Product Concat\").replace(\"tensorfusion\", \"tensor-fusion\").title()\n",
    "        )\n",
    "\n",
    "    # Set y-axis label and subplot title\n",
    "    ax.set_title(labels[i])\n",
    "\n",
    "    # Only show x tick labels and x label on the last plot\n",
    "    if i > 0:\n",
    "        # Turn off x tick labels for all but the last subplot\n",
    "        ax.tick_params(axis='y', which='both', labelleft=False)\n",
    "        # ax.set_ylabel(\"Accuracy\")\n",
    "    else:\n",
    "        # For the last subplot, show the x-axis label\n",
    "        ax.set_ylabel(\"Accuracy\")\n",
    "\n",
    "    ax.set_xlabel(\"Noise Level\")\n",
    "\n",
    "    # Add a subtle grid\n",
    "    ax.grid(True, linestyle=':', linewidth=0.7, alpha=0.8)\n",
    "    ax.set_ylim(0, 1)\n",
    "\n",
    "    # Add legend only to the top plot\n",
    "    if i == 2:\n",
    "        # Get current handles and labels from the plot\n",
    "        handles, labels_legend = ax.get_legend_handles_labels()\n",
    "\n",
    "        # Sort by label\n",
    "        sorted_pairs = sorted(zip(labels_legend, handles), key=lambda x: x[0])\n",
    "        sorted_labels, sorted_handles = zip(*sorted_pairs)\n",
    "\n",
    "        # Now create the legend with the sorted entries\n",
    "        ax.legend(sorted_handles, sorted_labels, loc='best')\n",
    "\n",
    "# Adjust layout for better spacing\n",
    "plt.tight_layout()\n",
    "\n",
    "# Save and show the figure\n",
    "plt.savefig(\"model_performance_noise_accuracy_only.png\", dpi=300)\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "77b7b16a-bcf1-44db-a2a7-1eb9aa4a5361",
   "metadata": {},
   "outputs": [],
   "source": [
    "import dataset.tweet.dataset as tweet_set\n",
    "from datasets import load_dataset, Dataset\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "dc87b43c-168c-4e4f-94ff-bd5a4c73fca1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "subsets contain entire dataset, using entire dataset\n",
      "subsets contain entire dataset, using entire dataset\n",
      "subsets contain entire dataset, using entire dataset\n"
     ]
    }
   ],
   "source": [
    "with open(os.path.join(PROJECT_ROOT, \"config\", \"tweet\", \"concat_tweet.yaml\"), 'r') as f:\n",
    "        cfg = yaml.load(f, Loader=yaml.FullLoader)\n",
    "cfg['data']['num_subsets'] = 100_000_000\n",
    "\n",
    "cfg_data = cfg['data']\n",
    "synth = True\n",
    "\n",
    "if synth:\n",
    "    conf_name = \"synth\"\n",
    "else:\n",
    "    conf_name = \"basic\"\n",
    "\n",
    "if \"synthetic\" not in cfg[\"data\"].keys():\n",
    "    cfg[\"data\"][\"synthetic\"] = synth\n",
    "    cfg_data[\"synthetic\"] = synth\n",
    "\n",
    "train_set = tweet_set.Product(\"train\", cfg_data[\"num_subsets\"], synth=cfg_data[\"synthetic\"])\n",
    "test_set = tweet_set.Product(\"test\", cfg_data[\"num_subsets\"], synth=cfg_data[\"synthetic\"])\n",
    "val_set = tweet_set.Product(\"val\", cfg_data[\"num_subsets\"], synth=cfg_data[\"synthetic\"])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "ebf1e2d2-f978-4702-a6a3-afe27eeff723",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(768,)\n",
      "(768,)\n",
      "(2,)\n",
      "0\n"
     ]
    }
   ],
   "source": [
    "for el in train_set:\n",
    "    elem = el\n",
    "    break\n",
    "print(elem[0].shape)\n",
    "print(elem[1].shape)\n",
    "print(elem[2].shape)\n",
    "print(elem[3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "eef3ac0c-85b3-4b7c-aac4-585cd3336b56",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_list = []\n",
    "for el in train_set:\n",
    "    train_list.append({\n",
    "        'img': el[0], \n",
    "        'text': el[1], \n",
    "        'label': el[2], \n",
    "        'value': el[3]\n",
    "    })\n",
    "    \n",
    "test_list = []\n",
    "for el in test_set:\n",
    "    test_list.append({\n",
    "        'img': el[0], \n",
    "        'text': el[1], \n",
    "        'label': el[2], \n",
    "        'value': el[3]\n",
    "    })\n",
    "    \n",
    "val_list = []\n",
    "for el in val_set:\n",
    "    val_list.append({\n",
    "        'img': el[0], \n",
    "        'text': el[1], \n",
    "        'label': el[2], \n",
    "        'value': el[3]\n",
    "    })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "58e4e823-91a2-4fea-862b-e46e50eaa3e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dset = Dataset.from_list(train_list)\n",
    "test_dset = Dataset.from_list(test_list)\n",
    "val_dset = Dataset.from_list(val_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "53f3f964-fdaf-4da9-b7d0-f6703142e1b5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fae72aaa6fb045eeb64e32024f3ee0f8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1acec357006e4238a79b759ee59650c3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating parquet from Arrow format:   0%|          | 0/20 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c84068e9122d4f8e93837fb5d7ab5946",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "README.md:   0%|          | 0.00/2.09k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "67962a4ff7504f67b177b3031a5bdb95",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b718f980036a4bc59f36e4e3db567721",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating parquet from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "21f5d9b1c3a54f648bb5cf05879d495f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "README.md:   0%|          | 0.00/2.09k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ee1806bcad1d4c7da72c1b51a939fbfd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e7f826c9336a4b2590d015d14ca2e411",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Creating parquet from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7dc462b2d36343a89a340b6cb13a1722",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "README.md:   0%|          | 0.00/2.09k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "CommitInfo(commit_url='https://huggingface.co/datasets/icmlmossanonymousauthor2025/moss_submission_tweet_dataset/commit/004d8683bbf85e961a52f393a6518eaba63676d5', commit_message='Upload dataset', commit_description='', oid='004d8683bbf85e961a52f393a6518eaba63676d5', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/icmlmossanonymousauthor2025/moss_submission_tweet_dataset', endpoint='https://huggingface.co', repo_type='dataset', repo_id='icmlmossanonymousauthor2025/moss_submission_tweet_dataset'), pr_revision=None, pr_num=None)"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "write_token = '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'\n",
    "dset_path = 'icmlmossanonymousauthor2025/moss_submission_tweet_dataset'\n",
    "train_dset.push_to_hub(\n",
    "    dset_path, \n",
    "    token=write_token,\n",
    "    split='train',\n",
    "    config_name=conf_name\n",
    ")\n",
    "\n",
    "test_dset.push_to_hub(\n",
    "    dset_path, \n",
    "    token=write_token,\n",
    "    split='test',\n",
    "    config_name=conf_name\n",
    ")\n",
    "\n",
    "val_dset.push_to_hub(\n",
    "    dset_path, \n",
    "    token=write_token,\n",
    "    split='val',\n",
    "    config_name=conf_name\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "id": "c4a7f0f6-99ef-42ee-bd7f-875829f2c59f",
   "metadata": {},
   "outputs": [],
   "source": [
    "prod_train_dl = torch.utils.data.DataLoader(train_set,\n",
    "                                       4,\n",
    "                                       )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "165f57db-0194-409d-a0ad-646cf49082c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "read_token = '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'\n",
    "dset_path = 'icmlmossanonymousauthor2025/moss_submission_tweet_dataset'\n",
    "\n",
    "train_set = load_dataset(dset_path, \"basic\", split='train', token=read_token)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "97f1dd38-5185-44ea-aa05-0ecfe1a8105d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([16])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.Size([16, 768])\n",
    "torch.Size([16, 768])\n",
    "torch.Size([16, 2])\n",
    "torch.Size([16])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "3716262f-6fb0-438a-83f4-583733d40027",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([16, 768])\n",
      "torch.Size([16, 768])\n",
      "torch.Size([16, 2])\n",
      "torch.Size([16])\n"
     ]
    }
   ],
   "source": [
    "def collate(batch):\n",
    "    bzs = len(batch)\n",
    "    img = torch.tensor([batch[i]['img'] for i in range(bzs)])\n",
    "    text = torch.tensor([batch[i]['text'] for i in range(bzs)])\n",
    "    label = torch.tensor([batch[i]['label'] for i in range(bzs)])\n",
    "    value = torch.tensor([batch[i]['value'] for i in range(bzs)])\n",
    "    return (img, text, label, value)\n",
    "\n",
    "\n",
    "train_subset = train_set.select(range(100))\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(train_subset,\n",
    "                                           16,\n",
    "                                           collate_fn=collate)\n",
    "for batch_idx, input_tuple in enumerate(train_loader):\n",
    "    # print(input_tuple)\n",
    "    for el in input_tuple:\n",
    "        print(el.shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "e355e462-4d44-43a5-9bdf-5656c92ca777",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "32c71f37-4a07-4ee8-a71c-1a85c03c74c3",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "97e0bb3f0f9f47668fe7a6d5d6e977a8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "0it [00:00, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "for batch_idx, input_tuple in tqdm(enumerate(train_loader)):\n",
    "    pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0eb4dde8-23fc-4523-abd4-85d78e8910d2",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "diffuse",
   "language": "python",
   "name": "diffuse"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
