import os
from collections import Counter

import click
import numpy as np
import polars as pl
from rdkit.Chem import AllChem, Mol
from skfp.preprocessing import MolFromInchiTransformer
from skfp.utils import run_in_parallel
from tqdm import tqdm

mol_from_inchi = MolFromInchiTransformer(valid_only=True, suppress_warnings=True)

identifiers_map = Counter()


@click.command()
@click.option(
    "--input-file",
    type=click.Path(exists=True),
    help="Input Parquet file path",
    required=True,
)
@click.option(
    "--tmp-sentences-file",
    type=click.Path(),
    help="Temporary txt file path with sentences",
    default="outputs/tmp_mol2vec_corpus.txt",
    required=False,
)
@click.option(
    "--output-file",
    type=click.Path(),
    help="Output txt file path",
    default="outputs/mol2vec_corpus.txt",
    required=False,
)
def generate_corpus(input_file: str, tmp_sentences_file: str, output_file: str) -> None:
    """
    Generate corpus containing one sentence per molecule (combination of radii 0 and 1).

    :param input_file: Input Parquet file path
    :param tmp_sentences_file: Temporary txt file path to save sentences before UNK replacement
    :param output_file: Output txt file path, with one sentence per line
    """
    df = pl.read_parquet(input_file)

    print("Processing sentences")
    sentences = run_in_parallel(
        get_sentences,
        data=df["InChI"],
        n_jobs=-1,
        batch_size=1000,
        flatten_results=True,
        verbose=True,
    )

    print("Counting identifiers")
    for sentence_array in tqdm(sentences):
        identifiers_map.update(Counter(sentence_array))

    # save sentences and free memory
    print("Writing sentences to temporary file")
    num_sentences = len(sentences)
    with open(tmp_sentences_file, "w") as file:
        for sentence in tqdm(sentences):
            sentence_str = " ".join([str(x) for x in sentence])
            file.write(sentence_str + "\n")

    del sentences

    print("Processing UNK tokens")
    process_unk_tokens(tmp_sentences_file, output_file, num_sentences)
    os.remove(tmp_sentences_file)


def get_sentences(inchis: pl.Series) -> list[np.ndarray]:
    from rdkit import rdBase

    rdBase.DisableLog("rdApp.*")  # turn off unnecessary warnings

    mols = mol_from_inchi.transform(inchis)
    sentences = [mol_to_sentence(mol) for mol in mols]

    return sentences


def mol_to_sentence(mol: Mol) -> np.ndarray:
    # info: subgraph identifier -> (atom_idx, radius):
    info = dict()
    AllChem.GetMorganFingerprint(mol, 1, bitInfo=info)

    sentence = np.zeros((mol.GetNumAtoms(), 2), dtype=int)

    for identifier, element in info.items():
        for atom_idx, radius_at in element:
            sentence[atom_idx, radius_at] = identifier

    return sentence.ravel()


def process_unk_tokens(
    tmp_sentences_file: str,
    output_file: str,
    num_sentences: int,
) -> None:
    """
    Handling of uncommon "words" (i.e. identifiers). Replace rare identifiers with "UNK"
    string and write output to the final file.
    """
    with open(tmp_sentences_file) as in_file, open(output_file, "w") as out_file:
        for line in tqdm(in_file, total=num_sentences):
            sentence = [int(x) for x in line.split()]
            sentence = " ".join(
                ["UNK" if identifiers_map.get(x, 0) <= 3 else str(x) for x in sentence]
            )
            out_file.write(sentence + "\n")


if __name__ == "__main__":
    generate_corpus()
