import argparse
import gc
import logging
import os
import sys
import time

if "--cuda_device" in sys.argv:
    i = sys.argv.index("--cuda_device")
    if i + 1 < len(sys.argv):
        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
        os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[i + 1]

from datetime import datetime
from pathlib import Path

import anndata as ad
import numpy as np
import pacmap

# import parampacmap  # from ParamRepulsor (GPU LocalMAP / Parametric PaCMAP)
import scanpy as sc
import trimap

from cuml.manifold import TSNE, UMAP
from ctmc_cuda import cuCTMCEmbeddings, CauchyKernel, cauchy_kernel
from sklearn.preprocessing import normalize
from tqdm import tqdm

# TISSUES_ALL = ["brain", "blood", "lung", "breast", "heart", "eye"]
TISSUES_ALL = ["lung"]
N_NEIGHBORS = 15

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s | %(levelname)s | %(message)s",
)
logger = logging.getLogger(__name__)


def log_runtime(
    csv_path: Path,
    tissue: str,
    method: str,
    runtime_seconds: float,
    experiment_start_str: str,
):
    """Append a single CSV row; create file with header if missing."""
    header = "tissue,method,runtime,start_time\n"
    line = f"{tissue},{method},{runtime_seconds},{experiment_start_str}\n"
    csv_path.parent.mkdir(parents=True, exist_ok=True)
    if not csv_path.exists():
        with open(csv_path, "w") as f:
            f.write(header)
            f.write(line)
    else:
        with open(csv_path, "a") as f:
            f.write(line)


def timed_call(func, *args, **kwargs):
    """Run func(*args, **kwargs) and return (result, runtime_seconds)."""
    t0 = time.perf_counter()
    out = func(*args, **kwargs)
    t1 = time.perf_counter()
    return out, (t1 - t0)


def initialize_embedding(method_name: str):
    """
    Return a configured model for the given method name.
    All models expose .fit_transform(X) with a 2D embedding output.
    """
    if method_name == "ctmc":
        return cuCTMCEmbeddings(
            n_neighbors=N_NEIGHBORS,
            n_components=2,
            n_epochs=1500,
            force_kernel=CauchyKernel(a=2.0, b=0.67),
            graph_kernel=cauchy_kernel,
            gradient_clip=0.0,
            random_state=0,
            verbose=True,
            learning_rate=1.0,
            negative_sample_rate=5.0,
        )
    if method_name == "umap":
        return UMAP(
            n_neighbors=N_NEIGHBORS,
            n_components=2,
            verbose=True,
            n_epochs=1500,
        )
    if method_name == "tsne":
        return TSNE(
            n_components=2,
            verbose=1,
        )
    if method_name == "trimap":
        return trimap.TRIMAP(n_dims=2)
    if method_name == "pacmap":
        return pacmap.PaCMAP(n_components=2, n_neighbors=N_NEIGHBORS)

    raise ValueError(f"Unknown method: {method_name}")


def compute_embedding(
    method_name: str,
    tissue: str,
    in_path: Path,
    obsm_key: str,
    out_dir: Path,
    runtime_registry_path: Path | None,
    experiment_start_str: str,
):
    adata = ad.read_h5ad(in_path)

    if obsm_key not in adata.obsm:
        raise KeyError(
            f"'{obsm_key}' not found in adata.obsm. Available keys: {list(adata.obsm.keys())}"
        )

    X = adata.obsm[obsm_key]
    if not isinstance(X, np.ndarray):
        X = np.asarray(X)
    logger.info(f"Using obsm layer: '{obsm_key}' with shape {X.shape}")

    try:
        model = initialize_embedding(method_name)
        X_embedding, rt = timed_call(model.fit_transform, X)
        logger.info(f"{method_name.upper()} runtime: {rt:.2f}s")
        del model
        gc.collect()
    except Exception as e:
        logger.error(f"{method_name.upper()} failed: {e}")
        log_runtime(
            runtime_registry_path, tissue, method_name, -1, experiment_start_str
        )
        return

    adata.obsm[f"X_{method_name}"] = X_embedding

    out_dir.mkdir(parents=True, exist_ok=True)
    out_path = out_dir / f"slice_{tissue}_{method_name}.h5ad"
    adata.write_h5ad(out_path)
    logger.info(f"Wrote {out_path}")

    log_runtime(runtime_registry_path, tissue, method_name, rt, experiment_start_str)


