import numpy as np
import pandas as pd
import scanpy as sc
import logging
import anndata
import warnings
import psutil
import os
import time
import networkx as nx
import bench.bench_utils as bench_utils
import matplotlib.pyplot as plt
from collections import Counter
from models.baseline import SlatBaseline, Baseline, STAGATEBaseline, GraphSTBaseline, QueSTV1Baseline

warnings.filterwarnings("ignore")
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


class Experiment:
    def __init__(self, dataset='DLPFC', library_key='library_id', BaseLineClass=Baseline, device=None, save_query=False, test=False):
        self.test = test
        self.dataset = None
        self.save_folder = None
        self.query_adata_folder = None
        self.adata_truth_folder = None
        self.device = device
        self.library_key = library_key
        self.save_query = save_query
        self.adata_truth_list = None
        save_folder, cell_type_key, sample_id_query, sample_id_ref, niche_prefix_list, adata_query, adata_ref_list = self.load_benchmark(dataset)
        self.cell_type_key = cell_type_key
        self.sample_id_query = sample_id_query
        self.sample_id_ref = sample_id_ref
        self.niche_prefix_list = niche_prefix_list
        self.adata_query = adata_query
        self.adata_ref_list = adata_ref_list
        logger.info(f"query id: {sample_id_query}, ref list: {sample_id_ref}, niche prefix list: {niche_prefix_list}")
        self.base = BaseLineClass(adata_q=self.adata_query, adata_ref_list=self.adata_ref_list, query_sample_id=self.sample_id_query[0], device=device,
                                  ref_sample_id_list=self.sample_id_ref, dataset=self.dataset, cell_type_key=self.cell_type_key, library_key=library_key,
                                  save_folder=save_folder, save_query=save_query)
        logger.info(f"BaseLineClass: {BaseLineClass}")

    def load_benchmark(self, dataset='DLPFC'):
        logger.info(f"loading benchmark: {dataset}, TEST mode: {self.test}, time: {bench_utils.get_time_str()}")
        self.query_adata_folder = f"./bench/adata_query/{dataset}"
        self.dataset = dataset
        if self.dataset == "DLPFC":  # spot_size = [5]
            self.save_folder = "./results/compete-dlpfc"
            cell_type_key = 'cell_type'
            sample_id_query = ['151507']
            sample_id_ref = ["151508", "151509", "151510", "151669", "151670", "151671", "151672", "151673", "151674", "151675", "151676"]
            niche_prefix_list = ['Layer4_50', 'Layer4_100', 'Layer4_200', 'Layer5_Layer6_50', 'Layer5_Layer6_100', 'Layer5_Layer6_200', 'Layer3_Layer4_Layer5_50', 'Layer3_Layer4_Layer5_100', 'Layer3_Layer4_Layer5_200']
            logger.info(f"query id: {sample_id_query}, ref id: {sample_id_ref}")
            logger.info(f"niche list: {niche_prefix_list}")
            logger.info(f"reading adata, time: {bench_utils.get_time_str()}")
            adata_query = sc.read_h5ad(f"{self.query_adata_folder}/{sample_id_query[0]}.h5ad")
            adata_ref_folder = "./data/DLPFC/adata_filtered"
            adata_ref_list = [sc.read_h5ad(f"{adata_ref_folder}/{sample_id}.h5ad") for sample_id in sample_id_ref]

        elif self.dataset == "MouseOlfactoryBulbTissue":  # spot_size = [75, 25]
            self.save_folder = "./results/compete-mobt"
            cell_type_key = 'cell_type'
            sample_id_query = ['stereoseq']
            sample_id_ref = ["10x", 'slidev2']
            niche_prefix_list = ['GCL_100', 'GCL_150', 'GCL_50', 'GCL_MCL_EPL_100', 'GCL_MCL_EPL_150', 'GCL_MCL_EPL_50', 'GL_ONL_100', 'GL_ONL_150', 'GL_ONL_50']
            logger.info("reading adata")
            adata_query = anndata.read_h5ad(f"{self.query_adata_folder}/{sample_id_query[0]}.h5ad")
            adata_ref_folder = "./data/MouseOlfactoryBulbTissue/adata_relabeled"
            adata_ref_list = [anndata.read_h5ad(f"{adata_ref_folder}/{sample_id}.h5ad") for sample_id in sample_id_ref]
        else:
            assert False
        return self.save_folder, cell_type_key, sample_id_query, sample_id_ref, niche_prefix_list, adata_query, adata_ref_list

    def run(self, query_k=[3]):
        for niche_prefix in self.niche_prefix_list:
            logger.info(f"performing query for niche {niche_prefix}, time: {bench_utils.get_time_str()}")
            subgraph_dict, sim_dict = self.base.query(k=query_k, niche_prefix=niche_prefix)  # main query function, change ref adata and return query cell ind for all samples

        if self.save_query:
            self.base.save_ref_data_with_query_res(test=self.test)


def run_experiment(dataset='DLPFC', BaselineClass=None, cpu_cores=list(range(100)), save_query=True, test=False, query_k=[3]):
    p = psutil.Process(os.getpid())
    p.cpu_affinity(cpu_cores)
    exp = Experiment(BaseLineClass=BaselineClass, dataset=dataset, device="1", save_query=save_query, test=test)
    exp.run(query_k=query_k)


run_experiment(BaselineClass=QueSTV1Baseline)
