# %%
from pathlib import Path

import anndata as ad
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np

from matplotlib.colors import ListedColormap

mpl.rcParams["pdf.fonttype"] = 42
mpl.rcParams["ps.fonttype"] = 42
mpl.rcParams["svg.fonttype"] = "none"

# %%
ADATA_IN = "../results/slice_brain_embeddings.h5ad"
SUBSET = True
adata = ad.read_h5ad(ADATA_IN)

# %%
if SUBSET:
    n_subset_cells = 750_000
    rng = np.random.default_rng(0)
    subset_idx = rng.choice(
        adata.n_obs, size=min(n_subset_cells, adata.n_obs), replace=False
    )
    adata = adata[subset_idx]

# %%
coarse_map = {
    # neurons
    "neuron": "Neuron",
    "glutamatergic neuron": "Neuron",
    "GABAergic neuron": "Neuron",
    "pyramidal neuron": "Neuron",
    "cerebral cortex pyramidal neuron": "Neuron",
    "hippocampal pyramidal neuron": "Neuron",
    "medium spiny neuron": "Neuron",
    "Purkinje cell": "Neuron",
    "cerebellar granule cell": "Neuron",
    "interneuron": "Neuron",
    "cortical interneuron": "Neuron",
    "inhibitory interneuron": "Neuron",
    # glia
    "astrocyte": "Glia",
    "astrocyte of the cerebral cortex": "Glia",
    "immature astrocyte": "Glia",
    "oligodendrocyte": "Glia",
    "oligodendrocyte precursor cell": "Glia",
    "differentiation-committed oligodendrocyte precursor": "Glia",
    "Bergmann glial cell": "Glia",
    "radial glial cell": "Glia",
    "macroglial cell": "Glia",
    "ependymal cell": "Glia",
    # immune
    "microglial cell": "Immune",
    "mature microglial cell": "Immune",
    "macrophage": "Immune",
    "meningeal macrophage": "Immune",
    "leukocyte": "Immune",
    "T cell": "Immune",
    "B cell": "Immune",
    "myeloid cell": "Immune",
    "erythrocyte": "Immune",
    "erythroid lineage cell": "Immune",
    # vascular
    "endothelial cell": "Vascular",
    "cerebral cortex endothelial cell": "Vascular",
    "capillary endothelial cell": "Vascular",
    "endothelial cell of artery": "Vascular",
    "endothelial cell of vascular tree": "Vascular",
    "blood vessel endothelial cell": "Vascular",
    "vascular leptomeningeal cell": "Vascular",
    "vascular associated smooth muscle cell": "Vascular",
    "smooth muscle cell": "Vascular",
    "smooth muscle cell of the brain vasculature": "Vascular",
    "pericyte": "Vascular",
    "mural cell": "Vascular",
    "brain vascular cell": "Vascular",
    # progenitors / stem
    "neuroblast (sensu Vertebrata)": "Progenitor/Stem",
    "neuroblast (sensu Nematoda and Protostomia)": "Progenitor/Stem",
    "forebrain neuroblast": "Progenitor/Stem",
    "neural progenitor cell": "Progenitor/Stem",
    "neural stem cell": "Progenitor/Stem",
    "pluripotent stem cell": "Progenitor/Stem",
    "glioblast": "Progenitor/Stem",
    "neural crest cell": "Progenitor/Stem",
    "premigratory neural crest cell": "Progenitor/Stem",
    "retinal progenitor cell": "Progenitor/Stem",
    # other
    "fibroblast": "Other",
    "fibroblast of choroid plexus": "Other",
    "epithelial cell": "Other",
    "choroid plexus epithelial cell": "Other",
    "mesenchymal cell": "Other",
    "cell": "Other",
    "unknown": "Other",
}

adata.obs["cell_type_coarse"] = (
    adata.obs["cell_type"].map(coarse_map).fillna("Other").astype("category")
)

# %%


def plot_embedding(
    adata,
    embedding,
    obsm_key,
    color_key,
    save_dir,
    dpi=300,
    palette=None,
):
    X = adata.obsm[obsm_key]
    y = adata.obs[color_key].astype("category")
    codes = y.cat.codes.to_numpy()
    categories = y.cat.categories

    if palette is None:
        base_cmap = plt.get_cmap("tab20")
        colors = [base_cmap(i % base_cmap.N) for i in range(len(categories))]
    else:
        colors = [palette[i % len(palette)] for i in range(len(categories))]

    cmap = ListedColormap(colors)

    plt.figure(figsize=(5, 5))
    plt.scatter(
        X[:, 0],
        X[:, 1],
        c=codes,
        cmap=cmap,
        s=0.1,
        alpha=0.7,
        rasterized=False,
    )
    plt.xlabel(f"{embedding} 1")
    plt.ylabel(f"{embedding} 2")

    handles = [
        mpatches.Patch(color=colors[i], label=cat) for i, cat in enumerate(categories)
    ]
    plt.legend(
        handles=handles,
        loc="upper center",
        bbox_to_anchor=(0.5, -0.15),
        ncol=len(categories) // 2,
        fontsize="small",
        frameon=False,
    )
    plt.tight_layout()

    save_dir = Path(save_dir)
    plt.savefig(save_dir / f"brain_{embedding}_{color_key}.pdf", dpi=dpi)

    # plt.show()


# %%
color_key = "cell_type_coarse"
save_dir = "../results"
dpi = 300
palette = [
    "#88A0DCFF",
    "#381A61FF",
    "#7C4B73FF",
    "#ED968CFF",
    "#AB3329FF",
    "#E78429FF",
    "#F9D14AFF",
]
# source: https://emilhvitfeldt.github.io/r-color-palettes/discrete/MetBrewer/Archambault/

# %%
plot_embedding(adata, "DAE", "X_ctmc", color_key, save_dir, dpi, palette)

# %%
plot_embedding(adata, "UMAP", "X_umap", color_key, save_dir, dpi, palette)

# %%
plot_embedding(adata, "t-SNE", "X_tsne", color_key, save_dir, dpi, palette)

# %%
