import numpy as np
import scanpy as sc
from sklearn.preprocessing import LabelEncoder

import torch
import torch.nn.functional as F

def annotation_baseline(adata, celltype_key, dataset='immune', model_name='scvi', hvg=None, seed=10, *kwargs):
    if hvg is not None:
        sc.pp.highly_variable_genes(adata, layer='counts',flavor='seurat_v3', subset=True)
    
    if dataset=='PBMC12K_raw': #dataset.endswith('raw'):
        adata.var['ensg'] = adata.var.index
        adata.var = adata.var.set_index('gene_symbols')
        adata.var_names_make_unique()
        sc.pp.filter_genes(adata, min_cells=1)
        sc.pp.filter_cells(adata, min_genes=1)
        adata.layers['counts'] = adata.X.copy()
        sc.pp.normalize_total(adata, target_sum=1e4, key_added='library_size')
        sc.pp.log1p(adata)

        import anndata as ad
        adata_processed = ad.read_h5ad('./data/PBMC12K_processed.h5ad', backed='r')
        adata.obs['split'] = adata_processed.obs['split']
        
    ref_adata = adata[adata.obs.split != 'test']
    query_adata = adata[adata.obs.split == 'test']

    celltype_le = LabelEncoder()
    y = celltype_le.fit_transform(adata.obs[celltype_key])
    y_train = celltype_le.transform(ref_adata.obs[celltype_key])
    y_test = celltype_le.transform(query_adata.obs[celltype_key])

    if model_name == 'actinn':
        from dance.modules.single_modality.cell_type_annotation.actinn import ACTINN
        from dance.transforms import FilterGenesPercentile
        from dance.transforms.base import AnnDataAdaptor

        print(adata.shape)
        sum_filter = AnnDataAdaptor(FilterGenesPercentile(mode='sum'))
        var_filter = AnnDataAdaptor(FilterGenesPercentile(mode='var'))
        adata = sum_filter(adata)
        adata = var_filter(adata)

        # n_genes = 2000
        # if adata.shape[1] > n_genes:
        #     sc.pp.highly_variable_genes(adata, layer='counts', n_top_genes=n_genes, subset=True, flavor='seurat_v3')
        print(adata.shape)
        train_idx = np.arange(len(adata))[(adata.obs.split != 'test').values]
        test_idx = np.arange(len(adata))[(adata.obs.split == 'test').values]
        x = torch.tensor(adata.X.A)
        y = F.one_hot(torch.tensor(y).long())
        x_train = x[train_idx]
        x_test = x[test_idx]

        # latent_dim = 10000 # if dataset == 'PBMC10k' else 2000
        model = ACTINN(device='cuda:0', hidden_dims=[100, 50, 25], lambd=0.001)
        num_epochs = 100  # ep: 100
        model.fit(x_train, y[train_idx], lr=1e-3, num_epochs=num_epochs, batch_size=128, print_cost=True)
        # lr 1e-4
        y_pred = model.predict(x_test).cpu()

    elif model_name == 'scdeepsort':
        from dance.modules.single_modality.cell_type_annotation.scdeepsort import ScDeepSort
        from dance.transforms.base import AnnDataAdaptor

        # model = ScDeepSort(dense_dim=400, hidden_dim=200, n_layers=1, species='human', tissue='unknown', dropout=0.1,
        #                    batch_size=500, device='cuda:0')

        adata = adata[adata.obs.sort_values('split', ascending=False).index]
        y_test = celltype_le.transform(adata[adata.obs.split == 'test'].obs[celltype_key])
        # import ipdb
        # ipdb.set_trace()
        train_size = sum(adata.obs.split != 'test')

        # dp: 0.1
        model = ScDeepSort(400, 200, 1, 'human', 'unknown', dropout=0.5, batch_size=500, device='cuda:0')
        preprocessing_pipeline = AnnDataAdaptor(model.preprocessing_pipeline(n_components=400), train_size=train_size)
        # if adata.shape[1] > 5000:
        #     sc.pp.highly_variable_genes(adata, layer='counts', n_top_genes=5000, subset=True, flavor='seurat_v3')
        adata = preprocessing_pipeline(adata)        
        print(adata.shape)

        g = adata.uns["CellFeatureGraph"]
        num_genes = adata.shape[1]
        gene_ids = torch.arange(num_genes)
        train_idx = np.arange(len(adata))[(adata.obs.split != 'test').values]
        test_idx = np.arange(len(adata))[(adata.obs.split == 'test').values]
        train_cell_ids = torch.LongTensor(train_idx) + num_genes
        test_cell_ids = torch.LongTensor(test_idx) + num_genes
        g_train = g.subgraph(torch.concat((gene_ids, train_cell_ids)))
        g_test = g.subgraph(torch.concat((gene_ids, test_cell_ids)))

        # Train and evaluate the model
        y_train = torch.LongTensor(y_train)
        
        # lr 1e-3, wd: 5e-4, ep: 300
        model.fit(g_train, y_train, epochs=300, lr=1e-3, weight_decay=0, val_ratio=0.1)
        y_pred = model.predict(g_test)

    elif model_name == 'celltypist':
        import celltypist

        mini_batch = True if len(ref_adata) > 1000 else False
        model = celltypist.train(ref_adata, y_train, feature_selection=True, n_jobs=64, 
                                 use_SGD=True, mini_batch=mini_batch)
        y_pred = celltypist.annotate(query_adata, model).predicted_labels.values[:,0]

    elif model_name == 'singlecellnet':
        import pySingleCellNet as pySCN

        if 'index' in adata.var.columns:
            del adata.var['index']
        adata.var.index = adata.var.index.str.replace('_', '|')
        adata.X = adata.layers['counts'].copy()
        ref_adata = adata[adata.obs.split != 'test']
        query_adata = adata[adata.obs.split == 'test']
        [cgenesA, xpairs, tspRF] = pySCN.scn_train(ref_adata, nTopGenes = 100, nRand = 100, nTrees = 1000 , 
                                                   nTopGenePairs = 100, dLevel = celltype_key, 
                                                   stratify=True, limitToHVG=True)
        adTest = pySCN.scn_classify(query_adata, cgenesA, xpairs, tspRF, nrand = 0)
        y_test = celltype_le.transform(adTest.obs[celltype_key])
        y_pred = celltype_le.transform(adTest.obs['SCN_class'])

    elif model_name == 'scgpt': # too long, keep it in notebook
        return None

    elif model_name == 'tosica':
        import TOSICA

        TOSICA.train(ref_adata, query_adata, gmt_path='human_gobp', label_name=celltype_key, epochs=10,
                     project=f'{dataset}_{seed}')
        model_weight_path = f'{dataset}_{seed}/model-9.pth'
        pred_df = TOSICA.pre(query_adata, model_weight_path=model_weight_path, project=f'{dataset}_{seed}')
        y_pred = celltype_le.transform(pred_df['Prediction'])
    
    elif model_name == 'scanvi':
        import scvi as scvi_
        from scvi.model import SCVI, SCANVI

        scvi_.settings.seed = seed

        # Defaults from SCVI github tutorials scanpy_pbmc3k and harmonization
        n_latent = 30
        n_hidden = 128
        n_layers = 2

        # copying to not return values added to adata during setup_anndata
        net_adata = ref_adata.copy()
        net_adata.X = net_adata.layers['counts'].copy()
        # net_adata[net_adata.obs.split == 'test'].obs[celltype_key] = 'Unknown'
        if 'HLCA' in dataset or dataset == 'PBMC12K':
            batch_key = 'batch'
        elif dataset == 'Pancreas':
            batch_key = 'tech'
        else:
            batch_key = 'donor_id'
        SCVI.setup_anndata(net_adata, batch_key=batch_key)

        vae = SCVI(
            net_adata,
            gene_likelihood="nb",
            n_layers=n_layers,
            n_latent=n_latent,
            n_hidden=n_hidden,
        )
        train_kwargs = {"train_size": 1.0}
        vae.train(**train_kwargs)
        scanvae = SCANVI.from_scvi_model(
            scvi_model=vae,
            labels_key=celltype_key,
            unlabeled_category="Unknown",
        )
        scanvae.train(max_epochs=20, n_samples_per_label=100)  # from scanvi tutorial
        query_adata.obs[celltype_key] = 'Unknown'
        query_adata.X = query_adata.layers['counts'].copy()
        y_pred = celltype_le.transform(scanvae.predict(query_adata))

    else:
        raise NotImplementedError(f"Unsupported model {model_name}")
    
    scores = annotation_eval(y_pred, y_test, num_classes=len(celltype_le.classes_))
    return scores

