import argparse
import os

from pathlib import Path

import cellxgene_census
import pandas as pd
import tiledbsoma as soma

from cellxgene_census.experimental.pp import highly_variable_genes


def to_filter_list(values):
    if isinstance(values, str):
        values = [values]
    return "[" + ", ".join([f'"{str(v).replace("\"", "\\\"")}"' for v in values]) + "]"


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--out_dir", required=True)
    args = p.parse_args()

    TISSUES_ALL = ["brain", "blood", "eye", "lung", "breast", "heart"]

    task_id = os.environ.get("SLURM_ARRAY_TASK_ID")
    if task_id is not None and task_id != "":
        i = int(task_id)
        tissues = [TISSUES_ALL[i]]
    else:
        tissues = TISSUES_ALL

    out_dir = Path(args.out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    census = cellxgene_census.open_soma(census_version="2025-01-30")
    for tissue in tissues:
        try:
            print(f"Running HVG selection for tissue: {tissue}...")
            obs_val_filt = (
                "is_primary_data == True and "
                'assay == "10x 3\' v3" and '
                f"tissue_general in {to_filter_list(tissue)} and "
                "nnz >= 300 and "
                "disease == 'normal'"
            )
            hvgs_df = highly_variable_genes(
                census["census_data"]["homo_sapiens"].axis_query(
                    measurement_name="RNA",
                    obs_query=soma.AxisQuery(value_filter=obs_val_filt),
                ),
                n_top_genes=5000,
                batch_key=["dataset_id", "donor_id"],
            )

            print("Saving HVGs...")
            hv_idx = hvgs_df.highly_variable
            series = pd.Series(hv_idx[hv_idx].index, name="gene")
            out_path = out_dir / f"hvg_{tissue}.csv"
            series.to_csv(out_path, index=False)

            print("HVG selection completed.")

        except Exception as e:
            print(f"HVG selection failed: {e}")

    census.close()


if __name__ == "__main__":
    main()
