import os

import polars as pl

from src.pipelines.base_pipeline import BasePipeline


class ChEMBLPipeline(BasePipeline):
    def __init__(self):
        super().__init__(
            source_name="ChEMBL",
            filename="chembl.tsv",
            archive_name="chembl.tsv.gz",
        )
        self.url = "https://ftp.ebi.ac.uk/pub/databases/chembl/ChEMBLdb/latest/chembl_35_chemreps.txt.gz"

    def download(self, force_download: bool = False) -> None:
        """
        ChEMBL output: single TSV file (despite original .txt extension), with columns:
        [chembl_id, canonical_smiles, standard_inchi, standard_inchi_key]
        """
        super()._download_single_file_archive(
            url=self.url,
            force_download=force_download,
        )

    def preprocess(self) -> None:
        input_file_path = os.path.join(self.output_dir, self.filename)
        output_file_path = os.path.join(self.output_dir, self.preprocessed_filename)

        # check if preprocessed file already exists
        if os.path.exists(output_file_path):
            print("Found preprocessed dataset, skipping")
            return

        df = pl.read_csv(
            input_file_path, separator="\t", columns=["chembl_id", "standard_inchi"]
        )
        df = df.rename({"chembl_id": "id", "standard_inchi": "InChI"})
        df = df.unique("InChI")

        df.write_parquet(output_file_path)
