{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7ulEPXkI6bv_"
      },
      "source": [
        "# **For UKP Essays and AbstRCT (.ann files)**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xRhbhQfP6nmN"
      },
      "outputs": [],
      "source": [
        "import zipfile\n",
        "import os\n",
        "\n",
        "zip_path = \"\" # YOUR ZIP FILE HERE\n",
        "extract_dir = \"/content\"\n",
        "\n",
        "# Unzip\n",
        "with zipfile.ZipFile(zip_path, 'r') as zip_ref:\n",
        "    zip_ref.extractall(extract_dir)\n",
        "\n",
        "# Check a few files\n",
        "print(\"Extracted files:\", os.listdir(extract_dir)[:5])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OrchulT76yiE"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "from collections import defaultdict\n",
        "\n",
        "DATA_DIR = \"\"  # Update if needed\n",
        "\n",
        "all_edges = {}\n",
        "text_by_node = {}\n",
        "graphs = {}\n",
        "\n",
        "for filename in sorted(os.listdir(DATA_DIR)):\n",
        "    if not filename.endswith(\".ann\"):\n",
        "        continue\n",
        "\n",
        "    filepath = os.path.join(DATA_DIR, filename)\n",
        "\n",
        "    with open(filepath, \"r\", encoding=\"utf-8\") as f:\n",
        "        lines = f.readlines()\n",
        "\n",
        "    text_nodes = {}         # T1: MajorClaim/Claim/Premise (UPDATE IF NEEDED)\n",
        "    stances = {}            # A1: {target_id, stance}\n",
        "    supports = []           # list of (arg1, arg2)\n",
        "    attacks = []            # list of (arg1, arg2)\n",
        "    claim_to_stance = {}    # T3: For/Against\n",
        "    major_claims = set()\n",
        "    node_texts = {}         # T1: full_text\n",
        "\n",
        "    for line in lines:\n",
        "        line = line.strip()\n",
        "        if line.startswith(\"T\"):\n",
        "            parts = line.split('\\t')\n",
        "            if len(parts) < 3:\n",
        "                continue\n",
        "            node_id = parts[0]\n",
        "            tag = parts[1].split()[0]\n",
        "            text_nodes[node_id] = tag\n",
        "            node_texts[node_id] = parts[2]  # Save text content\n",
        "            if tag == \"MajorClaim\":\n",
        "                major_claims.add(node_id)\n",
        "\n",
        "        elif line.startswith(\"A\"):  # Stance\n",
        "            parts = line.split()\n",
        "            if len(parts) >= 4:\n",
        "                target_id = parts[2]    # e.g., T3\n",
        "                stance_value = parts[3] # For/Against\n",
        "                claim_to_stance[target_id] = stance_value\n",
        "\n",
        "        elif line.startswith(\"R\"):\n",
        "            parts = line.strip().split()\n",
        "            if len(parts) >= 4:\n",
        "                rel_type = parts[1]\n",
        "                arg1 = parts[2].split(\":\")[1]\n",
        "                arg2 = parts[3].split(\":\")[1]\n",
        "                if rel_type == \"Support\":\n",
        "                    supports.append((arg1, arg2))\n",
        "                elif rel_type == \"Attack\"  or rel_type == \"Partial-Attack\":\n",
        "                    attacks.append((arg1, arg2))  # not used\n",
        "                else:\n",
        "                    print(\"Unknown relation type:\", rel_type)\n",
        "\n",
        "    # Construct edges\n",
        "    edges = []\n",
        "\n",
        "    # for src, tgt in supports:\n",
        "    #     edges.append((src, tgt, \"support\"))\n",
        "\n",
        "    # if a major claim is deleted\n",
        "    for src, tgt in supports:\n",
        "        assert(src in text_nodes and tgt in text_nodes)\n",
        "        if src in text_nodes and tgt in text_nodes:\n",
        "            edges.append((src, tgt, \"support\"))\n",
        "\n",
        "\n",
        "    # for claim_id, stance in claim_to_stance.items():\n",
        "    #     if stance == \"For\" and text_nodes.get(claim_id) == \"Claim\":\n",
        "    #         for mc in major_claims:\n",
        "    #             edges.append((claim_id, mc, \"stance_for\"))\n",
        "\n",
        "    # If a major claim is deleted\n",
        "    for claim_id, stance in claim_to_stance.items():\n",
        "        assert(claim_id in text_nodes)\n",
        "        if (stance == \"For\" and text_nodes.get(claim_id) == \"Claim\" and claim_id in text_nodes):\n",
        "            for mc in major_claims:\n",
        "                if mc in text_nodes:\n",
        "                    edges.append((claim_id, mc, \"stance_for\"))\n",
        "\n",
        "\n",
        "\n",
        "    all_edges[filename] = edges\n",
        "    text_by_node[filename] = node_texts\n",
        "\n",
        "    # Now construct the graph for this file\n",
        "    graph = defaultdict(list)\n",
        "    for src, tgt, _ in edges:\n",
        "        graph[src].append(tgt)\n",
        "\n",
        "    graphs[filename] = dict(graph)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IP6sg8sl7Czq"
      },
      "outputs": [],
      "source": [
        "# Extract structures and store them here\n",
        "all_structures = {}\n",
        "\n",
        "for comment_id, graph in graphs.items():\n",
        "    linear = []\n",
        "    convergent = []\n",
        "    divergent = []\n",
        "\n",
        "    # Linear: A -> B -> C\n",
        "    for A in graph:\n",
        "        for B in graph.get(A, []):\n",
        "            for C in graph.get(B, []):\n",
        "                assert(A != C and A != B and B != C)\n",
        "                linear.append((A, B, C))\n",
        "\n",
        "    # Convergent: A -> C, B -> C\n",
        "    incoming = {}\n",
        "    for src, targets in graph.items():\n",
        "        for tgt in targets:\n",
        "            incoming.setdefault(tgt, []).append(src)\n",
        "\n",
        "    for C, sources in incoming.items():\n",
        "        if len(sources) >= 2:\n",
        "            for i in range(len(sources)):\n",
        "                for j in range(i+1, len(sources)):\n",
        "                    A, B = sources[i], sources[j]\n",
        "                    if A != B:\n",
        "                        assert(A != C and A != B and B != C)\n",
        "                        convergent.append((A, B, C))\n",
        "\n",
        "    # Divergent: A -> B, A -> C\n",
        "    for A, targets in graph.items():\n",
        "        if len(targets) >= 2:\n",
        "            for i in range(len(targets)):\n",
        "                for j in range(i+1, len(targets)):\n",
        "                    B, C = targets[i], targets[j]\n",
        "                    if B != C:\n",
        "                        assert(A != C and A != B and B != C)\n",
        "                        divergent.append((A, B, C))\n",
        "\n",
        "    all_structures[comment_id] = {\n",
        "        \"lin\": linear,\n",
        "        \"conv\": convergent,\n",
        "        \"div\": divergent\n",
        "    }\n",
        "\n",
        "lin = 0\n",
        "conv = 0\n",
        "div = 0\n",
        "for comment_id, structures in all_structures.items():\n",
        "    lin += len(structures['lin'])\n",
        "    conv += len(structures['conv'])\n",
        "    div += len(structures['div'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OGdkpxxQ7Yjy"
      },
      "outputs": [],
      "source": [
        "full_text = {}\n",
        "for filename in sorted(os.listdir(DATA_DIR)):\n",
        "    if not filename.endswith(\".txt\"):\n",
        "        continue\n",
        "\n",
        "    essay_id = filename.replace(\"essay\", \"\").replace(\".txt\", \"\").zfill(3)\n",
        "    full_text[essay_id] = \"\"\n",
        "    filepath = os.path.join(DATA_DIR, filename)\n",
        "    with open(filepath, \"r\", encoding=\"utf-8\") as f:\n",
        "        full_text[essay_id] = f.read()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DNaoXRl_7otd"
      },
      "outputs": [],
      "source": [
        "# Build JSON output\n",
        "import json\n",
        "OUTPUT_PATH = \"\" # YOUR OUTPUT .jsonlist\n",
        "\n",
        "with open(OUTPUT_PATH, \"w\", encoding=\"utf-8\") as out_file:\n",
        "    for filename in sorted(graphs):\n",
        "        graph = graphs[filename]\n",
        "        text_nodes = text_by_node[filename]\n",
        "        edge_list = all_edges[filename]\n",
        "\n",
        "        # Build reverse graph for \"reasons\"\n",
        "        reverse_graph = defaultdict(list)\n",
        "        for src, tgt, _ in edge_list:\n",
        "            reverse_graph[tgt].append(src)\n",
        "\n",
        "        # Build propositions\n",
        "        propositions = []\n",
        "        for pid, text in text_nodes.items():\n",
        "            propositions.append({\n",
        "                \"id\": pid,\n",
        "                \"text\": text,\n",
        "                \"reasons\": reverse_graph.get(pid, [])\n",
        "            })\n",
        "\n",
        "        # Build triplet structures\n",
        "        structure_list = []\n",
        "        structure_id = 0\n",
        "\n",
        "        # Linear: A -> B -> C\n",
        "        for A in graph:\n",
        "            for B in graph.get(A, []):\n",
        "                for C in graph.get(B, []):\n",
        "                    if A != B and B != C and A != C:\n",
        "                        structure_list.append({\n",
        "                            \"instanceID\": structure_id,\n",
        "                            \"structure\": \"lin\",\n",
        "                            \"prop1\": A,\n",
        "                            \"prop2\": B,\n",
        "                            \"prop3\": C\n",
        "                        })\n",
        "                        structure_id += 1\n",
        "\n",
        "        # Convergent: A -> C, B -> C\n",
        "        incoming = defaultdict(list)\n",
        "        for src, targets in graph.items():\n",
        "            for tgt in targets:\n",
        "                incoming[tgt].append(src)\n",
        "\n",
        "        for C, sources in incoming.items():\n",
        "            if len(sources) >= 2:\n",
        "                for i in range(len(sources)):\n",
        "                    for j in range(i+1, len(sources)):\n",
        "                        A, B = sources[i], sources[j]\n",
        "                        if A != B and A != C and B != C:\n",
        "                            structure_list.append({\n",
        "                                \"instanceID\": structure_id,\n",
        "                                \"structure\": \"conv\",\n",
        "                                \"prop1\": A,\n",
        "                                \"prop2\": B,\n",
        "                                \"prop3\": C\n",
        "                            })\n",
        "                            structure_id += 1\n",
        "\n",
        "        # Divergent: A -> B, A -> C\n",
        "        for A, targets in graph.items():\n",
        "            if len(targets) >= 2:\n",
        "                for i in range(len(targets)):\n",
        "                    for j in range(i+1, len(targets)):\n",
        "                        B, C = targets[i], targets[j]\n",
        "                        if B != C and A != B and A != C:\n",
        "                            structure_list.append({\n",
        "                                \"instanceID\": structure_id,\n",
        "                                \"structure\": \"div\",\n",
        "                                \"prop1\": A,\n",
        "                                \"prop2\": B,\n",
        "                                \"prop3\": C\n",
        "                            })\n",
        "                            structure_id += 1\n",
        "\n",
        "        # Combine into one JSON object\n",
        "        essay_id = filename.replace(\"essay\", \"\").replace(\".ann\", \"\").zfill(3)\n",
        "        data = {\n",
        "            \"fileID\": essay_id, #replace this with filename[:-4] or essay_id\n",
        "            \"propositions\": propositions,\n",
        "            \"structures\": structure_list,\n",
        "            \"full_text\": full_text[essay_id]\n",
        "        }\n",
        "\n",
        "        out_file.write(json.dumps(data) + \"\\n\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9EhLToFJ75Jn"
      },
      "source": [
        "# **For CDCP & AM^2 (.jsonlist files)**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "b8IHhmPG43er"
      },
      "outputs": [],
      "source": [
        "import json\n",
        "\n",
        "input_file_path = \"\" # ORIGINAL .JSONLIST FILENAME HERE\n",
        "\n",
        "# Helper function to parse a reason list, excluding ranges like \"1_3\"\n",
        "def parse_ids(id_list):\n",
        "    if not id_list:\n",
        "        return []\n",
        "    result = []\n",
        "    for item in id_list:\n",
        "        if '_' in item:\n",
        "            # ignore linked structure\n",
        "            continue\n",
        "        else:\n",
        "            result.append(item)\n",
        "    return result\n",
        "\n",
        "\n",
        "# Load the original data and build graph\n",
        "graphs = {}\n",
        "original_data = []\n",
        "\n",
        "with open(input_file_path, 'r', encoding='utf-8') as f:\n",
        "    for line in f:\n",
        "        entry = json.loads(line)\n",
        "        original_data.append(entry)\n",
        "\n",
        "        comment_id = entry['commentID'] # reviewID if AM^2, commentID if CDCP\n",
        "        propositions = entry['propositions']\n",
        "\n",
        "        graph = {}\n",
        "        for prop in propositions:\n",
        "            conclusion_id = int(prop['id'])\n",
        "            reason_ids = parse_ids(prop.get('reasons'))\n",
        "\n",
        "            for premise_id in reason_ids:\n",
        "                graph.setdefault(int(premise_id), []).append(conclusion_id)\n",
        "\n",
        "        graphs[comment_id] = graph\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "# Extract structures and store them here\n",
        "all_structures = {}\n",
        "\n",
        "for comment_id, graph in graphs.items():\n",
        "    linear = []\n",
        "    convergent = []\n",
        "    divergent = []\n",
        "\n",
        "    # Linear: A -> B -> C\n",
        "    for A in graph:\n",
        "        for B in graph.get(A, []):\n",
        "            for C in graph.get(B, []):\n",
        "                assert(A != C and A != B and B != C)\n",
        "                linear.append((A, B, C))\n",
        "\n",
        "    # Convergent: A -> C, B -> C\n",
        "    incoming = {}\n",
        "    for src, targets in graph.items():\n",
        "        for tgt in targets:\n",
        "            incoming.setdefault(tgt, []).append(src)\n",
        "\n",
        "    for C, sources in incoming.items():\n",
        "        if len(sources) >= 2:\n",
        "            for i in range(len(sources)):\n",
        "                for j in range(i+1, len(sources)):\n",
        "                    A, B = sources[i], sources[j]\n",
        "                    if A != B:\n",
        "                        assert(A != C and A != B and B != C)\n",
        "                        convergent.append((A, B, C))\n",
        "\n",
        "    # Divergent: A -> B, A -> C\n",
        "    for A, targets in graph.items():\n",
        "        if len(targets) >= 2:\n",
        "            for i in range(len(targets)):\n",
        "                for j in range(i+1, len(targets)):\n",
        "                    B, C = targets[i], targets[j]\n",
        "                    if B != C:\n",
        "                        assert(A != C and A != B and B != C)\n",
        "                        divergent.append((A, B, C))\n",
        "\n",
        "    all_structures[comment_id] = {\n",
        "        \"lin\": linear,\n",
        "        \"conv\": convergent,\n",
        "        \"div\": divergent\n",
        "    }\n",
        "\n",
        "lin = 0\n",
        "conv = 0\n",
        "div = 0\n",
        "for comment_id, structures in all_structures.items():\n",
        "    lin += len(structures['lin'])\n",
        "    conv += len(structures['conv'])\n",
        "    div += len(structures['div'])\n",
        "\n",
        "# Count the number of triplets extracted (counting duplicates)\n",
        "print(\"Linear: \", lin)\n",
        "print(\"Convergent: \", conv)\n",
        "print(\"Divergent: \", div)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ptUpd-BP5DQu"
      },
      "outputs": [],
      "source": [
        "# Extract structures and add them to entries\n",
        "structured_data = []\n",
        "\n",
        "for entry in original_data:\n",
        "\n",
        "    # Now, we add an element 'structures' to the entry\n",
        "\n",
        "    comment_id = entry['commentID'] # reviewID if AM^2, commentID if CDCP\n",
        "    graph = graphs.get(comment_id, {})\n",
        "\n",
        "    structure_list = []\n",
        "    structure_id = 0\n",
        "\n",
        "    # Linear: A -> B -> C\n",
        "    for A in graph:\n",
        "        for B in graph.get(A, []):\n",
        "            for C in graph.get(B, []):\n",
        "                assert(A != C and A != B and B != C)\n",
        "                structure_list.append({\n",
        "                    \"instanceID\": structure_id,\n",
        "                    \"structure\": \"lin\", # linear\n",
        "                    \"prop1\": A,\n",
        "                    \"prop2\": B,\n",
        "                    \"prop3\": C\n",
        "                })\n",
        "                structure_id += 1\n",
        "\n",
        "    # Convergent: A -> C, B -> C\n",
        "    incoming = {}\n",
        "    for src, targets in graph.items():\n",
        "        for tgt in targets:\n",
        "            incoming.setdefault(tgt, []).append(src)\n",
        "\n",
        "    for C, sources in incoming.items():\n",
        "        if len(sources) >= 2:\n",
        "            for i in range(len(sources)):\n",
        "                for j in range(i+1, len(sources)):\n",
        "                    A, B = sources[i], sources[j]\n",
        "                    if A != B:\n",
        "                        assert(A != C and A != B and B != C)\n",
        "                        structure_list.append({\n",
        "                            \"instanceID\": structure_id,\n",
        "                            \"structure\": \"conv\", # convergent\n",
        "                            \"prop1\": A,\n",
        "                            \"prop2\": B,\n",
        "                            \"prop3\": C\n",
        "                        })\n",
        "                        structure_id += 1\n",
        "\n",
        "    # Divergent: A -> B, A -> C\n",
        "    for A, targets in graph.items():\n",
        "        if len(targets) >= 2:\n",
        "            for i in range(len(targets)):\n",
        "                for j in range(i+1, len(targets)):\n",
        "                    B, C = targets[i], targets[j]\n",
        "                    if B != C:\n",
        "                        assert(A != C and A != B and B != C)\n",
        "                        structure_list.append({\n",
        "                            \"instanceID\": structure_id,\n",
        "                            \"structure\": \"div\", # divergent\n",
        "                            \"prop1\": A,\n",
        "                            \"prop2\": B,\n",
        "                            \"prop3\": C\n",
        "                        })\n",
        "                        structure_id += 1\n",
        "\n",
        "    entry['structures'] = structure_list\n",
        "    structured_data.append(entry)\n",
        "\n",
        "# Write to new jsonlist file\n",
        "\n",
        "output_file_path = \"\" # OUTPUT FILE HERE\n",
        "with open(output_file_path, 'w', encoding='utf-8') as f:\n",
        "    for entry in structured_data:\n",
        "        f.write(json.dumps(entry) + '\\n')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# **Filtering Triplets**\n",
        "From all the triplets extracted, we filtered to ensure that no 2 triplet contains similar propositions."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### **Load unfiltered file**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "import json\n",
        "\n",
        "input_file_path = \"\" # YOUR .jsonlist file after extracting all the triplets\n",
        "\n",
        "original_data = []\n",
        "all_structures = {}\n",
        "\n",
        "with open(input_file_path, 'r', encoding='utf-8') as f:\n",
        "    for line in f:\n",
        "        entry = json.loads(line)\n",
        "        original_data.append(entry)\n",
        "        linear = []\n",
        "        convergent = []\n",
        "        divergent = []\n",
        "        for structure in entry['structures']:\n",
        "            new_structure = (structure['prop1'], structure['prop2'], structure['prop3'])\n",
        "\n",
        "            if structure['structure'] == 'lin':\n",
        "                linear.append(new_structure)\n",
        "            elif structure['structure'] == 'conv':\n",
        "                convergent.append(new_structure)\n",
        "            else:\n",
        "                assert structure['structure'] == 'div'\n",
        "                divergent.append(new_structure)\n",
        "\n",
        "        all_structures[entry['commentID']] = { # reviewID for AM2, commentID for CDCP, essayID for UKPEssays, fileID for AbstRCT\n",
        "            'lin': linear,\n",
        "            'conv': convergent,\n",
        "            'div': divergent\n",
        "        }\n",
        "\n",
        "count_linear = 0\n",
        "count_convergent = 0\n",
        "count_divergent = 0\n",
        "for comment_id, structure in all_structures.items():\n",
        "    count_linear += len(structure['lin'])\n",
        "    count_convergent += len(structure['conv'])\n",
        "    count_divergent += len(structure['div'])\n",
        "\n",
        "print(count_linear, count_convergent, count_divergent)\n",
        "print(count_linear + count_convergent + count_divergent)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### **Filtering, and store in a new .jsonlist file**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "from itertools import combinations\n",
        "from random import shuffle as random_shuffle\n",
        "\n",
        "squared_error = float('inf')\n",
        "final_result_list = {}\n",
        "final_count_linear = 0\n",
        "final_count_convergent = 0\n",
        "final_count_divergent = 0\n",
        "\n",
        "for _ in range(1000):\n",
        "    result_list = {}\n",
        "    type_counter = {'lin': 0, 'conv': 0, 'div': 0}\n",
        "\n",
        "    for comment_id, structure in all_structures.items():\n",
        "\n",
        "\n",
        "        triplets_by_type = {'lin': [], 'conv': [], 'div': []}\n",
        "        list_id = {}\n",
        "\n",
        "        for type_of_structure in structure:\n",
        "            for triple in structure[type_of_structure]:\n",
        "                triplets_by_type[type_of_structure].append((type_of_structure, triple))\n",
        "                for id in triple:\n",
        "                    list_id[id] = 1\n",
        "\n",
        "        used_ids = set()\n",
        "        selected_triples = []\n",
        "\n",
        "        def try_add_triplets(candidates, max_to_add, used_ids, selected_triples):\n",
        "            added = 0\n",
        "            random_shuffle(candidates)\n",
        "            for size in range(min(max_to_add, len(candidates)), 0, -1):\n",
        "                for combo in combinations(candidates, size):\n",
        "                    combo_ids = set()\n",
        "                    conflict = False\n",
        "                    for _, triple in combo:\n",
        "                        for id in triple:\n",
        "                            if id in used_ids or id in combo_ids:\n",
        "                                conflict = True\n",
        "                                break\n",
        "                            combo_ids.add(id)\n",
        "                        if conflict:\n",
        "                            break\n",
        "                    if not conflict:\n",
        "                        for _, triple in combo:\n",
        "                            for id in triple:\n",
        "                                used_ids.add(id)\n",
        "                        selected_triples.extend(combo)\n",
        "                        added += size\n",
        "                        return added, used_ids, selected_triples  # only take one valid combo per size\n",
        "            return added, used_ids, selected_triples\n",
        "\n",
        "        max_len = 3\n",
        "        for label_group in [['div'], ['lin','conv']]:\n",
        "            shuffle_labels = label_group.copy()\n",
        "            random_shuffle(shuffle_labels)\n",
        "            for label in shuffle_labels:\n",
        "                if len(selected_triples) >= max_len:\n",
        "                    break\n",
        "                remaining = max_len - len(selected_triples)\n",
        "                added, used_ids, selected_triples = try_add_triplets(\n",
        "                    triplets_by_type[label], remaining, used_ids, selected_triples\n",
        "                )\n",
        "\n",
        "\n",
        "        if selected_triples:\n",
        "            result_list[comment_id] = selected_triples\n",
        "            for t, _ in selected_triples:\n",
        "                type_counter[t] += 1\n",
        "\n",
        "    c_lin = type_counter[\"lin\"]\n",
        "    c_conv = type_counter[\"conv\"]\n",
        "    c_div = type_counter[\"div\"]\n",
        "\n",
        "    imbalance_score = (\n",
        "        abs(c_lin - c_conv) ** 2 + abs(c_conv - c_div) ** 2 + abs(c_div - c_lin) ** 2\n",
        "    )\n",
        "\n",
        "    if imbalance_score < squared_error:\n",
        "        squared_error = imbalance_score\n",
        "        final_result_list = result_list\n",
        "        final_count_linear = c_lin\n",
        "        final_count_convergent = c_conv\n",
        "        final_count_divergent = c_div\n",
        "\n",
        "print(\"linear count: \", final_count_linear)\n",
        "print(\"convergent count: \", final_count_convergent)\n",
        "print(\"divergent count: \", final_count_divergent)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "output_file_path = \"\" # .jsonlist file containing all the triplets, after filtering\n",
        "\n",
        "comment_dict = {entry['commentID']: entry for entry in original_data}\n",
        "\n",
        "with open(output_file_path, 'w', encoding='utf-8') as fout:\n",
        "    for comment_id in final_result_list:\n",
        "        entry = comment_dict[comment_id]\n",
        "        props = {p['id']: p['text'] for p in entry['propositions']}\n",
        "\n",
        "        for idx, (stype, (a, b, c)) in enumerate(final_result_list[comment_id]):\n",
        "            structure = {\n",
        "                \"commentID\": comment_id, # reviewID for AM^2, commentID for CDCP\n",
        "                \"instanceID\": idx,\n",
        "                \"structure\": stype,\n",
        "                # this for the text\n",
        "                \"prop1\": props.get(a, f\"[Missing text for ID {a}]\"), # In case of missing\n",
        "                \"prop2\": props.get(b, f\"[Missing text for ID {b}]\"),\n",
        "                \"prop3\": props.get(c, f\"[Missing text for ID {c}]\"),\n",
        "\n",
        "                # this for the id\n",
        "                \"prop1_id\": a,\n",
        "                \"prop2_id\": b,\n",
        "                \"prop3_id\": c\n",
        "            }\n",
        "            json.dump(structure, fout)\n",
        "            fout.write('\\n')"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