# def annotation_eval(pred_labels, true_labels, num_classes=None):
#     from torchmetrics.functional.classification import (
#         multiclass_f1_score, 
#         multiclass_accuracy,
#         multiclass_precision,
#         multiclass_recall
#     )
#     true_labels = true_labels if torch.is_tensor(true_labels) else torch.tensor(true_labels)
#     pred_labels = pred_labels if torch.is_tensor(pred_labels) else torch.tensor(pred_labels)
#     num_classes = len(true_labels.unique()) if num_classes is None else num_classes
#     acc = multiclass_accuracy(pred_labels, true_labels, num_classes).cpu().item()
#     f1_score = multiclass_f1_score(pred_labels, true_labels, num_classes).cpu().item()
#     precision = multiclass_precision(pred_labels, true_labels, num_classes).cpu().item()
#     recall = multiclass_recall(pred_labels, true_labels, num_classes).cpu().item()
#     return {'acc': acc, 'f1_score': f1_score, 'precision': precision, 'recall': recall}

def annotation_eval(pred_labels, true_labels, num_classes=None):
    from sklearn.metrics import f1_score, precision_score, recall_score, balanced_accuracy_score
    true_labels = true_labels if isinstance(true_labels, np.ndarray) else np.array(true_labels)
    pred_labels = pred_labels if isinstance(pred_labels, np.ndarray) else np.array(pred_labels)
    num_classes = len(true_labels.unique()) if num_classes is None else num_classes
    acc = balanced_accuracy_score(true_labels, pred_labels)
    f1 = f1_score(true_labels, pred_labels, average='macro')
    precision = precision_score(true_labels, pred_labels, average='macro')
    recall = recall_score(true_labels, pred_labels, average='macro')
    return {'acc': acc, 'f1_score': f1, 'precision': precision, 'recall': recall}