{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "07431a8c",
   "metadata": {},
   "source": [
    "# Eval BLEU4 and ROUGE-L of all"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9cc786e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import re\n",
    "import spacy\n",
    "from bleu.bleu import Bleu\n",
    "from rouge.rouge import Rouge\n",
    "\n",
    "nlp = spacy.load(\"en_core_web_md\")\n",
    "\n",
    "def extract_concepts(text):\n",
    "    pattern = r'(.*?)(<\\|.*?\\|>)(.*)'\n",
    "    concepts = []\n",
    "    while text:\n",
    "        match = re.match(pattern, text, re.DOTALL)\n",
    "        if match:\n",
    "            if match.group(1):\n",
    "                concepts.append(match.group(1))\n",
    "            concepts.append(match.group(2))\n",
    "            text = match.group(3)\n",
    "        else:\n",
    "            if text:\n",
    "                concepts.append(text)\n",
    "            break\n",
    "    return concepts\n",
    "\n",
    "def tokenize_list(sentences):\n",
    "    tokenized = []\n",
    "    for sentence in sentences:\n",
    "        tokens = ' '.join([token.text for token in nlp(sentence)])\n",
    "        tokenized.append(tokens)\n",
    "    return tokenized\n",
    "\n",
    "def weighted_mean(input_scores):\n",
    "    weights = [0.25, 0.25, 0.25, 0.25]\n",
    "    return sum(x * w for x, w in zip(input_scores, weights)) * 100\n",
    "\n",
    "def prepare_data(cands, targets):\n",
    "    \"\"\"\n",
    "    cands: dict con {id: {\"title\": ..., \"generated\": ...}}\n",
    "    targets: lista di dict con {\"id\": ..., \"title\": ..., \"story\": ...}\n",
    "    ritorna: gts, res per BLEU/ROUGE\n",
    "    \"\"\"\n",
    "    target_dict = {str(t[\"id\"]): f'{t[\"title\"]}\\n{t[\"story\"]}' for t in targets}\n",
    "    cand_dict = {str(k): f'{v[\"generated\"]}' for k, v in cands.items()}\n",
    "\n",
    "    target_tok = {k: ' '.join([tok.text for tok in nlp(v)]) for k, v in target_dict.items()}\n",
    "    cand_tok = {k: ' '.join([tok.text for tok in nlp(v)]) for k, v in cand_dict.items() if k in target_tok}\n",
    "\n",
    "    gts = {k: [v] for k, v in target_tok.items() if k in cand_tok}\n",
    "    res = {k: [v] for k, v in cand_tok.items() if k in target_tok}\n",
    "\n",
    "    return gts, res\n",
    "\n",
    "def find_corresponding_input_file(cand_filename, input_files_folder):\n",
    "    match = re.search(r'output(\\d+)', cand_filename)\n",
    "    if match:\n",
    "        num = match.group(1)\n",
    "        input_filename = f\"output{num}.json\"\n",
    "        input_path = os.path.join(input_files_folder, input_filename)\n",
    "        \n",
    "        if os.path.exists(input_path):\n",
    "            return input_path\n",
    "    \n",
    "    if \"output10\" in cand_filename:\n",
    "        input_path = os.path.join(input_files_folder, \"output10.json\")\n",
    "        if os.path.exists(input_path):\n",
    "            return input_path\n",
    "    elif \"output20\" in cand_filename:\n",
    "        input_path = os.path.join(input_files_folder, \"output20.json\")\n",
    "        if os.path.exists(input_path):\n",
    "            return input_path\n",
    "    elif \"output30\" in cand_filename:\n",
    "        input_path = os.path.join(input_files_folder, \"output30.json\")\n",
    "        if os.path.exists(input_path):\n",
    "            return input_path\n",
    "    \n",
    "    return None\n",
    "\n",
    "def evaluate(cands, targets):\n",
    "    gts, res = prepare_data(cands, targets)\n",
    "    bleu_scorer = Bleu()\n",
    "    rouge_scorer = Rouge()\n",
    "\n",
    "    bleu_scores, _ = bleu_scorer.compute_score(gts, res)\n",
    "    bleu4 = weighted_mean(bleu_scores)\n",
    "    rouge_score, _ = rouge_scorer.compute_score(gts, res)\n",
    "    \n",
    "    return bleu4, rouge_score * 100\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    targets_file = \"\"\n",
    "    cands_folder = \"\"\n",
    "    input_files_folder = \"\"\n",
    "\n",
    "    with open(targets_file, \"r\", encoding=\"utf-8\") as f:\n",
    "        targets = json.load(f)\n",
    "\n",
    "    input_files = {\n",
    "        \"output10\": None,\n",
    "        \"output20\": None,\n",
    "        \"output30\": None\n",
    "    }\n",
    "\n",
    "    all_runs_bleu = {\n",
    "        \"output10\": [],\n",
    "        \"output20\": [],\n",
    "        \"output30\": []\n",
    "    }\n",
    "\n",
    "    all_runs_rouge = {\n",
    "        \"output10\": [],\n",
    "        \"output20\": [],\n",
    "        \"output30\": []\n",
    "    }\n",
    "    \n",
    "    for fname in [\"output10.json\", \"output20.json\", \"output30.json\"]:\n",
    "        input_path = os.path.join(input_files_folder, fname)\n",
    "        if os.path.exists(input_path):\n",
    "            with open(input_path, \"r\", encoding=\"utf-8\") as f:\n",
    "                input_files[fname.replace(\".json\", \"\")] = json.load(f)\n",
    "        else:\n",
    "            print(f\"No {input_path}\")\n",
    "\n",
    "    report_data = {}\n",
    "    \n",
    "    for fname in os.listdir(cands_folder):\n",
    "        if fname.endswith(\".json\"):\n",
    "            cand_path = os.path.join(cands_folder, fname)\n",
    "            \n",
    "            input_path = find_corresponding_input_file(fname, input_files_folder)\n",
    "            \n",
    "            if input_path is None:\n",
    "                continue\n",
    "                \n",
    "            if \"output10\" in fname or \"10\" in fname:\n",
    "                input_data = input_files[\"output10\"]\n",
    "            elif \"output20\" in fname or \"20\" in fname:\n",
    "                input_data = input_files[\"output20\"]\n",
    "            elif \"output30\" in fname or \"30\" in fname:\n",
    "                input_data = input_files[\"output30\"]\n",
    "            else:\n",
    "                continue\n",
    "\n",
    "            try:\n",
    "                with open(cand_path, \"r\", encoding=\"utf-8\") as f:\n",
    "                    cands = json.load(f)\n",
    "\n",
    "                bleu4, rouge = evaluate(cands, targets)\n",
    "\n",
    "                if \"output10\" in fname or \"10\" in fname:\n",
    "                    all_runs_bleu[\"output10\"].append(bleu4)\n",
    "                    all_runs_rouge[\"output10\"].append(rouge)\n",
    "\n",
    "                elif \"output20\" in fname or \"20\" in fname:\n",
    "                    all_runs_bleu[\"output20\"].append(bleu4)\n",
    "                    all_runs_rouge[\"output20\"].append(rouge)\n",
    "                elif \"output30\" in fname or \"30\" in fname:\n",
    "                    all_runs_bleu[\"output30\"].append(bleu4)\n",
    "                    all_runs_rouge[\"output30\"].append(rouge)\n",
    "\n",
    "                report_data[fname] = {\n",
    "                    \"BLEU-4\": bleu4,\n",
    "                    \"ROUGE-L\": rouge,\n",
    "                }\n",
    "\n",
    "                print(f\"File: {fname}\")\n",
    "                print(f\"  BLEU-4: {bleu4:.2f}\")\n",
    "                print(f\"  ROUGE-L: {rouge:.2f}\")\n",
    "                print(\"-\" * 40)\n",
    "                \n",
    "            except Exception as e:\n",
    "                print(f\"Errore durante l'elaborazione di {fname}: {str(e)}\")\n",
    "    \n",
    "    report_path = os.path.join(cands_folder, \"evaluation_report.json\")\n",
    "    with open(report_path, \"w\", encoding=\"utf-8\") as f:\n",
    "        json.dump(report_data, f, indent=2, ensure_ascii=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "39c054fb",
   "metadata": {},
   "source": [
    "# ABS Coverage"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "aaf15e6e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import string\n",
    "from FAdo.fa import *\n",
    "from FAdo.reex import *\n",
    "import base64\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f957ccd",
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_corresponding_input_file(cand_filename, input_files_folder):\n",
    "    match = re.search(r'output(\\d+)', cand_filename)\n",
    "    if match:\n",
    "        num = match.group(1)\n",
    "        input_filename = f\"output{num}.json\"\n",
    "        input_path = os.path.join(input_files_folder, input_filename)\n",
    "        \n",
    "        if os.path.exists(input_path):\n",
    "            return input_path\n",
    "    \n",
    "    # Se non trova con il pattern outputX, prova altri pattern\n",
    "    if \"output10\" in cand_filename:\n",
    "        input_path = os.path.join(input_files_folder, \"output10.json\")\n",
    "        if os.path.exists(input_path):\n",
    "            return input_path\n",
    "    elif \"output20\" in cand_filename:\n",
    "        input_path = os.path.join(input_files_folder, \"output20.json\")\n",
    "        if os.path.exists(input_path):\n",
    "            return input_path\n",
    "    elif \"output30\" in cand_filename:\n",
    "        input_path = os.path.join(input_files_folder, \"output30.json\")\n",
    "        if os.path.exists(input_path):\n",
    "            return input_path\n",
    "    \n",
    "    return None\n",
    "\n",
    "def compute_next_state(current_state, word, dfa):\n",
    "    transitions = dfa['transition_matrix']\n",
    "    alphabet = dfa['alphabet']\n",
    "\n",
    "    for ch in word:\n",
    "        if ch in alphabet:\n",
    "            idx = alphabet.index(ch)\n",
    "        else:\n",
    "            idx = 0\n",
    "        current_state = transitions[current_state][idx]\n",
    "\n",
    "    return current_state, dfa[\"distances\"][current_state]\n",
    "\n",
    "def eval_coverage(cands, dfas, file):\n",
    "    correct = 0\n",
    "    for can in cands.keys():\n",
    "        for dfa in dfas:\n",
    "            if can == str(dfa[\"id\"]) and file == dfa[\"source_file\"]:\n",
    "                current_dfa = pickle.loads(base64.b64decode(dfa[\"dfa\"]))\n",
    "                generated = cands[can][\"generated\"].replace(\"  \", \" \")\n",
    "                new_state, distance = compute_next_state(0, generated, current_dfa)\n",
    "                if distance <= 1:\n",
    "                    correct += 1\n",
    "\n",
    "    return correct / len(cands.keys()) * 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16edfc73",
   "metadata": {},
   "outputs": [],
   "source": [
    "if __name__ == \"__main__\":\n",
    "    dfa_file = \"../../data/text_infilling/dfa_cache_complete.json\"\n",
    "    cands_folder = \"\"\n",
    "    input_files_folder = \"\"\n",
    "\n",
    "    with open(dfa_file, \"r\", encoding=\"utf-8\") as f:\n",
    "        dfa = json.load(f)\n",
    "\n",
    "    input_files = {\n",
    "        \"output10\": None,\n",
    "        \"output20\": None,\n",
    "        \"output30\": None\n",
    "    }\n",
    "    \n",
    "    for fname in [\"output10.json\", \"output20.json\", \"output30.json\"]:\n",
    "        input_path = os.path.join(input_files_folder, fname)\n",
    "        if os.path.exists(input_path):\n",
    "            with open(input_path, \"r\", encoding=\"utf-8\") as f:\n",
    "                input_files[fname.replace(\".json\", \"\")] = json.load(f)\n",
    "        else:\n",
    "            print(f\"No {input_path}\")\n",
    "\n",
    "    report_data = {}\n",
    "    \n",
    "    for fname in os.listdir(cands_folder):\n",
    "        if fname.endswith(\".json\"):\n",
    "            cand_path = os.path.join(cands_folder, fname)\n",
    "            \n",
    "            input_path = find_corresponding_input_file(fname, input_files_folder)\n",
    "            \n",
    "            if input_path is None:\n",
    "                print(f\"skipping...\")\n",
    "                continue\n",
    "                \n",
    "            if \"output10\" in fname:\n",
    "                input_data = input_files[\"output10\"]\n",
    "                current_file = \"output10\"\n",
    "            elif \"output20\" in fname:\n",
    "                input_data = input_files[\"output20\"]\n",
    "                current_file = \"output20\"\n",
    "            elif \"output30\" in fname:\n",
    "                input_data = input_files[\"output30\"]\n",
    "                current_file = \"output30\"\n",
    "            else:\n",
    "                continue\n",
    "\n",
    "            try:\n",
    "                with open(cand_path, \"r\", encoding=\"utf-8\") as f:\n",
    "                    cands = json.load(f)\n",
    "\n",
    "                if \"alpha\" in fname:\n",
    "                    print(f\"Evaluating constraints for {fname} with input file {current_file}.json\")\n",
    "                    coverage = eval_coverage(cands, dfa, current_file + \".json\")\n",
    "\n",
    "                print(f\"File: {fname}\")\n",
    "                print(f\"  Coverage: {coverage:.2f}%\")\n",
    "                print(\"-\" * 40)\n",
    "                \n",
    "            except Exception as e:\n",
    "                print(e)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d799eb6f",
   "metadata": {},
   "source": [
    "# Eval Coverage ILM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75cbfdae",
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "\n",
    "def extract_concepts(text):\n",
    "    pattern = r'(.*?)(<\\|.*?\\|>)(.*)'\n",
    "    concepts = []\n",
    "    while text:\n",
    "        match = re.match(pattern, text, re.DOTALL)\n",
    "        if match:\n",
    "            if match.group(1):\n",
    "                concepts.append(match.group(1))\n",
    "            concepts.append(match.group(2))\n",
    "            text = match.group(3)\n",
    "        else:\n",
    "            if text:\n",
    "                concepts.append(text)\n",
    "            break\n",
    "    return concepts\n",
    "\n",
    "def remove_fixed_parts(prompt):\n",
    "    concepts = extract_concepts(prompt)\n",
    "    final_concepts = [c for c in concepts if c in [\"<|infill_word|>\", \"<|infill_ngram|>\", \"<|infill_sentence|>\"]]\n",
    "    return final_concepts\n",
    "\n",
    "import os\n",
    "import json\n",
    "\n",
    "cartella = \"\"\n",
    "with open(\"\", \"r\", encoding=\"utf-8\") as f:\n",
    "    prompt = json.load(f) \n",
    "\n",
    "num_idk = 0\n",
    "num_correct = 0\n",
    "total = 0\n",
    "for nome_file in os.listdir(cartella):\n",
    "    if nome_file.endswith(\".json\") and \"output10\" in nome_file:\n",
    "        percorso_file = os.path.join(cartella, nome_file)\n",
    "        \n",
    "        try:\n",
    "            with open(percorso_file, \"r\", encoding=\"utf-8\") as f:\n",
    "                cands = json.load(f) \n",
    "                print(f\"File: {nome_file}\")\n",
    "            for cand in cands.keys():\n",
    "                correct = True\n",
    "                if \"spans\" not in cands[cand].keys():\n",
    "                    print(f\"No spans in id {cand}\")\n",
    "                    total += 1\n",
    "                    correct = False\n",
    "                    continue\n",
    "                for item in prompt:\n",
    "                    if cand == str(item[\"id\"]):\n",
    "                        final_concepts = remove_fixed_parts(item[\"title\"] + \"\\n\" + item[\"story\"])\n",
    "                        if \"spans\" not in cands[cand].keys():\n",
    "                            print(f\"No spans in id {cand}\")\n",
    "                            total += 1\n",
    "                            correct = False\n",
    "                            continue\n",
    "                        parts = cands[cand][\"spans\"]\n",
    "                        parts = [p.replace(\"-\", \" \") for p in parts]\n",
    "                        for p in range(len(parts)):\n",
    "                            if final_concepts[p] == \"<|infill_word|>\":\n",
    "                                if not re.fullmatch(r\"[ ]?[a-zA-Z0-9'.!,?]+[ ]?\", parts[p]):\n",
    "                                    correct = False\n",
    "                            if final_concepts[p] == \"<|infill_ngram|>\":\n",
    "                                if not re.fullmatch(r\"[a-zA-Z0-9' ,.!?]+(?:[ ,.][a-zA-Z0-9' ]+)*\", parts[p]):\n",
    "                                    correct = False\n",
    "                            if final_concepts[p] == \"<|infill_sentence|>\":\n",
    "                                if not re.fullmatch(r\"[a-zA-Z0-9', ]+[.!?]\", parts[p]):\n",
    "                                    correct = False\n",
    "                if correct:\n",
    "                    num_correct += 1\n",
    "                total += 1\n",
    "        except Exception as e:\n",
    "            print(f\"Error reading {nome_file}: {e}\")\n",
    "\n",
    "print(f\"Correct: {num_correct}/{total} = {num_correct/total*100:.2f}%\", f\"Idk: {num_idk}\")\n",
    "\n",
    "num_idk = 0\n",
    "num_correct = 0\n",
    "total = 0\n",
    "with open(\"\", \"r\", encoding=\"utf-8\") as f:\n",
    "    prompt = json.load(f) \n",
    "for nome_file in os.listdir(cartella):\n",
    "    if nome_file.endswith(\".json\") and \"output20\" in nome_file:\n",
    "        percorso_file = os.path.join(cartella, nome_file)\n",
    "        \n",
    "        try:\n",
    "            with open(percorso_file, \"r\", encoding=\"utf-8\") as f:\n",
    "                cands = json.load(f) \n",
    "                print(f\"File: {nome_file}\")\n",
    "            for cand in cands.keys():\n",
    "                if \"spans\" not in cands[cand].keys():\n",
    "                    print(f\"No spans in id {cand}\")\n",
    "                    total += 1\n",
    "                    correct = False\n",
    "                    continue\n",
    "                correct = True\n",
    "                for item in prompt:\n",
    "                    if cand == str(item[\"id\"]):\n",
    "                        final_concepts = remove_fixed_parts(item[\"title\"] + \"\\n\" + item[\"story\"])\n",
    "                        \n",
    "                        parts = cands[cand][\"spans\"]\n",
    "                        parts = [p.replace(\"-\", \" \") for p in parts]\n",
    "                        for p in range(len(parts)):\n",
    "                            if final_concepts[p] == \"<|infill_word|>\":\n",
    "                                if not re.fullmatch(r\"[ ]?[a-zA-Z0-9'.!,?]+[ ]?\", parts[p]):\n",
    "                                    correct = False\n",
    "                            if final_concepts[p] == \"<|infill_ngram|>\":\n",
    "                                if not re.fullmatch(r\"[a-zA-Z0-9' ,.!?]+(?:[ ,.][a-zA-Z0-9' ]+)*\", parts[p]):\n",
    "                                    correct = False\n",
    "                            if final_concepts[p] == \"<|infill_sentence|>\":\n",
    "                                if not re.fullmatch(r\"[a-zA-Z0-9', ]+[.!?]\", parts[p]):\n",
    "                                    correct = False\n",
    "                if correct:\n",
    "                    num_correct += 1\n",
    "                total += 1\n",
    "        except Exception as e:\n",
    "            print(f\"Error reading {nome_file}: {e}\")\n",
    "\n",
    "print(f\"Correct: {num_correct}/{total} = {num_correct/total*100:.2f}%\", f\"Idk: {num_idk}\")\n",
    "\n",
    "num_idk = 0\n",
    "num_correct = 0\n",
    "total = 0\n",
    "with open(\"\", \"r\", encoding=\"utf-8\") as f:\n",
    "    prompt = json.load(f) \n",
    "for nome_file in os.listdir(cartella):\n",
    "    if nome_file.endswith(\".json\") and \"output30\" in nome_file:\n",
    "        percorso_file = os.path.join(cartella, nome_file)\n",
    "        \n",
    "        try:\n",
    "            with open(percorso_file, \"r\", encoding=\"utf-8\") as f:\n",
    "                cands = json.load(f) \n",
    "                print(f\"File: {nome_file}\")\n",
    "            for cand in cands.keys():\n",
    "                correct = True\n",
    "                if \"spans\" not in cands[cand].keys():\n",
    "                    print(f\"No spans in id {cand}\")\n",
    "                    total += 1\n",
    "                    correct = False\n",
    "                    continue\n",
    "                for item in prompt:\n",
    "                    if cand == str(item[\"id\"]):\n",
    "                        final_concepts = remove_fixed_parts(item[\"title\"] + \"\\n\" + item[\"story\"])\n",
    "                        \n",
    "                        parts = cands[cand][\"spans\"]\n",
    "                        parts = [p.replace(\"-\", \" \") for p in parts]\n",
    "                        for p in range(len(parts)):\n",
    "                            if final_concepts[p] == \"<|infill_word|>\":\n",
    "                                if not re.fullmatch(r\"[ ]?[a-zA-Z0-9'.!,?]+[ ]?\", parts[p]):\n",
    "                                    correct = False\n",
    "                            if final_concepts[p] == \"<|infill_ngram|>\":\n",
    "                                if not re.fullmatch(r\"[a-zA-Z0-9' ,.!?]+(?:[ ,.][a-zA-Z0-9' ]+)*\", parts[p]):\n",
    "                                    correct = False\n",
    "                            if final_concepts[p] == \"<|infill_sentence|>\":\n",
    "                                if not re.fullmatch(r\"[a-zA-Z0-9', ]+[.!?]\", parts[p]):\n",
    "                                    correct = False\n",
    "                if correct:\n",
    "                    num_correct += 1\n",
    "                total += 1\n",
    "        except Exception as e:\n",
    "            print(f\"Error reading {nome_file}: {e}\")\n",
    "\n",
    "print(f\"Correct: {num_correct}/{total} = {num_correct/total*100:.2f}%\", f\"Idk: {num_idk}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "trident",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
