import os

import polars as pl
from skfp.utils import run_in_parallel

from src.pipelines.base_pipeline import BasePipeline
from src.standardization import smiles_to_inchi_convert


class GDBPipeline(BasePipeline):
    def __init__(self):
        super().__init__(
            source_name="GDB",
            filename="gdb.smi",
            archive_name="gdb.smi.gz",
        )
        self.url = (
            "https://zenodo.org/record/5172018/files/GDB17.50000000.smi.gz?download=1"
        )

    def download(self, force_download: bool = False) -> None:
        """
        GDB output: .smi file, with one SMILES per line, like CSV without a header.
        """
        super()._download_single_file_archive(
            url=self.url,
            force_download=force_download,
        )

    def preprocess(self) -> None:
        input_file_path = os.path.join(self.output_dir, self.filename)
        output_file_path = os.path.join(self.output_dir, self.preprocessed_filename)

        # check if preprocessed file already exists
        if os.path.exists(output_file_path):
            print("Found preprocessed dataset, skipping")

        df = pl.read_csv(
            input_file_path,
            has_header=False,
            separator="\t",
            new_columns=["SMILES"],
        )
        df = df.with_row_index(name="id")

        print("Initial length:", len(df))

        inchis = run_in_parallel(
            smiles_to_inchi_convert,
            data=df["SMILES"],
            n_jobs=-1,
            batch_size=1000,
            flatten_results=True,
            verbose=True,
        )
        df = df.with_columns(pl.Series("SMILES", values=inchis))
        df = df.rename({"SMILES": "InChI"})

        df = df.filter(pl.col("InChI").is_not_null())
        df = df.unique("InChI")

        print("Preprocessed length:", len(df))

        df.write_parquet(output_file_path)
