import sys, os
sys.path.append(os.path.abspath(os.curdir))
import scanpy as sc
import csv
import numpy as np
import pickle
import pathlib
import json


import matplotlib.pyplot as plt
from matplotlib.colors import Normalize

from protein_ot import *
from utils import maximum_separation, p_value, empty_metric_constrain, never_backup_cst, correlation_cst, monoto_cst, get_center_point, get_points_std, prob_from_data
from scipy.stats import pearsonr, spearmanr
from sklearn.covariance import EmpiricalCovariance
import argparse

class RGCAtlasSctCombined(object):
    support_ot_raw_data_type = ['pca55', 'umap', 'removed_genes']
    support_ot_cost = ['l2', 'pearson']
    support_additional_metric_constrains = ['empty', 'never_backup']
    def __init__(self, file, monoto_gen_csv, monoto_gen_constraints_csv, 
                 age_set = ['E13', 'E14', 'E16', 'E18', 'P0', 'P2', 'P4', 'P7', 'P56'],
                 ot_raw_data_type = 'pca55',
                 result_dir = 'test',
                 ot_lambda_list = [5, 7, 5, 7, 5, 7, 8, 12],
                 ot_norm = False, ot_algorithm = 'Sinkhorn1',
                 ot_cost = 'l2',
                 path_filter_method = 'max_sep',
                 rna_lineage_path = 'living_traj_cluster_wise.json',
                 additional_metric_constraint = 'never_backup',
                 mono_algo_type = 'none',
                 minimum_gene_number = 10) -> None:
        self.data_file = file
        self.rna_data = self.read_h5ad(self.data_file)
        self.monoto_gen_csv = monoto_gen_csv
        self.monoto_gen_constraints_csv = monoto_gen_constraints_csv
        self.monoto_gen_idx, self.monoto_gen_constraints_idx = self.get_monoto_gen_idx(self.monoto_gen_csv, self.monoto_gen_constraints_csv)

        self.age_set = age_set
        self.mono_algo_type = mono_algo_type
        # assert ot_raw_data_type in self.support_ot_raw_data_type, f"Error ot data type, support: {self.support_ot_raw_data_type}"
        self.ot_raw_data_type = ot_raw_data_type
        self.minimum_gene_number = minimum_gene_number
        
        self.result_dir = pathlib.Path(result_dir)
        self.result_dir.mkdir(exist_ok=True)
        
        
        # optimal transport setting
        self.ot_lambda_list = ot_lambda_list
        self.ot_norm = ot_norm
        self.ot_algorithm = ot_algorithm
        self.ot_cost = ot_cost
        
        # draw setting
        self.path_filter_method = path_filter_method
        
        # metric setting
        self.rna_lineage_path = rna_lineage_path
        self.additional_metric_constraint = additional_metric_constraint
        
    def read_h5ad(self, file):
        print(f"Loading h5ad file: {file}...")
        full_set = sc.read_h5ad(file)
        if 'Age' not in full_set.obs and 'orig.ident' in full_set.obs:
            full_set.obs['Age'] = full_set.obs['orig.ident']
        elif 'Age' not in full_set.obs and 'time' in full_set.obs:
            full_set.obs['Age'] = full_set.obs['time']
        
        if 'features' not in full_set.var:
            full_set.var['features'] = full_set.var['gene_name']
        
        if 'seurat_clusters' not in full_set.obs:
            full_set.obs['seurat_clusters'] = full_set.obs['cell_type']
        
        return full_set
    
    def get_monoto_gen_idx(self, gen_csv, monoto_gen_constraints_csv):
        features = self.rna_data.var['features']
        check_idx = []
        file_name = gen_csv
        selected_genes = set()
        
        with open(file_name, 'r') as fin:
            reader = csv.reader(fin, delimiter=',')
            # print(reader)
            next(reader) # skip name
            for line in reader:
                # print(line)
                idx, = np.where(features.values == line[1])
                
                if len(idx) > 0:
                    selected_genes.add(line[1])
                    check_idx.append(idx)
        if len(check_idx) == 0:
            check_idx = []
        else:
            check_idx = np.stack(check_idx, axis=0).reshape(-1)
            check_idx.sort()
        
        
        constraints_data = {}
        with open(monoto_gen_constraints_csv, 'r') as fin:
            reader = csv.reader(fin, delimiter=',')
            next(reader)
            for line in reader:
                idx, = np.where(features.values == line[0])
                if len(idx) > 0:
                    if line[0] not in selected_genes:
                        constraints_data[int(idx[0])] = [line[3], float(line[2])]
        # monoto_gen_constraints = {}
        
        return check_idx, constraints_data
    
    def age_cluster_raw_data(self):
        cluster_points = {}
        cluster_raw_data = {}
        filtered_cluster_points = {}
        for age in self.age_set:
            cluster_points[age] = {}
            filtered_cluster_points[age] = {}
            cluster_raw_data[age] = []
            
        for i in range(len(self.rna_data)):
            cluster_raw_data[self.rna_data.obs['Age'][i]].append([self.rna_data.X[i], i])

        for age in self.age_set:
            for item in cluster_raw_data[age]:
                cluster_id = self.rna_data.obs['seurat_clusters'][item[1]]
                if cluster_id not in cluster_points[age]:
                    cluster_points[age][cluster_id] = []
                cluster_points[age][cluster_id].append(item[0])
        
        
        for age in cluster_points:
            for cluster_id in cluster_points[age]:
                # if len(cluster_points[age][cluster_id]) > self.minimum_gene_number:
                filtered_cluster_points[age][cluster_id] = np.stack(cluster_points[age][cluster_id], axis=0)
        return filtered_cluster_points
    
    def ot_raw_data(self, data_type: str = 'pca55'):
        group_data = {}
        for age in self.age_set:
            group_data[age] = []

        if data_type == 'raw_data':
            for i in range(len(self.rna_data)):
                group_data[self.rna_data.obs['Age'][i]].append([self.rna_data.X[i], i])
        elif data_type == 'umap':
            for i in range(len(self.rna_data)):
                group_data[self.rna_data.obs['Age'][i]].append([self.rna_data.obsm['X_umap'][i], i])

        else:
            raise NotImplementedError

        cluster_level_group = {}
        filtered_cluster_level_group = {}
        self.filtered_cluster_points = {}
        for age in self.age_set:
            cluster_level_group[age] = {}
            filtered_cluster_level_group[age] = {}
            self.filtered_cluster_points[age] = {}

        for age in group_data:
            for item in group_data[age]:
                cluster_id = self.rna_data.obs['seurat_clusters'][item[1]]
                if cluster_id not in cluster_level_group[age]:
                    cluster_level_group[age][cluster_id] = []
                cluster_level_group[age][cluster_id].append(item[0])
            
        for age in cluster_level_group:
            for cluster_id in cluster_level_group[age]:
                if len(cluster_level_group[age][cluster_id]) > self.minimum_gene_number:
                    # print(cluster_level_group[age][cluster_id])
                    filtered_cluster_level_group[age][cluster_id] = get_center_point(np.stack(cluster_level_group[age][cluster_id], axis=0))
                    self.filtered_cluster_points[age][cluster_id] = np.stack(cluster_level_group[age][cluster_id], axis=0)
        
        return filtered_cluster_level_group
    
    def load_additional_features(self, feat_path):
        group_data = {}
        for age in self.age_set:
            group_data[age] = []
        with open(feat_path, 'rb') as fin:
            feat_np = pickle.load(fin)
            print(feat_np.shape)
        fmin = feat_np.min()
        fmax = feat_np.max()
        pca_min = self.rna_data.obsm['X_pca'].min()
        pca_max = self.rna_data.obsm['X_pca'].max()
        for i in range(len(feat_np)):
            feat = (feat_np[i] - fmin) / (fmax - fmin)
            # pca55 = (self.rna_data.obsm['X_pca'][i] - pca_min) / (pca_max - pca_min)
            # feat_pca = np.concatenate([feat, pca55])
            feat_pca = feat
            group_data[self.rna_data.obs['Age'][i]].append([feat_pca, i])
        
        cluster_level_group = {}
        filtered_cluster_level_group = {}
        self.filtered_cluster_points = {}
        for age in self.age_set:
            cluster_level_group[age] = {}
            filtered_cluster_level_group[age] = {}
            self.filtered_cluster_points[age] = {}

        for age in group_data:
            for item in group_data[age]:
                cluster_id = self.rna_data.obs['seurat_clusters'][item[1]]
                if cluster_id not in cluster_level_group[age]:
                    cluster_level_group[age][cluster_id] = []
                cluster_level_group[age][cluster_id].append(item[0])
            
        for age in cluster_level_group:
            for cluster_id in cluster_level_group[age]:
                if len(cluster_level_group[age][cluster_id]) > self.minimum_gene_number:
                    filtered_cluster_level_group[age][cluster_id] = get_center_point(np.stack(cluster_level_group[age][cluster_id], axis=0))
                    self.filtered_cluster_points[age][cluster_id] = np.stack(cluster_level_group[age][cluster_id], axis=0)
                    # filtered_cluster_level_group[age][cluster_id] = np.stack(cluster_level_group[age][cluster_id], axis=0).mean(axis=0)
        
        return filtered_cluster_level_group
    
    
    def load_dist_matric(self, pre_age, age, prev_keys, cur_keys):
        dist_matrix_dir = self.result_dir.parent
        file_name = f'{self.ot_cost}_{self.ot_raw_data_type}_{pre_age}_{age}.pkl'
        dist_metrix = []
        if dist_matrix_dir.joinpath(file_name).exists():
            with open(dist_matrix_dir.joinpath(file_name), 'rb') as fin:
                dist_metrix = pickle.load(fin)
        else:
            for i in prev_keys:
                dist_metrix.append([])
                for j in cur_keys:
                    data_i = self.filtered_cluster_points[pre_age][i]
                    data_j = self.filtered_cluster_points[age][j]
                    if data_i.shape[1] > 200:
                        i_prob = np.ones((len(data_i), )) / len(data_i)
                        j_prob = np.ones((len(data_j), )) / len(data_j)
                    else:
                        i_prob = prob_from_data(data_i)
                        j_prob = prob_from_data(data_j)
                    ij_dist = ot.dist(data_i, data_j, metric='correlation')
                    dist_metrix[-1].append(ot.emd2(i_prob, j_prob, ij_dist))
            dist_metrix = np.array(dist_metrix)
            with open(dist_matrix_dir.joinpath(file_name), 'wb') as fout:
                pickle.dump(dist_metrix, fout)
        return dist_metrix
    
    def calc_ot_mapping(self, cluster_level_group, cluster_points):
        
        if self.ot_cost == 'l2':
            coorelation_func = l2
        elif self.ot_cost == 'pearson':
            coorelation_func = lambda x, y: pearsonr(x,y).statistic
        
        pre_age = None
        index = 0
        history_link = None
        if self.additional_metric_constraint == 'empty':
            metric_constraint = empty_metric_constrain
        elif self.additional_metric_constraint == 'never_backup':
            metric_constraint = never_backup_cst
            
        for age in self.age_set:
            if pre_age is None:
                pre_age = age
                continue
            
            ot_mapping, ot_cost = None, None
            lambda_value = self.ot_lambda_list[index]
            # prev data
            prev_keys = list(cluster_level_group[pre_age].keys())
            prev_keys.sort()
            prev_data = []
            for k in prev_keys:
                prev_data.append(cluster_level_group[pre_age][k])
            prev_data = np.stack(prev_data, axis=0)
            if history_link is None:
                history_link = {}
                for prev_cls in prev_keys:
                    history_link[prev_cls] = set()
            # cur data
            cur_keys = list(cluster_level_group[age].keys())
            cur_keys.sort()
            cur_data = []
            for k in cur_keys:
                cur_data.append(cluster_level_group[age][k])
            cur_data = np.stack(cur_data, axis=0)
            
            max_value = max(prev_data.max(), cur_data.max())
            min_value = min(prev_data.min(), cur_data.min())
            if self.ot_norm:
                norm_value = max_value
            else:
                norm_value = 1
            
            ot_data_1 = prev_data / norm_value
            ot_data_2 = cur_data / norm_value
            if self.ot_cost == 'l2':
                dist_metrix = ot.dist(ot_data_1, ot_data_2)
                dist_metrix = correlation_cst(dist_metrix, cluster_level_group, pre_age, age, prev_keys, cur_keys, coorelation_func)
                dist_metrix = monoto_cst(dist_metrix, cluster_points, pre_age, age, prev_keys, cur_keys, self.monoto_gen_constraints_idx, self.mono_algo_type)
                dist_metrix = metric_constraint(dist_metrix, history_link, prev_keys, cur_keys)
                ot_mapping, ot_cost = optimal_transport_lambda_detect(ot_data_1, 
                                                        ot_data_2, 
                                                        method=self.ot_algorithm, 
                                                        lambda_value=lambda_value, 
                                                        reverse_cost=False,
                                                        cost_matrix = dist_metrix)
            elif self.ot_cost == 'pearson':
                dist_metrix = []
                for i in range(len(ot_data_1)):
                    dist_metrix.append([])
                    for j in range(len(ot_data_2)):
                        dist_metrix[i].append(pearsonr(ot_data_1[i], ot_data_2[j]).statistic)
                        # print(spearmanr(ot_data_1[i], ot_data_2[j]))
                dist_metrix = 1-np.array(dist_metrix)
                dist_metrix = correlation_cst(dist_metrix, cluster_level_group, pre_age, age, prev_keys, cur_keys, coorelation_func)
                dist_metrix = monoto_cst(dist_metrix, cluster_points, pre_age, age, prev_keys, cur_keys, self.monoto_gen_constraints_idx, self.mono_algo_type)
                dist_metrix = metric_constraint(dist_metrix, history_link, prev_keys, cur_keys)
                ot_mapping, ot_cost = optimal_transport_lambda_detect(ot_data_1, 
                                                        ot_data_2, 
                                                        method=self.ot_algorithm, 
                                                        lambda_value=lambda_value, 
                                                        reverse_cost=False,
                                                        cost_matrix = dist_metrix)
            else:
                raise NotImplementedError
            print(ot_mapping.sum(), ot_cost)
            _, history_link = self.discrete_ot_online(ot_mapping, prev_keys, cur_keys, link_history = history_link)
            
            dump_dict = {
                "ot_mapping": ot_mapping,
                "ot_cost": ot_cost,
                "pre_index": np.array(prev_keys),
                "cur_index": np.array(cur_keys)
            }
            with open(self.result_dir.joinpath(f"{pre_age}_to_{age}_lamdba_{lambda_value}.pkl"), 'wb') as fout:
                print("Saving file:", f"{pre_age}_to_{age}_lamdba_{lambda_value}.pkl")
                pickle.dump(dump_dict, fout)
            pre_age = age
            index += 1
            
    def discrete_ot_online(self, ot_mapping, prev_keys, cur_keys, thres = 1e-4, link_history = None):

        cur_link_history = {}
        for cur_cls in cur_keys:
            cur_link_history[cur_cls] = set()
        # ot_mapping[ot_mapping < thres] = 0
        # keep fully connected
        for i in range(len(ot_mapping)):
            # print(ot_mapping[i], ot_mapping[i].sum())
            ot_mapping[i] /= ot_mapping[i].sum()
            if ot_mapping[i].sum() < 0.99:
                ot_mapping[i] /= ot_mapping[i].sum()
            # print(ot_mapping[i], ot_mapping[i].sum())
            assert ot_mapping[i].sum() > 0.99
            # print(ot_mapping[i].sum(), ot_mapping[i].max(), ot_mapping[i].min())
            # print(ot_mapping[i])
        n_prev = len(prev_keys)
        n_cur = len(cur_keys)
        max_weight_idx_per_cur = ot_mapping.argmax(axis=0)
        for i in range(n_cur):
            ot_mapping[max_weight_idx_per_cur[i], i] = ot_mapping[max_weight_idx_per_cur[i]].max()

        for i in range(len(ot_mapping)):
            target = ot_mapping[i]
            if self.path_filter_method == 'max_sep':
                sort_index = (-target).argsort()
                max_sep_idx = maximum_separation(target[sort_index])
                # print(max_sep_idx)
                prev_cls = prev_keys[i]
                for ee in sort_index[max_sep_idx+1:]:
                    ot_mapping[i][ee] = 0
                conn_idxs = np.where(ot_mapping[i] > 0)[0]
                for con_id in conn_idxs:
                    cur_cls = cur_keys[con_id]
                    cur_link_history[cur_cls].add(prev_cls)
                    for pp_cls in link_history[prev_cls]:
                        cur_link_history[cur_cls].add(pp_cls)
            if self.path_filter_method == 'p_value':
                prev_cls = prev_keys[i]
                removable_idx = p_value(ot_mapping[i], thres = 0.0001)
                for ee in removable_idx:
                    ot_mapping[i][ee] = 0
                conn_idxs = np.where(ot_mapping[i] > 0)[0]
                for con_id in conn_idxs:
                    cur_cls = cur_keys[con_id]
                    cur_link_history[cur_cls].add(prev_cls)
                    for pp_cls in link_history[prev_cls]:
                        cur_link_history[cur_cls].add(pp_cls)
        return ot_mapping, cur_link_history

    def discrete_ot_result(self, thres = 1e-4):
        for fp in self.result_dir.iterdir():
            if 'csv' in fp.__str__() or 'pdf' in fp.__str__() or 'json' in fp.__str__():
                continue
            with open(fp, 'rb') as fin:
                print(f"Disreate OTR result of: {fp}")
                ot_result = pickle.load(fin)
                for lidx in range(len(ot_result['ot_mapping'])):
                    ot_result['ot_mapping'][lidx] /= ot_result['ot_mapping'][lidx].sum()
                # n2mapping = len(ot_result['ot_mapping'][0])
                # print(1 / n2mapping)
                # ot_result['ot_mapping'][ot_result['ot_mapping'] < thres] = 0 # crup dist
                # print(ot_result['ot_mapping'].max(), ot_result['ot_mapping'].min(), ot_result['ot_mapping'].sum(), (ot_result['ot_mapping'] > 0).sum())
                # print((ot_result['ot_mapping'].sum(axis=1)).min())
                # csv file name
                fpn = fp.__str__()
                csv_path = f"{fpn[: fpn.rfind('.')]}.csv"
                contents = [['cluster_id', 'target_clusters', "target_weight", 'InnerWeight', 'map2N']]
                for i in range(len(ot_result['ot_mapping'])):
                    # idx = ot_result['pre_index'][i]
                    # ident = ident_np[idx]
                    cluster_id = ot_result['pre_index'][i]

                    target = ot_result['ot_mapping'][i]
                    target_idxs = np.where(target > 0)[0] # index in local mapping
                    # target_idxs = 
                    target_clusters = ot_result['cur_index'][target_idxs]
                    # print(target_idents)
                    target_weight = target[target_idxs]
                    inner_weight = target_weight / target_weight.sum()
                    contents.append([cluster_id, target_clusters, target_weight, inner_weight, len(target_idxs)])
                with open(csv_path, 'w') as fout:
                    print("Dumping discrete result in csv file:", csv_path)
                    writer = csv.writer(fout)
                    writer.writerows(contents)
                with open(f"{csv_path}.pkl", 'wb') as fout:
                    print(f"Dumping discrete result in pkl file: {csv_path}.csv")
                    pickle.dump(contents[1:], fout)
    
    def get_age_to_cluster(self):
        age_to_cluster = {}
        age_to_cluster_count = {}
        for i in range(len(self.rna_data.X)):
            age = self.rna_data.obs['Age'][i]
            cls_id = self.rna_data.obs['seurat_clusters'][i]
            if age not in age_to_cluster_count:
                age_to_cluster_count[age] = {}
            if cls_id not in age_to_cluster_count[age]:
                age_to_cluster_count[age][cls_id] = 0
            age_to_cluster_count[age][cls_id] += 1
        
        for age in age_to_cluster_count:
            for cls_id in age_to_cluster_count[age]:
                if age not in age_to_cluster:
                    age_to_cluster[age] = []
                if age_to_cluster_count[age][cls_id] > self.minimum_gene_number:
                    age_to_cluster[age].append(cls_id)

        for age in age_to_cluster:
            age_to_cluster[age].sort()
        return age_to_cluster
    
    def age_level_mapping(self, draw = True, cluster_info_file = 'Clusters_info.json',
                          draw_font_family = 'serif', draw_font_color='darkred',
                          draw_font_weight = 'normal', draw_font_size= 8, 
                          draw_max_height = 160, p_t_value=0.1):
        # with open(cluster_info_file, 'r') as fin:
        #     age_to_cluster = json.load(fin)
        # self.rna_data
        age_to_cluster = self.get_age_to_cluster()
        
        font = {'family': draw_font_family,
            'color':  draw_font_color,
            'weight': draw_font_weight,
            'size': draw_font_size,
            }
        overall_mapping = {}
        for i in range(len(self.age_set) - 1):
            n_prev = len(age_to_cluster[self.age_set[i]])
            n_cur = len(age_to_cluster[self.age_set[i+1]])
            print(f"Pre:{self.age_set[i]}, {n_prev}, Cur:{self.age_set[i+1]}, {n_cur}")
            normalize = Normalize(vmin=0, vmax=max(n_prev,n_cur))
            cmap = plt.cm.RdBu
            font['size'] = max(30 - max(n_prev, n_cur), 4)
            max_height =  max(n_prev, n_cur) * 5
            # prev node
            prev_y_step_size = max_height / n_prev
            prev_xs = [0 + prev_y_step_size * i for i in range(n_prev)]
            prev_ys = [-max_height + prev_y_step_size * i for i in range(n_prev)]
            # cur node
            cur_y_step_size = max_height / n_cur
            cur_xs = [prev_xs[-1] // 2 + cur_y_step_size * i for i in range(n_cur)]
            cur_ys = [-max_height + cur_y_step_size * i for i in range(n_cur)]
            if draw:
                plt.scatter(cur_xs, cur_ys, c = age_to_cluster[self.age_set[i+1]], s=(57 / max(n_prev, n_cur)))
                plt.scatter(prev_xs, prev_ys, c = age_to_cluster[self.age_set[i]], s=(57 / max(n_prev, n_cur)))
                plt.text(np.mean(prev_xs)-10, np.mean(prev_ys)+10, self.age_set[i])
                plt.text(np.mean(cur_xs)+10, np.mean(cur_ys)-10, self.age_set[i+1])
            
            # load connection information
            for fp in self.result_dir.iterdir():
                if f'{self.age_set[i]}_to_{self.age_set[i+1]}' not in fp.__str__() or 'csv.pkl' not in fp.__str__():
                    continue
                fpn = fp.__str__()
                pic_path = f"{fpn[: fpn.find('csv.pkl')]}pdf"
                # prev to cur dict
                prev2cur_dict = {}
                for cls_id in age_to_cluster[self.age_set[i]]:
                    prev2cur_dict[cls_id] = []
                count_dict = {}
                prev_count = {}
                cur_count = {}
                prev2cur_weight = {}
                cur2prev_weight = {}
                with open(fp, 'rb') as fin:
                    reader = pickle.load(fin)
                    n_rna = len(reader)
                    for line in reader:
                        prev_count[line[0]] = len(line[1])
                        if line[0] not in prev2cur_weight:
                            prev2cur_weight[line[0]] = {}
                        for k in range(len(line[1])):
                            target_cls = line[1][k]
                            if target_cls not in cur_count:
                                cur_count[target_cls] = 0
                            if target_cls not in cur2prev_weight:
                                cur2prev_weight[target_cls] = {}
                            if line[0] not in prev2cur_dict:
                                prev2cur_dict[line[0]] = []
                            prev2cur_dict[line[0]].append(target_cls)
                            cur_count[target_cls] += 1
                            prev2cur_weight[line[0]][target_cls] = line[2][k]
                            cur2prev_weight[target_cls][line[0]] = line[2][k]
                prev2cur_dict = self.edge_cutoff(prev2cur_dict, 
                                 prev_count, 
                                 cur_count, 
                                 prev2cur_weight, 
                                 cur2prev_weight,
                                 self.path_filter_method,
                                 p_t_value)
            overall_mapping[self.age_set[i]] = prev2cur_dict
            if draw:
                for prev in prev2cur_dict:
                    prev_index = age_to_cluster[self.age_set[i]].index(prev)
                    for target in prev2cur_dict[prev]:
                        target_index = age_to_cluster[self.age_set[i+1]].index(target)
                        plt.plot((prev_xs[prev_index], cur_xs[target_index]), 
                                 (prev_ys[prev_index], cur_ys[target_index]), 
                                 c=cmap(normalize(prev_index)))
                plt.title(fpn[fpn.rfind('/')+1:])
                plt.savefig(pic_path)
                plt.close()
        return overall_mapping

    def edge_cutoff(self, prev2cur_dict, prev_count, 
                    cur_count, prev2cur_weight, 
                    cur2prev_weight,
                    path_filter_method, p_t_value = 0.1):
        prev_cls_list = list(prev_count.keys())
        cur_cls_list = list(cur_count.keys())
        for prev in prev_cls_list:
            if prev_count[prev] > 2:
                removable_cur_cls = []
                for cur in prev2cur_dict[prev]:
                    if cur_count[cur] > 1:
                        removable_cur_cls.append(cur)
                if len(removable_cur_cls) > 2:
                    removable_weights = []
                    for cur in removable_cur_cls:
                        removable_weights.append(prev2cur_weight[prev][cur])
                    removable_weights = np.array(removable_weights)
                    if path_filter_method == 'p_value':
                        removable_idx = p_value(removable_weights, thres = p_t_value)
                        for ee in removable_idx:
                            remove_cls = removable_cur_cls[ee]
                            prev_count[prev] -= 1
                            cur_count[cur] -= 1
                            prev2cur_dict[prev].remove(remove_cls)
                            prev2cur_weight[prev].pop(remove_cls)
                            cur2prev_weight[remove_cls].pop(prev)
                    elif path_filter_method == 'max_sep':
                        # print(removable_weights)
                        sort_index = (-removable_weights).argsort()
                        if len(removable_weights) > 2:
                            max_sep_idx = maximum_separation(removable_weights[sort_index])
                            for ee in sort_index[max_sep_idx+1:]:
                                remove_cls = removable_cur_cls[ee]
                                prev_count[prev] -= 1
                                cur_count[cur] -= 1
                                prev2cur_dict[prev].remove(remove_cls)
                                prev2cur_weight[prev].pop(remove_cls)
                                cur2prev_weight[remove_cls].pop(prev)
                    else:
                        raise NotImplementedError
        return prev2cur_dict
    
    def collect_lineage(self, overall_mapping, dep = 8):
        for age in overall_mapping:
            for cls_id in overall_mapping[age]:
                overall_mapping[age][cls_id] = list(overall_mapping[age][cls_id])
        
        json_file = self.result_dir.joinpath(self.rna_lineage_path)
        with open(json_file, 'w') as fout: # clear file content
            pass
        depth = 0
        count = [0]
        trajs = []
        def dps_mapping(i_dep, cls_id, traj, mapping, trajs_list):
            for cc in mapping[self.age_set[i_dep]][cls_id]:
                traj.append(cc)
                if i_dep == dep-2:
                    count[0] += 1
                    trajs_list.append(str(np.stack(traj).tolist()) + '\n')
                    if len(trajs_list) > 10000:
                        with open(json_file, '+a', newline='\n') as fout:
                            fout.writelines(trajs_list)
                        trajs_list = []
                else:
                    dps_mapping(i_dep+1, cc, traj, mapping, trajs_list)
                traj.pop(-1)
        
        for c1 in overall_mapping[self.age_set[0]]:
            traj = [c1]
            dps_mapping(0, c1, traj, overall_mapping, trajs)
        with open(json_file, '+a', newline='\n') as fout:
            fout.writelines(trajs)
        print("Total count:", count)
        
    def ground_truth_path(self):
        def rule_one(path, start_nodes: list, legal_nodes: list):
            start = False
            legal = True
            for node in path:
                if start:
                    if node not in legal_nodes:
                        legal = False
                if node in start_nodes:
                    start = True
            return legal

        def rule_two(path):
            legal_nodes = [36, 37]
            legal = True
            for i in range(len(path)-2):
                if path[i] == 41:
                    if path[i+1] != 8 or path[i+2] not in legal_nodes:
                        legal = False
            return legal
        file_name = self.result_dir.joinpath(self.rna_lineage_path)
        new_file_name = self.result_dir.joinpath("Filtered_path.json")
        new_lineages = []
        with open(file_name, 'r') as fin:
            content = fin.readlines()
            path_num = len(content)
            right_num = 0
            for line in content:
                path = eval(line)
                legal_for_one = rule_one(path, start_nodes=[36, 37], legal_nodes=[36, 37])
                legal_for_two = rule_two(path)
                if legal_for_one and legal_for_two:
                    right_num += 1
                    new_lineages.append(line)
                else:
                    pass
            print(f"{right_num}/{path_num}, {(right_num / path_num) * 100}")
        with open(new_file_name, 'w') as fout:
            fout.writelines(new_lineages)
        return right_num, path_num, right_num / path_num
        
    def lineage_correlation(self, overall_mapping, cluster_points):
        # 'pearson', 'spearman']
        if self.ot_cost == 'l2':
            coorelation_func = l2
        elif self.ot_cost == 'pearson':
            coorelation_func = lambda x, y: pearsonr(x,y).statistic
        empty_mapping = []
        illega_mapping = []
        
        ret_results = {}
        mean_values = []
        # cluster_level_group
        for i in range(len(self.age_set) - 1):
            cur_age = self.age_set[i]
            target_age = self.age_set[i+1]
            target_cls_list = list(cluster_points[target_age].keys())
            legal_count = 0
            overal_count = 0
            for cur in overall_mapping[cur_age]:
                within_lineage = []
                out_lineage = []
                for target in target_cls_list:
                    if cur not in cluster_points[cur_age]:
                        continue
                    if target not in cluster_points[target_age]:
                        continue
                    X = get_center_point(cluster_points[cur_age][cur])
                    Y = get_center_point(cluster_points[target_age][target])
                    # X = cluster_points[cur_age][cur].mean(axis=0)
                    # Y = cluster_points[target_age][target].std(axis=0)
                    # print(X.shape, Y.shape)
                    results = []
                    if target in overall_mapping[cur_age][cur]:
                        within_lineage.append(coorelation_func(X, Y))
                    else:
                        out_lineage.append(coorelation_func(X, Y))
                # print(within_lineage)
                for corr in within_lineage:
                    overal_count += 1
                    # print(corr, np.max(out_lineage))
                    if corr > np.max(out_lineage):
                        legal_count += 1
            if overal_count == 0:
                print(f"{cur_age}->{target_age}, ({legal_count}/{overal_count}) 0")
                ret_results[f"{cur_age}->{target_age}"] = 0
                mean_values.append(0)
            else:
                print(f"{cur_age}->{target_age}, ({legal_count}/{overal_count}) {legal_count / overal_count: .5f}")
                ret_results[f"{cur_age}->{target_age}"] = legal_count / overal_count
                mean_values.append(legal_count / overal_count)
        print(f"Mean: {np.mean(mean_values) * 100}")
        ret_results["Mean"] = np.mean(mean_values) * 100
        return ret_results
        
    def monoto_check_per_path(self, cluster_points, check_idx, gene_std_tolerance = 1, std_threshold = 0.9):
        file_name = self.result_dir.joinpath(self.rna_lineage_path)
        def monoto_check(path):
            # collect cluster center point
            cluster_center_points = []
            cluster_center_var = []
            for i in range(len(path)):
                cluster_center_points.append(get_center_point(cluster_points[self.age_set[i]][path[i]])[check_idx])            
                # cluster_center_points.append(np.mean(cluster_points[self.age_set[i]][path[i]], axis=0)[check_idx])
                cluster_center_var.append(cluster_points[self.age_set[i]][path[i]].std(axis=0)[check_idx])
     
            mono_path_features = np.stack(cluster_center_points) # 
            # print(mono_path_features.shape)
            # print(np.mean(mono_path_features, axis=0).shape
            mono_value = mono_path_features[1:] - mono_path_features[:-1]
            mono_path_relax = np.empty_like(mono_value)
            for i in range(len(mono_value)):
                mono_path_relax[i] = np.min(cluster_center_var[i:i+2], axis=0)
            # print(mono_value.shape, cluster_center_var.shape)
            # increase check
            inc_check = (mono_value >= (-1 * mono_path_relax * gene_std_tolerance)).sum(axis=0)
            increase_num = (inc_check > 7).sum()
            # decrease check
            des_check = (mono_value <= mono_path_relax * gene_std_tolerance).sum(axis=0)
            decrease_num = (des_check > 7).sum()
            legal_num = min(increase_num + decrease_num, len(check_idx))
            result = legal_num / len(check_idx)
            return result
        
        bin_data = []
        with open(file_name, 'r') as fin:
            content = fin.readlines()
            for line in content:
                path = eval(line)
                ratio = monoto_check(path)
                bin_data.append(ratio)
        return bin_data
    
    def monoto_check_per_gen(self, cluster_points, check_idx, gene_std_tolerance = 1):
        file_name = self.result_dir.joinpath(self.rna_lineage_path)
        def monoto_check_per_gen(path_set, check_idx):
            # collect cluster center point
            all_path_points = []
            all_path_stds = []
            for i in range(len(path_set)):
                path = path_set[i]
                cluster_center_points = []
                cluster_center_var = []
                for j in range(len(path)):
                    # cluster_center_points.append(cluster_points[self.age_set[j]][path[j]].mean(axis=0)[check_idx])
                    cluster_center_var.append(cluster_points[self.age_set[j]][path[j]].std(axis=0)[check_idx])
                    cluster_center_points.append(get_center_point(cluster_points[self.age_set[j]][path[j]])[check_idx])
                    # cluster_center_var.append(get_points_std(cluster_points[self.age_set[j]][path[j]])[check_idx])
                # cluster_center_points = np.stack(cluster_center_points)
                all_path_points.append(np.stack(cluster_center_points))
                all_path_stds.append(np.stack(cluster_center_var))
            # mono_path_features = np.stack(cluster_center_points)
            ret_data = []
            all_path_points = np.stack(all_path_points) # (2585, 9, 230)
            all_path_stds = np.stack(all_path_stds)
            for gen_idx in range(len(check_idx)):
                mono_data = all_path_points[:, :, gen_idx] # (2585, 9)
                mono_std = all_path_stds[:, :, gen_idx] # (2585, 9)
                mono_data = mono_data.transpose(1, 0) # (9, 2585)
                mono_std = mono_std.transpose(1, 0) # (9, 2585)
                mono_value = mono_data[1:] - mono_data[:-1] # (8, 2585)
                mono_relax = np.empty_like(mono_value)
                for i in range(len(mono_value)):
                    mono_relax[i] = np.min(mono_std[i:i+2], axis=0)
                # increase check
                inc_check = (mono_value >= (-1 * mono_relax * gene_std_tolerance)).sum(axis=0)
                increase_num = (inc_check > 7).sum()
                # decrease check
                des_check = (mono_value <= mono_relax * gene_std_tolerance).sum(axis=0)
                decrease_num = (des_check > 7).sum()
                legal_num = min(increase_num + decrease_num, len(path_set))
                ret_data.append(legal_num / len(path_set))
            return ret_data
        
        bin_data = []
        with open(file_name, 'r') as fin:
            content = fin.readlines()
            
            path_set = []
            for line in content:
                path = eval(line)
                path_set.append(path)
            bin_data = monoto_check_per_gen(path_set, check_idx)
        return bin_data
    
    def draw_monoto(self, bin_data, title: str, x_label: str, y_label: str, per: str):
        plt.hist(bin_data, bins=100)
        plt.title(title)
        plt.xlabel(x_label)
        plt.ylabel(y_label)
        img_name = f"{self.ot_raw_data_type}_{self.ot_algorithm}_{self.ot_cost}_{self.path_filter_method}_{per}.pdf"
        img_name = self.result_dir.joinpath(img_name)
        plt.savefig(img_name.__str__())
        plt.close()
        
    def dump_json(self, obj, fn):
        full_fn = self.result_dir.joinpath(fn)
        with open(full_fn, 'w') as fout:
            json.dump(obj, fout, indent=2)
        
def l2(x, y):
    return -(np.sqrt((x-y)**2)).sum()

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--path-filter-method', type=str, default='max_sep')
    parser.add_argument('--data-type', type=str, default='pca55')
    parser.add_argument('--ot-cost', type=str, default='pearson')
    parser.add_argument('--result-dir', type=str, default='test')
    parser.add_argument('--metric-constraint', type=str, default='empty')
    parser.add_argument('--ot-algorithm', type=str, default='Sinkhorn1')
    parser.add_argument('--mono-algo-type', type=str, default='empty')
    parser.add_argument('--additional-feature-path', type=str, default='')
    parser.add_argument('--mono-gene-list', type=str, default='Age_asso_genes.csv')
    parser.add_argument('--cell-data', type=str, default='demo.h5ad')
    parser.add_argument('--mono-gene', type=str, default='all_the_way_up_genes.csv')
    args = parser.parse_args()
    print(args)
    import time
    s_time = time.time()
    result_dict = {}
    data_type = args.data_type
    test = RGCAtlasSctCombined(
        args.cell_data,
        monoto_gen_csv = args.mono_gene,
        monoto_gen_constraints_csv= args.mono_gene_list,
        age_set= [0,1,2,3],
        ot_lambda_list=[10, 10, 10, 10],
        path_filter_method=args.path_filter_method,
        ot_raw_data_type=data_type,
        ot_cost=args.ot_cost,
        result_dir = args.result_dir,
        ot_norm = True,
        additional_metric_constraint = args.metric_constraint,
        mono_algo_type=args.mono_algo_type,
        minimum_gene_number= 10
    )
    age_cluster_raw_data = test.age_cluster_raw_data()
    if data_type in ['pca55', 'umap', 'removed_genes']:
        ot_raw_data = test.ot_raw_data(data_type=data_type)
    else:
        ot_raw_data = test.load_additional_features(args.additional_feature_path)
    ot_mapping = test.calc_ot_mapping(ot_raw_data, age_cluster_raw_data)
    test.discrete_ot_result()
    overall_mapping = test.age_level_mapping(cluster_info_file='data/Clusters_info.json',
                                             p_t_value=0.0001)
    with open(test.result_dir.joinpath('ot_overall_mapping.pkl'), 'wb') as fout:
        pickle.dump(overall_mapping, fout)
    lineage = test.collect_lineage(overall_mapping, len(test.age_set))
    
    gtp_result = test.ground_truth_path()
    result_dict['ground_truth_path'] = {}
    result_dict['ground_truth_path']['right_num'] = gtp_result[0]
    result_dict['ground_truth_path']['path_num'] = gtp_result[1]
    result_dict['ground_truth_path']['ratio'] = gtp_result[2] * 100
    test.ot_cost = 'pearson'
    lineage_correlation = test.lineage_correlation(overall_mapping, age_cluster_raw_data)
    result_dict['lineage_correlation'] = lineage_correlation
    
    # bin_per_path = test.monoto_check_per_path(age_cluster_raw_data, test.monoto_gen_idx, 0.5)
    # test.draw_monoto(bin_per_path, 'Per lineage', 'BIN (%)', 'Path Number', 'lineage')
    # bin_per_gen = test.monoto_check_per_gen(age_cluster_raw_data, test.monoto_gen_idx, 0.5)
    # test.draw_monoto(bin_per_gen, 'Per Gen', 'BIN (%)', 'Gen Number', 'gen')
    # result_dict['monoto'] = {}
    # result_dict['monoto']['per_path'] = np.mean(bin_per_path) * 100
    # result_dict['monoto']['per_gen'] = np.mean(bin_per_gen) * 100
    
    # print(f"[Monoto] Per-path: {result_dict['monoto']['per_path']}")
    # print(f"[Monoto] Per-gen: {result_dict['monoto']['per_gen']}")
    
    result_dict['running_time'] = time.time()-s_time
    test.dump_json(result_dict, "finally_result.json")