import os

import duckdb
import polars as pl
from skfp.utils import run_in_parallel

from src.config import OUTPUTS_DIR
from src.pipelines.base_pipeline import BasePipeline
from src.standardization import inchi_to_smiles_convert


def merge_datasets(pipelines: list[BasePipeline], merged_file_path: str) -> None:
    """
    Merges datasets from multiple pipelines into a single Parquet file.
    Datasets are merged on by on starting from the biggest one.
    """
    if os.path.exists(merged_file_path):
        print("Found merged dataset, skipping")
        return

    print("Merging datasets from all pipelines...")

    # starting point - first dataset
    source_name = pipelines[0].source_name
    input_file_path = f"{OUTPUTS_DIR}/{pipelines[0].filtered_filename}"
    print(f"Merging {source_name}")

    duckdb.sql(
        f"""
        COPY (
            SELECT '{pipelines[0].source_name}' AS source, id, InChI
            FROM read_parquet('{input_file_path}')
        )
        TO '{merged_file_path}' (FORMAT parquet);
        """
    )

    # add one dataset at a time to avoid memory errors
    for pipeline in pipelines[1:]:
        source_name = pipeline.source_name
        input_file_path = f"{OUTPUTS_DIR}/{pipeline.filtered_filename}"

        print(f"Merging {source_name}")

        # add new rows from current source
        duckdb.sql(
            f"""
            SET preserve_insertion_order=false;

            COPY (
                SELECT source, id, InChI
                FROM read_parquet('{merged_file_path}')

                UNION ALL

                SELECT '{source_name}' AS source, id, InChI
                FROM read_parquet('{input_file_path}') AS input_file

                WHERE NOT EXISTS (
                    SELECT 1
                    FROM read_parquet('{merged_file_path}') AS output_file
                    WHERE input_file.InChI = output_file.InChI
                )
            )
            TO '{merged_file_path}' (FORMAT parquet);
            """
        )


def create_smiles_dataset(merged_file_path: str, smiles_file_path: str) -> None:
    """
    Replace InChI with SMILES and remove any non-parseable molecules.
    """
    df = pl.read_parquet(merged_file_path)

    initial_length = len(df)

    print("Translating InChI to SMILES")

    smiles = run_in_parallel(
        inchi_to_smiles_convert,
        data=df["InChI"],
        n_jobs=-1,
        batch_size=1000,
        flatten_results=True,
        verbose=True,
    )

    df = df.with_columns(pl.Series("SMILES", values=smiles))

    df = df.filter(pl.col("SMILES").is_not_null())
    df = df.unique("SMILES")

    df = df.select(["id", "SMILES"])
    filtered_length = len(df)

    df.write_parquet(smiles_file_path)

    print(
        f"InChI to SMILES conversion went from size {initial_length} to {filtered_length}"
    )
