import os
os.environ["OMP_NUM_THREADS"] = "8"
os.environ["OPENBLAS_NUM_THREADS"] = "8"
os.environ["MKL_NUM_THREADS"] = "8"
os.environ["VECLIB_MAXIMUM_THREADS"] = "8"
os.environ["NUMEXPR_NUM_THREADS"] = "8"

import argparse
import random
import os
import torch
import numpy as np
import anndata as ad
import os.path as osp

# def scnormalize(x, target_sum=1e4, log1p=True, eps=1e-8):
#     x = x * target_sum / (x.sum(1, keepdim=True) + eps)
#     if log1p:
#         x = torch.log1p(x)
#     return x

DATADIR = './data'
TASK_DATASET_MODEL_DICT = {
    'integration':{
        'dataset': {
            'HLCAsubset': 'HLCA_zstd_sub.h5ad',
            'Pancreas': 'Pancreas_processed.h5ad',
            'LungAtlas': 'LungAtlas_processed.h5ad',
            'LueckenImmune': 'LueckenImmune_processed.h5ad',
        },
        'model': ['scvi', 'scanvi', 'scanorama', 'mnn', 'harmony', 'bbknn', 'scanvi_test',
                  'trvae', 'desc', 'combat', 'saucie'],
    },
    'denoising':{
        'dataset': {
            'Jurkat': 'Jurkat_processed.h5ad',
            'PBMC1K': 'PBMC1K_processed.h5ad',
            '293T': '293T_processed.h5ad',
            'HLCAsubset': 'HLCA_zstd_sub.h5ad',
        },
        'model': ['dca', 'magic', 'alra', 'scgnn2', 'deepimpute', 'saver', 'saver-x', 'scvi'], # TODO: R script
    },
    'perturbation':{
        'dataset': {
            'pbmc': 'pbmc_processed.h5ad',
            'salmonella': 'salmonella_processed.h5ad',
            'hpoly': 'hpoly_processed.h5ad',
        },
        'model': ['scgen', 'cvae', 'vec', 'pca_vec', 'cpa'], # TODO: need more baselines
    },
    'annotation':{
        'dataset': {
            'HLCA_sub': 'HLCA_zstd_sub.h5ad',
            'Immune': 'Immune_processed.h5ad',
            'PBMC106K': 'PBMC106K_processed.h5ad', 
            'PBMC12K': 'PBMC12K_processed.h5ad', # TODO: one more dataset?
            'PBMC12K_raw': 'PBMC12K.h5ad',
            'Zheng68K': 'Zheng68K_processed.h5ad',
            'Pancreas': 'Pancreas_processed.h5ad',
            'HLCAnaw': 'HLCA_zstd_Nawijin_GRO-09.h5ad',
            'Immune_sub': 'Immune_sub.h5ad',
            'Brain': 'Brain_processed.h5ad',
            'Liver': 'Liver_processed.h5ad',
            'Lung': 'LungAtlas_processed.h5ad',
        },
        'model': ['actinn', 'celltypist', 'singlecellnet', 'scgpt', 'tosica', 'geneformer', 
                  'scbert', 'scdeepsort', 'scanvi'],
    },
}