def process_tissue(
    tissue: str,
    in_dir: Path,
    out_dir: Path,
    obsm_key: str,
    methods_to_run: list[str],
    runtime_registry_path: Path | None,
    experiment_start_str: str,
):
    in_path = in_dir / f"slice_{tissue}_pca.h5ad"
    if not in_path.exists():
        logger.warning(f"Missing {in_path}, skipping.")
        return

    logger.info(f"Processing: tissue={tissue}")

    for method in methods_to_run:
        logger.info(f"--- Running {method.upper()} ---")
        compute_embedding(
            method,
            tissue,
            in_path,
            obsm_key,
            out_dir,
            runtime_registry_path,
            experiment_start_str,
        )


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--in_dir", required=True)
    ap.add_argument("--out_dir", required=True)
    ap.add_argument(
        "--obsm_key",
        required=True,
        help="Matrix from adata.obsm to embed (e.g., X_pca, X_scvi)",
    )
    ap.add_argument(
        "--cuda_device", required=False, help="CUDA device to be used (e.g., 3)."
    )
    ap.add_argument(
        "--runtime_registry_path",
        required=False,
        help="Path to .csv for runtime rows (appended).",
    )
    ap.add_argument(
        "--log_file_path",
        required=False,
        help="Optional path to a log file. If given, logs will also be written there.",
    )

    # method flags
    ap.add_argument("--run_all", action="store_true", help="Run all methods.")
    ap.add_argument("--ctmc", action="store_true", help="Run CTMC (GPU).")
    ap.add_argument("--umap", action="store_true", help="Run UMAP (GPU/cuML).")
    ap.add_argument("--tsne", action="store_true", help="Run tSNE (GPU/cuML).")
    ap.add_argument("--trimap", action="store_true", help="Run TriMap (CPU).")
    ap.add_argument("--pacmap", action="store_true", help="Run PaCMAP (CPU).")

    args = ap.parse_args()
    in_dir = Path(args.in_dir)
    out_dir = Path(args.out_dir)
    runtime_registry_path = (
        Path(args.runtime_registry_path) if args.runtime_registry_path else None
    )
    experiment_start_str = str(datetime.now())
    if args.log_file_path:
        fh = logging.FileHandler(args.log_file_path)
        fh.setLevel(logging.INFO)
        fh.setFormatter(logging.Formatter("%(asctime)s | %(levelname)s | %(message)s"))
        logging.getLogger().addHandler(fh)

    methods = []
    if args.run_all:
        methods = ["ctmc", "umap", "tsne", "trimap", "pacmap"]
    else:
        if args.ctmc:
            methods.append("ctmc")
        if args.umap:
            methods.append("umap")
        if args.tsne:
            methods.append("tsne")
        if args.trimap:
            methods.append("trimap")
        if args.pacmap:
            methods.append("pacmap")
        if not methods:
            raise SystemExit(
                "No methods selected. Use --all or one of --ctmc --umap --tsne --trimap --pacmap"
            )

    task_id = os.environ.get("SLURM_ARRAY_TASK_ID")
    tissues = [TISSUES_ALL[int(task_id)]] if task_id else TISSUES_ALL

    for tissue in tqdm(tissues):
        try:
            process_tissue(
                tissue=tissue,
                in_dir=in_dir,
                out_dir=out_dir,
                obsm_key=args.obsm_key,
                methods_to_run=methods,
                runtime_registry_path=runtime_registry_path,
                experiment_start_str=experiment_start_str,
            )
        except Exception as e:
            logger.error(f"Failed {tissue}: {e}")


if __name__ == "__main__":
    main()
