import argparse
import os
import sys

from pathlib import Path

if "--cuda_device" in sys.argv:
    i = sys.argv.index("--cuda_device")
    if i + 1 < len(sys.argv):
        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
        os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[i + 1]


import anndata as ad
import numpy as np
import scvi

from tqdm import tqdm

TISSUES_ALL = ["blood", "lung", "breast", "heart", "eye", "brain"]


def run_scvi(adata_path: str, out_dir: str):
    in_path = Path(adata_path)
    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    adata = ad.read_h5ad(in_path, backed="r")

    covariates = [k for k in ["dataset_id", "donor_id"] if k in adata.obs.columns]
    print(f"Setting up anndata with batch: {covariates}")
    if covariates:
        scvi.model.SCVI.setup_anndata(
            adata,
            categorical_covariate_keys=covariates,
        )
    else:
        scvi.model.SCVI.setup_anndata(adata)

    model = scvi.model.SCVI(
        adata,
        n_hidden=512,
        n_latent=50,
        n_layers=2,
        gene_likelihood="nb",
        encode_covariates=False,
    )

    print(f"Training model:")
    model.train(
        max_epochs=100,
        train_size=0.9,
        batch_size=50_000,
        plan_kwargs=dict(lr=1e-4, n_epochs_kl_warmup=20),
    )

    print(f"Computing latent representation:")
    z = model.get_latent_representation(adata)

    obs = adata.obs.copy()[["soma_joinid", "dataset_id", "donor_id", "cell_type"]]
    var = adata.var.copy()[["soma_joinid", "feature_id"]]
    adata_out = ad.AnnData(
        X=None,
        obs=obs,
        var=var,
    )
    adata_out.obsm["X_scvi"] = np.asarray(z, dtype=np.float32)

    out_file = out_dir / (in_path.stem + "_scvi.h5ad")
    adata_out.write_h5ad(out_file)
    print(f"Saved output to {out_file}")


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--in_dir", required=True, help="Directory with input h5ad files")
    p.add_argument("--out_dir", required=True, help="Directory to write output h5ad")
    p.add_argument(
        "--cuda_device", required=True, help="CUDA device to be used (e.g., 3)."
    )
    args = p.parse_args()

    task_id = os.environ.get("SLURM_ARRAY_TASK_ID")
    tissues = [TISSUES_ALL[int(task_id)]] if task_id else TISSUES_ALL

    in_dir = Path(args.in_dir)
    for tissue in tqdm(tissues):
        print(f"Processing tissue={tissue}")
        try:
            adata_path = in_dir / f"slice_{tissue}.h5ad"
            run_scvi(adata_path, args.out_dir)
        except Exception as e:
            print(f"Failed {tissue}: {e}")


if __name__ == "__main__":
    main()
