{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72c8eeaf",
   "metadata": {},
   "outputs": [],
   "source": [
    "from foldingdiff.plotting import *\n",
    "from foldingdiff.tokenizer import *\n",
    "from foldingdiff.bpe import *\n",
    "import pickle\n",
    "import os\n",
    "os.chdir(Path.cwd().parents[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b0720e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_dir = 'ckpts/1744875790.3072364'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb9c74f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install --upgrade google-api-python-client google-auth-httplib2 google-auth-oauthlib\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e4acad1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def visualize_bonds(self, i1, length, output_path, xlim=None, ylim=None, zlim=None):\n",
    "    def offset(tup, k):\n",
    "        return (tup[0]-k, tup[1], tup[2])\n",
    "    coords = self.compute_coords(i1, length)\n",
    "    # ATOM_TYPES[i1%3], ATOM_TYPES[i1%3+1], ..., ATOM_TYPES[i1%3+length]\n",
    "    bts = [Tokenizer.ATOM_TYPES[(i1%3+i)%3] for i in range(length+1)]\n",
    "    tokens = [offset(self.bond_to_token[i], i1) for i in sorted(self.bond_to_token) if i >= i1 and i < i1+length]\n",
    "    return plot_backbone(coords, output_path, bts, title=f\"{Path(self.fname).stem} bonds {i1}-{i1+length-1}\", vis_dihedral=False, zoom_factor=0.5, tokens=tokens, xlim=xlim, ylim=ylim, zlim=zlim)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c61db45e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from google.oauth2 import service_account\n",
    "from googleapiclient.discovery import build\n",
    "from googleapiclient.http import MediaFileUpload\n",
    "\n",
    "# 1) Load credentials\n",
    "SCOPES = ['https://www.googleapis.com/auth/drive.file']\n",
    "creds = service_account.Credentials.from_service_account_file(\n",
    "    'service-account.json', scopes=SCOPES)\n",
    "\n",
    "# 2) Build the Drive service\n",
    "service = build('drive', 'v3', credentials=creds)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9254047",
   "metadata": {},
   "outputs": [],
   "source": [
    "from foldingdiff.datasets import FullCathCanonicalCoordsDataset\n",
    "\n",
    "dataset = FullCathCanonicalCoordsDataset(pdbs=\"cath\", use_cache=False, toy=10, debug=True, zero_center=False, pad=512, secondary=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d6e6572",
   "metadata": {},
   "outputs": [],
   "source": [
    "def str2dict(v):\n",
    "    m = re.match('\\d+-\\d+(?::\\d+-\\d+)*$', v)\n",
    "    if not m:\n",
    "        raise\n",
    "    pairs = re.findall(r'(\\d+)-(\\d+)', v)        \n",
    "    bins = {}\n",
    "    for (a, b) in pairs:\n",
    "        bins[int(a)] = int(b)\n",
    "    return bins\n",
    "\n",
    "\n",
    "bin_str = \"1-100:2-20:5-10\"\n",
    "for strat in ['uniform','histogram']:\n",
    "    bpe = BPE(dataset.structures, str2dict(bin_str), bin_strategy=strat)\n",
    "    bpe.initialize(path=f'plots/hist_{bin_str}.png')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa5f7cfc",
   "metadata": {},
   "outputs": [],
   "source": [
    "i = 3\n",
    "lims = []\n",
    "# for _iter in [0] + list(range(10,100,10)) + list(range(1000,10000,1000)):\n",
    "for _iter in [0]:\n",
    "    step = 100\n",
    "    if _iter not in [0, 6000]: continue # delete\n",
    "    bpe = pickle.load(open(f'{base_dir}/bpe_iter={_iter}.pkl', 'rb'))\n",
    "    t = bpe.tokenizers[i]   \n",
    "    if len(lims ) == 0:\n",
    "        lims = [None for _ in range(0, 3*t.n-1, step)]\n",
    "    for idx, start in enumerate(range(0, 3*t.n-1, step)):\n",
    "        if idx > 0: continue # delete\n",
    "        end = start+step\n",
    "        start = t.token_pos[start]\n",
    "        end = t.token_pos[end] if end < 3*t.n else 3*t.n-1\n",
    "        print(start, end, 3*t.n-1)\n",
    "        l = end-start\n",
    "        path = os.path.join(base_dir, f'{i}_iter={_iter}_{start}-{end}.png')        \n",
    "        if lims[idx] is None:\n",
    "            res = visualize_bonds(t, start, l, path, )\n",
    "            lims[idx] = tuple(res)\n",
    "        else:\n",
    "            visualize_bonds(t, start, l, path, xlim=lims[idx][0], ylim=lims[idx][1], zlim=lims[idx][2])\n",
    "        print(path)\n",
    "        # # 3) Prepare file metadata & media\n",
    "        # file_metadata = {\n",
    "        #     'name': Path(path).name,\n",
    "        #     'parents': ['1NOxavUomer-WMYlaUG9olQBjEX9r7cx8']  # optional: specify folder\n",
    "        # }\n",
    "        # media = MediaFileUpload(path)\n",
    "        # # 4) Upload\n",
    "        # file = service.files().create(\n",
    "        #     body=file_metadata,\n",
    "        #     media_body=media,\n",
    "        #     fields='id'\n",
    "        # ).execute()\n",
    "        # print('Uploaded file ID:', file.get('id'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f788a54b",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(100000):\n",
    "    bpe.tokenizers[i].bond_to_token.tree.visualize(os.path.join('ckpts/1744875790.3072364', f'tokens_{i}_iter={_iter}.png'), horizontal_gap=0.5, font_size=6)"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
