import numpy as np
import scanpy as sc
import anndata as ad
import baseline.scgen as scgen # XXX: https://github.com/theislab/scgen-reproducibility/

DEFAULT_CELL_TYPE_DICT = {
    "pbmc": ["CD4 T cells"], # ["CD4T"],
    "hpoly": ["TA.Early"],
    "salmonella": ["TA.Early"],
}

PERT_DICT = {
    "pbmc": ("ctrl", "stim"), # ("control", "stimulated"),
    "hpoly": ("Control", "Hpoly.Day10"),
    "salmonella": ("Control", "Salmonella"),
}

def prepare_split(adata, dataset, celltype_key, pert_key='condition', test_cell_types=None):
    if test_cell_types is None:
        test_cell_types = DEFAULT_CELL_TYPE_DICT[dataset]
    assert all([x in np.unique(adata.obs[celltype_key]) for x in test_cell_types])
    ctrl_key, stim_key = PERT_DICT[dataset]
    adata.obs['split'] = adata.obs['train_valid_split'].astype(str)
    adata.obs['split'][
        (adata.obs[celltype_key].isin(test_cell_types)) & 
        (adata.obs[pert_key] == stim_key)
    ] = 'test'

    return adata

def perturbation_baseline(adata, celltype_key, dataset, pert_key='condition', test_cell_types=None, model='scvi', 
                          seed=10, hvg_7000=True, path_to_save='./baseline/results', *kwargs):
    # if hvg is not None:
    #     sc.pp.highly_variable_genes(adata, layer='counts',flavor='seurat_v3', subset=True)
    if test_cell_types is None:
        test_cell_types = DEFAULT_CELL_TYPE_DICT[dataset]
    
    if hvg_7000 and model != 'cpa':
        adata = adata[:, adata.var.highly_variable]

    adata = prepare_split(adata, dataset, celltype_key, test_cell_types=test_cell_types)
    ctrl_key, stim_key = PERT_DICT[dataset]
    adata_train = adata[adata.obs['split'] == 'train']
    adata_valid = adata[adata.obs['split'] == 'valid']
    adata_train_valid = adata[adata.obs['split'] != 'test']
    adata_test_ctrl = adata[
        ((adata.obs[celltype_key].isin(test_cell_types)) & 
         (adata.obs[pert_key] == ctrl_key))
    ]
    adata_test_stim = adata[adata.obs['split'] == 'test']
    adata_test = adata[adata.obs[celltype_key].isin(test_cell_types)]
    adata_test.X = adata_test.X.A
    
    if model == 'scgen':
        network = scgen.VAEArith(x_dimension=adata_train.X.shape[1],
                                 z_dimension=100,
                                 alpha=5e-5,
                                 dropout_rate=0.2,
                                 learning_rate=1e-3,
                                 model_path=f"./models/scGen/{dataset}")
        network.train(adata_train, use_validation=True, valid_data=adata_valid, n_epochs=300, batch_size=32) # 300
        pred, delta = network.predict(adata=adata_train,
                                      adata_to_predict=adata_test_ctrl,
                                      conditions={"ctrl": ctrl_key, "stim": stim_key},
                                      cell_type_key=celltype_key,
                                      condition_key=pert_key,) # celltype_to_predict=test_cell_types[0])
        adata_pred = adata_test_ctrl.copy()
        adata_pred.X = pred
        adata_pred.obs[pert_key] = 'pred'
        adata_pred.obs.index = range(len(adata_pred))
        adata_pred.obs.index = adata_pred.obs.index.astype(str)
        adata_eval = adata_test.concatenate(adata_pred)
        
    
    elif model == 'cvae':
        network = scgen.CVAE(x_dimension=adata_train.X.shape[1], 
                             z_dimension=20, 
                             alpha=0.1, 
                             model_path="./models/CVAE/{dataset}")
        network.train(adata_train, use_validation=True, valid_data=adata_valid, n_epochs=100)
        fake_labels = np.ones((len(adata_test_ctrl), 1))
        pred = network.predict(adata_test_ctrl, fake_labels)
        adata_pred = adata_test_ctrl.copy()
        adata_pred.X = pred
        adata_pred.obs[pert_key] = 'pred'
        adata_pred.obs.index = range(len(adata_pred))
        adata_pred.obs.index = adata_pred.obs.index.astype(str)
        adata_eval = adata_test.concatenate(adata_pred)

    elif model == 'vec':
        def predict(cd_x, hfd_x, cd_y, p_type="unbiased"):
            if p_type == "biased":
                cd_ind = np.arange(0, len(cd_x))
                stim_ind = np.arange(0, len(hfd_x))
            else:
                eq = min(len(cd_x), len(hfd_x))
                cd_ind = np.random.choice(range(len(cd_x)), size=eq, replace=False)
                stim_ind = np.random.choice(range(len(hfd_x)), size=eq, replace=False)
            cd = np.average(cd_x[cd_ind, :], axis=0)
            stim = np.average(hfd_x[stim_ind, :], axis=0)
            delta = stim - cd
            predicted_cells = delta + cd_y
            return predicted_cells

        p_type = "unbiased"
        train_real_cd = adata_train_valid[adata_train_valid.obs[pert_key] == ctrl_key]            
        train_real_stimulated = adata_train_valid[adata_train_valid.obs[pert_key] == stim_key]
        if p_type == "unbiased":
            train_real_cd = scgen.util.balancer(train_real_cd, cell_type_key=celltype_key)
            train_real_stimulated = scgen.util.balancer(train_real_stimulated, cell_type_key=celltype_key)
        train_real_cd = train_real_cd.X
        train_real_stimulated = train_real_stimulated.X
        pred = predict(train_real_cd, train_real_stimulated, adata_test_ctrl.X.A, p_type=p_type)
        adata_pred = adata_test_ctrl.copy()
        adata_pred.X = pred
        adata_pred.obs[pert_key] = 'pred'
        adata_pred.obs.index = range(len(adata_pred))
        adata_pred.obs.index = adata_pred.obs.index.astype(str)
        adata_eval = adata_test.concatenate(adata_pred)

    elif model == 'pca_vec':
        from sklearn.decomposition import PCA

        def predict(pca, cd_x, hfd_x, cd_y, p_type="unbiased"):
            if p_type == "unbiased":
                eq = min(len(cd_x), len(hfd_x))
                cd_ind = np.random.choice(range(len(cd_x)), size=eq, replace=False)
                stim_ind = np.random.choice(range(len(hfd_x)), size=eq, replace=False)
            else:
                cd_ind = np.arange(0, len(cd_x))
                stim_ind = np.arange(0, len(hfd_x))
            cd = np.average(cd_x[cd_ind, :], axis=0)
            stim = np.average(hfd_x[stim_ind, :], axis=0)
            delta = stim - cd
            predicted_cells_pca = delta + cd_y
            predicted_cells = pca.inverse_transform(predicted_cells_pca)
            return predicted_cells

        p_type = "unbiased"
        train_real_cd = adata_train_valid[adata_train_valid.obs[pert_key] == ctrl_key]            
        train_real_stimulated = adata_train_valid[adata_train_valid.obs[pert_key] == stim_key]
        if p_type == "unbiased":
            train_real_cd = scgen.util.balancer(train_real_cd, cell_type_key=celltype_key)
            train_real_stimulated = scgen.util.balancer(train_real_stimulated, cell_type_key=celltype_key)
        train_real_cd = train_real_cd.X
        train_real_stimulated = train_real_stimulated.X

        pca = PCA(n_components=100)
        pca.fit(adata_train_valid.X.A)
        train_real_stimulated_PCA = pca.transform(train_real_stimulated)
        train_real_cd_PCA = pca.transform(train_real_cd)
        test_ctrl_PCA = pca.transform(adata_test_ctrl.X.A)
        pred = predict(pca, train_real_cd_PCA, train_real_stimulated_PCA, test_ctrl_PCA, p_type)
        adata_pred = adata_test_ctrl.copy()
        adata_pred.X = pred
        adata_pred.obs[pert_key] = 'pred'
        adata_pred.obs.index = range(len(adata_pred))
        adata_pred.obs.index = adata_pred.obs.index.astype(str)
        adata_eval = adata_test.concatenate(adata_pred)

    elif model == 'cpa': # TODO: cpa requires raw counts input
        import cpa
        import scvi
        
        scvi.settings.seed = seed
        # adata.X = adata.layers['counts'].copy()
        # adata_train_valid = adata[adata.obs['split'] != 'test']
        cpa.CPA.setup_anndata(adata, #adata_train_valid,
                      perturbation_key=pert_key,
                      control_group=ctrl_key,
                      categorical_covariate_keys=[celltype_key],
                      is_count_data=False,
                      max_comb_len=1,
                     )
        model_params = {
            'n_latent': 32,
            'recon_loss': 'zinb',
            'n_hidden_encoder': 1024,
            'n_layers_encoder': 3,
            'n_hidden_decoder': 1024,
            'n_layers_decoder': 3,
            'use_batch_norm_encoder': False,
            'use_layer_norm_encoder': True,
            'use_batch_norm_decoder': False,
            'use_layer_norm_decoder': False,
            'dropout_rate_encoder': 0.1,
            'dropout_rate_decoder': 0.2,
            'variational': False,
            'seed': seed
        }

        trainer_params = {
            'n_epochs_adv_warmup': 10,
            'n_epochs_mixup_warmup': 5,
            'n_epochs_pretrain_ae': 5,
            'mixup_alpha': 0.0,
            'lr': 0.0001,
            'wd': 4e-07,
            'adv_steps': 5,
            'reg_adv': 40.0,
            'pen_adv': 30.0,
            'adv_lr': 0.001,
            'adv_wd': 4e-07,
            'n_layers_adv': 2,
            'n_hidden_adv': 128,
            'use_batch_norm_adv': False,
            'use_layer_norm_adv': False,
            'dropout_rate_adv': 0.3,
            'step_size_lr': 10,
            'do_clip_grad': False,
            'adv_loss': 'focal',
            'gradient_clip_value': 1.0,
            'n_epochs_verbose': 5,
        }

        cpa_model = cpa.CPA(adata=adata, #adata_train_valid,
                split_key='split',
                train_split='train',
                valid_split='valid',
                test_split='test',
                **model_params,
               )
        cpa_model.train(max_epochs=2000,
            use_gpu=True,
            batch_size=512,
            plan_kwargs=trainer_params,
            early_stopping_patience=10,
            check_val_every_n_epoch=10,
            save_path=f'./models/CPA/{dataset}',
           )
        

        # adata_test_ctrl = adata_train_valid[
        #     ((adata_train_valid.obs[celltype_key].isin(test_cell_types)) & 
        #     (adata_train_valid.obs[pert_key] == ctrl_key))
        # ]
        # adata_test_ctrl.obs[pert_key] = stim_key
        cpa_model.predict(adata, batch_size=2048)

        pred = adata[adata.obs.split == 'test'].obsm['CPA_pred']
        adata_pred = adata_test_stim.copy()
        adata_pred.X = pred
        adata_pred.obs[pert_key] = 'pred'
        adata_pred.obs.index = range(len(adata_pred))
        adata_pred.obs.index = adata_pred.obs.index.astype(str)

        adata_test = adata[adata.obs[celltype_key].isin(test_cell_types)]
        adata_test.X = adata_test.X.A
        adata_eval = adata_test.concatenate(adata_pred)
        # sc.pp.normalize_total(adata_eval, target_sum=1e4, key_added='library_size')
        # sc.pp.log1p(adata_eval)

    else:
        raise NotImplementedError(f"Unsupported model {model}")
    
    scores = perturbation_eval(adata_eval, condition_key=pert_key, ctrl_key=ctrl_key, stim_key=stim_key, dataset=dataset, 
                      model=model, path_to_save=path_to_save, seed=seed)
    
    return scores

def perturbation_eval(adata, condition_key='condition', ctrl_key='Control', stim_key='stimulated', dataset='pbmc', 
                      model='scgen', path_to_save='./baseline/results', seed=10):
    adata_sub =  adata[adata.obs[condition_key] != 'pred']
    sc.tl.rank_genes_groups(adata_sub, groupby=condition_key, method="wilcoxon")
    diff_genes = adata_sub.uns["rank_genes_groups"]["names"][stim_key]
    scores = scgen.plotting.reg_mean_plot(
        adata,
        condition_key=condition_key,
        axis_keys={"x": "pred", "y": stim_key, "x1": ctrl_key},
        gene_list=diff_genes[:10],
        top_100_genes=diff_genes[:100],
        labels={"x": "predicted", "y": "ground truth", "x1": "ctrl"},
        path_to_save=f'{path_to_save}/{dataset}_{model}_seed{seed}_reg_mean.png',
        title=f'{dataset} {model}',
        show=False,
        legend=False,
    )
    return scores