import os

import polars as pl
from skfp.utils import run_in_parallel

from src.pipelines.base_pipeline import BasePipeline
from src.standardization import smiles_to_inchi_convert


class PubChemPipeline(BasePipeline):
    def __init__(self):
        super().__init__(
            source_name="PubChem",
            filename="pubchem.tsv",
            archive_name="CID-SMILES.gz",
        )
        self.url = "https://ftp.ncbi.nlm.nih.gov/pubchem/Compound/Extras/CID-SMILES.gz"

    def download(self, force_download: bool = False) -> None:
        downloaded_file = self._download_single_file_archive(
            url=self.url,
            force_download=force_download,
        )
        if downloaded_file:
            os.rename(
                os.path.join(self.output_dir, "CID-SMILES"),
                os.path.join(self.output_dir, self.filename),
            )

    def preprocess(self) -> None:
        input_file_path = os.path.join(self.output_dir, self.filename)
        output_file_path = os.path.join(self.output_dir, self.preprocessed_filename)

        # check if preprocessed file already exists
        if os.path.exists(output_file_path):
            print("Found preprocessed dataset, skipping")
            return

        df = pl.read_csv(
            input_file_path,
            has_header=False,
            separator="\t",
            new_columns=["id", "SMILES"],
            schema={"SMILES": pl.String, "id": pl.String},
        )

        inchis = run_in_parallel(
            smiles_to_inchi_convert,
            data=df["SMILES"],
            n_jobs=-1,
            batch_size=1000,
            flatten_results=True,
            verbose=True,
        )
        df = df.with_columns(pl.Series("SMILES", values=inchis))
        df = df.rename({"SMILES": "InChI"})

        df = df.filter(pl.col("InChI").is_not_null())
        df = df.unique("InChI")

        df.write_parquet(output_file_path)
