{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# **SoftAdaClip**"
      ],
      "metadata": {
        "id": "uhEqFbU0ZHec"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "W3UwwkE2ZGQg"
      },
      "outputs": [],
      "source": [
        "from opacus.optimizers import AdaClipDPOptimizer, DPOptimizer\n",
        "from opacus.optimizers.optimizer import (\n",
        "    _check_processed_flag,\n",
        "    _generate_noise,\n",
        "    _mark_as_processed\n",
        ")\n",
        "\n",
        "class FairOptimizer(AdaClipDPOptimizer):\n",
        "    def __init__(self, *args, **kwargs):\n",
        "        super().__init__(*args,**kwargs)\n",
        "\n",
        "\n",
        "    def clip_and_accumulate(self):\n",
        "        per_param_norms = [\n",
        "            g.view(len(g), -1).norm(2, dim=-1) for g in self.grad_samples\n",
        "        ]\n",
        "        per_sample_norms = torch.stack(per_param_norms, dim=1).norm(2, dim=1)\n",
        "        per_sample_clip_factor = torch.tanh(self.max_grad_norm / (per_sample_norms + 1e-6))  #the only change to make it fair with tanh\n",
        "\n",
        "        # the two lines below are the only changes\n",
        "        # relative to the parent DPOptimizer class.\n",
        "        self.sample_size += len(per_sample_clip_factor)\n",
        "        self.unclipped_num += (\n",
        "            len(per_sample_clip_factor) - (per_sample_clip_factor < 1).sum()\n",
        "        )\n",
        "\n",
        "        for p in self.params:\n",
        "            _check_processed_flag(p.grad_sample)\n",
        "            grad_sample = self._get_flat_grad_sample(p)\n",
        "            grad = torch.einsum(\"i,i...\", per_sample_clip_factor, grad_sample)\n",
        "\n",
        "            if p.summed_grad is not None:\n",
        "                p.summed_grad += grad\n",
        "            else:\n",
        "                p.summed_grad = grad\n",
        "\n",
        "            _mark_as_processed(p.grad_sample)"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from opacus import PrivacyEngine\n",
        "import torch.optim as optim\n",
        "from typing import List, Union\n",
        "from opacus.optimizers import DPOptimizer\n",
        "\n",
        "class fairPrivacyEngine(PrivacyEngine):\n",
        "    def __init__(self, *args, **kwargs):\n",
        "        super().__init__(*args, **kwargs)\n",
        "\n",
        "    def _prepare_optimizer(\n",
        "        self,\n",
        "        *,\n",
        "        optimizer: optim.Optimizer,\n",
        "        noise_multiplier: float,\n",
        "        max_grad_norm: Union[float, List[float]],\n",
        "        expected_batch_size: int,\n",
        "        loss_reduction: str = \"mean\",\n",
        "        distributed: bool = False,\n",
        "        clipping: str = \"flat\",\n",
        "        noise_generator=None,\n",
        "        grad_sample_mode=\"hooks\",\n",
        "        **kwargs,\n",
        "    ) -> DPOptimizer:\n",
        "        if isinstance(optimizer, DPOptimizer):\n",
        "            optimizer = optimizer.original_optimizer\n",
        "\n",
        "        generator = None\n",
        "        if self.secure_mode:\n",
        "            generator = self.secure_rng\n",
        "        elif noise_generator is not None:\n",
        "            generator = noise_generator\n",
        "\n",
        "\n",
        "        new_fair = FairOptimizer(\n",
        "            optimizer=optimizer,\n",
        "            noise_multiplier=noise_multiplier,\n",
        "            max_grad_norm=max_grad_norm,\n",
        "            expected_batch_size=expected_batch_size,\n",
        "            loss_reduction=loss_reduction,\n",
        "            generator=generator,\n",
        "            secure_mode=self.secure_mode,\n",
        "            **kwargs,\n",
        "        )\n",
        "\n",
        "        return new_fair\n",
        ""
      ],
      "metadata": {
        "id": "esJwmID7ZK7E"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "from transformers import AutoTokenizer, AutoModelForSequenceClassification, get_linear_schedule_with_warmup\n",
        "from datasets import load_dataset\n",
        "import numpy as np\n",
        "import yaml\n",
        "from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score\n",
        "import wandb\n",
        "from tqdm.auto import tqdm\n",
        "from torch.utils.data import DataLoader\n",
        "from torch.utils.data.dataset import Dataset\n",
        "import torch.nn.utils as utils\n",
        "import opacus\n",
        "from opacus.optimizers import AdaClipDPOptimizer\n",
        "from opacus.utils.batch_memory_manager import BatchMemoryManager\n",
        "from opacus.accountants.utils import get_noise_multiplier\n",
        "from opacus.validators import ModuleValidator\n",
        "from opacus import GradSampleModule\n",
        "import os\n",
        "import random\n",
        "\n",
        "\n",
        "\n",
        "# ====================================================================================================\n",
        "# User-configurable settings\n",
        "# ====================================================================================================\n",
        "\n",
        "# Random seed (options tried: 0, 42, 1000, 1234, 2025)\n",
        "SEED = 42\n",
        "\n",
        "# Paths\n",
        "CACHE_DIR = \"//\"   # change to your cache dir\n",
        "DATA_DIR = \"//datasets/unpackedmimiciii/New2\"\n",
        "CHECKPOINT_DIR = f\"//Checkpoints/MimicLOS/NEWSEED{SEED}-20010CheckpointsFairAdaptiveoptimizer\"\n",
        "\n",
        "# Dataset files (relative to DATA_DIR)\n",
        "TRAIN_FILE = \"LOS_WEEKS_train.csv\"\n",
        "VAL_FILE   = \"LOS_WEEKS_val.csv\"\n",
        "TEST_FILE  = \"LOS_WEEKS_test.csv\"\n",
        "\n",
        "# Model\n",
        "MODEL_NAME_OR_PATH = \"medicalai/ClinicalBERT\"\n",
        "\n",
        "# Training hyperparameters\n",
        "LR = 1e-3\n",
        "EPOCHS = 200\n",
        "MICRO_BATCH_SIZE = 40      # \"physical batch size\"\n",
        "LOGICAL_BATCH_SIZE = 500   # \"effective batch size\"\n",
        "MAX_GRAD_NORM = 0.1\n",
        "EPSILON = 8\n",
        "\n",
        "# WandB\n",
        "WANDB_PROJECT = \"doc_classification_Seed\"\n",
        "RUN_NAME = f\"LosFairAdaptiveDP-Seed{SEED}-500epoch{EPOCHS}Epsilon{EPSILON}batch={MICRO_BATCH_SIZE}model={MODEL_NAME_OR_PATH}\"\n",
        "\n",
        "# Differential privacy\n",
        "DELTA = 1e-5\n",
        "\n",
        "\n",
        "\n",
        "def seed_everything(seed):\n",
        "    \"\"\"Set the seed for reproducibility across all libraries, including GPU operations.\"\"\"\n",
        "    random.seed(seed)\n",
        "    os.environ['PYTHONHASHSEED'] = str(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    torch.cuda.manual_seed(seed)\n",
        "    torch.cuda.manual_seed_all(seed)\n",
        "\n",
        "    # Ensure deterministic behavior in cuDNN\n",
        "    torch.backends.cudnn.deterministic = True\n",
        "    torch.backends.cudnn.benchmark = False  # Set to False for full determinism\n",
        "\n",
        "    print(f\"Global seed set to: {seed}\")\n",
        "\n",
        "seed = SEED\n",
        "seed_everything(seed)\n",
        "\n",
        "class TextDataset(Dataset):\n",
        "    def __init__(self, encodings, labels):\n",
        "        self.encodings = encodings\n",
        "        self.labels = labels\n",
        "\n",
        "    def __getitem__(self, idx):\n",
        "        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}\n",
        "        item['labels'] = torch.tensor(self.labels[idx])\n",
        "        return item\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.labels)\n",
        "\n",
        "class EarlyStopping:\n",
        "    def __init__(self, patience=3, threshold=0.0):\n",
        "        self.patience = patience\n",
        "        self.threshold = threshold\n",
        "        self.counter = 0\n",
        "        self.best_score = None\n",
        "        self.early_stop = False\n",
        "\n",
        "    def __call__(self, val_score):\n",
        "        if self.best_score is None:\n",
        "            self.best_score = val_score\n",
        "        elif val_score < self.best_score + self.threshold:\n",
        "            self.counter += 1\n",
        "            if self.counter >= self.patience:\n",
        "                self.early_stop = True\n",
        "        else:\n",
        "            self.best_score = val_score\n",
        "            self.counter = 0\n",
        "        return self.early_stop\n",
        "\n",
        "def compute_batch_metrics(logits, labels):\n",
        "    predictions = torch.argmax(logits, dim=-1).cpu().numpy()\n",
        "    labels = labels.cpu().numpy()\n",
        "    return {\n",
        "        \"accuracy\": accuracy_score(labels, predictions),\n",
        "        \"precision\": precision_score(labels, predictions, average='macro', zero_division=0),\n",
        "        \"recall\": recall_score(labels, predictions, average='macro', zero_division=0),\n",
        "        \"f1\": f1_score(labels, predictions, average='macro', zero_division=0)\n",
        "    }\n",
        "\n",
        "def compute_metrics(predictions, labels):\n",
        "    predictions = np.argmax(predictions, axis=-1)\n",
        "    return {\n",
        "        \"accuracy\": accuracy_score(labels, predictions),\n",
        "        \"precision\": precision_score(labels, predictions, average='macro'),\n",
        "        \"recall\": recall_score(labels, predictions, average='macro'),\n",
        "        \"f1\": f1_score(labels, predictions, average='macro')\n",
        "    }\n",
        "\n",
        "\n",
        "def doc_classification(model_name_or_path=MODEL_NAME_OR_PATH , cache_dir=CACHE_DIR , run_name=RUN_NAME , lr=LR, epochs=EPOCHS , batch_size= MICRO_BATCH_SIZE , max_grad_norm=MAX_GRAD_NORM , epsilon= EPSILON ,gradient_accumulation_steps=1):\n",
        "    # Initialize wandb\n",
        "    wandb.init(project=WANDB_PROJECT, name=RUN_NAME)\n",
        "    wandb.config.seed = seed\n",
        "    checkpoint_dir = CHECKPOINT_DIR\n",
        "    os.makedirs(checkpoint_dir, exist_ok=True)\n",
        "    #results_file = open(\"withoutAllTextMulti.txt\", \"w\")\n",
        "\n",
        "    # Load task config\n",
        "    with open(task_config) as file:\n",
        "        task_config = yaml.safe_load(file)\n",
        "\n",
        "    # Load dataset\n",
        "    data_dir = DATA_DIR\n",
        "    data_files = {\n",
        "        \"train\": f\"{data_dir}/{TRAIN_FILE}\",\n",
        "        \"validation\": f\"{data_dir}/{VAL_FILE}\",\n",
        "        \"test\": f\"{data_dir}/{TEST_FILE}\"\n",
        "    }\n",
        "    dataset = load_dataset('csv', data_files=data_files, cache_dir=cache_dir)\n",
        "\n",
        "    # Load tokenizer and model\n",
        "    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)\n",
        "    model = AutoModelForSequenceClassification.from_pretrained(\n",
        "        model_name_or_path,\n",
        "        num_labels = 4\n",
        "    )\n",
        "\n",
        "    model.to('cuda')\n",
        "\n",
        "    #Freeze layers\n",
        "    for name, parameter in model.named_parameters():\n",
        "        if \"classifier\" in name or \"layer.5.output_layer_norm\" in name or \"layer.5.ffn.lin2\" in name or \"layer.5.ffn.lin1\" in name or \"layer.5.sa_layer_norm\" in name or \"layer.5.attention.out_lin\" in name or \"layer.5.attention.v_lin\" in name or \"layer.5.attention.k_lin\" in name:\n",
        "            parameter.requires_grad = True\n",
        "        else:\n",
        "            parameter.requires_grad = False\n",
        "\n",
        "\n",
        "    # Process datasets\n",
        "    def process_dataset(dataset_split):\n",
        "        encodings = tokenizer(\n",
        "            dataset_split['text'],\n",
        "            truncation=True,\n",
        "            padding='max_length',\n",
        "            max_length=512,\n",
        "            return_tensors='pt'\n",
        "        )\n",
        "        labels = [int(label) for label in dataset_split['los_label']]\n",
        "        return TextDataset(encodings, labels)\n",
        "\n",
        "    # Create datasets\n",
        "    train_dataset = process_dataset(dataset['train'])\n",
        "    eval_dataset = process_dataset(dataset['validation'])\n",
        "\n",
        "    logical_batch_size = LOGICAL_BATCH_SIZE\n",
        "    micro_batch_size = batch_size\n",
        "\n",
        "    g = torch.Generator()\n",
        "    g.manual_seed(seed)\n",
        "\n",
        "    # Create DataLoaders\n",
        "    train_dataloader = DataLoader(\n",
        "        train_dataset,\n",
        "        batch_size=logical_batch_size, #batch_size,\n",
        "        shuffle=True,\n",
        "        drop_last=True,#,\n",
        "        worker_init_fn=lambda worker_id: torch.manual_seed(seed + worker_id),  # Fixes worker shuffling\n",
        "        generator=g  # Ensures deterministic shuffling\n",
        "    )\n",
        "    eval_dataloader = DataLoader(\n",
        "        eval_dataset,\n",
        "        batch_size=logical_batch_size #batch_size\n",
        "    )\n",
        "\n",
        "    # Create base optimizer\n",
        "    base_optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)\n",
        "\n",
        "    model.train()\n",
        "\n",
        "    # Opacus setup\n",
        "    sigma = get_noise_multiplier(\n",
        "        target_epsilon=epsilon,\n",
        "        target_delta=1e-5,\n",
        "        sample_rate=logical_batch_size / len(dataset['train']),\n",
        "        epochs=epochs\n",
        "    )\n",
        "    print(sigma)\n",
        "\n",
        "    #z = sigma\n",
        "    m = logical_batch_size  # users per round\n",
        "    sigma_b = m / 20        # from the paper\n",
        "\n",
        "\n",
        "    privacy_engine = fairPrivacyEngine()\n",
        "\n",
        "    model, optimizer, train_loader = privacy_engine.make_private(\n",
        "        module=model,\n",
        "        optimizer=base_optimizer,\n",
        "        clipping = 'adaptive',\n",
        "        data_loader=train_dataloader,\n",
        "        noise_multiplier=sigma,\n",
        "        max_grad_norm=max_grad_norm,\n",
        "        poisson_sampling=False, # poisson_sampling=True, If I was using drop_last = False in dataloader\n",
        "        loss_reduction=\"sum\", #sum\n",
        "        target_unclipped_quantile=0.5, # gamma = median\n",
        "        clipbound_learning_rate=0.2, # ηC\n",
        "        max_clipbound=2.0,  # upper limit on clipping norm\n",
        "        min_clipbound=0.1,  # lower limit\n",
        "        unclipped_num_std=sigma_b\n",
        "    )\n",
        "\n",
        "\n",
        "    print(f\"Optimizer type: {type(optimizer)}\")\n",
        "\n",
        "    num_training_steps = (len(train_dataloader) * epochs) * (logical_batch_size/micro_batch_size)\n",
        "    num_warmup_steps = 500 * (logical_batch_size/micro_batch_size)\n",
        "    scheduler = get_linear_schedule_with_warmup(\n",
        "        optimizer,\n",
        "        num_warmup_steps=num_warmup_steps,\n",
        "        num_training_steps=num_training_steps\n",
        "    )\n",
        "\n",
        "    # Early stopping setup\n",
        "    early_stopping = EarlyStopping(patience=200, threshold=0.0)\n",
        "    best_val_f1 = 0\n",
        "\n",
        "\n",
        "    # Training loop\n",
        "    grad_norms_before_clipping = []\n",
        "    grad_norms_after_clipping = []\n",
        "    grad_norms_after_clipping_noisy = []\n",
        "    all_noises_grad_sample = []\n",
        "    all_noises_grad = []\n",
        "\n",
        "    for epoch in range(epochs):\n",
        "        model.train()\n",
        "        total_loss = 0\n",
        "        all_train_predictions = []\n",
        "        all_train_labels = []\n",
        "        grad_list = []\n",
        "        all_logical_predictions = []\n",
        "        all_logical_labels = []\n",
        "        layer_wise_grad_samples = {}\n",
        "        cosine_similarities = []\n",
        "        angles = []\n",
        "\n",
        "\n",
        "        print_loss=0\n",
        "        total_micro_batches = 0\n",
        "        total_logical_batches = 0 #Number of logical_batch\n",
        "        progress_bar = tqdm(train_dataloader, desc=f\"Epoch {epoch + 1}\")\n",
        "\n",
        "        optimizer.zero_grad()\n",
        "        #for step, batch in enumerate(progress_bar):\n",
        "        with BatchMemoryManager(\n",
        "            data_loader=train_dataloader,\n",
        "            max_physical_batch_size=micro_batch_size,\n",
        "            optimizer=optimizer,\n",
        "        ) as memory_safe_data_loader:\n",
        "\n",
        "            for batch_idx, batch in enumerate(memory_safe_data_loader):\n",
        "\n",
        "                total_micro_batches += 1\n",
        "                # Move batch to GPU\n",
        "                batch = {k: v.to('cuda') for k, v in batch.items()}\n",
        "\n",
        "                # Forward pass\n",
        "                outputs = model(**batch)\n",
        "                logits = outputs.logits\n",
        "                labels = batch['labels']\n",
        "\n",
        "                loss_fn = torch.nn.CrossEntropyLoss(reduction=\"sum\")\n",
        "                loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))\n",
        "\n",
        "                total_loss += loss.item()\n",
        "                print_loss = loss.item()\n",
        "\n",
        "\n",
        "                # Compute and log metrics every batch\n",
        "                predictions = logits.detach().cpu().numpy()\n",
        "                labels_cpu = labels.detach().cpu().numpy()\n",
        "                all_logical_predictions.extend(predictions)\n",
        "                all_logical_labels.extend(labels_cpu)\n",
        "\n",
        "\n",
        "                # Backward pass\n",
        "                loss.backward()\n",
        "\n",
        "\n",
        "                for name, p in model.named_parameters():\n",
        "                    if p.requires_grad:\n",
        "                        with torch.no_grad():\n",
        "                            grad_sample = p.grad_sample.sum(dim = 0).clone().detach()  #birone if #sum accross samples\n",
        "                            if name not in layer_wise_grad_samples:\n",
        "                                layer_wise_grad_samples[name] = torch.zeros_like(grad_sample)  # Store per-layer name\n",
        "                            layer_wise_grad_samples[name] += grad_sample  # Accumulate across physical batches\n",
        "\n",
        "                #Calculating and storing gradient before clipping and adding noise and after clipping and adding noise\n",
        "\n",
        "                is_final_step = optimizer.pre_step()\n",
        "\n",
        "                if is_final_step == True:\n",
        "                    total_grad_norm_squared = 0.0\n",
        "                    total_grad_norm_squared_summed_noisy = 0.0\n",
        "                    total_noise_grad = 0.0\n",
        "                    all_layer_grads_before=[]\n",
        "                    all_layer_grads_after=[]\n",
        "\n",
        "                    num_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
        "                    expected_noise = (num_trainable_params ** 0.5) * sigma # * max_grad_norm\n",
        "\n",
        "\n",
        "                    for name, grad_list in layer_wise_grad_samples.items():\n",
        "                        all_layer_grads_before.append(grad_list)  # Store per-layer gradients\n",
        "                        ## Compute the L2 norm squared\n",
        "                        total_grad_norm_squared += torch.norm(grad_list, p=2) ** 2\n",
        "\n",
        "                    for p in model.parameters():\n",
        "                        if p.requires_grad:\n",
        "                             with torch.no_grad():\n",
        "                                    all_layer_grads_after.append(p.summed_grad.clone().detach())\n",
        "\n",
        "                                    layer_grad_norm_summed_noisy = torch.norm(p.summed_grad, p=2).item()  #torch.norm(p.summed_grad.view(-1), p=2).item()\n",
        "                                    total_grad_norm_squared_summed_noisy += layer_grad_norm_summed_noisy ** 2\n",
        "\n",
        "                                    grad = torch.norm(p.grad, p=2).item()\n",
        "                                    total_noise_grad += (grad - layer_grad_norm_summed_noisy) ** 2\n",
        "\n",
        "\n",
        "                    # Convert lists to single tensors\n",
        "                    total_grad_before = torch.cat([v.view(-1) for v in all_layer_grads_before])\n",
        "                    total_grad_after = torch.cat([v.view(-1) for v in all_layer_grads_after])  # Aggregate all layers\n",
        "\n",
        "                    # Compute Cosine Similarity\n",
        "                    cosine_similarity = torch.nn.functional.cosine_similarity(\n",
        "                        total_grad_before.view(1, -1),\n",
        "                        total_grad_after.view(1, -1),\n",
        "                        dim=1\n",
        "                    ).item()\n",
        "\n",
        "                    # Compute Angle in Degrees\n",
        "                    angle = torch.acos(torch.clamp(torch.tensor(cosine_similarity), -1.0, 1.0)).item() * (180 / torch.pi)\n",
        "\n",
        "                    # Store values for logging\n",
        "                    cosine_similarities.append(cosine_similarity)\n",
        "                    angles.append(angle)\n",
        "\n",
        "                    print(f\"Cosine Similarity: {cosine_similarity}\")\n",
        "                    print(f\"Angle in Degrees: {angle}\")\n",
        "\n",
        "                    # Compute the overall gradient norm\n",
        "                    overall_grad_norm = total_grad_norm_squared ** 0.5\n",
        "                    grad_norms_before_clipping.append(overall_grad_norm)\n",
        "                    overall_grad_norm_summed_noisy = total_grad_norm_squared_summed_noisy ** 0.5\n",
        "                    grad_norms_after_clipping_noisy.append(overall_grad_norm_summed_noisy)\n",
        "                    overal_noise_grad = total_noise_grad ** 0.5\n",
        "                    all_noises_grad.append(overal_noise_grad)\n",
        "\n",
        "\n",
        "                    print(f\"logical batch grad_norm before clipping: {overall_grad_norm}\")\n",
        "                    print(f\"logical batch grad norm after clipping with Noise: {overall_grad_norm_summed_noisy}\")\n",
        "                    print(f\"overal noise: grad - summed_grad: {overal_noise_grad}\")\n",
        "                    print(f\"summed_grad/noise grad: {overall_grad_norm_summed_noisy/overal_noise_grad}\")\n",
        "                    print(f\"Expected noise: {expected_noise}\")\n",
        "\n",
        "                    # Log gradient norms for the logical batch\n",
        "                    wandb.log({\n",
        "                        \"logical_batch_grad_norm_after_clipping_with_Noise\":overall_grad_norm_summed_noisy,\n",
        "                        \"logical_batch_grad_norm_before_clipping\":overall_grad_norm,\n",
        "                        \"overal noise grad - summed_grad\":overal_noise_grad,\n",
        "                        \"summed_grad/noise grad\":overall_grad_norm_summed_noisy/overal_noise_grad,\n",
        "                        \"Expected noise\" : expected_noise,\n",
        "                        \"norm after clipping / norm before clipping: \" :overall_grad_norm_summed_noisy/overall_grad_norm ,\n",
        "                        f\"Cosine Similarity\": cosine_similarity,\n",
        "                        f\"Angle in Degrees\": angle\n",
        "                    })\n",
        "\n",
        "                    # Compute metrics after all physical batches in the logical batch are processed\n",
        "                    train_metrics = compute_metrics(np.array(all_logical_predictions), np.array(all_logical_labels))\n",
        "                    avg_train_loss = total_loss / total_micro_batches\n",
        "\n",
        "                    # Log to wandb\n",
        "                    wandb.log({\n",
        "                        \"train_loss\": avg_train_loss,\n",
        "                        \"train_accuracy\": train_metrics['accuracy'],\n",
        "                        \"train_precision\": train_metrics['precision'],\n",
        "                        \"train_recall\": train_metrics['recall'],\n",
        "                        \"train_f1\": train_metrics['f1']})\n",
        "\n",
        "                    print(f\"train_loss: {avg_train_loss}\")\n",
        "                    print(f\"train_accuracy: {train_metrics['accuracy']}\")\n",
        "                    print(f\"train_precision: {train_metrics['precision']}\")\n",
        "                    print(f\"train_recall: { train_metrics['recall']}\")\n",
        "                    print(f\"train_f1: {train_metrics['f1']}\")\n",
        "\n",
        "                    optimizer.original_optimizer.step()\n",
        "                    scheduler.step()\n",
        "                    all_logical_labels = []\n",
        "                    all_logical_predictions = []\n",
        "                    grad_list = []\n",
        "                    layer_wise_grad_samples = {}\n",
        "\n",
        "                progress_bar.set_postfix({\n",
        "                #\"loss\": total_loss / len(memory_safe_data_loader),\n",
        "                #\"loss\": total_loss / (step + 1),\n",
        "                \"lr\": scheduler.get_last_lr()[0],\n",
        "                #\"lr\" : optimizer.param_groups[0]['lr'],\n",
        "                \"step_loss\": print_loss\n",
        "                })\n",
        "\n",
        "                optimizer.zero_grad()\n",
        "\n",
        "\n",
        "        # Compute epoch-level training metrics\n",
        "        #train_metrics = compute_metrics(np.array(all_train_predictions), np.array(all_train_labels))\n",
        "        #avg_train_loss = total_loss / total_micro_batches\n",
        "\n",
        "        # Save checkpoint every 50 epochs\n",
        "        if (epoch + 1) % 20 == 0:\n",
        "            checkpoint_path = os.path.join(checkpoint_dir, f\"checkpoint_epoch_{epoch+1}.pt\")\n",
        "            torch.save({\n",
        "                'epoch': epoch + 1,  # Save the current epoch\n",
        "                'model_state_dict': model.state_dict(),  # Save model parameters\n",
        "                'optimizer_state_dict': optimizer.state_dict(),  # Save optimizer state\n",
        "                'scheduler_state_dict': scheduler.state_dict(),  # Save scheduler state\n",
        "                'best_val_f1': best_val_f1,  # Save best validation F1 score\n",
        "                'sigma': sigma\n",
        "            }, checkpoint_path)\n",
        "            print(f\" Checkpoint saved at {checkpoint_path}\")\n",
        "\n",
        "            # Save tokenizer with epoch number\n",
        "            tokenizer_save_path = os.path.join(checkpoint_dir, f\"tokenizer_epoch_{epoch+1}\")\n",
        "            tokenizer.save_pretrained(tokenizer_save_path)  # Save tokenizer in a uniquely named folder\n",
        "\n",
        "        # Evaluation\n",
        "        model.eval()\n",
        "        eval_loss = 0\n",
        "        all_eval_predictions = []\n",
        "        all_eval_labels = []\n",
        "\n",
        "        with torch.no_grad():\n",
        "            for batch in eval_dataloader:\n",
        "                batch = {k: v.to('cuda') for k, v in batch.items()}\n",
        "                outputs = model(**batch)\n",
        "                loss_fn = torch.nn.CrossEntropyLoss(reduction=\"sum\")\n",
        "                loss = loss_fn(outputs.logits.view(-1, outputs.logits.size(-1)), batch['labels'].view(-1))\n",
        "                #loss=outputs.loss\n",
        "                eval_loss += loss.item()\n",
        "                predictions = outputs.logits.cpu().numpy()\n",
        "                labels = batch[\"labels\"].cpu().numpy()\n",
        "                all_eval_predictions.extend(predictions)\n",
        "                all_eval_labels.extend(labels)\n",
        "\n",
        "        eval_loss = eval_loss / len(eval_dataloader)\n",
        "        eval_metrics = compute_metrics(np.array(all_eval_predictions), np.array(all_eval_labels))\n",
        "\n",
        "        # Log metrics for the epoch\n",
        "        current_lr = scheduler.get_last_lr()[0]\n",
        "\n",
        "        # Print epoch metrics\n",
        "        print(f\"\\nEpoch {epoch + 1} Summary:\")\n",
        "        #print(f\"Training Loss: {avg_train_loss:.4f}\")\n",
        "        #print(f\"Training Accuracy: {train_metrics['accuracy']:.4f}\")\n",
        "        #print(f\"Training F1: {train_metrics['f1']:.4f}\")\n",
        "        print(f\"Validation Loss: {eval_loss:.4f}\")\n",
        "        print(f\"Validation Accuracy: {eval_metrics['accuracy']:.4f}\")\n",
        "        print(f\"Validation F1: {eval_metrics['f1']:.4f}\")\n",
        "        print(f\"Learning Rate: {current_lr:.2e}\")\n",
        "        print(f\"Sigma: {sigma}\")\n",
        "\n",
        "        # Log to wandb\n",
        "        wandb.log({\n",
        "            \"epoch\": epoch,\n",
        "            #\"train_loss\": avg_train_loss,\n",
        "            #\"train_accuracy\": train_metrics['accuracy'],\n",
        "            #\"train_precision\": train_metrics['precision'],\n",
        "            #\"train_recall\": train_metrics['recall'],\n",
        "            #\"train_f1\": train_metrics['f1'],\n",
        "            \"eval_loss\": eval_loss,\n",
        "            \"eval_accuracy\": eval_metrics['accuracy'],\n",
        "            \"eval_precision\": eval_metrics['precision'],\n",
        "            \"eval_recall\": eval_metrics['recall'],\n",
        "            \"eval_f1\": eval_metrics['f1'],\n",
        "            \"learning_rate\": current_lr,\n",
        "            \"sigma\":sigma\n",
        "        })\n",
        "\n",
        "        # Write to results file\n",
        "        # results_file.write(f\"\\nEpoch {epoch + 1} Summary:\\n\")\n",
        "        # #results_file.write(f\"Training Loss: {avg_train_loss:.4f}\\n\")\n",
        "        # #results_file.write(f\"Training Metrics: {train_metrics}\\n\")\n",
        "        # results_file.write(f\"Validation Loss: {eval_loss:.4f}\\n\")\n",
        "        # results_file.write(f\"Validation Metrics: {eval_metrics}\\n\")\n",
        "        # results_file.write(f\"Learning Rate: {current_lr:.2e}\\n\")\n",
        "\n",
        "\n",
        "        # Early stopping\n",
        "        #if early_stopping(eval_metrics[\"f1\"]):\n",
        "        #    print(\"Early stopping triggered\")\n",
        "        #    break\n",
        "\n",
        "    #results_file.close()\n",
        "    wandb.finish()\n",
        "\n",
        "if __name__ == '__main__':\n",
        "    import fire\n",
        "    fire.Fire(doc_classification)"
      ],
      "metadata": {
        "id": "AbjsMTJ5ZMUE"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}