{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26edca15",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import time\n",
    "import torch\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "from collections import defaultdict\n",
    "from peft import get_peft_model, LoraConfig, AdaLoraConfig, TaskType\n",
    "\n",
    "from glue_eval.glue_eval import GLUEEval\n",
    "from util import nethook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5083abbb",
   "metadata": {},
   "outputs": [],
   "source": [
    "def eval(model, tok, method):\n",
    "    delta_dir = f'./results/PTE/batch_tmp/_{method}/256bs/'\n",
    "    for file in os.listdir(delta_dir):\n",
    "        if not file.endswith('_delta.pth'):\n",
    "            continue\n",
    "\n",
    "        delta_weight = torch.load(os.path.join(delta_dir, file))\n",
    "        for k, v in delta_weight.items():\n",
    "            with torch.no_grad():\n",
    "                w = nethook.get_parameter(model, k)\n",
    "                w[...] = v.cuda().float()\n",
    "\n",
    "        out_file = f'./results/glue/{method}/{file}/{time.time_ns()}.json'\n",
    "        os.makedirs(f'./results/glue/{method}/{file}/', exist_ok=True)\n",
    "\n",
    "        glue_results = {'edit_num': -1}\n",
    "        glue_eval = GLUEEval(model, tok, number_of_tests = 100)\n",
    "        glue_results = glue_eval.evaluate(glue_results, out_file, nli_flag=True, sst_flag=True, cola_flag=True, rte_flag=True, mmlu_flag=True, mrpc_flag=True)\n",
    "\n",
    "        output_filename = out_file.replace('.json', '_glue.json')\n",
    "        with open(output_filename, \"w\") as f:\n",
    "            json.dump(glue_results, f, indent=4)\n",
    "\n",
    "def eval_lora(model, tok, method):\n",
    "    lora_config_dict = {\n",
    "        'lora': LoraConfig,\n",
    "        'adalora': AdaLoraConfig\n",
    "    }\n",
    "    lora_config = lora_config_dict[method](\n",
    "        task_type=TaskType.CAUSAL_LM,\n",
    "        inference_mode=False,\n",
    "        r=8,\n",
    "        lora_alpha=8,\n",
    "        lora_dropout=0.1,\n",
    "        layers_to_transform=None,\n",
    "        target_modules=[\"up_proj\", \"down_proj\"]\n",
    "    )\n",
    "    lora_model = get_peft_model(model, lora_config)\n",
    "    lora_model.generate = lora_model.base_model.generate\n",
    "\n",
    "    delta_dir = f'./results/PTE/batch_tmp/_{method}/256bs/'\n",
    "    for file in os.listdir(delta_dir):\n",
    "        if not file.endswith('_delta.pth'):\n",
    "            continue\n",
    "\n",
    "        delta_weight = torch.load(os.path.join(delta_dir, file))\n",
    "        for name, params in lora_model.named_parameters():\n",
    "            for delta_name in delta_weight:\n",
    "                if delta_name.replace('.weight', '') in name:\n",
    "                    params.data = delta_weight[delta_name].to(params.device)\n",
    "\n",
    "        out_file = f'./results/glue/{method}/{file}/{time.time_ns()}.json'\n",
    "        os.makedirs(f'./results/glue/{method}/{file}/', exist_ok=True)\n",
    "\n",
    "        glue_results = {'edit_num': -1}\n",
    "        glue_eval = GLUEEval(model, tok, number_of_tests = 100)\n",
    "        glue_results = glue_eval.evaluate(glue_results, out_file, nli_flag=True, sst_flag=True, cola_flag=True, rte_flag=True, mmlu_flag=True, mrpc_flag=True)\n",
    "\n",
    "        output_filename = out_file.replace('.json', '_glue.json')\n",
    "        with open(output_filename, \"w\") as f:\n",
    "            json.dump(glue_results, f, indent=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f286c44",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = 'Meta-Llama-3-8B-Instruct'\n",
    "model = AutoModelForCausalLM.from_pretrained(model_name).cuda()\n",
    "model.config._name_or_path = 'meta-l3'\n",
    "tok = AutoTokenizer.from_pretrained(model_name)\n",
    "tok.pad_token = tok.eos_token"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49890881",
   "metadata": {},
   "outputs": [],
   "source": [
    "def summarize(method):\n",
    "    glue_dir = f'./results/glue/{method}/'\n",
    "    results = defaultdict(list)\n",
    "    for batch_dir in os.listdir(glue_dir):\n",
    "        for file in os.listdir(os.path.join(glue_dir, batch_dir)):\n",
    "            if file.endswith('_glue.json'):\n",
    "                with open(os.path.join(glue_dir, batch_dir, file), 'r') as f:\n",
    "                    data = json.load(f)\n",
    "                for key, value in data.items():\n",
    "                    if key != 'edit_num':\n",
    "                        # results[key].append(value['correct'] / value['total'] * 100)\n",
    "                        results[key].append(value['f1'] * 100)\n",
    "    return {k: sum(v)/len(v) for k, v in results.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "c531d894",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "'alphaedit': {'SST': 0.0, 'MMLU': 14.41, 'MRPC': 6.82, 'COLA': 3.04, 'RTE': 14.84, 'NLI': 9.53, 'AVG.': 8.11},\n"
     ]
    }
   ],
   "source": [
    "for method in ['alphaedit']:\n",
    "    result = summarize(method)\n",
    "    result['avg.'] = sum(result.values()) / len(result)\n",
    "    round_result = {k: round(v, 2) for k, v in result.items()}\n",
    "\n",
    "    final_result = {k.replace('mmm', 'mm').upper(): v for k, v in round_result.items()}\n",
    "\n",
    "    print(f'\\'{method}\\': {final_result},')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "memit_llama",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
