{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "################# Convert LAMA into preparation file.\n",
    "\n",
    "import os\n",
    "from transformers import AutoTokenizer\n",
    "import torch\n",
    "import jsonlines\n",
    "import re\n",
    "import torch.nn as nn\n",
    "from tqdm import tqdm\n",
    "import json\n",
    "import numpy as np\n",
    "import os\n",
    "import random\n",
    "\n",
    "model_name = \"meta-llama/Llama-2-7b-hf\"\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
    "tokenizer.pad_token = tokenizer.eos_token  \n",
    "\n",
    "\n",
    "# Relation configurations\n",
    "config_dic = {}\n",
    "with jsonlines.open('data/LAMA/relations.jsonl', 'r') as reader:\n",
    "    for dic in reader:\n",
    "        config_dic[dic['relation']] =dic\n",
    "\n",
    "##TREX file list\n",
    "list0 = os.listdir('data/LAMA/TREx')\n",
    "list0 = sorted(list0, key=lambda x: int(x.replace('P','').replace('.jsonl','')))\n",
    "TREX_relations = [x.replace('.jsonl','') for x in list0]\n",
    "\n",
    "### invariant, variant   list\n",
    "invariant_rel = ['P19','P20','P279','P37','P449','P47','P138','P364','P527','P176','P27','P407','P30','P178','P1376','P131','P1412','P17','P276','P937','P140','P103','P190','P1001','P495','P36','P740','P361']\n",
    "file_list = os.listdir('dataset/data/TREx')\n",
    "total_rel_list = [x.replace('.jsonl','') for x in file_list]\n",
    "total_rel_list = list(sorted(total_rel_list, key=lambda x:int(x[1:])))\n",
    "variant_rel = []\n",
    "for file0 in total_rel_list:\n",
    "    if file0 not in invariant_rel:\n",
    "        variant_rel.append(file0)\n",
    "\n",
    "### Filter and tailor into preparation file\n",
    "def make_task_schematic(sub, rel, obj):\n",
    "    task = f\"Guess the object. \\n  subject is {sub} , relation is {rel} , object is {obj}\"\n",
    "    return task\n",
    "\n",
    "def make_task_descriptive(sub, obj,template):\n",
    "    task = template.replace('[X]',sub).replace('[Y]',obj)\n",
    "    return task\n",
    "\n",
    "\n",
    "\n",
    "with jsonlines.open(f'temp/TReX_for_train_attention.jsonl', 'w')  as writer:\n",
    "    for rel in total_rel_list:\n",
    "        relation_label = config_dic[rel]['label']  # blace_of_birth, ...\n",
    "        template = config_dic[rel]['template']\n",
    "\n",
    "        if rel in TREX_relations:\n",
    "            \n",
    "            if rel in invariant_rel:\n",
    "                invariant= True\n",
    "            elif rel in variant_rel:\n",
    "                invariant= False\n",
    "            else:\n",
    "                raise Exception('wrong')\n",
    "            \n",
    "            with jsonlines.open(f'data/LAMA/{rel}.jsonl', 'r') as reader1:\n",
    "                for dic1 in tqdm(reader1):\n",
    "                    with torch.no_grad():\n",
    "                        X = dic1['sub_label'] # X\n",
    "                        Y = dic1['obj_label'] # Y\n",
    "\n",
    "                        # 1. make 2 tasks: task_descriptive,  task_schematic\n",
    "                        #1)\n",
    "                        task_descriptive = make_task_descriptive(X, Y, template)\n",
    "                        \n",
    "                        #2)\n",
    "                        task_schematic = make_task_schematic(X, relation_label, Y)\n",
    "\n",
    "                        ### position & length\n",
    "                        evidences=[]\n",
    "                        for ev in dic1['evidences']:\n",
    "                            evidence = ev['masked_sentence']\n",
    "                            evidences.append(evidence)\n",
    "                        evidences = list(sorted(evidences , key=lambda x:len(x), reverse=True))\n",
    "                        evidence = evidences[0]\n",
    "                        position = re.search(r'\\[MASK\\]', evidence).span()[0] / len(evidence)\n",
    "                        \n",
    "                        masked_evidence = evidence \n",
    "                        evidence = masked_evidence.replace('[MASK]', Y)\n",
    "                        \n",
    "                        \n",
    "                        write_dic ={'relation_code':rel, 'uuid': dic1['uuid'],\n",
    "                                    'task_descriptive':task_descriptive,\n",
    "                                    'task_schematic':task_schematic,\n",
    "                                    'subject':X,\n",
    "                                    'relation_label': relation_label,\n",
    "                                    'object':Y, \n",
    "                                    'masked_evidence':masked_evidence, 'evidence':evidence, \n",
    "                                    'position':position,\n",
    "                                    'evidence_length': len(evidence),\n",
    "                                    'invariant':invariant,\n",
    "                                    'scores':{}}\n",
    "                        writer.write(write_dic)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "############### Measure Accuracy with baseline models ###########\n",
    "from transformers import AutoTokenizer\n",
    "from TAALM import TAALM, Llama2_kadapter\n",
    "\n",
    "model_name = \"meta-llama/Llama-2-7b-hf\"\n",
    "adapter_file= \"Llama-2-7b\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\n",
    "tokenizer.pad_token = tokenizer.eos_token  \n",
    "\n",
    "############## put in other models if you want ##################\n",
    "model_7b = TAALM.init_theta(model_name=\"meta-llama/Llama-2-7b-hf\", adapter_file=\"Llama-2-7b\", onepiece=True).to('cuda')\n",
    "model_1b = TAALM.init_theta(model_name='TinyLlama/TinyLlama-1.1B-Chat-v1.0', adapter_file=\"Llama-2-1b\", onepiece=True).to('cuda')\n",
    "model_1b_kadapter = Llama2_kadapter.init_gamma_theta(model_name='TinyLlama/TinyLlama-1.1B-Chat-v1.0' , onepiece=True).to('cuda')\n",
    "model_7b_kadapter = Llama2_kadapter.init_gamma_theta(model_name=\"meta-llama/Llama-2-7b-hf\" , onepiece=True).to('cuda')\n",
    "####################################################################\n",
    "\n",
    "\n",
    "# Execute calculating\n",
    "def giveme_label_mask(query_token, label_token):  \n",
    "    pallet = torch.zeros(len(query_token))\n",
    "    for start0 in list(range(len(pallet)- len(label_token)+1))[::-1]:\n",
    "        if torch.equal(query_token[start0: start0+len(label_token)], label_token):\n",
    "            pallet[start0: start0+len(label_token)] =1 \n",
    "            break\n",
    "    return pallet[1:]\n",
    "\n",
    "def giveme_acc(model, task_descriptive, task_descriptive_label_mask):\n",
    "    response= model(**task_descriptive)\n",
    "    log_probs = -nn.functional.log_softmax(response.logits[:,:-1,:], dim=-1)\n",
    "    output_tk = torch.argmin(log_probs, dim=-1) \n",
    "    output_labeled = output_tk * task_descriptive_label_mask\n",
    "    output_cop = output_labeled[output_labeled !=0]\n",
    "    score = output_cop == object.input_ids[0]\n",
    "    score = score.float().mean()\n",
    "    return score\n",
    "\n",
    "\n",
    "\n",
    "with jsonlines.open(f'temp/TREx_for_train_attention.jsonl', 'r') as reader, jsonlines.open(f'temp/TREx_for_train_attention_w_scores.jsonl', 'w') as writer:\n",
    "    n=0\n",
    "    for dic in tqdm(reader):\n",
    "        obj = dic['object']\n",
    "        task_descriptive = dic['task_descriptive']\n",
    "        task_schematic = dic['task_schematic']\n",
    "        object = dic['object']\n",
    "        task_descriptive = tokenizer(task_descriptive, return_tensors='pt').to('cuda')\n",
    "        task_schematic = tokenizer(task_schematic, return_tensors='pt').to('cuda')\n",
    "        object = tokenizer(object, return_tensors='pt', add_special_tokens=False).to('cuda')\n",
    "\n",
    "        task_descriptive_label_mask = giveme_label_mask(task_descriptive.input_ids[0] , object.input_ids[0]).to('cuda')\n",
    "        task_schematic_label_mask = giveme_label_mask(task_schematic.input_ids[0] , object.input_ids[0]).to('cuda')\n",
    "        \n",
    "        with torch.no_grad():\n",
    "            score_7b_desc = giveme_acc(model_7b, task_descriptive, task_descriptive_label_mask)\n",
    "            score_1b_desc = giveme_acc(model_1b, task_descriptive, task_descriptive_label_mask)\n",
    "            score_7b_kadapter_desc = giveme_acc(model_7b_kadapter, task_descriptive, task_descriptive_label_mask)\n",
    "            score_1b_kadapter_desc = giveme_acc(model_1b_kadapter, task_descriptive, task_descriptive_label_mask)\n",
    "\n",
    "            score_7b_schem = giveme_acc(model_7b, task_schematic, task_schematic_label_mask)\n",
    "            score_1b_schem = giveme_acc(model_1b, task_schematic, task_schematic_label_mask)\n",
    "            score_7b_kadapter_schem = giveme_acc(model_7b_kadapter, task_schematic, task_schematic_label_mask)\n",
    "            score_1b_kadapter_schem = giveme_acc(model_1b_kadapter, task_schematic, task_schematic_label_mask)\n",
    "\n",
    "            scores = {  'descriptive' : \n",
    "             {'llama_7b': score_7b_desc.cpu().numpy().tolist(),\n",
    "              'llama_1b': score_1b_desc.cpu().numpy().tolist(),\n",
    "              'llama_7b_kadapter': score_7b_kadapter_desc.cpu().numpy().tolist(),\n",
    "              'llama_1b_kadapter': score_1b_kadapter_desc.cpu().numpy().tolist()},\n",
    "              'schematic':\n",
    "              {'llama_7b': score_7b_schem.cpu().numpy().tolist(),\n",
    "              'llama_1b': score_1b_schem.cpu().numpy().tolist(),\n",
    "              'llama_7b_kadapter': score_7b_kadapter_schem.cpu().numpy().tolist(),\n",
    "              'llama_1b_kadapter': score_1b_kadapter_schem.cpu().numpy().tolist()}\n",
    "            }\n",
    "        dic['scores'] = scores\n",
    "        # print(dic)\n",
    "        writer.write(dic)\n",
    "        # n+=1\n",
    "        # if n>10:\n",
    "            # break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "################### Sample and Save the dataset ###################\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "\n",
    "invar_descriptive = []\n",
    "invar_schematic = []\n",
    "variant = []\n",
    "train_attention = []\n",
    "\n",
    "TP_invar_descriptive = 0  # true positive\n",
    "TP_invar_schematic = 0\n",
    "TN_var_descriptive = 0  # true negative\n",
    "TN_var_schematic = 0  \n",
    "\n",
    "with jsonlines.open('temp/TREx_for_train_attention_w_scores.jsonl', 'r') as reader:\n",
    "    for dic in tqdm(reader):\n",
    "        evidence = dic['evidence']\n",
    "        sub = dic['subject']\n",
    "        obj = dic['object']\n",
    "        if dic['evidence_length'] >200 and (obj in evidence and sub in evidence):\n",
    "            scores = dic['scores']\n",
    "            descriptive = scores['descriptive']\n",
    "            schematic = scores['schematic']\n",
    "            invariant = dic['invariant']\n",
    "            models = descriptive.keys()\n",
    "            if invariant:\n",
    "                invar_descriptive_score = np.mean(list(descriptive.values()))\n",
    "\n",
    "                invar_schematic_score = np.mean(list(schematic.values()))\n",
    "\n",
    "                if invar_descriptive_score == 1 :\n",
    "                    invar_descriptive.append(dic)\n",
    "\n",
    "                if invar_schematic_score ==1:\n",
    "                    invar_schematic.append(dic)\n",
    "\n",
    "                if invar_descriptive_score < 0.5 and invar_schematic_score < 0.5:\n",
    "                    train_attention.append(dic)\n",
    "\n",
    "            else:\n",
    "                var_descriptive_score = np.mean(list(descriptive.values()))\n",
    "                var_schematic_score = np.mean(list(schematic.values()))\n",
    "                if var_descriptive_score == 0 and var_schematic_score ==0:\n",
    "                    variant.append(dic)\n",
    "                if var_descriptive_score < 0.5 and var_schematic_score < 0.5:\n",
    "                    train_attention.append(dic)\n",
    "\n",
    "\n",
    "### control 'P530' relation type\n",
    "\n",
    "p530 = 0\n",
    "temp_variant=[]\n",
    "for var in variant:\n",
    "    if var['relation_code'] == 'P530':\n",
    "        p530+=1\n",
    "        if p530 >130:\n",
    "            continue\n",
    "        \n",
    "    temp_variant.append(var)\n",
    "\n",
    "\n",
    "### Sample and observe the distribution\n",
    "\n",
    "random.seed(42)\n",
    "sample_invar_descriptive =  random.sample(invar_descriptive, k=500)\n",
    "random.seed(42)\n",
    "sample_invar_schematic = random.sample(invar_schematic, k=500)\n",
    "random.seed(42)\n",
    "sample_variant = random.sample(temp_variant, k=500)\n",
    "uuids = [x['uuid'] for x in sample_variant]\n",
    "p530 = 0\n",
    "sample_train_attention = []\n",
    "for x in train_attention:\n",
    "    if x['uuid'] not in uuids:\n",
    "        if x['relation_code'] == 'P530':\n",
    "            p530+=1\n",
    "            if p530 >350:\n",
    "                continue\n",
    "        sample_train_attention.append(x)\n",
    "\n",
    "file_list = os.listdir('dataset/data/TREx')\n",
    "total_rel_list = [x.replace('.jsonl','') for x in file_list]\n",
    "total_rel_list = list(sorted(total_rel_list, key=lambda x:int(x[1:])))\n",
    "\n",
    "rel_dics={}\n",
    "rel_codes = []\n",
    "for dic in sample_train_attention: #### 여기 바꿔가면서 보면 됨\n",
    "    rel_codes.append(dic['relation_code'])\n",
    "\n",
    "for rel in total_rel_list:\n",
    "    rel_dics[rel] = rel_codes.count(rel)\n",
    "\n",
    "plt.figure(figsize=(15,1))\n",
    "plt.bar(list(rel_dics.keys()), list(rel_dics.values()))\n",
    "plt.xticks(rotation=90)\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "159"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#### Save into dataset\n",
    "with jsonlines.open('data/LAMA_ckl/invariant_descriptive.jsonl', 'w') as writer:\n",
    "    writer.write_all(sample_invar_descriptive)\n",
    "\n",
    "with jsonlines.open('data/LAMA_ckl/invariant_schematic.jsonl', 'w') as writer:\n",
    "    writer.write_all(sample_invar_schematic)\n",
    "\n",
    "with jsonlines.open('data/LAMA_ckl/variant.jsonl', 'w') as writer:\n",
    "    writer.write_all(sample_variant)\n",
    "\n",
    "with jsonlines.open('data/LAMA_ckl/train_attention_traindata.jsonl', 'w') as writer:\n",
    "    writer.write_all(sample_train_attention)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "yb_qlora",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
