{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import data\n",
    "import torch\n",
    "import dotenv\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dotenv.load_dotenv()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def calculate_stats(loader: data.DatasetLoader):\n",
    "    loader.prepare_raw_data()\n",
    "\n",
    "    train_images, _ = loader.load_train_data()\n",
    "\n",
    "    std, mean = torch.std_mean(train_images, dim=(0, 2, 3))\n",
    "    print(f\"{type(loader).__name__} std: {std}\")\n",
    "    print(f\"{type(loader).__name__} mean: {mean}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CIFAR10Loader std: tensor([0.2470, 0.2435, 0.2616])\n",
      "CIFAR10Loader mean: tensor([0.4914, 0.4822, 0.4465])\n",
      "MNISTLoader std: tensor([0.3081])\n",
      "MNISTLoader mean: tensor([0.1307])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/aernim/repos/canaries/canary_constructor/eval/data.py:59: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  targets = torch.tensor(raw_dataset.targets).to(torch.int64)\n"
     ]
    }
   ],
   "source": [
    "calculate_stats(data.CIFAR10Loader())\n",
    "calculate_stats(data.MNISTLoader())\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
