import pathlib
import numpy as np
import csv
from pcurve import PrincipalCurve
import scanpy as sc

def get_center_point(points):
    return np.median(points, axis=0)

def get_points_std(points, point_percent = 0.9):
    center_point = get_center_point(points)
    points_diff = (points - center_point)
    points_dist = np.linalg.norm(points_diff, axis=1)
    sort_rank = points_dist.argsort()[: int(len(points_dist) * point_percent)]
    return points[sort_rank].std(axis=0)

def ground_truth_path(result_root: pathlib.PosixPath, file_name):
    def rule_one(path, start_nodes: list, legal_nodes: list):
        start = False
        legal = True
        for node in path:
            if start:
                if node not in legal_nodes:
                    legal = False
            if node in start_nodes:
                start = True
        return legal

    def rule_two(path):
        legal_nodes = [36, 37]
        legal = True
        for i in range(len(path)):
            if path[i] == 41:
                if path[i+1] != 8 or path[i+2] not in legal_nodes:
                    legal = False
        return legal
    new_file_name = result_root.joinpath("Filtered_path.json")
    new_lineages = []
    with open(file_name, 'r') as fin:
        content = fin.readlines()
        path_num = len(content)
        right_num = 0
        for line in content:
            path = eval(line)
            legal_for_one = rule_one(path, start_nodes=[36, 37], legal_nodes=[36, 37])
            legal_for_two = rule_two(path)
            if legal_for_one and legal_for_two:
                right_num += 1
                new_lineages.append(line)
            else:
                pass
        print(f"{right_num}/{path_num}, {(right_num / path_num) * 100}")
    with open(new_file_name, 'w') as fout:
        fout.writelines(new_lineages)
    return right_num, path_num, right_num / path_num

def lineage_correlation(overall_mapping, cluster_points, coorelation_func, age_list):
    # 'pearson', 'spearman'
    empty_mapping = []
    illega_mapping = []
    ret_results = {}
    mean_values = []
    # cluster_level_group
    for i in range(len(age_list) - 1):
        cur_age = age_list[i]
        target_age = age_list[i+1]
        target_cls_list = list(cluster_points[target_age].keys())
        legal_count = 0
        overal_count = 0
        for cur in overall_mapping[cur_age]:
            within_lineage = []
            out_lineage = []
            for target in target_cls_list:
                if cur not in cluster_points[cur_age]:
                    continue
                if target not in cluster_points[target_age]:
                    continue
                # X = get_center_point(cluster_points[cur_age][cur])
                # Y = get_center_point(cluster_points[target_age][target])
                X = cluster_points[cur_age][cur].mean(axis=0)
                Y = cluster_points[target_age][target].mean(axis=0)
                # print(X.shape, Y.shape)
                results = []
                if target in overall_mapping[cur_age][cur]:
                    within_lineage.append(coorelation_func(X, Y))
                else:
                    out_lineage.append(coorelation_func(X, Y))
            # print(within_lineage)
            for corr in within_lineage:
                overal_count += 1
                # print(corr, np.max(out_lineage))
                if corr > np.max(out_lineage):
                    legal_count += 1
        if overal_count == 0:
            print(f"{cur_age}->{target_age}, ({legal_count}/{overal_count}) 0")
            ret_results[f"{cur_age}->{target_age}"] = 0
            mean_values.append(0)
        else:
            print(f"{cur_age}->{target_age}, ({legal_count}/{overal_count}) {legal_count / overal_count: .5f}")
            ret_results[f"{cur_age}->{target_age}"] = legal_count / overal_count
            mean_values.append(legal_count / overal_count)
    print(f"Mean: {np.mean(mean_values) * 100}")
    ret_results["Mean"] = np.mean(mean_values) * 100
    return ret_results

