import os

import polars as pl
from skfp.utils import run_in_parallel

from src.downloading import download_multiple_files
from src.pipelines.base_pipeline import BasePipeline
from src.standardization import smiles_to_inchi_convert


class ZINCPipeline(BasePipeline):
    def __init__(self):
        super().__init__(
            source_name="ZINC",
            filename="zinc.csv",
        )
        self.urls = os.path.join(self.input_dir, "zinc_urls.txt")
        self.tmp_output_dir = os.path.join(self.output_dir, "ZINC")

    def download(self, force_download: bool = False) -> None:
        """
        ZINC output: .smi files, with one SMILES per line, like TSV with columns [smiles zinc_id].
        """
        if not force_download and os.path.exists(self.tmp_output_dir):
            print(f"Found existing directory {self.tmp_output_dir}, skipping download")
            return

        with open(self.urls) as f:
            urls = [line.strip() for line in f if line.strip()]

        print(f"Downloading {len(urls)} files from provided list...")
        download_multiple_files(
            urls=urls,
            output_dir=self.tmp_output_dir,
        )
        print("Download complete.")

    def preprocess(self) -> None:
        output_file_path = os.path.join(self.output_dir, self.preprocessed_filename)

        # check if preprocessed file already exists
        if os.path.exists(output_file_path):
            print("Found preprocessed dataset, skipping")
            return output_file_path

        glob_path = os.path.join(self.tmp_output_dir, "*.smi")

        df = pl.read_csv(source=glob_path, has_header=True, separator=" ")

        print("Initial length:", len(df))

        inchis = run_in_parallel(
            smiles_to_inchi_convert,
            data=df["smiles"],
            n_jobs=-1,
            batch_size=1000,
            flatten_results=True,
            verbose=True,
        )
        df = df.with_columns(pl.Series("smiles", values=inchis))
        df = df.rename({"smiles": "InChI", "zinc_id": "id"})

        df = df.filter(pl.col("InChI").is_not_null())
        df = df.unique("InChI", keep="first")
        df = df.select(["id", "InChI"])

        print("Preprocessed length:", len(df))

        df.write_parquet(output_file_path)
