{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7fd79e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import os\n",
    "from tqdm.notebook import tqdm\n",
    "import pickle\n",
    "from torchvision import datasets, transforms\n",
    "import torch.optim.lr_scheduler as lr_scheduler\n",
    "import matplotlib.pyplot as plt\n",
    "%config InlineBackend.figure_formats = ['svg']\n",
    "\n",
    "\n",
    "SAVED_RUNS_PATH = 'saved_data/'\n",
    "EXP_PATH = 'exps_setup/'\n",
    "torch.use_deterministic_algorithms(True)\n",
    "\n",
    "def save_run_func(suffix, run):\n",
    "    if not os.path.isdir(SAVED_RUNS_PATH):\n",
    "        os.mkdir(SAVED_RUNS_PATH)\n",
    "    file = SAVED_RUNS_PATH + suffix + '.pickle'\n",
    "    with open(file, 'wb') as f:\n",
    "        pickle.dump(run, f)\n",
    "        \n",
    "def load_run_func(suffix=''):\n",
    "    file = SAVED_RUNS_PATH + suffix + '.pickle'\n",
    "    with open(file, 'rb') as f:\n",
    "        run = pickle.load(f)\n",
    "    return run\n",
    "\n",
    "\n",
    "# Define the CNN model    \n",
    "class SimpleCNN(nn.Module):\n",
    "    def __init__(self, conv_size1):\n",
    "        super(SimpleCNN, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)\n",
    "        self.relu = nn.ReLU()\n",
    "        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)\n",
    "        self.conv2 = nn.Conv2d(32, conv_size1, kernel_size=3, stride=1, padding=1)\n",
    "        self.fc1 = nn.Linear(8 * 8 * conv_size1, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        #print(x.shape)\n",
    "        x = self.conv1(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.maxpool(x)\n",
    "        x = self.conv2(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.maxpool(x)\n",
    "        #print(x.shape)\n",
    "        x = x.view(x.size(0), -1)\n",
    "        x = self.fc1(x)\n",
    "        return x\n",
    "\n",
    "# Function to calculate accuracy\n",
    "def calculate_accuracy(output, target):\n",
    "    _, predicted = torch.max(output, 1)\n",
    "    correct = (predicted == target).sum().item()\n",
    "    accuracy = correct / target.size(0)\n",
    "    return accuracy\n",
    "\n",
    "# Function to flatten parameters\n",
    "def flatten_parameters(model):\n",
    "    return torch.cat([param.flatten() for param in model.parameters()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab114b02",
   "metadata": {},
   "outputs": [],
   "source": [
    "if torch.cuda.is_available():\n",
    "    device=torch.device('cuda:0')\n",
    "else:\n",
    "    device=torch.device('cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57e847fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9e67916-c114-43bb-bc99-18abe3d025bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "random_seed = 1970\n",
    "conv_size1 = 128"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "869241b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Hyperparameters\n",
    "input_size = 28 * 28  # FashionMNIST image size\n",
    "output_size = 10\n",
    "lr = 0.01\n",
    "epochs = 1000\n",
    "schedule = 1\n",
    "\n",
    "\n",
    "# Load FashionMNIST dataset\n",
    "net = 'CNN'\n",
    "# Set random seed for reproducibility\n",
    "torch.manual_seed(random_seed)\n",
    "np.random.seed(random_seed)\n",
    "import os\n",
    "os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":16:8\"\n",
    "\n",
    "model = SimpleCNN(conv_size1).to(device)\n",
    "transform = transforms.Compose([transforms.ToTensor()])\n",
    "\n",
    "train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)\n",
    "train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef530261",
   "metadata": {},
   "outputs": [],
   "source": [
    "val_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.SGD(model.parameters(), lr=lr)\n",
    "scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader)*epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14de9260",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(model, './save/model_star_{}_{}_{}_{}_{}.pt'.format(epochs, lr, conv_size1, random_seed, schedule))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68a4bce2",
   "metadata": {},
   "outputs": [],
   "source": [
    "try: \n",
    "    print('./save/model_star_{}_{}_{}_{}_{}.pt'.format(epochs, lr, conv_size1, random_seed, schedule))\n",
    "    model_star = torch.load('./save/model_star_{}_{}_{}_{}_{}.pt'.format(epochs, lr, conv_size1, random_seed, schedule)).to(device)\n",
    "except:\n",
    "    print('No saved trained model')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36701913",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_random = np.random.randint(0,10000000)\n",
    "train_folder = f'./result/CIFAR10_fixed_{n_random}_train_{epochs}_{random_seed}_{lr}_{conv_size1}_{schedule}.pickle'\n",
    "check_folder = f'./result/CIFAR10_fixed_{n_random}_check_{epochs}_{random_seed}_{lr}_{conv_size1}_{schedule}.pickle'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d09918fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_folder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22e8cbe9",
   "metadata": {},
   "outputs": [],
   "source": [
    "check_folder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "452853f7",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Training loop\n",
    "train_losses = []\n",
    "val_losses = []\n",
    "parameters = []\n",
    "gradients = []\n",
    "accuracies = []\n",
    "\n",
    "train_logger = {'epoch':[],\n",
    "                'scorr':[],\n",
    "                'sloss':[],\n",
    "                'sangle':[],\n",
    "                'accuracy':[]\n",
    "}\n",
    "check_logger = {'epoch':[],\n",
    "                'loss':[],\n",
    "                'angle':[],\n",
    "                'corr':[],\n",
    "                \n",
    "    \n",
    "}\n",
    "\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    for batch_idx, (data, target) in enumerate(train_loader):\n",
    "        # Flatten input data\n",
    "        if net=='MLP':\n",
    "            data = data.view(-1, input_size)\n",
    "\n",
    "        data = data.to(device)\n",
    "        target = target.to(device)\n",
    "        # Forward pass\n",
    "        output = model(data.to(device))\n",
    "        loss = criterion(output, target)\n",
    "                        \n",
    "        # Compute CrossEntropyLoss\n",
    "        loss = criterion(output, target)\n",
    "\n",
    "        # Backward pass and optimization\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward(retain_graph=True)\n",
    "        \n",
    "            \n",
    "        # Calculate accuracy\n",
    "        accuracy = calculate_accuracy(output, target)\n",
    "\n",
    "\n",
    "        # Print progress\n",
    "        if batch_idx % len(train_loader) == 0 and epoch % 5 == 0:\n",
    "            print(f'Epoch {epoch + 1}/{epochs}, Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item()}, Accuracy: {accuracy}, LR: {scheduler.get_lr()[0]}')\n",
    "\n",
    "            print('Compute quantities') \n",
    "                \n",
    "            sgrad_norm = 0.\n",
    "            sdist = 0.\n",
    "            sgrad_corr = 0.\n",
    "            for p, p_star in zip(model.parameters(), model_star.parameters()):\n",
    "                p_grad = p.grad.data.view(-1)\n",
    "                p_param = p.data.view(-1)\n",
    "                p_param_star = p_star.data.view(-1)\n",
    "                sgrad_norm += p_grad.dot(p_grad).cpu()\n",
    "                sgrad_corr += p_grad.dot(p_param - p_param_star).cpu()\n",
    "                sdist += (p_param - p_param_star).dot(p_param - p_param_star).cpu()\n",
    "\n",
    "            sgrad_norm = np.sqrt(sgrad_norm)\n",
    "            sdist = np.sqrt(sdist)\n",
    "            sgrad_angle = sgrad_corr/sgrad_norm/sdist\n",
    "\n",
    "            train_logger['epoch'].append(epoch)\n",
    "            train_logger['scorr'].append(sgrad_corr.item())\n",
    "            train_logger['sloss'].append(loss.item())\n",
    "            train_logger['sangle'].append(sgrad_angle.item())\n",
    "            train_logger['accuracy'].append(accuracy)\n",
    "            \n",
    "            print('| stoch_grad_corr {:5f} | stoch_loss {:5f} | stoch_angle {:5f} |'.format(sgrad_corr, loss.item(),\n",
    "                                                                                            sgrad_angle))\n",
    "        \n",
    "        optimizer.step()\n",
    "        if schedule:\n",
    "            scheduler.step()\n",
    "\n",
    "        if batch_idx % len(train_loader) == 0 and epoch % 5 == 0:\n",
    "            # Save loss values, gradients, and parameter values\n",
    "            val_loss = 0.\n",
    "            \n",
    "            optimizer.zero_grad()\n",
    "            for batch_idx_val, (data_val, target_val) in enumerate(val_loader):\n",
    "                data_val = data_val.to(device)\n",
    "                target_val = target_val.to(device)\n",
    "                # Forward pass\n",
    "                output_val = model(data_val)\n",
    "        \n",
    "                # Compute CrossEntropyLoss\n",
    "                loss_val = criterion(output_val, target_val)        \n",
    "                \n",
    "                loss_val.backward(retain_graph=True)\n",
    "                val_loss += loss_val.item()\n",
    "                \n",
    "            grad_norm = 0.\n",
    "            dist = 0.\n",
    "            grad_corr = 0.\n",
    "            for p, p_star in zip(model.parameters(), model_star.parameters()):\n",
    "                p_grad = p.grad.data.view(-1)\n",
    "                p_param = p.data.view(-1)\n",
    "                p_param_star = p_star.data.view(-1)\n",
    "                grad_norm += p_grad.dot(p_grad).cpu()\n",
    "                grad_corr += p_grad.dot(p_param - p_param_star).cpu()\n",
    "                dist += (p_param - p_param_star).dot(p_param - p_param_star).cpu()\n",
    "\n",
    "            grad_norm = np.sqrt(grad_norm)\n",
    "            dist = np.sqrt(dist)\n",
    "            grad_angle = grad_corr/grad_norm/dist\n",
    "        \n",
    "            check_logger['epoch'].append(epoch)\n",
    "            check_logger['corr'].append(grad_corr.item()/len(val_loader))\n",
    "            check_logger['loss'].append(val_loss/len(val_loader))\n",
    "            check_logger['angle'].append(grad_angle.item())\n",
    "            \n",
    "            print('| grad_corr {:5f} | loss {:5f} | angle {:5f} |'.format(grad_corr/len(val_loader), \n",
    "                                                                          val_loss/len(val_loader), \n",
    "                                                                          grad_angle))\n",
    "            \n",
    "                    \n",
    "\n",
    "    print('End of epoch')    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dcadc87d",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(model, './save/model_star_{}_{}_{}_{}_{}.pt'.format(epochs, lr, conv_size1, random_seed, schedule))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