def monoto_check_per_path(lineages, cluster_points, check_idx, age_list, gene_std_tolerance = 1):
    def monoto_check(path):
        # collect cluster center point
        cluster_center_points = []
        cluster_center_var = []
        for i in range(len(path)):
            # print(cluster_points[age_seq[i]][path[i]].shape)
            if path[i] not in cluster_points[age_list[i]]:
                continue
            cluster_center_points.append(np.mean(cluster_points[age_list[i]][path[i]], axis=0)[check_idx])
            # cluster_center_var.append(cluster_points[age_list[i]][path[i]].std(axis=0)[check_idx])
            # cluster_center_points.append(get_center_point(cluster_points[age_list[i]][path[i]])[check_idx])
            cluster_center_var.append(get_points_std(cluster_points[age_list[i]][path[i]])[check_idx])
        
        # cluster_center_var = []
        # for i in range(len(path)-1):
        #     cluster_center_var.append(cluster_points[self.age_set[i]][path[i]].std(axis=0)[check_idx])
        if len(cluster_center_points) == 0:
            return 0
        mono_path_features = np.stack(cluster_center_points) # 
        # print(mono_path_features.shape)
        # print(np.mean(mono_path_features, axis=0).shape
        mono_value = mono_path_features[1:] - mono_path_features[:-1]
        mono_path_relax = np.empty_like(mono_value)
        for i in range(len(mono_value)):
            mono_path_relax[i] = np.min(cluster_center_var[i:i+2], axis=0)
        # print(mono_value.shape, cluster_center_var.shape)
        # increase check
        inc_check = (mono_value >= (-1 * mono_path_relax * gene_std_tolerance)).sum(axis=0)
        increase_num = (inc_check > len(age_list)-2).sum()
        # decrease check
        # print(cluster_center_var)
        des_check = (mono_value <= mono_path_relax * gene_std_tolerance).sum(axis=0)
        decrease_num = (des_check > len(age_list)-2).sum()
        legal_num = min(increase_num + decrease_num, len(check_idx))
        result = legal_num / len(check_idx)
        return result
    
    bin_data = []

    for line in lineages:
        ratio = monoto_check(line)
        bin_data.append(ratio)
    return bin_data

def mapping_from_lineage(lineages_file, age_list, cluster_ps):
    lineages = []
    with open(lineages_file, 'r') as fin:
        content = fin.readlines()
        for line in content:
            path = eval(line)
            path_idx = 0
            new_path = []
            fake_path = False
            for i_age in range(len(age_list)):
                age = age_list[i_age]
                if path_idx < len(path):
                    if path[path_idx] in cluster_ps[age]:
                        new_path.append(path[path_idx])
                    else:
                        fake_path = True
                        break
                    if (i_age + 1) < len(age_list):
                        if (path_idx+1) < len(path):
                            if path[path_idx+1] in cluster_ps[age+1]:
                                path_idx += 1
                            elif path[path_idx] not in cluster_ps[age+1]:
                                fake_path = True
                                break
                        elif path[path_idx] in cluster_ps[age+1]:
                            pass
                        else:
                            fake_path = True
                            break
                    elif (i_age + 1) < len(age_list) and path[path_idx] not in cluster_ps[age+1]:
                        fake_path = True
                        break
            if fake_path or len(new_path) < len(age_list):
                continue
            else:
                lineages.append(new_path[: len(age_list)])
    
    lineage_length = len(age_list)
    overall_mapping = {}
    for i_age in range(lineage_length-1):
        cur_age = age_list[i_age]
        next_age = age_list[i_age+1]
        if i_age not in overall_mapping:
            overall_mapping[i_age] = set()
        for lineage in lineages:
            overall_mapping[i_age].add(lineage[i_age+1])
    for i in overall_mapping:
        overall_mapping[i] = list(overall_mapping[i])
    return overall_mapping

