from collections import defaultdict
import csv
import math
from rdkit import Chem
from similarity_clustering import cluster

targets = []
with open("./targets.txt", 'r') as file:
    for line in file:
        targets.append(line.strip())
        
#parse all corresponding ligands from BindingDB
ligands = defaultdict(lambda: defaultdict(float))
with open("./BindingDB_All.tsv", 'r') as file:
    reader = csv.reader(file, delimiter='\t')
    next(reader)
    for idx, row in enumerate(reader):
        for target in targets:
            if target in row:
                affin = None
                if row[8]: affin = row[8]
                if row[9]: affin = row[9]
                if row[10]: affin = row[10]
                if row[11]: affin = row[11]
                if affin is None:
                    continue
                affin = affin.replace("<", "")
                affin = affin.replace(">", "")
                affin = float(affin)
                try:
                    affin = 298*0.001987*math.log(affin*math.pow(10, -9))
                except Exception as e:
                    continue
                
                ligand = row[1]
                try:
                    mol = Chem.MolFromSmiles(ligand)
                    if mol is not None:
                        ligands[target][Chem.MolToSmiles(mol)] = affin
                except:
                    continue
            if idx%10000 == 0: print(f"{idx} / 3,000,000", end='\r')
            
#cluster by similarity
clusters = defaultdict(list)
for idx, protein in enumerate(list(ligands.keys())):
    print(f"{protein} {idx} / {len(list(ligands.keys()))-1}")
    curr_ligands = list(ligands[protein].keys())
    new_cluster = cluster(curr_ligands)
    clusters[protein] = new_cluster

#sort within clusters and form into chains
total_data = 0
for idx, protein in enumerate(list(clusters.keys())):
    print(f"{protein} {idx} / {len(list(clusters.keys()))}")
    curr_clusters = clusters[protein]
    curr_clusters = [c for c in curr_clusters if len(c)>1]
    sorted_clusters = []
    for cluster in curr_clusters:
        sorted_cluster = sorted(cluster, key=ligands[protein].get, reverse=True)
        sorted_clusters.append(sorted_cluster)
    gap = 3
    affin_threshold = 0.01
    chains = []
    for cluster in sorted_clusters:
        for i in range(gap):
            if i == len(cluster)-1:
                break
            chain = []
            j = 0
            while i+j*gap < len(cluster):
                ligand = cluster[i+j*gap]
                if len(chain) == 0 or ligands[protein][ligand] < chain[-1][1]-affin_threshold: 
                    chain.append((ligand, ligands[protein][ligand]))
                j+=1
            if (len(cluster)-i-1) % gap != 0:
                last_ligand = cluster[-1]
            if len(chain) == 0 or ligands[protein][ligand] < chain[-1][1]-affin_threshold: 
                chain.append((last_ligand, ligands[protein][last_ligand]))
            
            if len(chain)>1:
                chains.append(chain)
                total_data += (len(chain)-1)

    with open(f"./chains/{protein}.txt", 'w') as file:
        for chain in chains:
            file.write("Chain start:\n")
            for ligand_pair in chain:
                file.write(f"{ligand_pair[0]} {ligand_pair[1]}\n")
print(total_data)