{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7fd79e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import os\n",
    "from tqdm.notebook import tqdm\n",
    "import pickle\n",
    "from torchvision import datasets, transforms\n",
    "import torch.optim.lr_scheduler as lr_scheduler\n",
    "import matplotlib.pyplot as plt\n",
    "%config InlineBackend.figure_formats = ['svg']\n",
    "\n",
    "\n",
    "SAVED_RUNS_PATH = 'saved_data/'\n",
    "EXP_PATH = 'exps_setup/'\n",
    "\n",
    "def save_run_func(suffix, run):\n",
    "    if not os.path.isdir(SAVED_RUNS_PATH):\n",
    "        os.mkdir(SAVED_RUNS_PATH)\n",
    "    file = SAVED_RUNS_PATH + suffix + '.pickle'\n",
    "    with open(file, 'wb') as f:\n",
    "        pickle.dump(run, f)\n",
    "        \n",
    "def load_run_func(suffix=''):\n",
    "    file = SAVED_RUNS_PATH + suffix + '.pickle'\n",
    "    with open(file, 'rb') as f:\n",
    "        run = pickle.load(f)\n",
    "    return run\n",
    "\n",
    "\n",
    "# Define the MLP model\n",
    "class MLP(nn.Module):\n",
    "    def __init__(self, input_size, hidden_size, output_size):\n",
    "        super(MLP, self).__init__()\n",
    "        self.fc1 = nn.Linear(input_size, hidden_size)\n",
    "        self.relu = nn.ReLU()\n",
    "        self.fc2 = nn.Linear(hidden_size, hidden_size)\n",
    "        self.fc3 = nn.Linear(hidden_size, output_size)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.fc1(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.fc2(x)\n",
    "        x = self.relu(x)\n",
    "        x = self.fc3(x)\n",
    "        return x\n",
    "\n",
    "# Function to calculate accuracy\n",
    "def calculate_accuracy(output, target):\n",
    "    _, predicted = torch.max(output, 1)\n",
    "    correct = (predicted == target).sum().item()\n",
    "    accuracy = correct / target.size(0)\n",
    "    return accuracy\n",
    "\n",
    "# Function to flatten parameters\n",
    "def flatten_parameters(model):\n",
    "    return torch.cat([param.flatten() for param in model.parameters()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab114b02",
   "metadata": {},
   "outputs": [],
   "source": [
    "if torch.cuda.is_available():\n",
    "    device=torch.device('cuda:0')\n",
    "else:\n",
    "    device=torch.device('cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "398579df",
   "metadata": {},
   "outputs": [],
   "source": [
    "device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57e847fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "hidden_size = 32\n",
    "random_seed = 1999"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "869241b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Hyperparameters\n",
    "input_size = 28 * 28  # FashionMNIST image size\n",
    "output_size = 10\n",
    "lr = 0.09\n",
    "epochs = 1500\n",
    "schedule = 0\n",
    "\n",
    "\n",
    "# Load MNIST dataset\n",
    "net = 'MLP'\n",
    "# Set random seed for reproducibility\n",
    "\n",
    "torch.manual_seed(random_seed)\n",
    "np.random.seed(random_seed)\n",
    "import os\n",
    "os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":16:8\"\n",
    "\n",
    "model = MLP(input_size, hidden_size, output_size).to(device)\n",
    "transform = transforms.Compose([transforms.ToTensor(), \n",
    "                                #transforms.Resize((24,24)), \n",
    "                                transforms.Lambda(lambda x: x.view(-1))])\n",
    "\n",
    "train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)\n",
    "train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef530261",
   "metadata": {},
   "outputs": [],
   "source": [
    "val_loader = torch.utils.data.DataLoader(train_dataset, batch_size=256, shuffle=True)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.SGD(model.parameters(), lr=lr)\n",
    "scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(train_loader)*epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14de9260",
   "metadata": {},
   "outputs": [],
   "source": [
    "#torch.save(model, './save/FashionMNIST_model_star_{}_{}_{}_{}_{}.pt'.format(epochs, lr, hidden_size, random_seed, schedule))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68a4bce2",
   "metadata": {},
   "outputs": [],
   "source": [
    "try: \n",
    "    print('./save/FashionMNIST_model_star_{}_{}_{}_{}_{}.pt'.format(epochs, lr, hidden_size, random_seed, schedule))\n",
    "    model_star = torch.load('./save/FashionMNIST_model_star_{}_{}_{}_{}_{}.pt'.format(epochs, lr, hidden_size, random_seed, schedule))\n",
    "except:\n",
    "    print('No saved trained model')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36701913",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_random = np.random.randint(0,10000000)\n",
    "train_folder = f'./result/FashioMNIST_{n_random}_train_{epochs}_{random_seed}_{lr}_{hidden_size}_{schedule}.pickle'\n",
    "check_folder = f'./result/FashioMNIST_{n_random}_check_{epochs}_{random_seed}_{lr}_{hidden_size}_{schedule}.pickle'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d09918fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_folder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22e8cbe9",
   "metadata": {},
   "outputs": [],
   "source": [
    "check_folder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5f20de9c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Training loop\n",
    "train_losses = []\n",
    "val_losses = []\n",
    "parameters = []\n",
    "gradients = []\n",
    "accuracies = []\n",
    "\n",
    "train_logger = {'epoch':[],\n",
    "                'scorr':[],\n",
    "                'sloss':[],\n",
    "                'sangle':[],\n",
    "                'accuracy':[]\n",
    "}\n",
    "check_logger = {'epoch':[],\n",
    "                'loss':[],\n",
    "                'angle':[],\n",
    "                'corr':[],\n",
    "                \n",
    "    \n",
    "}\n",
    "\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    for batch_idx, (data, target) in enumerate(train_loader):\n",
    "        # Flatten input data\n",
    "        if net=='MLP':\n",
    "            data = data.view(-1, input_size)\n",
    "\n",
    "        data = data.to(device)\n",
    "        target = target.to(device)\n",
    "        # Forward pass\n",
    "        output = model(data.to(device))\n",
    "        loss = criterion(output, target)\n",
    "                        \n",
    "        # Compute CrossEntropyLoss\n",
    "        loss = criterion(output, target)\n",
    "\n",
    "        # Backward pass and optimization\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward(retain_graph=True)\n",
    "        optimizer.step()\n",
    "        if schedule:\n",
    "            scheduler.step()\n",
    "            \n",
    "        # Calculate accuracy\n",
    "        accuracy = calculate_accuracy(output, target)\n",
    "\n",
    "\n",
    "        # Print progress\n",
    "        if batch_idx % len(train_loader) == 0 and epoch % 5 == 0:\n",
    "            print(f'Epoch {epoch + 1}/{epochs}, Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item()}, Accuracy: {accuracy}, LR: {scheduler.get_lr()[0]}')\n",
    "\n",
    "            print('Compute quantities') \n",
    "                \n",
    "            sgrad_norm = 0.\n",
    "            sdist = 0.\n",
    "            sgrad_corr = 0.\n",
    "            for p, p_star in zip(model.parameters(), model_star.parameters()):\n",
    "                p_grad = p.grad.data.view(-1)\n",
    "                p_param = p.data.view(-1)\n",
    "                p_param_star = p_star.data.view(-1)\n",
    "                sgrad_norm += p_grad.dot(p_grad).cpu()\n",
    "                sgrad_corr += p_grad.dot(p_param - p_param_star).cpu()\n",
    "                sdist += (p_param - p_param_star).dot(p_param - p_param_star).cpu()\n",
    "\n",
    "            sgrad_norm = np.sqrt(sgrad_norm)\n",
    "            sdist = np.sqrt(sdist)\n",
    "            sgrad_angle = sgrad_corr/sgrad_norm/sdist\n",
    "\n",
    "            train_logger['epoch'].append(epoch)\n",
    "            train_logger['scorr'].append(sgrad_corr.item())\n",
    "            train_logger['sloss'].append(loss.item())\n",
    "            train_logger['sangle'].append(sgrad_angle.item())\n",
    "            train_logger['accuracy'].append(accuracy)\n",
    "            \n",
    "            print('| stoch_grad_corr {:5f} | stoch_loss {:5f} | stoch_angle {:5f} |'.format(sgrad_corr, loss.item(),\n",
    "                                                                                            sgrad_angle))\n",
    "                    \n",
    "\n",
    "    print('End of epoch')    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76ab5387",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(model, './save/FashionMNIST_model_star_{}_{}_{}_{}_{}.pt'.format(epochs, lr, hidden_size, random_seed, schedule))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "290f05ba",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
