from pathlib import Path

import click
import duckdb
import pandas as pd


@click.command()
@click.option(
    "--input-files",
    type=click.Path(exists=True),
    multiple=True,
    help="Input parquet files paths",
    required=True,
)
@click.option(
    "--output-file",
    type=click.Path(),
    default="outputs/novel_molecules_matrix.csv",
    help="Output CSV file path",
)
def novel_molecules_matrix(input_files: list[str], output_file: str) -> None:
    """
    Generate a novel molecules matrix.
    """
    dataset_names = [Path(p).stem for p in input_files]

    for name, path in zip(dataset_names, input_files, strict=False):
        duckdb.sql(
            f"""
            CREATE OR REPLACE TABLE "{name}" AS
            SELECT InChI FROM read_parquet('{Path(path)}');
            """
        )

    matrix = pd.DataFrame(index=dataset_names, columns=dataset_names, dtype=int)
    dataset_mol_count_list = [
        duckdb.sql(f"SELECT COUNT(*) FROM '{name}';").fetchone()[0]
        for name in dataset_names
    ]

    for i, i_name in enumerate(dataset_names):
        for j, j_name in enumerate(dataset_names):
            if i == j:
                matrix.loc[i_name, j_name] = 0
            elif i < j:
                dataset_shared_mol_count = duckdb.sql(
                    f"""
                    SELECT COUNT(*) FROM "{i_name}"
                    WHERE InChI IN (SELECT InChI FROM "{j_name}");
                    """
                ).fetchone()[0]
                matrix.loc[i_name, j_name] = (
                    dataset_mol_count_list[i] - dataset_shared_mol_count
                )
                matrix.loc[j_name, i_name] = (
                    dataset_mol_count_list[j] - dataset_shared_mol_count
                )

    matrix.to_csv(output_file)
    print(f"Novel molecules matrix saved to {output_file}")


if __name__ == "__main__":
    novel_molecules_matrix()
