from typing import Any, Dict, List, Optional
from anndata import AnnData
import scanpy as sc
import numpy as np
import pandas as pd


def _check_gene_existence(adata: AnnData, perturb_key: str):
    all_genes = set(list(adata.var_names))
    obs_genes = {}
    unfound_genes = {}
    for idx, item in adata.obs[perturb_key].items():
        if item == '':
            continue
        targets = item.split(',')
        for gene in targets:
            if gene in all_genes:
                if gene not in obs_genes:
                    obs_genes[gene] = 1
                else:
                    obs_genes[gene] += 1
            else:
                if gene not in unfound_genes:
                    unfound_genes[gene] = 1
                else:
                    unfound_genes[gene] += 1
    
    presence = pd.Series(data=np.zeros_like(adata.var["means"]), index=adata.var.index)
    for gene in list(obs_genes.keys()):
        presence[gene] += 1
    adata.var["targeted"] = presence


def _filter_hvg(adata: AnnData,
                perturb_key: str,
                n_top_genes: int,
                span: float,
                hvg_rank_thd: int,
                gene_perturbation: bool):
    # Figure out which target genes are present
    if gene_perturbation:
        _check_gene_existence(adata, perturb_key)

    # Filter out genes
    sc.pp.highly_variable_genes(
        adata,
        layer="counts",
        flavor='seurat_v3',
        n_top_genes=n_top_genes,
        span=span
    )
    to_keep = adata.var["highly_variable_rank"] <= hvg_rank_thd
    if gene_perturbation:
        to_keep = np.logical_or(to_keep, adata.var["targeted"])

    return adata[:, to_keep].copy()


def pp_adata(adata: AnnData,
             perturb_key: str = 'perturbation',
             gene_perturbation: bool = True,
             n_top_genes: int = 1000,
             min_genes: int = 10,
             min_cells: int = 100,
             target_sum: int = 1e5,
             span: float = 0.2,
             hvg_rank_thd: int = 1500) -> AnnData:
    print('Preprocessing AnnData.')
    n_cell, n_gene = adata.shape
    sc.pp.filter_cells(adata, min_genes=min_genes)
    sc.pp.filter_genes(adata, min_cells=min_cells)
    adata.layers["counts"] = adata.X.copy()
    sc.pp.normalize_total(adata, target_sum=target_sum)
    sc.pp.log1p(adata)

    
    # Sanity check: make sure perturbed genes exist
    adata = _filter_hvg(
        adata,
        perturb_key,
        n_top_genes,
        span,
        hvg_rank_thd,
        gene_perturbation
    )
    print(f'Filtered out {adata.n_obs}/{n_cell} cells, {adata.n_vars}/{n_gene} genes.')
    return adata
    