import os
import re
import json
import torch
import wandb
import logging
import argparse
import warnings
import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt
from torch_geometric.utils import k_hop_subgraph

import models.model_utils as model_utils
import bench.bench_utils as bench_utils
from bench.metric import pearson, wl_kernel_subgraph_sim
from models.model import QueST_V1
warnings.filterwarnings("ignore")
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


class QueST_V1Trainer:
    def __init__(self, dataset="DLPFC", model_name=''):
        self.dataset = dataset
        self.model = None
        self.model_name = model_name
        self.model_param = None
        self.other_param = None
        self.wandb_run = None
        self.fig_folder = None
        self.ckpt_list = []

        logger.info(f"setting basic experiment parameters, loading adata, time: {bench_utils.get_time_str()}")
        if self.dataset == "DLPFC":
            self.save_folder = f"./results/current-run"
            self.model_folder = './results/current-run/model'
            self.query_id_list = ["151507"]
            self.ref_id_list = ["151508", "151509", "151510", "151669", "151670", "151671", "151672", "151673", "151674", "151675", "151676"]
            self.niche_prefix_list = ['Layer5_Layer6_50', 'Layer5_Layer6_100', 'Layer5_Layer6_200', 'Layer4_50', 'Layer4_100', 'Layer4_200', 'Layer3_Layer4_Layer5_50', 'Layer3_Layer4_Layer5_100', 'Layer3_Layer4_Layer5_200']
            self.niche_group_list = ['Layer5_Layer6', 'Layer4', 'Layer3_Layer4_Layer5']
            self.plot_batch_size, self.nrows, self.ncols, self.figsize = 12, 9, 4, (25, 60)
            self.adata_q = sc.read_h5ad(f"./bench/adata_query/DLPFC/{self.query_id_list[0]}.h5ad")
            self.adata_ref_list = [sc.read_h5ad(f"./data/DLPFC/adata_filtered/{sample_id}.h5ad") for sample_id in self.ref_id_list]
            self.adata_ref_truth_list = [sc.read_h5ad(f"./bench/adata_truth/DLPFC/{sample_id}.h5ad") for sample_id in self.ref_id_list]  # obs: 'Layer4_50_wdist', 'Layer4_50_query'
        elif self.dataset == "MouseOlfactoryBulbTissue":
            self.save_folder = f"./results/current-run"
            self.model_folder = './results/current-run/model'
            self.query_id_list = ["stereoseq"]
            self.ref_id_list = ["10x", "slidev2",]
            self.niche_prefix_list = ['GL_ONL_50', 'GL_ONL_100', 'GL_ONL_150', 'GCL_50', 'GCL_100', 'GCL_150', 'GCL_MCL_EPL_50', 'GCL_MCL_EPL_100', 'GCL_MCL_EPL_150']
            self.niche_group_list = ['GL_ONL', 'GCL', 'GCL_MCL_EPL']
            self.plot_batch_size, self.nrows, self.ncols, self.figsize = 3, 3, 3, (25, 30)
            self.adata_q = sc.read_h5ad(f"./bench/adata_query/MouseOlfactoryBulbTissue/{self.query_id_list[0]}.h5ad")
            self.adata_ref_list = [sc.read_h5ad(f"./data/MouseOlfactoryBulbTissue/adata_relabeled/{sample_id}.h5ad") for sample_id in self.ref_id_list]
            self.adata_ref_truth_list = [sc.read_h5ad(f"./bench/adata_truth/MouseOlfactoryBulbTissue/{sample_id}.h5ad") for sample_id in self.ref_id_list]  # obs: 'Layer4_50_wdist', 'Layer4_50_query'
        else:
            assert False
        self.get_params()
        self.test_str = "test/" if self.other_param['test'] else ""

    def train_model(self, feature_list, edge_ind_list, sub_node_sample_list, sub_edge_ind_sample_list, batch_label_list):
        logger.info(f"starting training model, time: {bench_utils.get_time_str()}")
        optimizer = self.model.build_optimizer()
        self.ckpt_list = []
        step = 0
        self.model.train()
        for epoch in range(self.model_param['epochs']):
            logger.info(f"Epoch: {epoch}, time: {bench_utils.get_time_str()}")
            for i, adata_ref in enumerate(self.adata_ref_list):
                logger.info(f"   sample {self.ref_id_list[i]}")
                feature, edge_index = feature_list[i], edge_ind_list[i]
                sub_node_list, sub_edge_ind_list = sub_node_sample_list[i], sub_edge_ind_sample_list[i]
                batch_label = batch_label_list[i]
                min_k, max_k, fix_portion = model_utils.get_shuffle_param(self.model_param, i, len(self.adata_ref_list))
                adata_shf, fixed_center, fixed_nodes, shuffle_center, feature_shf = model_utils.shuffle(adata_ref, dataset=self.dataset,
                                                                                                        plot=False, feature=feature,
                                                                                                        min_k=min_k, max_k=max_k, fix_portion=fix_portion,
                                                                                                        min_shuffle_ratio=self.model_param['min_shuffle_ratio'],
                                                                                                        max_shuffle_ratio=self.model_param['max_shuffle_ratio'])
                shuffle_center_sampled = np.random.choice(shuffle_center, size=len(fixed_center), replace=False)
                recon, logits_positive, logits_negative, logits_batch = self.model(feature, feature_shf, edge_index, sub_node_list, sub_edge_ind_list,
                                                                                   fixed_center, shuffle_center_sampled, batch_label)
                loss = self.model.compute_loss(feature, recon, logits_positive, logits_negative, logits_batch,
                                               batch_label, self.model_param, epoch, step, adata_ref.uns['library_id'])
                optimizer.zero_grad(), loss.backward(), optimizer.step()
                step += 1

            if (epoch + 1) % self.model_param['ckpt_epoch'] == 0 & self.other_param['save_model']:
                model_utils.save_checkpoint(self.model, self.model_folder, epoch, self.ckpt_list, self.wandb_run, self.model_name, self.test_str)

    def eval_ckpt(self, feature_list, edge_ind_list, sub_node_sample_list, sub_edge_ind_sample_list):
        logger.info(f"start evaluating checkpoints, time: {bench_utils.get_time_str()}")
        feature_q = model_utils.get_feature(self.adata_q, query=True, param=self.model_param)
        adj_q = self.adata_q.obsp['spatial_connectivities'].tocoo()
        edge_index_q = torch.tensor(np.vstack((adj_q.row, adj_q.col)), dtype=torch.int64).to(self.model_param['device'])

        metric_df_list = []
        for ckpt_epoch, model_path in self.ckpt_list:
            logger.info(f"loading model of epoch {ckpt_epoch}, time: {bench_utils.get_time_str()}")
            self.model.load_state_dict(torch.load(model_path))
            self.model.eval()
            with torch.no_grad():
                for niche_prefix in self.niche_prefix_list:
                    logger.info(f"querying niche {niche_prefix}, time: {bench_utils.get_time_str()}")
                    niche_mask = torch.tensor(self.adata_q.obs[f"{niche_prefix}_niche"] == 'Niche').to(self.model_param['device'])
                    niche_ind = np.where(self.adata_q.obs[f'{niche_prefix}_niche'] == 'Niche')[0]
                    for i, (adata_ref, ref_id) in enumerate(zip(self.adata_ref_list, self.ref_id_list)):
                        logger.info(f"processing sample {ref_id}")
                        feature_ref, edge_index_ref = feature_list[i], edge_ind_list[i]
                        sub_node_list_ref, sub_edge_ind_list_ref = sub_node_sample_list[i], sub_edge_ind_sample_list[i]
                        sim = model_utils.query(feature_q, feature_ref, edge_index_q, edge_index_ref, sub_node_list_ref, sub_edge_ind_list_ref,
                                                self.model, niche_mask, method=self.model_param['query_method'])
                        best_node = torch.argmax(sim).item()
                        best_subgraph, _, _, _ = k_hop_subgraph(best_node, self.model_param['model_k'], edge_index_ref)
                        best_subgraph = best_subgraph.cpu().numpy()
                        sim = sim.cpu().numpy()

                        # adata_ref.obs[f'{niche_prefix}_subgraph'] = pd.Categorical(['Query' if i in best_subgraph else 'Else' for i in range(adata_ref.shape[0])])
                        # adata_ref.obs[f'{niche_prefix}_subgraph_cell_type'] = pd.Categorical([adata_ref.obs["cell_type"][i] if i in best_subgraph else 'Else' for i in
                        #                                                                       range(adata_ref.shape[0])])
                        # adata_ref.obs[f'{niche_prefix}_sim'] = sim
                        adata_ref.obs[f'{ckpt_epoch}_{niche_prefix}_subgraph'] = pd.Categorical(['Query' if i in best_subgraph else 'Else' for i in range(adata_ref.shape[0])])
                        adata_ref.obs[f'{ckpt_epoch}_{niche_prefix}_subgraph_cell_type'] = pd.Categorical([adata_ref.obs["cell_type"][i] if i in best_subgraph else 'Else' for i in
                                                                                           range(adata_ref.shape[0])])
                        adata_ref.obs[f'{ckpt_epoch}_{niche_prefix}_sim'] = sim

                        if self.dataset == "DLPFC":
                            sim_truth = np.array(self.adata_ref_truth_list[i].obs[f"{niche_prefix}_wdist"])
                        elif self.dataset == "MouseOlfactoryBulbTissue":
                            sim_truth = np.array(self.adata_ref_truth_list[i].obs[f"{niche_prefix}_sim"])
                        else:
                            assert False
                        pearson_val = pearson(sim, sim_truth)
                        wdist = wl_kernel_subgraph_sim(adata_q=self.adata_q, adata_ref=adata_ref, ind_q=niche_ind, ind_ref=best_subgraph, dataset="DLPFC", query_ref=["151507", ref_id])

                        metric_df_list.append([pearson_val, 'Pearson Correlation', ref_id, niche_prefix,ckpt_epoch, self.model_name])
                        metric_df_list.append([wdist, 'Subgraph Wasserstein Distance', ref_id, niche_prefix, ckpt_epoch, self.model_name])

        metric_df = pd.DataFrame(metric_df_list, columns=['Metric Value', 'Metric', 'Sample ID', 'Niche Name', 'Epoch', 'Model Name'])
        print(metric_df)
        if self.other_param['save_metric']:
            metric_path = f"{self.save_folder}/metric/{self.test_str}metric.csv"
            logger.info(f"saving metric dataframe to {metric_path}")
            metric_df.to_csv(metric_path)
        if self.other_param['save_adata']:
            adata_folder = f"{self.save_folder}/adata-ref-queried/{self.test_str}"
            logger.info(f"saving adata ref with query results to {adata_folder}")
            for adata_ref, ref_id in zip(self.adata_ref_list, self.ref_id_list):
                try:
                    adata_ref.write_h5ad(f"{adata_folder}/{ref_id}.h5ad", compression='gzip')
                except:
                    del adata_ref.uns['spatial_neighbors']['params']['radius']
                    adata_ref.write_h5ad(f"{adata_folder}/{ref_id}.h5ad", compression='gzip')
        logger.info("plotting checkpoint results")
        for ckpt_epoch, _ in self.ckpt_list:
            for niche_group in self.niche_group_list:
                logger.info(f"plotting epoch={ckpt_epoch} query results for niche group {niche_group}, time: {bench_utils.get_time_str()}")
                fig, axs = plt.subplots(self.nrows, self.ncols, figsize=self.figsize)  # figsize is (col, row)
                sup_title = f"epoch={ckpt_epoch} {niche_group} cos sim"
                axs = axs.flatten()
                fig.suptitle(sup_title)

                # niche_list = [niche_prefix for niche_prefix in self.niche_prefix_list if niche_prefix.startswith(niche_group)]
                niche_list = [niche_prefix for niche_prefix in self.niche_prefix_list if re.match(rf'^{niche_group}_\d+$', niche_prefix)]
                print("checking niche list", niche_group, niche_list)
                niche_ind_within_group = 0
                for niche_prefix in niche_list:
                    for i, (adata_ref, ref_id) in enumerate(zip(self.adata_ref_list, self.ref_id_list)):
                        ax_ind = self.plot_batch_size * niche_ind_within_group + i
                        spot_size = model_utils.get_spot_size(self.dataset, ref_id)
                        axs[ax_ind].invert_yaxis()
                        # sc.pl.spatial(adata_ref, color=f"{self.model_param['query_method']}_{ckpt_epoch}_similarity",
                        #               ax=axs[ax_ind], spot_size=spot_size, show=False, title=f"{niche_prefix} {ref_id} epoch={ckpt_epoch}")
                        sc.pl.spatial(adata_ref, color=f"{ckpt_epoch}_{niche_prefix}_sim",
                                      ax=axs[ax_ind], spot_size=spot_size, show=False, title=f"{niche_prefix} {ref_id}")
                        # sc.pl.spatial(adata_ref, color=f"{niche_prefix}_sim",
                        #               ax=axs[ax_ind], spot_size=spot_size, show=False, title=f"{niche_prefix} {ref_id}")
                    axs[self.plot_batch_size * niche_ind_within_group + self.plot_batch_size - 1].axis('off')  # remove query sample frame
                    niche_ind_within_group += 1
                fig.tight_layout()
                # bench_utils.show_plot_with_timeout(1)
                plt.show()
                print(f"{self.fig_folder}/{self.test_str}{sup_title}.PNG")
                fig.savefig(f"{self.fig_folder}/{self.test_str}{sup_title}.PNG")
                fig.savefig(f"{self.save_folder}/fig/{sup_title}.PNG", dpi=300)


    def get_params(self):
        logger.info(f"processing arguments, time: {bench_utils.get_time_str()}")
        parser = argparse.ArgumentParser(description='Process Model Parameters.')
        model_param_group_name, other_param_group_name = 'model params', 'other params'
        model_param_group = parser.add_argument_group(model_param_group_name)

        model_param_group.add_argument('--model-path', default=None)
        model_param_group.add_argument('--device', type=str, default='cuda:0')
        model_param_group.add_argument('--library_key', type=str, default='library_id')

        # model architecture parameters
        model_param_group.add_argument('--model_k', default=3)
        model_param_group.add_argument('--shuffle_min_k', default=[3])
        model_param_group.add_argument('--shuffle_max_k', default=[3])
        model_param_group.add_argument('--fix_portion', default=[0.02])
        model_param_group.add_argument('--min_shuffle_ratio', default=0.25)
        model_param_group.add_argument('--max_shuffle_ratio', default=0.75)
        model_param_group.add_argument('--pooling', default='mean')
        model_param_group.add_argument('--enc-dims', default=[2048, 256, 32])
        model_param_group.add_argument('--dec-dims', default=[32])
        model_param_group.add_argument('--batch-discriminator-dims', default=[32, 32, 16])
        model_param_group.add_argument('--dec-batch-dim', default=2)
        model_param_group.add_argument('--batch-num', default=11)
        model_param_group.add_argument('--lr', default=0.001)
        model_param_group.add_argument('--norm', type=str, default='batchnorm')
        model_param_group.add_argument('--dropout', type=float, default=0.1)
        model_param_group.add_argument('--activation', type=str, default='relu')
        model_param_group.add_argument('--weight_decay', type=float, default=5e-4)
        model_param_group.add_argument('--epochs', type=int, default=50)
        model_param_group.add_argument('--ckpt-epoch', type=int, default=10)
        model_param_group.add_argument('--lbd-positive', type=int, default=1)
        model_param_group.add_argument('--lbd-negative', type=int, default=1)
        model_param_group.add_argument('--lbd-batch', type=int, default=0.1)
        model_param_group.add_argument('--lbd-kl', type=int, default=0.1)
        model_param_group.add_argument('--query-method', type=str, default='cosine')
        model_param_group.add_argument('--pca', type=bool, default=False)
        model_param_group.add_argument('--cca', type=bool, default=False)
        model_param_group.add_argument('--scale', type=bool, default=False)
        model_param_group.add_argument('--hvg', type=int, default=4000)
        model_param_group.add_argument('--min-count', type=int, default=10)

        other_param_group = parser.add_argument_group(other_param_group_name)
        other_param_group.add_argument('--test', type=bool, default=False)
        other_param_group.add_argument('--save-model', type=bool, default=False)
        other_param_group.add_argument('--save-metric', type=bool, default=True)
        other_param_group.add_argument('--save-adata', type=bool, default=True)
        other_param_group.add_argument('--gpu-id', type=str, default='1')

        args = parser.parse_args()
        args.batch_num = len(self.ref_id_list)
        logger.info(f"using GPU {args.gpu_id}")
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
        args_dict = vars(args)

        group_dicts = {}
        for group in parser._action_groups:
            group_name = group.title
            group_args = {a.dest: args_dict[a.dest] for a in group._group_actions if a.dest in args_dict}
            group_dicts[group_name] = group_args
            print(group_name, group_args)

        self.model_param = group_dicts[model_param_group_name]
        self.other_param = group_dicts[other_param_group_name]

    def run(self):
        wandb_run = wandb.init(project="TEST Model", name="TEST", dir="./wandb/test-model")
        self.wandb_run = wandb_run
        wandb.config.update(self.model_param)
        bench_utils.save_project_code(source_dir=".", output_zip=f"{wandb_run.dir}/code.zip", save_logger=logger)
        with open(f"{wandb_run.dir}/param.json", 'w') as f:
            json.dump(self.model_param, f)
        fig_folder = f"{wandb_run.dir}/fig"
        if not os.path.exists(fig_folder):
            os.mkdir(fig_folder)
        self.fig_folder = fig_folder
        model_utils.build_graphs(adata_list=self.adata_ref_list, dataset=self.dataset)
        adata_list = model_utils.preprocess_adata([self.adata_q] + self.adata_ref_list, param=self.model_param)
        self.adata_q, self.adata_ref_list = adata_list[0], adata_list[1:]
        feature_list, edge_ind_list, sub_node_sample_list, sub_edge_ind_sample_list, batch_label_list = model_utils.prepare_graph_data(self.adata_ref_list, self.model_param)
        self.model = QueST_V1(in_dim=feature_list[0].shape[1], param=self.model_param, logger=logger).to(self.model_param['device'])

        if self.model_param['model_path'] is not None:
            logger.info(f"loading model from {self.model_param['model_path']}, time: {bench_utils.get_time_str()}")
            self.model.load_state_dict(torch.load(self.model_param['model_path']))
            # TODO
        else:
            self.train_model(feature_list, edge_ind_list, sub_node_sample_list, sub_edge_ind_sample_list, batch_label_list)
            self.eval_ckpt(feature_list, edge_ind_list, sub_node_sample_list, sub_edge_ind_sample_list)


if __name__ == '__main__':
    trainer = QueST_V1Trainer(model_name='QueST_V1', dataset="DLPFC")
    trainer.run()







































