{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "c41c7ad2-56fe-4739-8df9-2d0b60e406dc",
   "metadata": {},
   "source": [
    "# compute the mean and std for dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "38ce3e52-52e3-4b81-990e-9b920d36b74b",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ck696/.conda/envs/H3/lib/python3.10/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: /home/ck696/.conda/envs/H3/lib/python3.10/site-packages/torchvision/image.so: undefined symbol: _ZN3c106detail19maybe_wrap_dim_slowEllb\n",
      "  warn(f\"Failed to load image Python extension: {e}\")\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "import argparse\n",
    "import yaml\n",
    "import logging\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.optim.lr_scheduler import CosineAnnealingLR\n",
    "import wandb\n",
    "import copy\n",
    "\n",
    "from datasets.dataloaders import get_dataset\n",
    "from datasets.transforms import data_transform\n",
    "from models.models import LinearModel_MNIST, LinearModel_CIFAR10, LinearModel_CIFAR100\n",
    "from utils.loss import C2Loss_Classification, gradient_centralization\n",
    "\n",
    "# Set up logging\n",
    "logging.basicConfig(level=logging.INFO)\n",
    "logger = logging.getLogger(__name__)\n",
    "info = logger.info\n",
    "\n",
    "# Argument Parser\n",
    "parser = argparse.ArgumentParser(description='Autoencoder Training')\n",
    "parser.add_argument('--yaml', type=str, default='./yaml/linear_cifar100_ccl.yaml', help='Path to the YAML configuration file')\n",
    "args, _ = parser.parse_known_args()\n",
    "\n",
    "with open(args.yaml, 'r') as file:\n",
    "    yaml_config = yaml.safe_load(file)\n",
    "\n",
    "# Set default arguments from YAML\n",
    "parser.set_defaults(**yaml_config['method'])\n",
    "parser.set_defaults(**yaml_config['dataset'])\n",
    "parser.set_defaults(**yaml_config['training'])\n",
    "args, _ = parser.parse_known_args()\n",
    "\n",
    "train_loader, val_loader, test_loader = get_dataset(args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "55f158a0-a0da-4b54-859b-ed5b0cd45c2a",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'loader' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[2], line 9\u001b[0m\n\u001b[1;32m      6\u001b[0m     mean \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m images\u001b[38;5;241m.\u001b[39mmean(\u001b[38;5;241m2\u001b[39m)\u001b[38;5;241m.\u001b[39msum(\u001b[38;5;241m0\u001b[39m)\n\u001b[1;32m      7\u001b[0m     std \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m images\u001b[38;5;241m.\u001b[39mstd(\u001b[38;5;241m2\u001b[39m)\u001b[38;5;241m.\u001b[39msum(\u001b[38;5;241m0\u001b[39m)\n\u001b[0;32m----> 9\u001b[0m mean \u001b[38;5;241m/\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlen\u001b[39m(\u001b[43mloader\u001b[49m\u001b[38;5;241m.\u001b[39mdataset)\n\u001b[1;32m     10\u001b[0m std \u001b[38;5;241m/\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlen\u001b[39m(loader\u001b[38;5;241m.\u001b[39mdataset)\n",
      "\u001b[0;31mNameError\u001b[0m: name 'loader' is not defined"
     ]
    }
   ],
   "source": [
    "mean = 0.\n",
    "std = 0.\n",
    "for images, _ in train_loader:\n",
    "    batch_samples = images.size(0) # batch size (the last batch can have smaller size!)\n",
    "    images = images.view(batch_samples, images.size(1), -1)\n",
    "    mean += images.mean(2).sum(0)\n",
    "    std += images.std(2).sum(0)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "fcd2b0cf-bb74-4e9c-a728-370cd9e8b08e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "mean /= len(train_loader.dataset)\n",
    "std /= len(train_loader.dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "7d8650b6-5bff-4f4a-b476-90b66cb2a6c5",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MNIST\n",
      "torch.Size([60000, 28, 28])\n",
      "tensor(0.1307)\n",
      "tensor(0.3081)\n",
      "\n",
      " FashionMNIST\n",
      "torch.Size([60000, 28, 28])\n",
      "tensor(0.2860)\n",
      "tensor(0.3530)\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torchvision import datasets, transforms\n",
    "\n",
    "print(\"MNIST\")\n",
    "train_dataset_raw = datasets.MNIST(root='./data', train=True, download=True)\n",
    "print(train_dataset_raw.train_data.shape)\n",
    "print(train_dataset_raw.train_data.float().mean(axis=(0,1,2))/255)\n",
    "print(train_dataset_raw.train_data.float().std(axis=(0,1,2))/255)\n",
    "\n",
    "print(\"\\n FashionMNIST\")\n",
    "train_dataset_raw = datasets.FashionMNIST(root='./data', train=True, download=True)\n",
    "print(train_dataset_raw.train_data.shape)\n",
    "print(train_dataset_raw.train_data.float().mean(axis=(0,1,2))/255)\n",
    "print(train_dataset_raw.train_data.float().std(axis=(0,1,2))/255)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "1346c0b5-4a04-4bfa-9bed-183b3c90c0a6",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      " CIFAR10\n",
      "Files already downloaded and verified\n",
      "(50000, 32, 32, 3)\n",
      "[0.49139968 0.48215841 0.44653091]\n",
      "[0.24703223 0.24348513 0.26158784]\n",
      "\n",
      " CIFAR100\n",
      "Files already downloaded and verified\n",
      "(50000, 32, 32, 3)\n",
      "[0.50707516 0.48654887 0.44091784]\n",
      "[0.26733429 0.25643846 0.27615047]\n"
     ]
    }
   ],
   "source": [
    "print(\"\\n CIFAR10\")\n",
    "train_dataset_raw = datasets.CIFAR10(root='./data', train=True, download=True)\n",
    "print(train_dataset_raw.data.shape)\n",
    "print(train_dataset_raw.data.mean(axis=(0,1,2))/255)\n",
    "print(train_dataset_raw.data.std(axis=(0,1,2))/255)\n",
    "\n",
    "print(\"\\n CIFAR100\")\n",
    "train_dataset_raw = datasets.CIFAR100(root='./data', train=True, download=True)\n",
    "print(train_dataset_raw.data.shape)\n",
    "print(train_dataset_raw.data.mean(axis=(0,1,2))/255)\n",
    "print(train_dataset_raw.data.std(axis=(0,1,2))/255)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df90c913-12ef-471b-88e0-d7a7e8b0a964",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
