{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7871ea96-fe2b-414a-bd41-4639af5fb515",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['TORCH_USE_CUDA_DSA'] = '1'\n",
    "os.environ['CUDA_LAUNCH_BLOCKING'] = '1'\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.optim import AdamW\n",
    "from torch.utils.data import DataLoader\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling, TrainerCallback\n",
    "from peft import get_peft_model, LoraConfig, TaskType\n",
    "from datasets import load_dataset\n",
    "\n",
    "import numpy as np\n",
    "import math\n",
    "from torch.nn import CrossEntropyLoss\n",
    "import csv\n",
    "import gc\n",
    "import matplotlib.pyplot as plt\n",
    "import pickle\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa0abc9e-5a32-4da9-9a51-60e4ae038765",
   "metadata": {},
   "outputs": [],
   "source": [
    "# For recording logs\n",
    "# ===========================================================\n",
    "class LogCallback(TrainerCallback):\n",
    "    def __init__(self):\n",
    "        self.epoch_logs = []\n",
    "        self.final_train_metrics = {}\n",
    "\n",
    "    def count_trainable_params(self, model):\n",
    "        return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "\n",
    "    def on_evaluate(self, args, state, control, metrics, **kwargs):\n",
    "        epoch = state.epoch if state.epoch is not None else 0\n",
    "        eval_loss = metrics.get(\"eval_loss\")\n",
    "        perplexity = round(np.exp(eval_loss), 4) if eval_loss is not None else None\n",
    "\n",
    "        train_loss = None\n",
    "        for log in reversed(state.log_history):\n",
    "            if \"loss\" in log:\n",
    "                train_loss = log[\"loss\"]\n",
    "                break\n",
    "                \n",
    "        model = kwargs.get(\"model\", None)\n",
    "        trainable_params = self.count_trainable_params(model) if model is not None else None\n",
    "\n",
    "        self.epoch_logs.append({\n",
    "            \"epoch\": epoch,\n",
    "            \"total_steps\": state.global_step,\n",
    "            \"train_loss\": train_loss,\n",
    "            \"eval_loss\": eval_loss,\n",
    "            \"perplexity\": perplexity,\n",
    "            \"trainable_params\": trainable_params,\n",
    "            \"avg_kept_energy\": None, \n",
    "            \"samples_per_second\": None,\n",
    "            \"steps_per_second\": None,\n",
    "            \"total_flos\": None,\n",
    "            \"train_runtime\": None,\n",
    "        })\n",
    "\n",
    "    def on_train_end(self, args, state, control, **kwargs):\n",
    "        for log in reversed(state.log_history):\n",
    "            if all(k in log for k in [\"train_samples_per_second\", \"train_steps_per_second\", \"total_flos\", \"train_runtime\"]):\n",
    "                self.final_train_metrics = {\n",
    "                    \"samples_per_second\": log[\"train_samples_per_second\"],\n",
    "                    \"steps_per_second\": log[\"train_steps_per_second\"],\n",
    "                    \"total_flos\": log[\"total_flos\"],\n",
    "                    \"train_runtime\": log[\"train_runtime\"]\n",
    "                }\n",
    "                break\n",
    "\n",
    "        if self.epoch_logs and self.final_train_metrics:\n",
    "            self.epoch_logs[-1].update(self.final_train_metrics)\n",
    "\n",
    "\n",
    "    def record_pruning_energy(self, constraints_callback):\n",
    "        avg_energy = getattr(constraints_callback, \"last_avg_kept_energy\", None)\n",
    "        if avg_energy is None:\n",
    "            print(\"[LogCallback] No avg_kept_energy found in constraints_callback.\")\n",
    "            return\n",
    "    \n",
    "        if self.epoch_logs:\n",
    "            self.epoch_logs[-1][\"avg_kept_energy\"] = avg_energy\n",
    "            print(f\"[LogCallback] Recorded avg_kept_energy = {avg_energy:.4f} to last epoch log.\")\n",
    "\n",
    "\n",
    "                \n",
    "    def save_to_csv(self, filepath=\"epoch_metrics.csv\"):\n",
    "        if not self.epoch_logs:\n",
    "            print(\"No logs to save.\")\n",
    "            return\n",
    "    \n",
    "        file_exists = os.path.isfile(filepath)\n",
    "        keys = self.epoch_logs[0].keys()\n",
    "    \n",
    "        with open(filepath, \"a\", newline=\"\") as f:\n",
    "            writer = csv.DictWriter(f, fieldnames=keys)\n",
    "    \n",
    "            if not file_exists:\n",
    "                writer.writeheader()  \n",
    "    \n",
    "            writer.writerows(self.epoch_logs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c597169-d7fc-4b84-8094-a423711ad157",
   "metadata": {},
   "outputs": [],
   "source": [
    "# For generate text\n",
    "# ===========================================================\n",
    "def generate_text(model, tokenizer, prompt, max_new_tokens=20):\n",
    "    model.eval()\n",
    "    inputs = tokenizer(prompt, return_tensors=\"pt\").to(model.device)\n",
    "    with torch.no_grad():\n",
    "        outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)\n",
    "    print(tokenizer.decode(outputs[0], skip_special_tokens=True))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04a6d768-c9a0-45e2-9517-a9471104730c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the dataset and tokenizer\n",
    "# ===========================================================\n",
    "dataset = load_dataset(\"wikitext\", \"wikitext-2-raw-v1\")\n",
    "train_data = dataset[\"train\"]\n",
    "val_data = dataset[\"validation\"]\n",
    "\n",
    "model_dir = './orig_models/llama2-7b-hf' \n",
    "tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast = False)\n",
    "tokenizer.pad_token = tokenizer.eos_token  # Set tokenizer padding\n",
    "\n",
    "# Map token\n",
    "def tokenize_fn(examples):\n",
    "    tokenized = tokenizer(examples[\"text\"], truncation = True, max_length = 512)\n",
    "    return tokenized\n",
    "\n",
    "tokenized_train = train_data.map(tokenize_fn, batched = True, remove_columns = [\"text\"])\n",
    "tokenized_val = val_data.map(tokenize_fn, batched = True, remove_columns = [\"text\"])\n",
    "print(\"Train / Validation mapping: Done\\n\")\n",
    "\n",
    "debug_usage = False # Only use a very small part of data for debugging \n",
    "if debug_usage is True:\n",
    "    print(\"Debugging mode: Only small part of data is loaded.\")\n",
    "    tokenized_train = tokenized_train.select(range(320))  # First x samples\n",
    "    tokenized_val = tokenized_val.select(range(64))      # First x samples\n",
    "\n",
    "# Drop input_ids with length ≤ 1\n",
    "def filter_short(example):\n",
    "    return len(example[\"input_ids\"]) > 1\n",
    "print(f\"Before dropping: Train ({len(tokenized_train)}) / Val ({len(tokenized_val)})\")\n",
    "tokenized_train = tokenized_train.filter(filter_short)\n",
    "tokenized_val = tokenized_val.filter(filter_short)\n",
    "print(f\"After dropping: Train ({len(tokenized_train)}) / Val ({len(tokenized_val)})\")\n",
    "print(tokenized_train[0]) #Example of the train token"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "abf8950f-91c6-415b-a664-2a10ff4b8930",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Customise the UDV layer for Q, K, V and MLP \n",
    "# ===========================================================\n",
    "class UDV_LoRA(nn.Module):\n",
    "    def __init__(self, in_dim, out_dim, udv_rank, alpha_scaling):\n",
    "        super().__init__()     \n",
    "        self.in_dim = in_dim\n",
    "        self.out_dim = out_dim\n",
    "        self.udv_rank = udv_rank\n",
    "        self.alpha_scaling = alpha_scaling\n",
    "\n",
    "        # self.W0 = nn.Parameter(torch.empty(self.in_dim, self.out_dim, dtype=torch.float16), requires_grad = False)\n",
    "        self.register_buffer(\"W0\", torch.empty(self.in_dim, self.out_dim, dtype=torch.float16), persistent = True) # Not trainable but persistent available\n",
    "        \n",
    "        self.U = nn.Parameter(torch.empty(self.in_dim, self.udv_rank, dtype=torch.float16), requires_grad = True)\n",
    "        self.D = nn.Parameter(torch.empty(1, self.udv_rank, dtype=torch.float16), requires_grad = True)\n",
    "        self.Vt = nn.Parameter(torch.empty(self.udv_rank, self.out_dim, dtype=torch.float16), requires_grad = True)\n",
    "        with torch.no_grad():\n",
    "            self.init_parameters()\n",
    "\n",
    "        self.scaling = self.alpha_scaling / self.udv_rank if self.alpha_scaling != 0 else 1.0\n",
    "\n",
    "    def init_parameters(self):\n",
    "        nn.init.normal_(self.U, std = 1e-2)\n",
    "        nn.init.constant_(self.D, 1.0)\n",
    "        nn.init.zeros_(self.Vt)\n",
    "\n",
    "    def set_weight(self, w):\n",
    "        if w.shape != self.W0.shape:\n",
    "            raise ValueError(f\"Shape mismatch: Expected {self.W0.shape}, got {w.shape}\")\n",
    "        self.W0.copy_(w.to(self.W0.dtype))\n",
    "\n",
    "     \n",
    "    def forward(self, x):\n",
    "        orig_out = torch.matmul(x, self.W0)\n",
    "        x = torch.matmul(x, self.U)               # [B, *, udv_rank]\n",
    "        x = torch.mul(x, self.D.view(1, 1, -1))   # [B, *, udv_rank]       \n",
    "        x = torch.matmul(x, self.Vt)              # [B, *, out_dim]\n",
    "\n",
    "        return orig_out + x * self.scaling\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2687ee59-f8c9-4791-83e5-e9183a83aee5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# (Batch) Replace linear layers by UDV_LoRA\n",
    "# ===========================================================\n",
    "def udv_replacement(model, target_keys, udv_rank, alpha_scaling):\n",
    "    replace_counter = 0\n",
    "    for name, module in model.named_modules():\n",
    "        for key in target_keys:\n",
    "            if name.endswith(key):\n",
    "                path = name.split(\".\")\n",
    "                parent = model\n",
    "                for p in path[:-1]:\n",
    "                    if hasattr(parent, p):\n",
    "                        parent = getattr(parent, p)\n",
    "                    elif p.isdigit():\n",
    "                        parent = parent[int(p)]\n",
    "                    else:\n",
    "                        raise RuntimeError(f\"Cannot resolve submodule path: {'.'.join(path)}\")\n",
    "\n",
    "                old_linear = getattr(parent, path[-1])\n",
    "                if not isinstance(old_linear, nn.Linear):\n",
    "                    continue\n",
    "\n",
    "                new_layer = UDV_LoRA(old_linear.in_features, old_linear.out_features, udv_rank, alpha_scaling)\n",
    "                with torch.no_grad():\n",
    "                    new_layer.set_weight(old_linear.weight.detach().clone().T)\n",
    "                setattr(parent, path[-1], new_layer)\n",
    "                \n",
    "                assert torch.allclose(old_linear.weight.detach().T.to(device), new_layer.W0.detach().half().to(device), atol=1e-5)      \n",
    "                replace_counter = replace_counter + 1\n",
    "                \n",
    "    print(f\"Replaced {replace_counter} layers in total.\")\n",
    "    return model\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ebc51b2-3ddd-42f7-83d1-3c1c160a9065",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Constraints\n",
    "# ===========================================================\n",
    "class UDV_UVbound(nn.Module):\n",
    "    def __init__(self, uv_norm_limit):\n",
    "        super().__init__()\n",
    "        self.uv_norm_limit = uv_norm_limit\n",
    "\n",
    "    def forward(self, udv_uv):\n",
    "        norm_sq = torch.linalg.matrix_norm(udv_uv, ord='fro') ** 2\n",
    "        if norm_sq > self.uv_norm_limit:\n",
    "            udv_uv.div_(torch.sqrt(norm_sq)) \n",
    "        return udv_uv\n",
    "\n",
    "class UDV_Dbound(nn.Module):\n",
    "    def __init__(self, d_lowerbound, d_boundto):\n",
    "        super().__init__()\n",
    "        self.d_lowerbound = d_lowerbound\n",
    "        self.d_boundto = d_boundto\n",
    "\n",
    "    def forward(self, udv_d):\n",
    "        udv_d[udv_d < self.d_lowerbound] = self.d_boundto\n",
    "        return udv_d\n",
    "\n",
    "def check_udv_constraints(udv_param_groups, uv_norm_limit, d_lowerbound):\n",
    "    failed_layers = []\n",
    "    scale_factor = 1.001 # To aviod numerical issue by scaling the norm \n",
    "\n",
    "    for layer_name, udv in udv_param_groups:\n",
    "        U = udv['U'].detach().float()\n",
    "        D = udv['D'].detach().float()\n",
    "        Vt = udv['Vt'].detach().float()\n",
    "\n",
    "        u_norm = torch.linalg.matrix_norm(U, ord='fro').item()\n",
    "        v_norm = torch.linalg.matrix_norm(Vt, ord='fro').item()\n",
    "        d_min = D.min().item()\n",
    "\n",
    "        u_ok = u_norm <= (uv_norm_limit * scale_factor)\n",
    "        v_ok = v_norm <= (uv_norm_limit * scale_factor)\n",
    "        d_ok = d_min >= d_lowerbound\n",
    "\n",
    "        if not (u_ok and v_ok and d_ok):\n",
    "            failed_layers.append({\n",
    "                \"layer\": layer_name,\n",
    "                \"u_norm\": u_norm,\n",
    "                \"v_norm\": v_norm,\n",
    "                \"d_min\": d_min,\n",
    "                \"u_ok\": u_ok,\n",
    "                \"v_ok\": v_ok,\n",
    "                \"d_ok\": d_ok,\n",
    "            })\n",
    "\n",
    "    if failed_layers:\n",
    "        print(f\"[Constraint Check] Constraint violation detected (scale factor [*{scale_factor}]):\")\n",
    "        for item in failed_layers:\n",
    "            print(f\"  Layer: {item['layer']}\")\n",
    "            print(f\"    U Fro norm = {item['u_norm']:.4f} | OK: {item['u_ok']}\")\n",
    "            print(f\"    V Fro norm = {item['v_norm']:.4f} | OK: {item['v_ok']}\")\n",
    "            print(f\"    D min val  = {item['d_min']:.4f} | OK: {item['d_ok']}\")\n",
    "    else:\n",
    "        print(f\"[Constraint Check] All UDV layers satisfy the constraints(scale factor [*{scale_factor}]).\")\n",
    "\n",
    "\n",
    "# ===========================================================        \n",
    "class Constraints_CallBack(TrainerCallback):\n",
    "    def __init__(self, uv_norm_limit, d_lowerbound, d_boundto, pruning_energy_threshold):\n",
    "        self.uv_bound = UDV_UVbound(uv_norm_limit)\n",
    "        self.d_bound = UDV_Dbound(d_lowerbound, d_boundto)\n",
    "        self.udv_param_groups = []  # List of (module_name, {'U': ..., 'D': ..., 'Vt': ...})\n",
    "        self.original_params = {}   # {layer_name: {'U': tensor, 'D': tensor, 'Vt': tensor}}\n",
    "        self.svd_record = {}        # {layer_name: [list of singular values per step]}\n",
    "        self.pruning_energy_threshold = pruning_energy_threshold\n",
    "        self.initialized = False\n",
    "\n",
    "    def _find_udv_groups(self,model):\n",
    "        udv_groups = []\n",
    "        for module_name, module in model.named_modules():\n",
    "            if all(hasattr(module, name) for name in ['U', 'D', 'Vt']):\n",
    "                U, D, Vt = module.U, module.D, module.Vt\n",
    "                if all(isinstance(p, torch.nn.Parameter) for p in [U, D, Vt]):\n",
    "                    udv_groups.append((module_name, {'U': U, 'D': D, 'Vt': Vt}))\n",
    "        return udv_groups\n",
    "    \n",
    "    def on_step_end(self, args, state, control, **kwargs):\n",
    "        model = kwargs[\"model\"]\n",
    "\n",
    "        if not self.initialized:\n",
    "            self.udv_param_groups = self._find_udv_groups(model)\n",
    "            self.initialized = True\n",
    "\n",
    "        if state.global_step % args.gradient_accumulation_steps == 0:\n",
    "            self.apply_constraints()\n",
    "            self.record_svd(step = state.global_step)\n",
    "\n",
    "        return control\n",
    "\n",
    "    def apply_constraints(self):\n",
    "        with torch.no_grad():\n",
    "            for name, param_dict in self.udv_param_groups:\n",
    "                U = param_dict['U']\n",
    "                D = param_dict['D']\n",
    "                Vt = param_dict['Vt']\n",
    "                \n",
    "                U.data.copy_(self.uv_bound(U.data))\n",
    "                Vt.data.copy_(self.uv_bound(Vt.data))\n",
    "                D.data.copy_(self.d_bound(D.data))\n",
    "\n",
    "    def record_svd(self, step):\n",
    "        with torch.no_grad():\n",
    "            for layer_name, udv in self.udv_param_groups:\n",
    "                U = udv['U']\n",
    "                D = udv['D']\n",
    "                try:\n",
    "                    uv_matrix = torch.mul(udv['U'], udv['D']).float() \n",
    "                    U_svd, S_svd, Vh_svd = torch.linalg.svd(uv_matrix, full_matrices = False)\n",
    "\n",
    "                    if layer_name not in self.svd_record:\n",
    "                        self.svd_record[layer_name] = []\n",
    "\n",
    "                    self.svd_record[layer_name].append({\"step\": step,\n",
    "                                                        \"svd\": S_svd.cpu().numpy()\n",
    "                                                       })\n",
    "                except Exception as e:\n",
    "                    print(f\"[Warning] SVD failed on layer {layer_name}: {e}\")\n",
    "\n",
    "    def plot_svd(self, layers=None, steps=None, normal_layers=None, show=True, save_path=None, model=None):\n",
    "    \n",
    "        plt.figure(figsize=(6, 4))\n",
    "        layers_to_plot = layers or list(self.svd_record.keys())\n",
    "    \n",
    "        # ---- Plot UDV layers\n",
    "        for layer in layers_to_plot:\n",
    "            records = self.svd_record.get(layer, [])\n",
    "            if not records:\n",
    "                print(f\"[plot_svd] No SVD records found for UDV layer: {layer}\")\n",
    "                continue\n",
    "    \n",
    "            for rec in records:\n",
    "                step = rec[\"step\"]\n",
    "                if steps is not None and step not in steps:\n",
    "                    continue\n",
    "                label = f\"{layer} (step {step})\"\n",
    "                plt.plot(rec[\"svd\"], label=label)\n",
    "    \n",
    "        # ---- Plot normal layers (final weights only)\n",
    "        if normal_layers and model is not None:\n",
    "            for layer_name in normal_layers:\n",
    "                try:\n",
    "                    module = model\n",
    "                    for attr in layer_name.split('.'):\n",
    "                        if attr.isdigit():\n",
    "                            module = module[int(attr)]\n",
    "                        else:\n",
    "                            module = getattr(module, attr)\n",
    "                    weight = module.weight.float()\n",
    "                    _, S_svd, _ = torch.linalg.svd(weight, full_matrices=False)\n",
    "                    label = f\"{layer_name} (final)\"\n",
    "                    plt.plot(S_svd.cpu().numpy(), label=label, linestyle='--')\n",
    "                except Exception as e:\n",
    "                    print(f\"[plot_svd] Failed to get SVD of {layer_name}: {e}\")\n",
    "    \n",
    "        plt.title(\"SVD Spectra - UDV and Normal Layers\")\n",
    "        plt.xlabel(\"Singular Value Index\")\n",
    "        plt.ylabel(\"Singular Value (log scale)\")\n",
    "        plt.yscale(\"log\")\n",
    "        plt.legend(fontsize=\"small\", loc=\"best\")\n",
    "        plt.grid(True)\n",
    "    \n",
    "        if save_path:\n",
    "            os.makedirs(save_path, exist_ok=True)\n",
    "            plt.savefig(os.path.join(save_path, \"svd_spectra.png\"))\n",
    "    \n",
    "        if show:\n",
    "            plt.show()\n",
    "    \n",
    "        plt.close()\n",
    "\n",
    "\n",
    "    def replace_para(self, optimizer, layer, name, new_param, transfer_state = False):\n",
    "        old_param = getattr(layer, name)\n",
    "        new_param = nn.Parameter(new_param)\n",
    "        setattr(layer, name, new_param)\n",
    "\n",
    "        for group in optimizer.param_groups:\n",
    "            for i, p in enumerate(group['params']):\n",
    "                if p is old_param:\n",
    "                    group['params'][i] = new_param\n",
    "        if old_param in optimizer.state:\n",
    "            if transfer_state:\n",
    "                optimizer.state[new_param] = optimizer.state.pop(old_param)\n",
    "            else:\n",
    "                del optimizer.state[old_param]\n",
    "\n",
    "    def auto_svd_pruning(self, model, optimizer, energy_threshold):\n",
    "        energy_threshold = energy_threshold or self.pruning_energy_threshold\n",
    "        kept_energies, pruned_layers, retained_layers = [], 0, 0\n",
    "        svd_record = {}\n",
    "\n",
    "        with torch.no_grad():\n",
    "            for layer_name, udv in self.udv_param_groups:\n",
    "                layer = dict(model.named_modules())[layer_name]\n",
    "                try:\n",
    "                    uv_matrix = torch.mul(udv['U'], udv['D']).float()\n",
    "                    U_svd, S_svd, Vh_svd = torch.linalg.svd(uv_matrix, full_matrices = False)\n",
    "                except Exception as e:\n",
    "                    print(f\"[Prune] Skipped {layer_name}: {e}\")\n",
    "                    continue\n",
    "\n",
    "                svd_record[layer_name] = S_svd.cpu().numpy()\n",
    "                energy = torch.cumsum(S_svd ** 2, dim=0)\n",
    "                total_energy = energy[-1]\n",
    "                keep_rank = torch.searchsorted(energy, energy_threshold * total_energy).item() + 1\n",
    "                kept_energy = energy[keep_rank - 1] / total_energy\n",
    "                kept_energies.append(kept_energy.item())\n",
    "                pruned_layers += 1\n",
    "\n",
    "                if keep_rank == S_svd.numel():\n",
    "                    retained_layers += 1\n",
    "                    continue\n",
    "\n",
    "                self.original_params[layer_name] = {\"U\": udv['U'].detach().cpu(),\n",
    "                                                    \"D\": udv['D'].detach().cpu(),\n",
    "                                                    \"Vt\": udv['Vt'].detach().cpu()}\n",
    "                \n",
    "                self.replace_para(optimizer, layer, 'U', U_svd[:, :keep_rank].half())\n",
    "                self.replace_para(optimizer, layer, 'D', S_svd[:keep_rank].unsqueeze(0).half())\n",
    "                self.replace_para(optimizer, layer, 'Vt', torch.matmul(Vh_svd[:keep_rank, :], udv['Vt'].float()).half())\n",
    "        avg_kept_energy = sum(kept_energies) / pruned_layers if pruned_layers > 0 else None\n",
    "        print(f\"[UDV Prune] Avg Kept Energy: {avg_kept_energy * 100:.4f}% over \"\n",
    "              f\"{pruned_layers} layers. {retained_layers} layers retained full rank.\")\n",
    "\n",
    "        return svd_record, avg_kept_energy\n",
    "\n",
    "    def recover_pruning(self, model, optimizer):\n",
    "        if not hasattr(self, \"original_params\"):\n",
    "            print(\"[Recover] No saved parameters found.\")\n",
    "            return\n",
    "    \n",
    "        print(f\"[Recover] Restoring {len(self.original_params)} layers to pre-pruned state...\")\n",
    "        for layer_name, params in self.original_params.items():\n",
    "            layer = dict(model.named_modules())[layer_name]\n",
    "            self.replace_para(optimizer, layer, \"U\", params[\"U\"].to(layer.U.device), transfer_state=False)\n",
    "            self.replace_para(optimizer, layer, \"D\", params[\"D\"].to(layer.D.device), transfer_state=False)\n",
    "            self.replace_para(optimizer, layer, \"Vt\", params[\"Vt\"].to(layer.Vt.device), transfer_state=False)\n",
    "    \n",
    "        print(\"[Recover] All pruned layers restored.\")\n",
    "\n",
    "\n",
    "    def set_tokenizer(self, tokenizer):\n",
    "        self.tokenizer = tokenizer\n",
    "\n",
    "    def on_train_end(self, args, state, control, **kwargs):\n",
    "        print(\"Pruning is not applied here. Save fine-tuned model first then test pruning.\")\n",
    "        \n",
    "    \n",
    "    def post_pruning(self, model, optimizer, energy_threshold):\n",
    "        print(f\"[Post Pruning] Starting manual pruning with energy threshold = {energy_threshold}...\")\n",
    "        # self.energy_threshold = energy_threshold\n",
    "        self.pruning_record, self.last_avg_kept_energy = self.auto_svd_pruning(model, optimizer, energy_threshold)\n",
    "        print(f\"[Post Pruning] Completed. Avg kept energy: {self.last_avg_kept_energy:.4f}\")\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6591820e-3471-4a90-aca0-6d7973e984bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "udv_lora_r = 64\n",
    "udv_lora_learning_rate = 1e-3\n",
    "udv_lora_num_train_epochs = 30\n",
    "udv_lora_output_path = f\"./udv_lora_r_{udv_lora_r}_SGD_{udv_lora_learning_rate}\"\n",
    "csv_path = udv_lora_output_path + \"/epoch_metrics.csv\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19447abe-e197-4a15-8b63-091a13ddeaab",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Re-load pre-trained model and replace linear layers by UDV_LoRA\n",
    "# ===========================================================\n",
    "udv_lora_model = AutoModelForCausalLM.from_pretrained(model_dir,\n",
    "                                                      device_map = \"auto\",\n",
    "                                                      torch_dtype = torch.float16)\n",
    "\n",
    "for param in udv_lora_model.parameters():\n",
    "    param.requires_grad = False\n",
    "\n",
    "TARGET_KEYS = [\"self_attn.q_proj\",\n",
    "               # \"self_attn.k_proj\",\n",
    "               \"self_attn.v_proj\",\n",
    "               # \"self_attn.o_proj\",\n",
    "               # \"mlp.up_proj\",\n",
    "               # \"mlp.gate_proj\",\n",
    "               # \"mlp.down_proj\",\n",
    "               # \"lm_head\"\n",
    "              ]\n",
    "\n",
    "# Replace layers and send to device\n",
    "udv_lora_model = udv_replacement(model = udv_lora_model,\n",
    "                                 target_keys = TARGET_KEYS,\n",
    "                                 udv_rank = udv_lora_r,\n",
    "                                 alpha_scaling = 0)\n",
    "udv_lora_model.to(device)\n",
    "\n",
    "# ===========================================================\n",
    "\n",
    "\n",
    "data_collator = DataCollatorForLanguageModeling(tokenizer, mlm = False)\n",
    "udv_training_args = TrainingArguments(output_dir = udv_lora_output_path,\n",
    "                                      per_device_train_batch_size = 8, \n",
    "                                      per_device_eval_batch_size = 8, \n",
    "                                      gradient_accumulation_steps = 4,\n",
    "                                      fp16 = False, \n",
    "                                      logging_nan_inf_filter = True, \n",
    "                                      seed = 42,\n",
    "                                      data_seed = 42,\n",
    "\n",
    "                                      optim = 'sgd',\n",
    "                                      learning_rate = udv_lora_learning_rate, \n",
    "                                      weight_decay = 0, \n",
    "                                      adam_beta1 = 0.9,\n",
    "                                      adam_beta2 = 0.999, \n",
    "                                      adam_epsilon = 1e-8,\n",
    "                                      max_grad_norm = 0, \n",
    "                                      lr_scheduler_type = \"cosine\",\n",
    "    \n",
    "                                      num_train_epochs = udv_lora_num_train_epochs, \n",
    "                                      eval_strategy = \"epoch\", \n",
    "\n",
    "                                      save_strategy = \"no\",\n",
    "                                      save_total_limit = 1, \n",
    "                                      load_best_model_at_end = False,\n",
    "                                      metric_for_best_model = \"eval_loss\",\n",
    "                                      greater_is_better = False,\n",
    "\n",
    "                                      logging_strategy = \"epoch\",\n",
    "                                      report_to = [\"none\"], \n",
    "    \n",
    "                                      label_names = [\"labels\"])\n",
    "\n",
    "# ===========================================================\n",
    "log_callback = LogCallback()\n",
    "constraint_callback = Constraints_CallBack(uv_norm_limit = 1,\n",
    "                                           d_lowerbound = 0,\n",
    "                                           d_boundto = 0,\n",
    "                                           pruning_energy_threshold = 1)\n",
    "constraint_callback.set_tokenizer(tokenizer)\n",
    "\n",
    "\n",
    "udv_trainer = Trainer(model = udv_lora_model,\n",
    "                      args = udv_training_args,\n",
    "                      train_dataset = tokenized_train,\n",
    "                      eval_dataset = tokenized_val,\n",
    "                      data_collator = data_collator,\n",
    "                      callbacks = [log_callback, constraint_callback]\n",
    "                     )\n",
    "\n",
    "udv_lora_output = udv_trainer.evaluate() # Inference once to create udv_lora baseline\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a0cbd2d-f667-4294-b826-ac4ff1d61e44",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluation: Fine-tuning with UDV-LoRA\n",
    "# ===========================================================\n",
    "udv_lora_output = udv_trainer.train() # Fine-tuning with inference\n",
    "check_udv_constraints(constraint_callback.udv_param_groups,\n",
    "                      uv_norm_limit = constraint_callback.uv_bound.uv_norm_limit,\n",
    "                      d_lowerbound = constraint_callback.d_bound.d_lowerbound)\n",
    "\n",
    "log_callback.save_to_csv(csv_path)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3516b6b3-4bdd-4183-8585-66c3db007a48",
   "metadata": {},
   "outputs": [],
   "source": [
    "energy_list = [round(x, 1) for x in torch.arange(1.0, 0.0, -0.05).tolist()]\n",
    "\n",
    "for energy in energy_list:\n",
    "    constraint_callback.recover_pruning(model=udv_trainer.model,\n",
    "                                        optimizer=udv_trainer.optimizer)\n",
    "\n",
    "    constraint_callback.post_pruning(model=udv_trainer.model,\n",
    "                                     optimizer=udv_trainer.optimizer, \n",
    "                                     energy_threshold=energy)\n",
    "\n",
    "    after_pruning_log = LogCallback()\n",
    "    \n",
    "    udv_lora_output = udv_trainer.evaluate()\n",
    "    after_pruning_log.on_evaluate(\n",
    "        args=udv_trainer.args,\n",
    "        state=udv_trainer.state,\n",
    "        control=udv_trainer.control,\n",
    "        metrics=udv_lora_output,\n",
    "        model=udv_trainer.model\n",
    "    )\n",
    "    \n",
    "    after_pruning_log.record_pruning_energy(constraint_callback)\n",
    "    after_pruning_log.save_to_csv(csv_path)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a472110-ff06-4458-b7f9-5aebc530bc40",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(udv_lora_output_path + \"/svd_dict.pkl\", \"wb\") as f:\n",
    "    pickle.dump(constraint_callback.svd_record, f)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7e8b3c4-635e-4c79-b8a4-d8ab014e2de3",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"all done\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
