import pandas as pd
import argparse
from transformers import AutoTokenizer, AutoModel
import os
from os.path import join
import torch
import numpy as np

def get_arguments():
	parser = argparse.ArgumentParser()
	parser.add_argument(
		"--save_dir",
		type=str,
		default = "",
		help="Directory to save the point lists",
		)
	parser.add_argument(
		"--train_file",
		type=str,
		default = "",
		help="Path to train file with SMILES strings in the column Ligand_SMILES",
		)
	parser.add_argument(
		"--val_file",
		type=str,
		default = "",
		help="Path to validation file with SMILES strings in the column Ligand_SMILES",
		)
	parser.add_argument(
		"--test_file",
		type=str,
		default = "",
		help="Path to test file with SMILES strings in the column Ligand_SMILES",
		)
	parser.add_argument(
		"--Molformer_model_path",
		type=str,
		default = "",
		help="Path to Molformer model",
		)
	parser.add_argument(
		"--continue_from_previous_file",
		type=bool,
		default = False,
		help="Continue from previous file",
		)
	return parser.parse_args()



args = get_arguments()
args_dict = vars(args)
globals().update(args_dict)


all_SMILES = []
if train_file != "":
	if train_file.endswith(".csv"):
		all_SMILES.extend(pd.read_csv(train_file, sep = ",")["Ligand_SMILES"].tolist())
	elif train_file.endswith(".xlsx"):
		all_SMILES.extend(pd.read_excel(train_file)["Ligand_SMILES"].tolist())
if val_file != "":
	if val_file.endswith(".csv"):
		all_SMILES.extend(pd.read_csv(val_file, sep = ",")["Ligand_SMILES"].tolist())
	elif val_file.endswith(".xlsx"):
		all_SMILES.extend(pd.read_excel(val_file)["Ligand_SMILES"].tolist())
if test_file != "":
	if test_file.endswith(".csv"):
		all_SMILES.extend(pd.read_csv(test_file, sep = ",")["Ligand_SMILES"].tolist())
	elif test_file.endswith(".xlsx"):
		all_SMILES.extend(pd.read_excel(test_file)["Ligand_SMILES"].tolist())

#get unique SMILES
all_SMILES = list(set(all_SMILES))
print("Number of SMILES: ", len(all_SMILES))

if Molformer_model_path == "":
	smiles_bert = AutoModel.from_pretrained("ibm/MoLFormer-XL-both-10pct", trust_remote_code=True, deterministic_eval=True)
	smiles_tokenizer = AutoTokenizer.from_pretrained("ibm/MoLFormer-XL-both-10pct", deterministic_eval=True, trust_remote_code=True)
else:
	smiles_bert =  AutoModel.from_pretrained(Molformer_model_path, trust_remote_code=True, deterministic_eval=True)
	smiles_tokenizer = AutoTokenizer.from_pretrained(Molformer_model_path, deterministic_eval=True, trust_remote_code=True)

def create_empty_path(path):
	try:
		os.mkdir(path)
	except:
		pass

	all_files = os.listdir(path)
	for file in all_files:
		os.remove(join(path, file))


def calculate_smiles_embeddings(all_smiles, outpath, continue_from_previous_file = False):
	if continue_from_previous_file:
		smiles_reprs = np.load(join(outpath, "SMILES", "SMILES_repr_temp.npy"))
	else:
		smiles_reprs = {}
	create_empty_path(join(outpath, "SMILES"))

	existing_smiles_embeddings = list(smiles_reprs.keys())

	n = len(all_smiles)

	for k, smiles in enumerate(all_smiles):
		
		if smiles not in existing_smiles_embeddings:
			smiles_rep = get_last_layer_repr(smiles)
			#smiles_rep.requires_grad = False
			smiles_reprs[smiles] = smiles_rep

		if k % 100 == 0:
			print("Processed ", k, " SMILES")
			np.save(join(outpath, "SMILES", "SMILES_repr_temp.npy"), smiles_reprs)

	print("Processed ", k+1, " SMILES")
	np.save(join(outpath, "SMILES", "SMILES_repr.npy"), smiles_reprs)
	#delete temp file:
	
	os.remove(join(outpath, "SMILES", "SMILES_repr_temp.npy"))
	print("Saved SMILES embeddings")

def get_last_layer_repr(smiles):
	tokenizer = smiles_tokenizer
	model = smiles_bert

	tokens = tokenizer(smiles, max_length=500, 
			padding=True, 
			truncation=True, 
			return_tensors="pt")
	tokens["input_ids"] = tokens["input_ids"]
	tokens["attention_mask"] = tokens["attention_mask"]
	with torch.no_grad():
		outputs = model(**tokens)
	return outputs.last_hidden_state


if not os.path.exists(join(save_dir, "SMILES")):
	os.makedirs(join(save_dir, "SMILES"))

calculate_smiles_embeddings(all_SMILES, save_dir, continue_from_previous_file)