def monoto_check_per_gen(lineages, cluster_points, check_idx, age_list, gene_std_tolerance = 1):
    def monoto_check_per_gen(path_set, check_idx):
        # collect cluster center point
        all_path_points = []
        all_path_stds = []
        for i in range(len(path_set)):
            path = path_set[i]
            cluster_center_points = []
            cluster_center_var = []
            for j in range(len(path)):
                if path[j] not in cluster_points[age_list[j]]:
                    break
                cluster_center_points.append(cluster_points[age_list[j]][path[j]].mean(axis=0)[check_idx])
                # cluster_center_var.append(cluster_points[age_list[j]][path[j]].std(axis=0)[check_idx])
                # print(len(cluster_points[age_list[i]][path[i]]))
                # cluster_center_points.append(get_center_point(cluster_points[age_list[j]][path[j]])[check_idx])
                cluster_center_var.append(get_points_std(cluster_points[age_list[j]][path[j]])[check_idx])
            # cluster_center_points = np.stack(cluster_center_points)
            # if len(cluster_center_points) == 0:
            #     continue
            if len(cluster_center_points) < len(path):
                continue
            all_path_points.append(np.stack(cluster_center_points))
            all_path_stds.append(np.stack(cluster_center_var))
        # mono_path_features = np.stack(cluster_center_points)

        ret_data = []
        all_path_points = np.stack(all_path_points) # (2585, 9, 230)
        all_path_stds = np.stack(all_path_stds)
        for gen_idx in range(len(check_idx)):
            mono_data = all_path_points[:, :, gen_idx] # (2585, 9)
            mono_std = all_path_stds[:, :, gen_idx] # (2585, 9)
            mono_data = mono_data.transpose(1, 0) # (9, 2585)
            mono_std = mono_std.transpose(1, 0) # (9, 2585)
            mono_value = mono_data[1:] - mono_data[:-1] # (8, 2585)
            # print(mono_value)
            mono_relax = np.empty_like(mono_value)
            for i in range(len(mono_value)):
                mono_relax[i] = np.min(mono_std[i:i+2], axis=0)
            # increase check
            inc_check = (mono_value >= (-1 * mono_relax * gene_std_tolerance)).sum(axis=0)
            increase_num = (inc_check > len(age_list)-2).sum()
            # decrease check
            des_check = (mono_value <= mono_relax * gene_std_tolerance).sum(axis=0)
            
            decrease_num = (des_check > len(age_list)-2).sum()
            legal_num = min(increase_num + decrease_num, len(path_set))
            ret_data.append(legal_num / len(path_set))
        return ret_data
    
    bin_data = []

    path_set = []
    for line in lineages:
        path_set.append(line)
    bin_data = monoto_check_per_gen(path_set, check_idx)
    return bin_data


def get_monoto_gen_idx(adata, gen_csv, monoto_gen_constraints_csv):
    features = adata.var['features']
    check_idx = []
    file_name = gen_csv
    selected_genes = set()
    
    with open(file_name, 'r') as fin:
        reader = csv.reader(fin, delimiter=',')
        # print(reader)
        next(reader) # skip name
        # print(features.values)
        for line in reader:
            # print(line)
            idx, = np.where(features.values == line[1])
            
            if len(idx) > 0:
                selected_genes.add(line[1])
                check_idx.append(idx)
    if len(check_idx) == 0:
        check_idx = []
    else:
        check_idx = np.stack(check_idx, axis=0).reshape(-1)
        check_idx.sort()
    # print(len(check_idx))
    
    constraints_data = {}
    with open(monoto_gen_constraints_csv, 'r') as fin:
        reader = csv.reader(fin, delimiter=',')
        next(reader)
        for line in reader:
            idx, = np.where(features.values == line[0])
            if len(idx) > 0:
                if line[0] not in selected_genes:
                    constraints_data[int(idx[0])] = [line[3], float(line[2])]
    # monoto_gen_constraints = {}
    
    return check_idx, constraints_data


def cluster_points(adata, age_list, minimum_cell_number, feat = True):
    cluster_points = {}
    cluster_raw_data = {}
    filtered_cluster_points = {}
    for age in age_list:
        cluster_points[age] = {}
        filtered_cluster_points[age] = {}
        cluster_raw_data[age] = []
        
    for i in range(len(adata)):
        if feat:
            cluster_raw_data[adata.obs['Age'][i]].append([adata.obsm['X_pca'][i], i])
        else:
            cluster_raw_data[adata.obs['Age'][i]].append([adata.X[i], i])

    for age in age_list:
        for item in cluster_raw_data[age]:
            cluster_id = adata.obs['seurat_clusters'][item[1]]
            if cluster_id not in cluster_points[age]:
                cluster_points[age][cluster_id] = []
            cluster_points[age][cluster_id].append(item[0])
    
    
    for age in cluster_points:
        for cluster_id in cluster_points[age]:
            if len(cluster_points[age][cluster_id]) > minimum_cell_number:
                filtered_cluster_points[age][cluster_id] = np.stack(cluster_points[age][cluster_id], axis=0)
    return filtered_cluster_points


def get_lineage_cells(cluster_cells, lineage, age_list):
    lineage_cells = {}
    for i in range(len(age_list)):
        age = age_list[i]
        age_cls = lineage[i]
        lineage_cells[age] = cluster_cells[age][age_cls]
    return lineage_cells

