{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "rXC93PvZHALw"
   },
   "source": [
    "# Import e path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "id": "bBKm4O3pn19s"
   },
   "outputs": [],
   "source": [
    "models_save_path = \"Models/\"\n",
    "results_save_path = \"Results/\"\n",
    "dataset_path = \"../../../Datasets/\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision\n",
    "import torch\n",
    "import os\n",
    "import time\n",
    "import csv\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import shutil\n",
    "from random import randint\n",
    "import pandas as pd\n",
    "from PIL import Image\n",
    "import random\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as functional\n",
    "from torch.nn.utils.rnn import pad_sequence\n",
    "from torch.utils.data import DataLoader, TensorDataset, Dataset\n",
    "from torchvision.datasets import ImageFolder\n",
    "from einops import rearrange\n",
    "from einops.layers.torch import Rearrange\n",
    "import torchvision.transforms as transforms\n",
    "from torchvision.utils import make_grid\n",
    "import torch.optim.lr_scheduler as lr_scheduler\n",
    "from vit_pytorch.ats_vit import ViT as ATS\n",
    "\n",
    "from vit_pytorch import ViT\n",
    "from vit_pytorch.vit_with_patch_merger import ViT as PatchMerger\n",
    "from vit_pytorch import SimpleViT as SV"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Xi1SBM-8TRqg"
   },
   "source": [
    "<h3> Validi per ogni modello"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "id": "zIBVHD-TTU6u"
   },
   "outputs": [],
   "source": [
    "def set_seed(seed):\n",
    "    random.seed(seed)\n",
    "    os.environ['PYTHONHASHSEED'] = str(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed(seed)\n",
    "    torch.cuda.manual_seed_all(seed)\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "\n",
    "set_seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "id": "qL0i8NnGpUPc"
   },
   "outputs": [],
   "source": [
    "# Verifica se la GPU è disponibile\n",
    "device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')\n",
    "# device = torch.device('cpu')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "sG2VBWH-HmoI"
   },
   "source": [
    "# Funzioni per train e validation del modello SAMPLING"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "id": "SiazmHl-QtQ_"
   },
   "outputs": [],
   "source": [
    "def train_iter(model, optimz, data_load, loss_val, device, scheduler):\n",
    "    samples = len(data_load.dataset)\n",
    "    model.train()\n",
    "\n",
    "    for i, (data, target) in enumerate(data_load):\n",
    "        data = data.to(device)\n",
    "        target = target.to(device)\n",
    "\n",
    "        optimz.zero_grad()\n",
    "        out = functional.log_softmax(model(data), dim=1)\n",
    "        loss = functional.nll_loss(out, target)\n",
    "        loss.backward()\n",
    "        optimz.step()\n",
    "    \n",
    "        if i % 100 == 0:\n",
    "            print('[' +  '{:5}'.format(i * len(data)) + '/' + '{:5}'.format(samples) +\n",
    "                  ' (' + '{:3.0f}'.format(100 * i / len(data_load)) + '%)]  Loss: ' +\n",
    "                  '{:6.4f}'.format(loss.item()))\n",
    "    scheduler.step()\n",
    "    print(scheduler.get_last_lr())\n",
    "    loss_val.append(loss.item())\n",
    "\n",
    "def evaluate(model, optimizer, data_load, loss_val, device):\n",
    "    model.eval()\n",
    "\n",
    "    samples = len(data_load.dataset)\n",
    "    # predizioni corrette\n",
    "    csamp = 0\n",
    "    tloss = 0\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for data, target in data_load:\n",
    "\n",
    "            data = data.to(device)\n",
    "            target = target.to(device)\n",
    "\n",
    "            output = functional.log_softmax(model(data), dim=1)\n",
    "            loss = functional.nll_loss(output, target, reduction='sum')\n",
    "            _, pred = torch.max(output, dim=1)\n",
    "\n",
    "            tloss += loss.item()\n",
    "            csamp += pred.eq(target).sum()\n",
    "\n",
    "    aloss = tloss / samples\n",
    "    loss_val.append(aloss)\n",
    "    acc = (100.0 * csamp / samples).cpu()\n",
    "\n",
    "    print('\\nAverage test loss: ' + '{:.4f}'.format(aloss) +\n",
    "          '  Accuracy:' + '{:5}'.format(csamp) + '/' +\n",
    "          '{:5}'.format(samples) + ' (' +\n",
    "          '{:4.2f}'.format(acc) + '%)\\n')\n",
    "\n",
    "    return acc\n",
    "\n",
    "def train_validation(model, optimizer, train_loader, validation_loader, models_save_path, nome_dataset, epoche,scheduler, device):\n",
    "  tr_loss, ts_loss, ts_acc, epoch_time_list = [], [], [], []\n",
    "\n",
    "  for epoch in range(1, epoche + 1):\n",
    "\n",
    "      start_time = time.time()\n",
    "\n",
    "      print(f'Epoch: {epoch}/{epoche}')\n",
    "      print(\"INIZIO TRAINING\")\n",
    "      train_iter(model, optimizer, train_loader, tr_loss, device, scheduler= scheduler)\n",
    "      print(\"INIZIO VALIDATION\")\n",
    "      acc = evaluate(model, optimizer, validation_loader, ts_loss, device)\n",
    "\n",
    "      if (not ts_acc or acc >= max(ts_acc)):\n",
    "        checkpoint = {'model_state_dict': model.state_dict(),\n",
    "                      'optimizer_state_dict': optimizer.state_dict(),\n",
    "                      'train_loss_state_dict': tr_loss[-1],\n",
    "                      'val_loss_state_dict': ts_loss[-1],\n",
    "                      'val_acc_state_dict': acc\n",
    "                      }\n",
    "        print(f'Saving Best Accuracy Model')\n",
    "        torch.save(checkpoint, f'{models_save_path}{nome_dataset}')\n",
    "        print(f'End of Saving \\n')\n",
    "\n",
    "      ts_acc.append(acc)\n",
    "\n",
    "\n",
    "      epoch_time = time.time() - start_time\n",
    "      epoch_time_list.append(epoch_time)\n",
    "\n",
    "      print('Execution time:', '{:5.2f}'.format(epoch_time), 'seconds')\n",
    "      print(\"#\"*40)\n",
    "\n",
    "  return tr_loss, ts_loss, ts_acc, epoch_time_list"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 64\n",
    "img_size = 160\n",
    "patch_size = 16\n",
    "\n",
    "mean = (0.485, 0.456, 0.406)\n",
    "std = (0.229, 0.224, 0.225)\n",
    "\n",
    "trans = transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.size(0)==1 else x)\n",
    "\n",
    "\n",
    "transform_train = transforms.Compose([\n",
    "    transforms.Resize((img_size, img_size)),\n",
    "    transforms.RandomRotation(10),  # Random rotation by 10 degrees\n",
    "    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # Color Jitter\n",
    "    # transforms.RandomPerspective(distortion_scale=0.5, p=0.5),  # Random perspective\n",
    "    transforms.RandomVerticalFlip(p=0.5),  # Vertical flip with 50% probability\n",
    "    transforms.RandomResizedCrop(img_size),\n",
    "    transforms.RandomHorizontalFlip(),\n",
    "    transforms.ToTensor(),\n",
    "    trans,\n",
    "    transforms.Normalize(mean, std),\n",
    "])\n",
    "\n",
    "transform_validation = transforms.Compose([\n",
    "    transforms.Resize((img_size, img_size)),\n",
    "    transforms.ToTensor(),\n",
    "    trans,\n",
    "    transforms.Normalize(mean, std),\n",
    "])\n",
    "\n",
    "# Load the dataset\n",
    "train_dataset = ImageFolder(root=f'{dataset_path}/imagenette2/train', transform=transform_train)\n",
    "val_dataset =ImageFolder(root=f'{dataset_path}/imagenette2/val', transform=transform_validation)\n",
    "\n",
    "train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers = 4, shuffle=True, pin_memory = True)\n",
    "val_loader = DataLoader(val_dataset, batch_size=batch_size,num_workers = 4,pin_memory = True )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Modello"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from DWTViT_gini import DWTViT_gini\n",
    "from DWTViT_pruning import DWTViT_pruning\n",
    "from DWTViT_quantile import DWTViT_quantile"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def ViT_DWT(nome_dataset, pruning_locations, dim, heads, wavelet, epoche, model_type='pruning', strategy='cA'):\n",
    "    learning_rate = 0.0001\n",
    "\n",
    "    # Instanzia il modello in base al tipo specificato\n",
    "    if model_type == 'gini':\n",
    "        model = DWTViT_gini(\n",
    "            image_size=img_size,\n",
    "            patch_size=patch_size,\n",
    "            num_classes=10,\n",
    "            dim=dim,\n",
    "            depth=12,\n",
    "            heads=heads,\n",
    "            mlp_dim=dim*4,\n",
    "            dropout=0,\n",
    "            emb_dropout=0,\n",
    "            wavelet=wavelet,\n",
    "            pruning_locations=pruning_locations,\n",
    "        )\n",
    "    elif model_type == 'pruning':\n",
    "        if strategy is None:\n",
    "            raise ValueError(\"Per DWTViT_pruning devi specificare il parametro 'strategy'\")\n",
    "        model = DWTViT_pruning(\n",
    "            image_size=img_size,\n",
    "            patch_size=patch_size,\n",
    "            num_classes=10,\n",
    "            dim=dim,\n",
    "            depth=12,\n",
    "            heads=heads,\n",
    "            mlp_dim=dim*4,\n",
    "            dropout=0,\n",
    "            emb_dropout=0,\n",
    "            wavelet=wavelet,\n",
    "            pruning_locations=pruning_locations,\n",
    "            strategy=strategy\n",
    "        )\n",
    "    elif model_type == 'quantile':\n",
    "        model = DWTViT_quantile(\n",
    "            image_size=img_size,\n",
    "            patch_size=patch_size,\n",
    "            num_classes=10,\n",
    "            dim=dim,\n",
    "            depth=12,\n",
    "            heads=heads,\n",
    "            mlp_dim=dim*4,\n",
    "            dropout=0,\n",
    "            emb_dropout=0,\n",
    "            wavelet=wavelet,\n",
    "            pruning_locations=pruning_locations,\n",
    "        )\n",
    "    else:\n",
    "        raise ValueError(f\"model_type '{model_type}' non riconosciuto. Usa 'gini', 'pruning' o 'quantile'.\")\n",
    "\n",
    "    model.to(device)\n",
    "\n",
    "    optimizer = optim.Adam(model.parameters(), lr=learning_rate)\n",
    "    scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)\n",
    "    initial = time.time()\n",
    "    _, _, validation_acc, epoch_time = train_validation(\n",
    "        model, optimizer, train_loader, val_loader, models_save_path, nome_dataset, epoche, device=device, scheduler=scheduler\n",
    "    )\n",
    "    print(f'Total Time: {time.time() - initial}')\n",
    "\n",
    "    df = pd.DataFrame({\n",
    "        'validation_acc': [tensor.item() for tensor in validation_acc],\n",
    "        'epoch_time': epoch_time\n",
    "    })\n",
    "\n",
    "    df.to_csv(f'{results_save_path}{nome_dataset}.csv', index=False)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# RUN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# pruning_locations = [4, 6, 8]\n",
    "# epoche = 50\n",
    "# wavelet = 'haar'\n",
    "\n",
    "# ViT_DWT(nome_dataset = f\"imaginette_gini_DWTSmall75%\",pruning_locations = pruning_locations,dim=384,heads=6, wavelet=wavelet, epoche = epoche, model_type = gini))\n",
    "\n",
    "# ViT_DWT(nome_dataset = f\"imaginette_gini_DWTBase75%\",pruning_locations = pruning_locations,dim=768,heads=12,wavelet=wavelet, epoche = epoche, model_type = gini))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# for wavelet in ['haar', 'db2', 'db4', 'sym2']:\n",
    "\n",
    "#     pruning_locations = [4, 6, 8]\n",
    "#     epoche = 50\n",
    "    \n",
    "#     ViT_DWT(nome_dataset = f\"imaginette_{wavelet}_DWTSmall75%\",pruning_locations = pruning_locations,dim=384,heads=6, wavelet=wavelet, epoche = epoche)\n",
    "    \n",
    "#     ViT_DWT(nome_dataset = f\"imaginette_{wavelet}_DWTBase75%\",pruning_locations = pruning_locations,dim=768,heads=12,wavelet=wavelet, epoche = epoche)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# wavelet = 'haar'\n",
    "\n",
    "# pruning_locations = [4, 6, 8]\n",
    "# epoche = 50\n",
    "\n",
    "# for strategy in ['cA', 'cD']:\n",
    "#     ViT_DWT(nome_dataset = f\"imaginette_{strategy}_DWTSmall75%\",pruning_locations = pruning_locations,dim=384,heads=6, wavelet=wavelet, strategy=strategy, epoche = epoche)\n",
    "    \n",
    "#     ViT_DWT(nome_dataset = f\"imaginette_{strategy}_DWTBase75%\",pruning_locations = pruning_locations,dim=768,heads=12,wavelet=wavelet, epoche = epoche)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [
    "rXC93PvZHALw",
    "-WOxO6GNga2i",
    "xTio5oHEG9sv",
    "ng75uqlIgCvl",
    "SWAh2lXxHQhL",
    "QTJ-jgfrQ68o"
   ],
   "provenance": []
  },
  "gpuClass": "standard",
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
