import csv
import math
import pickle
from tqdm import tqdm
import numpy as np
from rdkit.Chem import MolFromSmiles, MolToSmiles
from sklearn.model_selection import train_test_split
import collections
from SmilesPE.pretokenizer import atomwise_tokenizer
import sys


def featurize_mol(smiles):
    fs = [1]
    for token in atomwise_tokenizer(smiles):
        if token in token_to_idx:
            fs.append(token_to_idx[token])
        else:
            fs.append(2)
    while len(fs) < max_len:
        fs.append(0)
    return np.array(fs)


mols = []
f = open(sys.argv[1], 'r')
next(f)
seqs = []
for i, row in tqdm(enumerate(csv.reader(f, delimiter='	'))):
    if (row[8] or row[10]) and (10 < len([char for char in row[1] if char not in '()=@[]123456789']) < 70) and row[37] != 'NULL' and MolFromSmiles(row[1]):
        val = (row[10] if row[10] else row[8]).replace('<', '').replace('>', '').strip()
        seqs.append(row[37].upper())
        mols.append((MolToSmiles(MolFromSmiles(row[1])), math.log10(float(val) + 1e-10)))
allowed_seqs = [seq for seq, count in collections.Counter(seqs).most_common() if count > 10]
for seq in tqdm(allowed_seqs):
    vals = [mols[i][1] for i in range(len(mols)) if seqs[i] == seq]
    if not (True in [-50 < val < 0 for val in vals] and True in [0 < val < 1 for val in vals] and True in [1 < val < 2 for val in vals] and True in [2 < val < 3 for val in vals] and True in [3 < val < 4 for val in vals] and True in [4 < val < 5 for val in vals] and True in [5 < val < 50 for val in vals]):
        i = 0
        while i < len(mols):
            if seqs[i] == seq:
                del mols[i]
                del seqs[i]
            else:
                i += 1
allowed_seqs = [seq for seq, count in collections.Counter(seqs).most_common() if count > 10]
training_seqs, testing_seqs = train_test_split(allowed_seqs, test_size=40)
train_mols, train_seqs = zip(*[(mols[i], seqs[i]) for i in range(len(mols)) if seqs[i] in training_seqs])
test_mols, test_seqs = zip(*[(mols[i], seqs[i]) for i in range(len(mols)) if seqs[i] in testing_seqs])
token_to_idx = {}
max_idx = 3
max_len = 0
for smile in list(zip(*train_mols))[0]:
    tokens = atomwise_tokenizer(smile)
    if len(tokens) + 1 > max_len:
        max_len = len(tokens) + 1
    for token in tokens:
        if token not in token_to_idx:
            token_to_idx[token] = max_idx
            max_idx += 1
x_train = np.array([featurize_mol(smiles) for smiles, _ in train_mols])
y_train = np.array([binding for _, binding in train_mols])
x_test = np.array([featurize_mol(smiles) for smiles, _ in test_mols])
y_test = np.array([binding for _, binding in test_mols])
pickle.dump((x_train, x_test, y_train, y_test, train_seqs, test_seqs, token_to_idx), open('data.pickle', 'wb'))