import csv
import os
from tqdm import tqdm
from utils.data_utils import canonicalize_smiles, tokenize_smiles


def split_raw_csv(fn, ofn_src, ofn_tgt):
    print(f"Converting {fn} into {ofn_src} and {ofn_tgt}")
    with open(fn, "r") as csv_file, open(ofn_src, "w") as of_src, open(ofn_tgt, "w") as of_tgt:
        csv_reader = csv.DictReader(csv_file)
        for row in tqdm(csv_reader):
            smi = row["reactants>reagents>production"]
            smi_r, _, smi_p = smi.split(">")

            of_src.write(f"{tokenize_smiles(canonicalize_smiles(smi_p, remove_atom_number=True))}\n")
            of_tgt.write(f"{tokenize_smiles(canonicalize_smiles(smi_r, remove_atom_number=True))}\n")


def main():
    fp = "./data/schneider50k"

    for phase in ["train", "val", "test"]:
        fn = os.path.join(fp, f"raw_{phase}.csv")
        ofn_src = os.path.join(fp, f"src-{phase}.txt")
        ofn_tgt = os.path.join(fp, f"tgt-{phase}.txt")
        split_raw_csv(fn, ofn_src, ofn_tgt)


if __name__ == "__main__":
    main()
