{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[{"file_id":"1TV-3P6fqyHYDPNgHkYuVM7XpDkhvyWpr","timestamp":1712882148420}]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","source":["# setup"],"metadata":{"id":"5SettayOqsgy"}},{"cell_type":"code","execution_count":null,"metadata":{"id":"LcqhTnNhmaEV"},"outputs":[],"source":["%%capture\n","!pip install transformers"]},{"cell_type":"code","source":["import json\n","import pickle\n","import torch\n","import numpy as np\n","from transformers import pipeline\n","from tqdm import tqdm"],"metadata":{"id":"0deF0foom_0q"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n","print(device)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"0ZH-V0EunX1N","executionInfo":{"status":"ok","timestamp":1726921687168,"user_tz":240,"elapsed":6,"user":{"displayName":"armin toroghi","userId":"07268870172709529513"}},"outputId":"2a637c1d-6ddb-47d7-ec55-5002ac733812"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["cpu\n"]}]},{"cell_type":"code","source":["%%capture\n","nli = pipeline(model=\"facebook/bart-large-mnli\", device=device)"],"metadata":{"collapsed":true,"id":"KZLGoSf2nDkl"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from google.colab import drive\n","import json\n","import pickle\n","\n","# Mount your Google Drive\n","drive.mount('/content/drive', force_remount=True)\n","\n","# Paths to your files in Google Drive\n","# Update the paths based on the location of your files in Google Drive\n","aspects_cache_path = '/content/drive/My Drive/AR/recipe_mpr_final2/caches/aspects_cache.json'\n","negations_cache_path = '//content/drive/My Drive/AR/recipe_mpr_final2/caches/negations_cache.json'\n","recipe_mpr_path = '/content/drive/My Drive/AR/recipe_mpr_final2/caches/Recipe-MPR.json'\n","entailment_cache_2_path = '/content/drive/My Drive/AR/recipe_mpr_final2/caches//entailment_cache_2.pkl'\n","query_ingredients = '/content/drive/My Drive/AR/recipe_mpr_final2/caches/query_ingredients_all.json'\n","options_ingredients = '/content/drive/My Drive/AR/recipe_mpr_final2/caches/options_ingredients_all.json'\n","\n","# Load the files from Google Drive\n","with open(aspects_cache_path, \"r\") as f:\n","    aspects_cache = json.load(f)\n","\n","with open(negations_cache_path, \"r\") as f:\n","    negations_cache = json.load(f)\n","\n","with open(recipe_mpr_path, \"r\") as f:\n","    data = json.load(f)\n","\n","with open(entailment_cache_2_path, \"rb\") as f:\n","    entailment_cache_2 = pickle.load(f)\n","\n","with open(query_ingredients, \"r\") as f:\n","    query2ingredients = json.load(f)\n","\n","with open(options_ingredients, \"r\") as f:\n","    options2ingredients = json.load(f)\n"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"xn2tJHQc-ial","executionInfo":{"status":"ok","timestamp":1726925503559,"user_tz":240,"elapsed":2169,"user":{"displayName":"armin toroghi","userId":"07268870172709529513"}},"outputId":"2c1888ee-1892-4240-cfbd-fbd91b48989c"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Mounted at /content/drive\n"]}]},{"cell_type":"code","source":["# with open(\"aspects_cache.json\", \"r\") as f:\n","#     aspects_cache = json.load(f)\n","\n","# with open(\"negations_cache.json\", \"r\") as f:\n","#     negations_cache = json.load(f)\n","\n","# with open(\"Recipe-MPR.json\", \"r\") as f:\n","#     data = json.load(f)\n","\n","# with open(\"entailment_cache_2.pkl\", \"rb\") as f:\n","#     entailment_cache_2 = pickle.load(f)"],"metadata":{"id":"AfbbZ9EsvWUH"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# code"],"metadata":{"id":"2h6s_-fIesNK"}},{"cell_type":"code","source":["def iter_recipe_mpr():\n","    with open(f'/content/drive/My Drive/AR/recipe_mpr_final2/caches/Recipe-MPR.json', \"r\") as f:\n","        data = json.load(f)\n","\n","    for row in data:\n","        query = row[\"query\"]\n","        options = row[\"options\"].values()\n","        answer = row[\"options\"][row[\"answer\"]]\n","\n","        yield query, options, answer"],"metadata":{"id":"E5uF9oXoI95y"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def score_entailment(premise, hypothesis):\n","    cache = entailment_cache_2\n","\n","    if (premise, hypothesis) in cache:\n","        return cache[(premise, hypothesis)]\n","\n","\n","    result = nli(premise, hypothesis)\n","\n","    print(f\"nli {premise} => {hypothesis}\")\n","    cache[(premise, hypothesis)] = result[\"scores\"][0]\n","\n","    with open(\"entailment_cache_2.pkl\", \"wb\") as f:\n","        pickle.dump(cache, f)\n","\n","\n","    return result[\"scores\"][0]\n","\n","def negate_aspect(aspect):\n","    return negations_cache[aspect]\n","\n","def extract_ingredients(s):\n","\n","    return s.split(\"(\")[1].split(\",\")[1].split(\")\")[0].strip()\n"],"metadata":{"id":"nN-rp1xbaPvk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def score_aspect(aspect, option):\n","    aspect_negated = negate_aspect(aspect)\n","\n","    pos_score = score_entailment(option, aspect)\n","    neg_score = score_entailment(option, aspect_negated)\n","\n","    # softmax\n","    scores = np.array([pos_score, neg_score])\n","    scores = np.exp(scores) / np.sum(np.exp(scores))\n","    pos_score, neg_score = scores\n","\n","    # score\n","    #odds = pos_score / neg_score\n","    #score = odds / (1 + odds)\n","    score = pos_score\n","    score = np.log(score)\n","\n","    result = {\n","        \"score\": score,\n","        \"raw_score\": {\n","            \"pos\": {\"score\": pos_score, \"aspect\": aspect},\n","            \"neg\": {\"score\": neg_score, \"aspect\": aspect_negated}\n","        }\n","    }\n","\n","    return result\n","\n","def score_ingredients(query_aspects, option):\n","    option_score = 0\n","    query_ingredients = []\n","\n","    for aspect in query_aspects:\n","      if query2ingredients[aspect] != 'None':\n","        query_ingredients.append(query2ingredients[aspect])\n","    option_ingredients = options2ingredients[option]\n","\n","    for query_ingredient_triple in query_ingredients:\n","        query_ingredient = extract_ingredients(query_ingredient_triple)\n","        ingredient_score = 0\n","        for option_ingredient_triple in option_ingredients:\n","            if option_ingredient_triple == 'None':\n","                continue\n","            option_ingredient = extract_ingredients(option_ingredient_triple)\n","\n","            if option_ingredient == query_ingredient:\n","                option_ingredient_score = 1\n","            else:\n","              option_ingredient_score = score_entailment(option_ingredient, query_ingredient)\n","            if option_ingredient_score > ingredient_score:\n","                ingredient_score = option_ingredient_score\n","        if ingredient_score == 0:\n","            ingredient_score = 1\n","        option_score += np.log(ingredient_score)\n","\n","    result = {\n","        \"score\": option_score\n","    }\n","    return result\n","\n","\n","def score_option(aspects, option):\n","    score = 0\n","    aspect_scores = []\n","\n","    for aspect in aspects:\n","        aspect_score = score_aspect(aspect, option)\n","\n","\n","        score += aspect_score[\"score\"]\n","        aspect_scores.append(aspect_score)\n","\n","    result = {\n","        \"option\": option,\n","        \"score\": score,\n","        \"aspect_scores\": aspect_scores\n","    }\n","\n","    return result"],"metadata":{"id":"hHENk-IFfSM8"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def answer_query(query, options):\n","    aspects = aspects_cache[query]\n","\n","    ranking = []\n","\n","    for option in options:\n","        option_score = score_option(aspects, option)\n","        #print(option_score[\"score\"])\n","        ingredients_score = score_ingredients(aspects, option)\n","        #print(ingredients_score[\"score\"])\n","\n","        option_score[\"score\"] += ingredients_score[\"score\"]\n","\n","        ranking.append(option_score)\n","    ranking.sort(key=lambda x: x[\"score\"], reverse=True)\n","\n","    return ranking"],"metadata":{"id":"BPDfl5xchHUH"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def log(query, answer, ranking, f):\n","    pred = ranking[0][\"option\"]\n","\n","    is_correct = (pred == answer)\n","\n","    f.write(f\"  QUERY: {query}\\n\")\n","    f.write(f\" ANSWER: {answer}\\n\")\n","    f.write(f\"   PRED: {pred}\\n\")\n","    f.write(f\"CORRECT: {is_correct}\\n\\n\")\n","\n","    for option in ranking:\n","\n","        if option[\"option\"] == answer:\n","            f.write(f\"    ({option['score']:.5f}) {option['option']} **ANSWER**\\n\")\n","        else:\n","            f.write(f\"    ({option['score']:.5f}) {option['option']}\\n\")\n","\n","        aspect_scores = option[\"aspect_scores\"]\n","\n","        for aspect_score in aspect_scores:\n","            raw_aspect_score = aspect_score[\"raw_score\"]\n","\n","            f.write(f\"        {raw_aspect_score['pos']['score']:.5f} => {raw_aspect_score['pos']['aspect']}\\n\")\n","            f.write(f\"        {raw_aspect_score['neg']['score']:.5f} => {raw_aspect_score['neg']['aspect']}\\n\")\n","            f.write(f\"        -------\\n\")\n","            f.write(f\"        {aspect_score['score']:.5f}\\n\\n\")"],"metadata":{"id":"mBfaRHbehenQ"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["log_f = \"run1.txt\"\n","\n","correct = 0\n","incorrect = 0\n","\n","f = open(log_f, \"w\")\n","\n","n = 0\n","\n","\n","for query, options, answer in iter_recipe_mpr():\n","    ranking = answer_query(query, options)\n","\n","    prediction = ranking[0][\"option\"]\n","\n","    if prediction == answer:\n","        correct += 1\n","    else:\n","        incorrect += 1\n","\n","    log(query, answer, ranking, f)\n","\n","    n += 1\n","    if n == 500:\n","        break\n","\n","accuracy = correct / (correct + incorrect)\n","print(f\"accuracy: {accuracy:.5f}\")\n","f.close()\n"],"metadata":{"id":"rp8bRA4Lh6Tm","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1726929835026,"user_tz":240,"elapsed":275,"user":{"displayName":"armin toroghi","userId":"07268870172709529513"}},"outputId":"91e882dc-9090-4a75-c8fa-f9f0571effcf"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["accuracy: 0.83800\n"]}]},{"cell_type":"code","source":[],"metadata":{"id":"pkUdCSsKcBeg"},"execution_count":null,"outputs":[]}]}