import os
import subprocess

import pandas as pd
from rdkit import RDLogger
from rdkit.Chem import MolFromInchi, MolToInchi, SDMolSupplier
from tqdm import tqdm

from src.pipelines.base_pipeline import BasePipeline


class ChemSpacePipeline(BasePipeline):
    def __init__(self):
        super().__init__(
            source_name="ChemSpace",
            filename="chemspace.sdf",
            archive_name="chemspace.zip",
        )
        self.url = "https://cloud.chem-space.com/s/np8g9LpYYCbStPD/download/Chemspace_Screening_Compounds_SDF.zip"

    def download(self, force_download: bool = False) -> None:
        """
        ChemSpace output: single SDF file.
        """
        downloaded_file = self._download_single_file_archive(
            url=self.url,
            force_download=force_download,
        )
        if downloaded_file:
            os.rename(
                os.path.join(self.output_dir, "Chemspace_Screening_Compounds_SDF.sdf"),
                os.path.join(self.output_dir, self.filename),
            )

    def preprocess(self) -> None:
        input_file_path = os.path.join(self.output_dir, self.filename)
        output_file_path = os.path.join(self.output_dir, self.preprocessed_filename)

        # check if preprocessed file already exists
        if os.path.exists(output_file_path):
            print("Found preprocessed dataset, skipping")
            return

        RDLogger.DisableLog("rdApp.*")

        # multithreaded one has a bug: https://github.com/rdkit/rdkit/issues/8530
        # this forces us to process sequentially
        supplier = SDMolSupplier(input_file_path)

        total = subprocess.check_output(
            [f"grep 'Chemspace_ID' {input_file_path} | wc -l"], shell=True
        )
        total = int(total.decode())

        results = []
        for mol in tqdm(supplier, total=total):
            mol_id = mol.GetProp("Chemspace_ID")
            try:
                # ensure idempotency, i.e. that writing and reading results
                # in a valid molecule
                mol = MolToInchi(MolFromInchi(MolToInchi(mol)))
            except Exception:
                continue
            results.append({"id": mol_id, "InChI": mol})

        df_preprocessed = pd.DataFrame(results)
        df_preprocessed = df_preprocessed.drop_duplicates("InChI")

        if self.verbose:
            print(
                f"Preprocessing of {self.source_name} dataset:",
                f"DataFrame shape: {len(df_preprocessed)}",
            )

        df_preprocessed.to_parquet(
            output_file_path,
            engine="pyarrow",
            compression="zstd",
            index=False,
        )
