{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a9c7bea0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "from pathlib import Path\n",
    "import csv\n",
    "import argparse\n",
    "import os\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.metrics import accuracy_score, precision_recall_fscore_support, adjusted_rand_score, mutual_info_score\n",
    "import math\n",
    "from collections import Counter\n",
    "os.chdir(Path.cwd().parents[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89616bb7",
   "metadata": {},
   "outputs": [],
   "source": [
    "bpe = pickle.load(open('./ckpts/1746804072.8772147/bpe_iter=9990.pkl', 'rb'))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70ec4b30",
   "metadata": {},
   "outputs": [],
   "source": [
    "def parse_and_write(infile, outfile):\n",
    "    records = []\n",
    "    with open(infile) as f:\n",
    "        for line in f:\n",
    "            if line.startswith('#'):\n",
    "                continue\n",
    "            parts = line.strip().split()\n",
    "            if not parts:\n",
    "                continue\n",
    "            # first 22 columns\n",
    "            (\n",
    "                target_name, target_acc, tlen,\n",
    "                query_name, query_acc, qlen,\n",
    "                e_value, score, bias,\n",
    "                dom_num, dom_of,\n",
    "                dom_c_evalue, dom_i_evalue,\n",
    "                dom_score, dom_bias,\n",
    "                dom_from, dom_to,\n",
    "                ali_from, ali_to,\n",
    "                env_from, env_to,\n",
    "                acc\n",
    "            ) = parts[:22]\n",
    "            description = ' '.join(parts[22:])\n",
    "\n",
    "            records.append({\n",
    "                'target_name':      target_name,\n",
    "                'target_accession': target_acc,\n",
    "                'tlen':             int(tlen),\n",
    "                'query_name':       query_name,\n",
    "                'query_accession':  query_acc,\n",
    "                'qlen':             int(qlen),\n",
    "                'E_value':          float(e_value),\n",
    "                'score':            float(score),\n",
    "                'bias':             float(bias),\n",
    "                'domain_num':       int(dom_num),\n",
    "                'domain_of':        int(dom_of),\n",
    "                'dom_c_Evalue':     float(dom_c_evalue),\n",
    "                'dom_i_Evalue':     float(dom_i_evalue),\n",
    "                'dom_score':        float(dom_score),\n",
    "                'dom_bias':         float(dom_bias),\n",
    "                'dom_from':         int(dom_from),\n",
    "                'dom_to':           int(dom_to),\n",
    "                'ali_from':         int(ali_from),\n",
    "                'ali_to':           int(ali_to),\n",
    "                'env_from':         int(env_from),\n",
    "                'env_to':           int(env_to),\n",
    "                'acc':              float(acc),\n",
    "                'description':      description\n",
    "            })\n",
    "\n",
    "    fieldnames = [\n",
    "        'target_name','target_accession','tlen',\n",
    "        'query_name','query_accession','qlen',\n",
    "        'E_value','score','bias',\n",
    "        'domain_num','domain_of',\n",
    "        'dom_c_Evalue','dom_i_Evalue','dom_score','dom_bias',\n",
    "        'dom_from','dom_to',\n",
    "        'ali_from','ali_to',\n",
    "        'env_from','env_to',\n",
    "        'acc','description'\n",
    "    ]\n",
    "    with open(outfile, 'w', newline='') as out:\n",
    "        writer = csv.DictWriter(out, fieldnames=fieldnames)\n",
    "        writer.writeheader()\n",
    "        writer.writerows(records)\n",
    "\n",
    "\n",
    "def parse_crh(inpath, outpath):\n",
    "    with open(inpath, 'r') as f:\n",
    "        field_line = None\n",
    "        for line in f:\n",
    "            if line.startswith('#FIELDS'):\n",
    "                # e.g. \"#FIELDS query-id match-id score boundaries resolved cond-evalue indp-evalue\"\n",
    "                field_line = line.lstrip('#FIELDS').strip().split()\n",
    "                break\n",
    "        if field_line is None:\n",
    "            raise RuntimeError(\"No #FIELDS line found in input\")\n",
    "\n",
    "        # We'll expand these two into _from/_to columns\n",
    "        expand_ranges = ['boundaries', 'resolved']\n",
    "\n",
    "        # Build output fieldnames: for each in field_line:\n",
    "        #  - if in expand_ranges, replace with two fields X_from, X_to\n",
    "        #  - else use the original name (normalized)\n",
    "        out_fields = []\n",
    "        for fn in field_line:\n",
    "            if fn in expand_ranges:\n",
    "                out_fields += [f\"{fn}_from\", f\"{fn}_to\"]\n",
    "            else:\n",
    "                # normalize hyphens to underscores\n",
    "                out_fields.append(fn.replace('-', '_'))\n",
    "\n",
    "        # Append numeric conversions for clarity (score, cond_evalue, indp_evalue)\n",
    "        # They already appear in out_fields as strings; converting happens row-wise.\n",
    "\n",
    "        # Rewind to beginning for actual parsing\n",
    "        f.seek(0)\n",
    "\n",
    "        records = []\n",
    "        for line in f:\n",
    "            if line.startswith('#'):\n",
    "                continue\n",
    "            parts = line.strip().split()\n",
    "            if not parts:\n",
    "                continue\n",
    "            if len(parts) != len(field_line):\n",
    "                raise RuntimeError(f\"Line has {len(parts)} cols but expected {len(field_line)}: {line}\")\n",
    "\n",
    "            row = dict(zip(field_line, parts))\n",
    "\n",
    "            out = {}\n",
    "            for fn in field_line:\n",
    "                val = row[fn]\n",
    "                if fn in expand_ranges:\n",
    "                    start, end = val.split('-', 1)\n",
    "                    out[f\"{fn}_from\"] = int(start)\n",
    "                    out[f\"{fn}_to\"]   = int(end)\n",
    "                elif fn == 'score':\n",
    "                    out['score'] = float(val)\n",
    "                elif fn == 'cond-evalue':\n",
    "                    out['cond_evalue'] = float(val)\n",
    "                elif fn == 'indp-evalue':\n",
    "                    out['indp_evalue'] = float(val)\n",
    "                else:\n",
    "                    out[fn.replace('-', '_')] = val\n",
    "\n",
    "            records.append(out)\n",
    "\n",
    "    # Write CSV\n",
    "    with open(outpath, 'w', newline='') as csvf:\n",
    "        writer = csv.DictWriter(csvf, fieldnames=out_fields)\n",
    "        writer.writeheader()\n",
    "        writer.writerows(records)\n",
    "\n",
    "\n",
    "def _entropy(labels):\n",
    "    counts = Counter(labels)\n",
    "    total = len(labels)\n",
    "    return -sum((count/total) * math.log(count/total) for count in counts.values())\n",
    "\n",
    "def _convert_true_intervals(true_domains):\n",
    "    \"\"\"\n",
    "    Convert true domain matches from 1-based inclusive to 0-based half-open intervals.\n",
    "    true_domains: list of (from_residue, to_residue) inclusive, 1-based.\n",
    "    Returns: list of (start_idx, end_idx) where start_idx inclusive, end_idx exclusive.\n",
    "    \"\"\"\n",
    "    return [(f - 1, t) for f, t in true_domains]\n",
    "\n",
    "def convert_true_labels(true_domains, seq_len):\n",
    "    \"\"\"\n",
    "    Label each residue by true domain index (1..N) or 0 for background.\n",
    "    true_domains: list of (from_residue, to_residue) inclusive, 1-based.\n",
    "    \"\"\"\n",
    "    labels = np.zeros(seq_len, dtype=int)\n",
    "    for idx, (f, t) in enumerate(true_domains):\n",
    "        start, end = f - 1, t  # convert to half-open\n",
    "        labels[start:end] = idx + 1\n",
    "    return labels\n",
    "\n",
    "def convert_pred_labels(pred_segs, seq_len):\n",
    "    \"\"\"\n",
    "    Label each residue by predicted segment index (1..M).\n",
    "    pred_segs: list of (start_idx, end_idx) half-open, 0-based.\n",
    "    \"\"\"\n",
    "    labels = np.zeros(seq_len, dtype=int)\n",
    "    for idx, (start, end) in enumerate(pred_segs):\n",
    "        labels[start:end] = idx + 1\n",
    "    return labels\n",
    "\n",
    "def domain_coverage(true_domains, pred_segs, thresholds=(0.5, 0.8)):\n",
    "    \"\"\"\n",
    "    Compute coverage per true domain:\n",
    "    For each true domain (1-based inclusive), convert to 0-based half-open,\n",
    "    find the predicted segment with maximal overlap, then coverage = overlap / true_length.\n",
    "    \"\"\"\n",
    "    true_intervals = _convert_true_intervals(true_domains)\n",
    "    coverages = []\n",
    "    for f, t in true_intervals:\n",
    "        true_len = t - f  # correct length for half-open\n",
    "        best_overlap = 0\n",
    "        for p, q in pred_segs:\n",
    "            overlap = max(0, min(t, q) - max(f, p))\n",
    "            best_overlap = max(best_overlap, overlap)\n",
    "        coverages.append(best_overlap / true_len)\n",
    "    mean_cov = np.mean(coverages)\n",
    "    recall_at = {f\"recall@{int(th*100)}\": np.mean([c >= th for c in coverages]) for th in thresholds}\n",
    "    return {\"mean_coverage\": mean_cov, **recall_at}\n",
    "\n",
    "def iou_metrics(true_domains, pred_segs, iou_threshold=0.5):\n",
    "    \"\"\"\n",
    "    Compute IoU-based Precision/Recall/F1:\n",
    "    Match each true domain to predicted segment with highest IoU.\n",
    "    \"\"\"\n",
    "    true_intervals = _convert_true_intervals(true_domains)\n",
    "    matched_true = 0\n",
    "    matched_pred = set()\n",
    "    for f, t in true_intervals:\n",
    "        best_iou = 0\n",
    "        best_j = None\n",
    "        for j, (p, q) in enumerate(pred_segs):\n",
    "            inter = max(0, min(t, q) - max(f, p))\n",
    "            union = (t - f) + (q - p) - inter\n",
    "            iou = inter / union if union > 0 else 0\n",
    "            if iou > best_iou:\n",
    "                best_iou = iou\n",
    "                best_j = j\n",
    "        if best_iou >= iou_threshold:\n",
    "            matched_true += 1\n",
    "            matched_pred.add(best_j)\n",
    "    precision = len(matched_pred) / len(pred_segs) if pred_segs else 0\n",
    "    recall = matched_true / len(true_domains) if true_domains else 0\n",
    "    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0\n",
    "    return {\"precision\": precision, \"recall\": recall, \"f1\": f1}\n",
    "\n",
    "def per_residue_metrics(true_domains, pred_segs, seq_len):\n",
    "    \"\"\"\n",
    "    Compare partitions as clusterings at residue level:\n",
    "    - ARI and VI require labeling each residue by its domain/segment index.\n",
    "    \"\"\"\n",
    "    # Generate label arrays\n",
    "    true_labels = convert_true_labels(true_domains, seq_len)\n",
    "    pred_labels = convert_pred_labels(pred_segs, seq_len)\n",
    "\n",
    "    # Overall clustering agreement\n",
    "    rand_idx = adjusted_rand_score(true_labels, pred_labels)\n",
    "    mi = mutual_info_score(true_labels, pred_labels)\n",
    "    H_true = _entropy(true_labels)\n",
    "    H_pred = _entropy(pred_labels)\n",
    "    vi = H_true + H_pred - 2 * mi\n",
    "\n",
    "    return {\n",
    "        \"rand_index\": rand_idx,\n",
    "        \"variation_of_information\": vi\n",
    "    }\n",
    "\n",
    "def boundary_metrics(true_domains, pred_segs, delta=0):\n",
    "    \"\"\"\n",
    "    Boundary Precision/Recall/F1 with tolerance delta:\n",
    "    Convert true to half-open for boundary positions, compare starts and ends.\n",
    "    \"\"\"\n",
    "    true_intervals = _convert_true_intervals(true_domains)\n",
    "    true_bounds = set()\n",
    "    for f, t in true_intervals:\n",
    "        true_bounds.update([f, t])\n",
    "    pred_bounds = set()\n",
    "    for p, q in pred_segs:\n",
    "        pred_bounds.update([p, q])\n",
    "\n",
    "    matched_true = sum(any(abs(tb - pb) <= delta for pb in pred_bounds) for tb in true_bounds)\n",
    "    matched_pred = sum(any(abs(pb - tb) <= delta for tb in true_bounds) for pb in pred_bounds)\n",
    "    precision = matched_pred / len(pred_bounds) if pred_bounds else 0\n",
    "    recall = matched_true / len(true_bounds) if true_bounds else 0\n",
    "    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0\n",
    "    return {\"precision\": precision, \"recall\": recall, \"f1\": f1}\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7881815",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_metrics = []\n",
    "for i in range(len(bpe.tokenizers)):\n",
    "    t = bpe.tokenizers[i]\n",
    "    p = Path(t.fname)\n",
    "    r = p.relative_to(os.getcwd())\n",
    "    n = Path(p.name)\n",
    "    out = Path(os.path.join('./scripts/', r, n.with_suffix('.domtblout')))\n",
    "    if os.path.exists(out):\n",
    "        csv_out = out.with_suffix(\".csv\")\n",
    "        try:\n",
    "            parse_and_write(out, csv_out)\n",
    "        except:\n",
    "            print(out)\n",
    "            continue\n",
    "        # parse_crh(out, csv_out)\n",
    "        df = pd.read_csv(csv_out)\n",
    "        seq_len = len(t.aa)\n",
    "        pred_segs = []\n",
    "        for (start, _, l) in t.bond_to_token.values():\n",
    "            if l % 3 != 0:\n",
    "                assert start + l == 3*t.n-1\n",
    "                assert l % 3 == 2\n",
    "                l += 1\n",
    "            pred_segs.append((start//3, start//3+l//3))\n",
    "        true_domains = []\n",
    "        for (f, to) in df[['ali_from', 'ali_to']].values:\n",
    "            true_domains.append((f, to))\n",
    "        if len(true_domains) == 0:\n",
    "            continue\n",
    "        metrics = {\"name\": p.name, \n",
    "                   \"domain_coverage\": domain_coverage(true_domains, pred_segs),\n",
    "                   \"iou\": iou_metrics(true_domains, pred_segs), \n",
    "                   \"boundary\": boundary_metrics(true_domains, pred_segs, delta=0),\n",
    "                   \"n\": len(true_domains)\n",
    "        }\n",
    "        for m in [\"domain_coverage\", \"iou\", \"boundary\"]:\n",
    "            if not isinstance(metrics[m], dict):\n",
    "                continue\n",
    "            for k in metrics[m]:\n",
    "                metrics[m+\"_\"+k] = metrics[m][k]\n",
    "            metrics.pop(m)\n",
    "        all_metrics.append(metrics)\n",
    "\n",
    "df = pd.DataFrame(all_metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4cf6dd5",
   "metadata": {},
   "outputs": [],
   "source": [
    "for c in df:\n",
    "    metric_col = False\n",
    "    for n in [\"domain_coverage\", \"iou\", \"boundary\"]:\n",
    "        if c[:len(n)] == n:\n",
    "            metric_col = True\n",
    "    if metric_col:\n",
    "        print(c, (df[c]*df['n']).sum()/df['n'].sum())"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