def set_seed(seed=10):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--task", type=str, default='integration') # integration, denoising, perturbation, annotation
    parser.add_argument("--dataset", type=str, default='Pancreas')
    parser.add_argument("--model", type=str, default='scvi')
    parser.add_argument("--seed", type=int, default=10)
    parser.add_argument("--hvg", type=int, default=None)
    parser.add_argument("--mask_ratio", type=int, default=10)
    parser.add_argument("--n", type=int, default=5)
    parser.add_argument("--topk", type=int, default=3)
    parser.add_argument("--subset", action='store_true')
    parser.add_argument("--inverse", action='store_true')
    parser.add_argument("--target_only", action='store_true')
    parser.add_argument("-fs", "--few_shot", action='store_true')
    parser.add_argument("-iok", "--iterate_on_k", action='store_true')
    parser.add_argument("-ios", "--iterate_on_seed", action='store_true')
    parser.add_argument("-iomr", "--iterate_on_mask_ratio", action='store_true')
    parser.add_argument("-csv", "--save_csv", action='store_true')
    args = parser.parse_args()
    set_seed(args.seed)

    assert args.model in TASK_DATASET_MODEL_DICT[args.task]['model']
    if args.task == 'denoising':
        assert args.dataset in TASK_DATASET_MODEL_DICT['denoising']['dataset'].keys()
        datafile = osp.join(DATADIR, f'{args.dataset}_{args.mask_ratio}_processed.h5ad')
        print(datafile)
    else:
        datafile = osp.join(DATADIR, TASK_DATASET_MODEL_DICT[args.task]['dataset'][args.dataset])
    adata = ad.read_h5ad(datafile)

    if args.task == 'integration':
        from baseline.integration import integration_baseline

        integraion_keys = {
            'HLCAsubset': ('batch', 'cell_type'),
            'Pancreas': ('tech', 'celltype'),
            'LungAtlas': ('batch', 'cell_type'),
            'LueckenImmune': ('batch', 'final_annotation'),
        }
        
        batch_key, celltype_key = integraion_keys[args.dataset]
        scores = integration_baseline(adata, batch_key, celltype_key, model=args.model, hvg=args.hvg, seed=args.seed)
        print(f"Task: {args.task}, model: {args.model}, dataset: {args.dataset}, seed: {args.seed}")
        print(f"scores: {scores}")

    elif args.task == 'denoising':
        from baseline.denoising import denoising_baseline

        if args.iterate_on_mask_ratio:
            mask_ratios = [10, 30, 50, 70, 90, 95]
            for mr in mask_ratios:
                assert args.dataset in TASK_DATASET_MODEL_DICT['denoising']['dataset'].keys()
                datafile = osp.join(DATADIR, f'{args.dataset}_{args.mask_ratio}_processed.h5ad')
                adata = ad.read_h5ad(datafile)
                scores = denoising_baseline(adata, dataset=args.dataset, model=args.model, hvg=args.hvg)
                print(f"Task: {args.task}, model: {args.model}, dataset: {args.dataset}, seed: {args.seed}, MR: {mr}")
                print(f"scrores: {scores}")
        elif args.iterate_on_seed:
            seeds = [10, 11, 12, 13, 14]
            rmse = []
            corr = []
            corr_glob = []
            r2 = []
            corr_all = []
            rmse_all = []
            for sd in seeds:
                set_seed(sd)
                scores = denoising_baseline(adata, dataset=args.dataset, model=args.model, hvg=args.hvg, seed=sd)
                print(f"Task: {args.task}, model: {args.model}, dataset: {args.dataset}, seed: {sd}, MR: {args.mask_ratio}")
                print(f"scrores: {scores}")
                rmse.append(scores['denoise_rmse_normed'])
                corr.append(scores['denoise_corr_normed'])
                corr_glob.append(scores['denoise_global_corr_normed'])
                r2.append(scores['denoise_global_r2_normed'])
                rmse_all.append(scores['denoise_rmse_normed_all'])
                corr_all.append(scores['denoise_corr_normed_all'])
            print(f"Task: {args.task}, model: {args.model}, dataset: {args.dataset}")
            print(f"rmse: {rmse}")
            print(f"corr: {corr}")
            print(f"corr_glob: {corr_glob}")
            print(f"r2: {r2}")
            print(f"rmse_all: {rmse_all}")
            print(f"corr_all: {corr_all}")
            print('######################')
            print(f"rmse: {np.nanmean(rmse)} +/- {np.nanstd(rmse)}")
            print(f"corr: {np.nanmean(corr)} +/- {np.nanstd(corr)}")
            print(f"corr_glob: {np.nanmean(corr_glob)} +/- {np.nanstd(corr_glob)}")
            print(f"r2: {np.nanmean(r2)} +/- {np.nanstd(r2)}")
            print(f"rmse_all: {np.nanmean(rmse_all)} +/- {np.nanstd(rmse_all)}")
            print(f"corr_all: {np.nanmean(corr_all)} +/- {np.nanstd(corr_all)}")
        else:
            scores = denoising_baseline(adata, dataset=args.dataset, model=args.model, hvg=args.hvg)
            print(f"Task: {args.task}, model: {args.model}, dataset: {args.dataset}, seed: {args.seed}")
            print(f"scrores: {scores}")

    elif args.task == 'perturbation': # scgen-reproducibility
        from baseline.perturbation import perturbation_baseline

        celltype_key = 'cell_label'
        if args.iterate_on_seed:
            seeds = [10, 11, 12, 13, 14]
            r2_delta = []
            r2_delta_de = []
            r2 = []
            r2_de = []
            for sd in seeds:
                set_seed(sd)
                scores = perturbation_baseline(adata, celltype_key, args.dataset, model=args.model, seed=sd, hvg_7000=True)
                print(f"Task: {args.task}, model: {args.model}, dataset: {args.dataset}, seed: {sd}")
                print(f"scrores: {scores}")
                r2.append(scores['R2'])
                r2_de.append(scores['R2_top_100'])
                r2_delta.append(scores['R2_delta'])
                r2_delta_de.append(scores['R2_delta_top_100'])
            print(f"Task: {args.task}, model: {args.model}, dataset: {args.dataset}")
            print(f"r2_delta: {r2}")
            print(f"r2_delta_de: {r2_de}")
            print(f"r2_delta: {r2_delta}")
            print(f"r2_delta_de: {r2_delta_de}")
            print('######################')
            print(f"r2: {np.nanmean(r2)} +/- {np.nanstd(r2)}")
            print(f"r2_de: {np.nanmean(r2_de)} +/- {np.nanstd(r2_de)}")
            print(f"r2_delta: {np.nanmean(r2_delta)} +/- {np.nanstd(r2_delta)}")
            print(f"r2_delta_de: {np.nanmean(r2_delta_de)} +/- {np.nanstd(r2_delta_de)}")
        else:
            perturbation_baseline(adata, celltype_key, args.dataset, model=args.model, seed=args.seed, hvg_7000=True)

    elif args.task == 'annotation':
        from baseline.annotation import annotation_baseline

        annotation_keys = {
            'HLCA_sub': 'cell_type',
            'Immune': 'Manually_curated_celltype',
            'PBMC106K': 'cell_types',
            'PBMC12K': 'str_labels',
            'PBMC12K_raw': 'str_labels',
            'Zheng68K': 'celltype',
            'Pancreas': 'celltype',
            'HLCAnaw': 'cell_type',
            'Immune_sub': 'cell_type',
            'Brain': 'cell_type',
            'Liver': 'cell_type',
            'Lung': 'cell_type',
        }

        THRESHOLD = {
            'HLCA_sub': 1000,
            'Immune_sub': 1000,
            'Brain': 3000,
            'Liver': 600,
        }

        celltype_key = annotation_keys[args.dataset]
        if args.few_shot:
            if args.dataset == 'HLCAnaw':
                PRETRAIN_CELL_TYPES = [
                    'respiratory basal cell', 'CD8-positive, alpha-beta T cell', 'CD4-positive, alpha-beta T cell'
                ]
                cell_type_counts_order = adata.obs['cell_type'].value_counts().index.tolist()
                cell_type_counts_order = [x for x in cell_type_counts_order if x not in PRETRAIN_CELL_TYPES]
                target_cell_types = cell_type_counts_order[:args.topk]

            elif args.dataset in ['HLCA_sub', 'Immune_sub', 'Brain', 'Liver']:
                threshold = THRESHOLD[args.dataset]
                cell_type_counts = adata.obs['cell_type'].value_counts()
                cell_types_candidates = cell_type_counts[cell_type_counts < threshold].index.tolist()
                target_cell_types = cell_types_candidates[:args.topk]
            
            if args.few_shot and args.save_csv:
                import pandas as pd

                res_df = pd.DataFrame(columns=['f1_cell_type', 'acc_cell_type', 'precision_cell_type', 'recall_cell_type',
                                               'n_cells', 'n_cell_types', 'seed'])
                seeds = range(10)
                n_cell_types = range(2, 11)
                n_cells = range(1, 11)
                for k in n_cell_types:
                    target_cell_types = cell_types_candidates[:k]
                    for n in n_cells:
                        preserve_idx = np.concatenate([
                            v for k, v in adata.uns[f'split_partial_{n}']['preserve_idx'].items()
                            if k in target_cell_types
                        ])
                        train_idx = preserve_idx
                        adata.obs['split'] = 'test'
                        adata.obs['split'][train_idx] = 'train'
                        if args.target_only:
                            adata_target = adata[adata.obs['cell_type'].isin(target_cell_types)].copy()
                        else:
                            adata_target = adata.copy()
                        print(f'target_cell_types: {target_cell_types}')
                        seeds = range(10)
                        for sd in seeds:
                            set_seed(sd)
                            scores = annotation_baseline(adata_target, celltype_key, dataset=args.dataset, model_name=args.model, hvg=args.hvg, seed=sd)
                            print(f"Task: {args.task}, model: {args.model}, dataset: {args.dataset}, seed: {sd}")
                            print(f"scores: {scores}")
                            res_df.loc[len(res_df)] = [scores['f1_score'], scores['acc'], scores['precision'], scores['recall'], n, k, sd]
                print(res_df)
                res_df.to_csv(f'./results/{args.dataset}_{args.model}_thres{threshold}_n={n}_top{k}_fewshot.csv')

            else:
                preserve_idx = np.concatenate([
                    v for k, v in adata.uns[f'split_partial_{args.n}']['preserve_idx'].items()
                    if k in target_cell_types
                ])
                # train_idx = np.where(~adata.obs['cell_type'].isin(target_cell_types))[0]
                # train_idx = sorted(np.concatenate([train_idx, preserve_idx]))
                train_idx = preserve_idx
                adata.obs['split'] = 'test'
                adata.obs['split'][train_idx] = 'train'
                adata.obs['test_flag'] = False
                adata.obs['test_flag'][adata.obs['cell_type'].isin(target_cell_types)] = True
                if args.target_only:
                    adata = adata[adata.obs['cell_type'].isin(target_cell_types)]
                print(f'target_cell_types: {target_cell_types}')
        
        if args.iterate_on_seed:
            acc = []
            f1 = []
            pre = []
            rec = []
            seeds = range(10, 15)
            for sd in seeds:
                scores = annotation_baseline(adata, celltype_key, dataset=args.dataset, model_name=args.model, hvg=args.hvg, seed=sd)
                print(f"Task: {args.task}, model: {args.model}, dataset: {args.dataset}, seed: {sd}")
                print(f"scores: {scores}")
                acc.append(scores['acc'])
                f1.append(scores['f1_score'])
                pre.append(scores['precision'])
                rec.append(scores['recall'])
            print(f"Task: {args.task}, model: {args.model}, dataset: {args.dataset}")
            print(f"acc: {acc}")
            print(f"f1_score: {f1}")
            print(f"precision: {pre}")
            print(f"recall: {rec}")
            print('######################')
            print(f"acc: {np.nanmean(acc)} +/- {np.nanstd(acc)}")
            print(f"f1_score: {np.nanmean(f1)} +/- {np.nanstd(f1)}")
            print(f"precision: {np.nanmean(pre)} +/- {np.nanstd(pre)}")
            print(f"recall: {np.nanmean(rec)} +/- {np.nanstd(rec)}")

        if not (args.few_shot and args.save_csv):
            scores = annotation_baseline(adata, celltype_key, dataset=args.dataset, model_name=args.model, hvg=args.hvg, seed=args.seed)
            print(f"Task: {args.task}, model: {args.model}, dataset: {args.dataset}, seed: {args.seed}")
            print(f"scores: {scores}")

    else:
        raise NotImplementedError(f"Unsupported task {args.task}")
    