def get_lineage_center(lineage_cells, age_list):
    lineage_center = []
    lineage_cells_np = []
    for i in age_list:
        lineage_center.append(lineage_cells[i].mean(axis=0))
        lineage_cells_np.append(lineage_cells[i])
    
    lineage_cells_np = np.concatenate(lineage_cells_np, axis=0)
    lineage_center = np.stack(lineage_center)
    return lineage_center, lineage_cells_np

def age_level_pseudotime(pseudotime_np, lineage_cells, age_list):
    pseudotime_age = {}
    start_idx = 0
    for age in age_list:
        n_cells = len(lineage_cells[age])
        pseudotime_age[age] = pseudotime_np[start_idx:start_idx+n_cells]
        start_idx += n_cells
    return pseudotime_age

def pseudotime_merge(pseudotime_age, lineage, age_list):
    m_age = age_list[0]
    merge_pseudotime = {}
    merged_age_list = []
    merge_pseudotime[m_age] = [pseudotime_age[m_age]]
    merged_age_list.append(m_age)

    for i in range(len(age_list) - 1):
        age = age_list[i+1]
        if lineage[i+1] != lineage[i]:
            m_age += 1
            merged_age_list.append(m_age)
        # else:
        #     m_age += 1
        #     merged_age_list.append(m_age)

        if m_age not in merge_pseudotime:
            merge_pseudotime[m_age] = []
        merge_pseudotime[m_age].append(pseudotime_age[age])

    for m_age in merge_pseudotime:
        if len(merge_pseudotime[m_age]) == 1:
            merge_pseudotime[m_age] = merge_pseudotime[m_age][0]
        else:
            # print(merge_pseudotime[m_age])
            merge_pseudotime[m_age] = np.concatenate(merge_pseudotime[m_age])

    return merge_pseudotime, merged_age_list

def pesudotime_eval(merge_pseudotime, merged_age_list):
    right_ratio_list = []

    for i in range(len(merged_age_list)):
        
        sub_right_list_1 = []
        sub_right_list_2 = []
        if i+1 in merge_pseudotime:
            m_age = merged_age_list[i]
            m_next_age = merged_age_list[i+1]
            for cpt in merge_pseudotime[m_age]:
                r = (merge_pseudotime[m_next_age] > cpt).sum() / len(merge_pseudotime[m_next_age])
                sub_right_list_1.append(r)
        if i-1 in merge_pseudotime:
            m_age = merged_age_list[i]
            m_prev_age = merged_age_list[i-1]
            for cpt in merge_pseudotime[m_age]:
                r = (merge_pseudotime[m_prev_age] < cpt).sum() / len(merge_pseudotime[m_prev_age])
                sub_right_list_2.append(r)
        if len(sub_right_list_1) == 0:
            right_ratio_list.append(np.mean(sub_right_list_2))
        elif len(sub_right_list_2) == 0:
            right_ratio_list.append(np.mean(sub_right_list_1))
        else:
            right_ratio_list.append(np.mean((np.array(sub_right_list_1)+np.array(sub_right_list_2)) / 2))

    return np.mean(right_ratio_list)

def pseudotime_eval(lineage_file, adata_file, minimum_cell_number, age_list):
    adata = sc.read_h5ad(adata_file)
    cluster_ps = cluster_points(adata, age_list, minimum_cell_number)
    pesudo_order_rank = []

    with open(lineage_file, 'r') as fin:
        results_list = []
        lineage_idx = 0
        for lineage in fin:
            lineage = eval(lineage)
            lineage_cells = get_lineage_cells(cluster_ps, lineage, age_list)
            lineage_center, lineage_cells_np = get_lineage_center(lineage_cells, age_list)
            
            curve = PrincipalCurve(k=5)
            curve.project_to_curve(lineage_cells_np, lineage_center, stretch=2)
            pseudotime_age = age_level_pseudotime(curve.pseudotimes_interp, lineage_cells, age_list)
            merge_pseudotime, merged_age_list = pseudotime_merge(pseudotime_age, lineage, age_list)
            # print(merged_age_list)
            pesudo_order_rank.append(pesudotime_eval(merge_pseudotime, merged_age_list))
            lineage_idx += 1
    
    return np.mean(pesudo_order_rank)