{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7871ea96-fe2b-414a-bd41-4639af5fb515",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['TORCH_USE_CUDA_DSA'] = '1'\n",
    "os.environ['CUDA_LAUNCH_BLOCKING'] = '1'\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.optim import AdamW\n",
    "from torch.utils.data import DataLoader\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling, TrainerCallback\n",
    "from peft import get_peft_model, LoraConfig, TaskType\n",
    "from datasets import load_dataset\n",
    "\n",
    "import numpy as np\n",
    "import math\n",
    "from torch.nn import CrossEntropyLoss\n",
    "import csv\n",
    "import gc\n",
    "import matplotlib.pyplot as plt\n",
    "import pickle\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa0abc9e-5a32-4da9-9a51-60e4ae038765",
   "metadata": {},
   "outputs": [],
   "source": [
    "# For recording logs\n",
    "# ===========================================================\n",
    "class LogCallback(TrainerCallback):\n",
    "    def __init__(self):\n",
    "        self.epoch_logs = []\n",
    "        self.final_train_metrics = {}\n",
    "\n",
    "    def count_trainable_params(self, model):\n",
    "        return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "\n",
    "    def on_evaluate(self, args, state, control, metrics, **kwargs):\n",
    "        epoch = state.epoch if state.epoch is not None else 0\n",
    "        eval_loss = metrics.get(\"eval_loss\")\n",
    "        perplexity = round(np.exp(eval_loss), 4) if eval_loss is not None else None\n",
    "\n",
    "        train_loss = None\n",
    "        for log in reversed(state.log_history):\n",
    "            if \"loss\" in log:\n",
    "                train_loss = log[\"loss\"]\n",
    "                break\n",
    "                \n",
    "        model = kwargs.get(\"model\", None)\n",
    "        trainable_params = self.count_trainable_params(model) if model is not None else None\n",
    "\n",
    "        self.epoch_logs.append({\n",
    "            \"epoch\": epoch,\n",
    "            \"total_steps\": state.global_step,\n",
    "            \"train_loss\": train_loss,\n",
    "            \"eval_loss\": eval_loss,\n",
    "            \"perplexity\": perplexity,\n",
    "            \"trainable_params\": trainable_params,\n",
    "            \"avg_kept_energy\": None, \n",
    "            \"samples_per_second\": None,\n",
    "            \"steps_per_second\": None,\n",
    "            \"total_flos\": None,\n",
    "            \"train_runtime\": None,\n",
    "        })\n",
    "\n",
    "    def on_train_end(self, args, state, control, **kwargs):\n",
    "\n",
    "        for log in reversed(state.log_history):\n",
    "            if all(k in log for k in [\"train_samples_per_second\", \"train_steps_per_second\", \"total_flos\", \"train_runtime\"]):\n",
    "                self.final_train_metrics = {\n",
    "                    \"samples_per_second\": log[\"train_samples_per_second\"],\n",
    "                    \"steps_per_second\": log[\"train_steps_per_second\"],\n",
    "                    \"total_flos\": log[\"total_flos\"],\n",
    "                    \"train_runtime\": log[\"train_runtime\"]\n",
    "                }\n",
    "                break\n",
    "\n",
    "        if self.epoch_logs and self.final_train_metrics:\n",
    "            self.epoch_logs[-1].update(self.final_train_metrics)\n",
    "\n",
    "\n",
    "    def record_pruning_energy(self, constraints_callback):\n",
    "        avg_energy = getattr(constraints_callback, \"last_avg_kept_energy\", None)\n",
    "        if avg_energy is None:\n",
    "            print(\"[LogCallback] No avg_kept_energy found in constraints_callback.\")\n",
    "            return\n",
    "    \n",
    "        if self.epoch_logs:\n",
    "            self.epoch_logs[-1][\"avg_kept_energy\"] = avg_energy\n",
    "            print(f\"[LogCallback] Recorded avg_kept_energy = {avg_energy:.4f} to last epoch log.\")\n",
    "\n",
    "\n",
    "                \n",
    "    def save_to_csv(self, filepath=\"epoch_metrics.csv\"):\n",
    "        if not self.epoch_logs:\n",
    "            print(\"No logs to save.\")\n",
    "            return\n",
    "    \n",
    "        file_exists = os.path.isfile(filepath)\n",
    "        keys = self.epoch_logs[0].keys()\n",
    "    \n",
    "        with open(filepath, \"a\", newline=\"\") as f:\n",
    "            writer = csv.DictWriter(f, fieldnames=keys)\n",
    "    \n",
    "            if not file_exists:\n",
    "                writer.writeheader()\n",
    "    \n",
    "            writer.writerows(self.epoch_logs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c597169-d7fc-4b84-8094-a423711ad157",
   "metadata": {},
   "outputs": [],
   "source": [
    "# For generate text\n",
    "# ===========================================================\n",
    "def generate_text(model, tokenizer, prompt, max_new_tokens=20):\n",
    "    model.eval()\n",
    "    inputs = tokenizer(prompt, return_tensors=\"pt\").to(model.device)\n",
    "    with torch.no_grad():\n",
    "        outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)\n",
    "    print(tokenizer.decode(outputs[0], skip_special_tokens=True))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04a6d768-c9a0-45e2-9517-a9471104730c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the dataset and tokenizer\n",
    "# ===========================================================\n",
    "dataset = load_dataset(\"wikitext\", \"wikitext-2-raw-v1\")\n",
    "train_data = dataset[\"train\"]\n",
    "val_data = dataset[\"validation\"]\n",
    "\n",
    "model_dir = './orig_models/llama2-7b-hf' \n",
    "tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast = False)\n",
    "tokenizer.pad_token = tokenizer.eos_token  # Set tokenizer padding\n",
    "\n",
    "# Map token\n",
    "def tokenize_fn(examples):\n",
    "    tokenized = tokenizer(examples[\"text\"], truncation = True, max_length = 512)\n",
    "    return tokenized\n",
    "\n",
    "tokenized_train = train_data.map(tokenize_fn, batched = True, remove_columns = [\"text\"])\n",
    "tokenized_val = val_data.map(tokenize_fn, batched = True, remove_columns = [\"text\"])\n",
    "print(\"Train / Validation mapping: Done\\n\")\n",
    "\n",
    "debug_usage = False # Only use a very small part of data for debugging \n",
    "if debug_usage is True:\n",
    "    print(\"Debugging mode: Only small part of data is loaded.\")\n",
    "    tokenized_train = tokenized_train.select(range(320))  # First x samples\n",
    "    tokenized_val = tokenized_val.select(range(64))      # First x samples\n",
    "\n",
    "# Drop input_ids with length ≤ 1\n",
    "def filter_short(example):\n",
    "    return len(example[\"input_ids\"]) > 1\n",
    "print(f\"Before dropping: Train ({len(tokenized_train)}) / Val ({len(tokenized_val)})\")\n",
    "tokenized_train = tokenized_train.filter(filter_short)\n",
    "tokenized_val = tokenized_val.filter(filter_short)\n",
    "print(f\"After dropping: Train ({len(tokenized_train)}) / Val ({len(tokenized_val)})\")\n",
    "print(tokenized_train[0]) #Example of the train token"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "abf8950f-91c6-415b-a664-2a10ff4b8930",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Customise the UDV layer for Q, K, V and MLP \n",
    "# ===========================================================\n",
    "class Outer_LoRA(nn.Module):\n",
    "    def __init__(self, in_dim, out_dim, udv_rank, alpha_scaling):\n",
    "        super().__init__()     \n",
    "        self.in_dim = in_dim\n",
    "        self.out_dim = out_dim\n",
    "        self.udv_rank = udv_rank\n",
    "        self.alpha_scaling = alpha_scaling\n",
    "\n",
    "        # self.W0 = nn.Parameter(torch.empty(self.in_dim, self.out_dim, dtype=torch.float16), requires_grad = False)\n",
    "        self.register_buffer(\"W0\", torch.empty(self.in_dim, self.out_dim, dtype=torch.float16), persistent = True) # Not trainable but persistent available\n",
    "        \n",
    "        self.U = nn.Parameter(torch.empty(self.in_dim, self.udv_rank, dtype=torch.float16), requires_grad = True)\n",
    "        self.D = nn.Parameter(torch.empty(1, self.udv_rank, dtype=torch.float16), requires_grad = False)\n",
    "        self.V = nn.Parameter(torch.empty(self.udv_rank, self.out_dim, dtype=torch.float16), requires_grad = True)\n",
    "        with torch.no_grad():\n",
    "            self.init_parameters()\n",
    "\n",
    "        self.scaling = self.alpha_scaling / self.udv_rank if self.alpha_scaling != 0 else 1.0\n",
    "\n",
    "    def init_parameters(self):\n",
    "        nn.init.normal_(self.U, std = 1e-2)\n",
    "        nn.init.constant_(self.D, 1.0)\n",
    "        nn.init.zeros_(self.V)\n",
    "\n",
    "    def set_weight(self, w): \n",
    "        if w.shape != self.W0.shape:\n",
    "            raise ValueError(f\"Shape mismatch: Expected {self.W0.shape}, got {w.shape}\")\n",
    "        self.W0.copy_(w.to(self.W0.dtype))\n",
    "\n",
    "     \n",
    "    def forward(self, x):\n",
    "        orig_out = torch.matmul(x, self.W0)        \n",
    "        x = torch.matmul(x, self.U)               # [B, *, udv_rank]  \n",
    "        x = torch.mul(x, self.D.view(1, 1, -1))   # [B, *, udv_rank]     \n",
    "        x = torch.matmul(x, self.V)               # [B, *, out_dim]\n",
    "\n",
    "        return orig_out + x * self.scaling\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2687ee59-f8c9-4791-83e5-e9183a83aee5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# (Batch) Replace linear layers by Outer_LoRA\n",
    "# ===========================================================\n",
    "def uv_replacement(model, target_keys, udv_rank, alpha_scaling):\n",
    "    replace_counter = 0\n",
    "    for name, module in model.named_modules():\n",
    "        for key in target_keys:\n",
    "            if name.endswith(key):\n",
    "                path = name.split(\".\")\n",
    "                parent = model\n",
    "                for p in path[:-1]:\n",
    "                    if hasattr(parent, p):\n",
    "                        parent = getattr(parent, p)\n",
    "                    elif p.isdigit():\n",
    "                        parent = parent[int(p)]\n",
    "                    else:\n",
    "                        raise RuntimeError(f\"Cannot resolve submodule path: {'.'.join(path)}\")\n",
    "\n",
    "                old_linear = getattr(parent, path[-1])\n",
    "                if not isinstance(old_linear, nn.Linear):\n",
    "                    continue\n",
    "\n",
    "                new_layer = Outer_LoRA(old_linear.in_features, old_linear.out_features, udv_rank, alpha_scaling)\n",
    "                with torch.no_grad():\n",
    "                    new_layer.set_weight(old_linear.weight.detach().clone().T)\n",
    "                setattr(parent, path[-1], new_layer)\n",
    "                \n",
    "                assert torch.allclose(old_linear.weight.detach().T.to(device), new_layer.W0.detach().half().to(device), atol=1e-5)\n",
    "                \n",
    "                replace_counter = replace_counter + 1\n",
    "                \n",
    "    print(f\"Replaced {replace_counter} layers in total.\")\n",
    "    return model\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ebc51b2-3ddd-42f7-83d1-3c1c160a9065",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Share the callback function with UDV method\n",
    "# ===========================================================        \n",
    "class Constraints_CallBack(TrainerCallback):\n",
    "    def __init__(self, uv_norm_limit, d_lowerbound, d_boundto, pruning_energy_threshold):\n",
    "        \n",
    "        self.uv_param_groups = []  \n",
    "        self.original_params = {}  \n",
    "        self.svd_record = {}        \n",
    "        self.pruning_energy_threshold = pruning_energy_threshold\n",
    "        self.initialized = False\n",
    "\n",
    "    def _find_uv_groups(self,model):\n",
    "        uv_groups = []\n",
    "        for module_name, module in model.named_modules():\n",
    "            if all(hasattr(module, name) for name in ['U', 'D', 'V']):\n",
    "                U, D, V = module.U, module.D, module.V\n",
    "                if all(isinstance(p, torch.nn.Parameter) for p in [U, D, V]):\n",
    "                    uv_groups.append((module_name, {'U': U, 'D': D, 'V': V}))\n",
    "        return uv_groups\n",
    "    \n",
    "    def on_step_end(self, args, state, control, **kwargs):\n",
    "        model = kwargs[\"model\"]\n",
    "\n",
    "        if not self.initialized:\n",
    "            self.uv_param_groups = self._find_uv_groups(model)\n",
    "            self.initialized = True\n",
    "\n",
    "        if state.global_step % args.gradient_accumulation_steps == 0:\n",
    "            self.record_svd(step = state.global_step)\n",
    "\n",
    "        return control\n",
    "\n",
    "    def record_svd(self, step):\n",
    "        with torch.no_grad():\n",
    "            for layer_name, udv in self.uv_param_groups:\n",
    "                try:\n",
    "                    uv_matrix = torch.mul(udv['U'], udv['D']).float() \n",
    "                    U_svd, S_svd, Vh_svd = torch.linalg.svd(uv_matrix, full_matrices = False)\n",
    "\n",
    "                    if layer_name not in self.svd_record:\n",
    "                        self.svd_record[layer_name] = []\n",
    "\n",
    "                    self.svd_record[layer_name].append({\"step\": step,\n",
    "                                                        \"svd\": S_svd.cpu().numpy()\n",
    "                                                       })\n",
    "                except Exception as e:\n",
    "                    print(f\"[Warning] SVD failed on layer {layer_name}: {e}\")\n",
    "\n",
    "\n",
    "    def replace_para(self, optimizer, layer, name, new_param, transfer_state = False):\n",
    "        old_param = getattr(layer, name)\n",
    "        new_param = nn.Parameter(new_param)\n",
    "        setattr(layer, name, new_param)\n",
    "\n",
    "        for group in optimizer.param_groups:\n",
    "            for i, p in enumerate(group['params']):\n",
    "                if p is old_param:\n",
    "                    group['params'][i] = new_param\n",
    "        if old_param in optimizer.state:\n",
    "            if transfer_state:\n",
    "                optimizer.state[new_param] = optimizer.state.pop(old_param)\n",
    "            else:\n",
    "                del optimizer.state[old_param]\n",
    "\n",
    "    def auto_svd_pruning(self, model, optimizer, energy_threshold):\n",
    "        energy_threshold = energy_threshold or self.pruning_energy_threshold\n",
    "        kept_energies, pruned_layers, retained_layers = [], 0, 0\n",
    "        svd_record = {}\n",
    "\n",
    "        with torch.no_grad():\n",
    "            for layer_name, udv in self.uv_param_groups:\n",
    "                layer = dict(model.named_modules())[layer_name]\n",
    "                try:\n",
    "                    uv_matrix = torch.mul(udv['U'], udv['D']).float()\n",
    "                    U_svd, S_svd, Vh_svd = torch.linalg.svd(uv_matrix, full_matrices = False)\n",
    "                except Exception as e:\n",
    "                    print(f\"[Prune] Skipped {layer_name}: {e}\")\n",
    "                    continue\n",
    "\n",
    "                svd_record[layer_name] = S_svd.cpu().numpy()\n",
    "                energy = torch.cumsum(S_svd ** 2, dim=0)\n",
    "                total_energy = energy[-1]\n",
    "                keep_rank = torch.searchsorted(energy, energy_threshold * total_energy).item() + 1\n",
    "                kept_energy = energy[keep_rank - 1] / total_energy\n",
    "                kept_energies.append(kept_energy.item())\n",
    "                pruned_layers += 1\n",
    "\n",
    "                if keep_rank == S_svd.numel():\n",
    "                    retained_layers += 1\n",
    "                    continue\n",
    "\n",
    "                self.original_params[layer_name] = {\"U\": udv['U'].detach().cpu(),\n",
    "                                                    \"D\": udv['D'].detach().cpu(),\n",
    "                                                    \"V\": udv['V'].detach().cpu()}\n",
    "                \n",
    "                self.replace_para(optimizer, layer, 'U', U_svd[:, :keep_rank].half())\n",
    "                self.replace_para(optimizer, layer, 'D', S_svd[:keep_rank].unsqueeze(0).half())\n",
    "                self.replace_para(optimizer, layer, 'V', torch.matmul(Vh_svd[:keep_rank, :], udv['V'].float()).half())\n",
    "        avg_kept_energy = sum(kept_energies) / pruned_layers if pruned_layers > 0 else None\n",
    "        print(f\"[UDV Prune] Avg Kept Energy: {avg_kept_energy * 100:.4f}% over \"\n",
    "              f\"{pruned_layers} layers. {retained_layers} layers retained full rank.\")\n",
    "\n",
    "        return svd_record, avg_kept_energy\n",
    "\n",
    "    def recover_pruning(self, model, optimizer):\n",
    "        if not hasattr(self, \"original_params\"):\n",
    "            print(\"[Recover] No saved parameters found.\")\n",
    "            return\n",
    "    \n",
    "        print(f\"[Recover] Restoring {len(self.original_params)} layers to pre-pruned state...\")\n",
    "        for layer_name, params in self.original_params.items():\n",
    "            layer = dict(model.named_modules())[layer_name]\n",
    "            self.replace_para(optimizer, layer, \"U\", params[\"U\"].to(layer.U.device), transfer_state=False)\n",
    "            self.replace_para(optimizer, layer, \"D\", params[\"D\"].to(layer.D.device), transfer_state=False)\n",
    "            self.replace_para(optimizer, layer, \"V\", params[\"V\"].to(layer.V.device), transfer_state=False)\n",
    "    \n",
    "        print(\"[Recover] All pruned layers restored.\")\n",
    "\n",
    "\n",
    "    def set_tokenizer(self, tokenizer):\n",
    "        self.tokenizer = tokenizer\n",
    "\n",
    "    def on_train_end(self, args, state, control, **kwargs):\n",
    "        print(\"Pruning is not applied here. Save fine-tuned model first then test pruning.\")\n",
    "        \n",
    "    \n",
    "    def post_pruning(self, model, optimizer, energy_threshold):\n",
    "        print(f\"[Post Pruning] Starting manual pruning with energy threshold = {energy_threshold}...\")\n",
    "        self.pruning_record, self.last_avg_kept_energy = self.auto_svd_pruning(model, optimizer, energy_threshold)\n",
    "        print(f\"[Post Pruning] Completed. Avg kept energy: {self.last_avg_kept_energy:.4f}\")\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6591820e-3471-4a90-aca0-6d7973e984bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "udv_lora_r = 64\n",
    "udv_lora_learning_rate = 1e-3\n",
    "udv_lora_num_train_epochs = 30\n",
    "udv_lora_output_path = f\"./lora_r_{udv_lora_r}_SGD_{udv_lora_learning_rate}\"\n",
    "csv_path = udv_lora_output_path + \"/epoch_metrics.csv\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19447abe-e197-4a15-8b63-091a13ddeaab",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Re-load pre-trained model and replace linear layers by Outer_LoRA\n",
    "# ===========================================================\n",
    "uv_lora_model = AutoModelForCausalLM.from_pretrained(model_dir,\n",
    "                                                      device_map = \"auto\",\n",
    "                                                      torch_dtype = torch.float16)\n",
    "\n",
    "for param in uv_lora_model.parameters():\n",
    "    param.requires_grad = False\n",
    "\n",
    "TARGET_KEYS = [\"self_attn.q_proj\",\n",
    "               # \"self_attn.k_proj\",\n",
    "               \"self_attn.v_proj\",\n",
    "               # \"self_attn.o_proj\",\n",
    "               # \"mlp.up_proj\",\n",
    "               # \"mlp.gate_proj\",\n",
    "               # \"mlp.down_proj\",\n",
    "               # \"lm_head\"\n",
    "              ]\n",
    "\n",
    "# Replace layers and send to device\n",
    "uv_lora_model = uv_replacement(model = uv_lora_model,\n",
    "                                 target_keys = TARGET_KEYS,\n",
    "                                 udv_rank = udv_lora_r,\n",
    "                                 alpha_scaling = 0)\n",
    "uv_lora_model.to(device)\n",
    "\n",
    "# ===========================================================\n",
    "\n",
    "\n",
    "data_collator = DataCollatorForLanguageModeling(tokenizer, mlm = False)\n",
    "udv_training_args = TrainingArguments(output_dir = udv_lora_output_path,\n",
    "                                      per_device_train_batch_size = 8, \n",
    "                                      per_device_eval_batch_size = 8, \n",
    "                                      gradient_accumulation_steps = 4,\n",
    "                                      fp16 = False, \n",
    "                                      logging_nan_inf_filter = True, \n",
    "                                      seed = 42,\n",
    "                                      data_seed = 42,\n",
    "\n",
    "                                      optim = 'sgd',\n",
    "                                      learning_rate = udv_lora_learning_rate, \n",
    "                                      weight_decay = 0, \n",
    "                                      adam_beta1 = 0.9,\n",
    "                                      adam_beta2 = 0.999, \n",
    "                                      adam_epsilon = 1e-8,\n",
    "                                      max_grad_norm = 0, \n",
    "                                      lr_scheduler_type = \"cosine\",\n",
    "    \n",
    "                                      num_train_epochs = udv_lora_num_train_epochs, \n",
    "                                      eval_strategy = \"epoch\", \n",
    "\n",
    "                                      save_strategy = \"no\",\n",
    "                                      save_total_limit = 1, \n",
    "                                      load_best_model_at_end = False,\n",
    "                                      metric_for_best_model = \"eval_loss\",\n",
    "                                      greater_is_better = False,\n",
    "\n",
    "                                      logging_strategy = \"epoch\",\n",
    "                                      report_to = [\"none\"], \n",
    "    \n",
    "                                      label_names = [\"labels\"])\n",
    "\n",
    "# ===========================================================\n",
    "log_callback = LogCallback()\n",
    "constraint_callback = Constraints_CallBack(uv_norm_limit = 1,\n",
    "                                           d_lowerbound = 0,\n",
    "                                           d_boundto = 0,\n",
    "                                           pruning_energy_threshold = 1)\n",
    "constraint_callback.set_tokenizer(tokenizer)\n",
    "\n",
    "\n",
    "udv_trainer = Trainer(model = uv_lora_model,\n",
    "                      args = udv_training_args,\n",
    "                      train_dataset = tokenized_train,\n",
    "                      eval_dataset = tokenized_val,\n",
    "                      data_collator = data_collator,\n",
    "                      callbacks = [log_callback, constraint_callback]\n",
    "                     )\n",
    "\n",
    "uv_lora_output = udv_trainer.evaluate() # Inference once to create udv_lora baseline\n",
    "\n",
    "# Example of generating text (Before fine-tuning)\n",
    "generate_text(uv_lora_model, tokenizer, \"What is the capital of France?\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a0cbd2d-f667-4294-b826-ac4ff1d61e44",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluation: Fine-tuning with UDV-LoRA\n",
    "# ===========================================================\n",
    "uv_lora_output = udv_trainer.train() # Fine-tuning with inference\n",
    "\n",
    "log_callback.save_to_csv(csv_path)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3516b6b3-4bdd-4183-8585-66c3db007a48",
   "metadata": {},
   "outputs": [],
   "source": [
    "energy_list = [round(x, 1) for x in torch.arange(1.0, 0.0, -0.05).tolist()]\n",
    "\n",
    "for energy in energy_list:\n",
    "    constraint_callback.recover_pruning(model=udv_trainer.model,\n",
    "                                        optimizer=udv_trainer.optimizer)\n",
    "\n",
    "    constraint_callback.post_pruning(model=udv_trainer.model,\n",
    "                                     optimizer=udv_trainer.optimizer, \n",
    "                                     energy_threshold=energy)\n",
    "\n",
    "    after_pruning_log = LogCallback()\n",
    "    \n",
    "    uv_lora_output = udv_trainer.evaluate()\n",
    "    after_pruning_log.on_evaluate(\n",
    "        args=udv_trainer.args,\n",
    "        state=udv_trainer.state,\n",
    "        control=udv_trainer.control,\n",
    "        metrics=uv_lora_output,\n",
    "        model=udv_trainer.model\n",
    "    )\n",
    "    \n",
    "    after_pruning_log.record_pruning_energy(constraint_callback)\n",
    "    after_pruning_log.save_to_csv(csv_path)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a472110-ff06-4458-b7f9-5aebc530bc40",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(udv_lora_output_path + \"/svd_dict.pkl\", \"wb\") as f:\n",
    "    pickle.dump(constraint_callback.svd_record, f)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7e8b3c4-635e-4c79-b8a4-d8ab014e2de3",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"all done\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
