import os
from shutil import rmtree

import polars as pl
from skfp.utils import run_in_parallel

from src.pipelines.base_pipeline import BasePipeline
from src.standardization import smiles_to_inchi_convert


class SuperNatural3Pipeline(BasePipeline):
    def __init__(self):
        super().__init__(
            source_name="SuperNatural3",
            filename="supernatural3.tsv",
            archive_name="dataset.csv.zip",
        )
        self.url = "https://bioinf-applied.charite.de/supernatural_3/full_data_download.csv.zip"
        self.filename_after_download = "SuperNatural3.tsv"

    def download(self, force_download: bool = False) -> None:
        downloaded_file = self._download_single_file_archive(
            url=self.url,
            force_download=force_download,
        )
        if downloaded_file:
            os.rename(
                os.path.join(self.output_dir, "full_data_download.csv"),
                os.path.join(self.output_dir, self.filename),
            )
            rmtree(os.path.join(self.output_dir, "__MACOSX"))

    def preprocess(self) -> None:
        input_file_path = os.path.join(self.output_dir, self.filename)
        output_file_path = os.path.join(self.output_dir, self.preprocessed_filename)

        # check if preprocessed file already exists
        if os.path.exists(output_file_path):
            print("Found preprocessed dataset, skipping")
            return

        df = pl.read_csv(input_file_path, separator=";", columns=["id", "smiles"])
        df = df.rename({"smiles": "SMILES"})

        inchis = run_in_parallel(
            smiles_to_inchi_convert,
            data=df["SMILES"],
            n_jobs=-1,
            batch_size=1000,
            flatten_results=True,
            verbose=True,
        )
        df = df.with_columns(pl.Series("SMILES", values=inchis))
        df = df.rename({"SMILES": "InChI"})

        df = df.filter(pl.col("InChI").is_not_null())
        df = df.unique("InChI")

        df.write_parquet(output_file_path)
