{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import csv\n",
    "with open('/home/1739678401/2024.9dataclean/MMGNN_0901/data/1-s2.0-S1385894721008925-mmc1.txt', 'r') as txt_file:\n",
    "    lines = txt_file.readlines()\n",
    "\n",
    "csv_data = []\n",
    "for line in lines:\n",
    "    parts = line.strip().split(',')\n",
    "    csv_data.append(parts)\n",
    "\n",
    "with open('/home/1739678401/2024.9dataclean/MMGNN_0901/data/output.csv', 'w', newline='') as csv_file:\n",
    "    writer = csv.writer(csv_file)\n",
    "    writer.writerows(csv_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "data = pd.read_csv('/home/1739678401/2024.9dataclean/MMGNN_0901/data/output.csv')\n",
    "all_smiles = data.iloc[:, 0].tolist() + data.iloc[:, 1].tolist()\n",
    "unique_smiles = list(set(all_smiles))\n",
    "\n",
    "print(len(unique_smiles))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_data = pd.DataFrame({'SMILES': unique_smiles})\n",
    "new_data.to_csv('unique_smiles.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import hnswlib\n",
    "from numpy import linalg as LA\n",
    "\n",
    "def cosine_sim(x,y):\n",
    "    return np.dot(x,y)/LA.norm(x)/LA.norm(y)\n",
    "\n",
    "class mol:\n",
    "    def __init__(self,smiles0,smiles1,feat):\n",
    "        self.smiles0=smiles0\n",
    "        self.smiles1=smiles1\n",
    "        self.feat=feat\n",
    "\n",
    "class infogain:\n",
    "    def __init__(self, initial_points=None, n=3000000, keep_seed=False, submodular_k=4):\n",
    "        self.clusters = None\n",
    "        self.target_n = n\n",
    "        self.current_n = 0\n",
    "        self.keep_seed = keep_seed\n",
    "        self.dim = 128*2\n",
    "        self.submodular_k = submodular_k\n",
    "\n",
    "        if initial_points: \n",
    "            self.data = initial_points.copy()\n",
    "            self.current_n = len(initial_points)\n",
    "            if keep_seed:\n",
    "                self.seeded_num = len(initial_points)\n",
    "            self.dim = len(initial_points[0].feat)\n",
    "        else:\n",
    "            self.data = []\n",
    "        self.submodular_gain = [(1)]*len(self.data)\n",
    "        # initialize HNSW index\n",
    "        self.knn_graph = hnswlib.Index(space='cosine', dim=self.dim)\n",
    "        self.knn_graph.init_index(max_elements=n, ef_construction=100, M=48, allow_replace_deleted = False)\n",
    "        self.precluster(initial_points)\n",
    "        self.knn_graph.set_ef(32)\n",
    "\n",
    "    def precluster(self, initial_points):\n",
    "    # Starting from some initial points (the cleaner the better) to do online selection\n",
    "        if initial_points is None or initial_points==[]: return\n",
    "        for idx,data in enumerate(self.data):\n",
    "            data.index = idx\n",
    "\n",
    "        for idx,data in enumerate(self.data):\n",
    "            self.submodular_gain[idx] = self.submodular_func(data, True)\n",
    "            self.knn_graph.add_items(data.feat, idx)\n",
    "\n",
    "    def submodular_func(self, data, skip_one=False):\n",
    "        if self.knn_graph.get_current_count()==0:\n",
    "            return (1.)\n",
    "        k = min(self.knn_graph.get_current_count(), self.submodular_k)\n",
    "        near_label,near_distances = self.knn_graph.knn_query(data.feat, k)\n",
    "        return np.mean(near_distances)\n",
    "\n",
    "    def add_item(self, data):\n",
    "        data.index = self.current_n\n",
    "        self.data.append(data)\n",
    "        self.knn_graph.add_items(data.feat, self.current_n)\n",
    "        self.current_n+=1\n",
    "\n",
    "    def replace_item(self, data, index):\n",
    "        # Not used in current work but provide for future extension on replacing samples\n",
    "        data_to_rep = self.data[index]\n",
    "        n_index = data_to_rep.index\n",
    "        data.index = self.current_n\n",
    "        self.knn_graph.mark_deleted(n_index)\n",
    "        self.knn_graph.add_items(data.I_feat, self.current_n, replace_deleted = True)\n",
    "        self.data[index] = data\n",
    "        self.current_n+=1\n",
    "\n",
    "    def process_item(self, data,):\n",
    "        # find near clusters\n",
    "        # go into nearest clusters to search near neighbour\n",
    "        # calculate corresponding threshold to decide if try to add or not\n",
    "        gain = self.submodular_func(data)\n",
    "        self.add_item(data)\n",
    "        self.submodular_gain.append(gain)\n",
    "\n",
    "    def final_gains(self):\n",
    "        return self.submodular_gain\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "dic = torch.load('/home/1739678401/2024.9dataclean/MMGNN_0901/temp_data/uq_smiles.pt')\n",
    "\n",
    "import pandas as pd\n",
    "def file_to_mol(filepath):\n",
    "    df = pd.read_csv(filepath, usecols=[0, 1])\n",
    "    res = []\n",
    "    for i, row in df.iterrows():\n",
    "        tensor1, _ = dic[row[0]]\n",
    "        tensor2, _ = dic[row[1]]\n",
    "        feat = torch.cat((tensor1, tensor2), dim=0).to(dtype=torch.float32).cpu().detach().numpy().reshape(256)\n",
    "        #print(feat.shape)\n",
    "        res.append(mol(row[0],row[1],feat))\n",
    "    return res\n",
    "\n",
    "filepath='/home/1739678401/2024.9dataclean/MMGNN_0901/temp_data/1M.csv'\n",
    "res=file_to_mol(filepath)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1000000\n"
     ]
    }
   ],
   "source": [
    "IG=infogain(res)\n",
    "mn_finalgains=IG.final_gains()\n",
    "print(len(mn_finalgains))\n",
    "import numpy as np\n",
    "gains = np.array([ (gain) for gain in mn_finalgains])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "thr = np.quantile(gains,0.99)\n",
    "gains=gains/thr\n",
    "gains[gains>1]=1\n",
    "gains[gains<0]=0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10000\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "data = pd.read_csv('/home/1739678401/2024.9dataclean/MMGNN_0901/temp_data/1M.csv')\n",
    "\n",
    "#selected_mask = np.random.rand(len(data)) < gains-0.05\n",
    "selected_data = data[::100]\n",
    "print(len(selected_data))\n",
    "selected_data.to_csv('/home/1739678401/2024.9dataclean/MMGNN_0901/temp_data/theory_random.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import csv\n",
    "\n",
    "input_file = '/home/1739678401/2024.9dataclean/MMGNN_0901/temp_data/theory_random.csv'\n",
    "output_file = '/home/1739678401/2024.9dataclean/MMGNN_0901/data/theory_random.csv'\n",
    "\n",
    "with open(input_file, 'r', encoding='utf-8') as infile, open(output_file, 'w', encoding='utf-8', newline='') as outfile:\n",
    "    reader = csv.reader(infile)\n",
    "    writer = csv.writer(outfile)\n",
    "    for row in reader:\n",
    "        new_row = row[1:4]  # 提取第二到第四列的数据\n",
    "        writer.writerow(new_row)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dataclean",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
