from scipy.stats import pearsonr, spearmanr
import argparse
import scanpy as sc
import pickle
from eval_util import *
import json
import pandas as pd

class MethodEval(object):
    def __init__(self, adata_file, lineage_file, gen_csv, gen_constraints_csv,) -> None:
        pass


def mapping_from_lineage(lineages_file, age_list, cluster_ps):
    lineages = []
    with open(lineages_file, 'r') as fin:
        content = fin.readlines()
        for line in content:
            path = eval(line)
            path_idx = 0
            new_path = []
            fake_path = False
            for i_age in range(len(age_list)):
                age = age_list[i_age]
                if path_idx < len(path):
                    if path[path_idx] in cluster_ps[age]:
                        new_path.append(path[path_idx])
                    else:
                        fake_path = True
                        break
                    if (i_age + 1) < len(age_list):
                        if (path_idx+1) < len(path):
                            if path[path_idx+1] in cluster_ps[age+1]:
                                path_idx += 1
                            elif path[path_idx] not in cluster_ps[age+1]:
                                fake_path = True
                                break
                        elif path[path_idx] in cluster_ps[age+1]:
                            pass
                        else:
                            fake_path = True
                            break
                    elif (i_age + 1) < len(age_list) and path[path_idx] not in cluster_ps[age+1]:
                        fake_path = True
                        break
            if fake_path or len(new_path) < len(age_list):
                lineages.append(new_path[: len(age_list)])
            else:
                lineages.append(new_path[: len(age_list)])
    
    lineage_length = len(age_list)
    overall_mapping = {}
    for i_age in range(lineage_length-1):
        cur_age = age_list[i_age]
        next_age = age_list[i_age+1]
        if cur_age not in overall_mapping:
            overall_mapping[cur_age] = set()
        # if next_age not in overall_mapping[cur_age]:
        #     overall_mapping[cur_age][next_age] = set()
        for lineage in lineages:
            if i_age + 1< len(lineage):
                overall_mapping[cur_age].add(lineage[i_age+1])
    for cur_age in overall_mapping:
            overall_mapping[cur_age] = list(overall_mapping[cur_age])
    return overall_mapping
    
def lineage_align(lineages_file, age_list, cluster_ps):
    lineages = set()
    with open(lineages_file, 'r') as fin:
        content = fin.readlines()
        for line in content:
            path = eval(line)
            # print(path)
            path_idx = 0
            new_path = []
            fake_path = False
            for i_age in range(len(age_list)):
                age = age_list[i_age]
                if path_idx < len(path):
                    if path[path_idx] in cluster_ps[age]:
                        new_path.append(path[path_idx])
                        print(new_path)
                    else:
                        fake_path = True
                        break
                    if (i_age + 1) < len(age_list):
                        if (path_idx+1) < len(path):
                            if path[path_idx+1] in cluster_ps[age+1]:
                                path_idx += 1
                            elif path[path_idx] not in cluster_ps[age+1]:
                                fake_path = True
                                break
                        elif path[path_idx] in cluster_ps[age+1]:
                            pass
                        else:
                            fake_path = True
                            break
                    elif (i_age + 1) < len(age_list) and path[path_idx] not in cluster_ps[age+1]:
                        fake_path = True
                        break
            if fake_path or len(new_path) < len(age_list):
                lineages.add(str(new_path[: len(age_list)]))
            else:
                # ttt = new_path[: len(age_list)]
                lineages.add(str(new_path[: len(age_list)]))
    lineages = list(lineages)
    ll_lineages = []
    for sl in lineages:
        ll_lineages.append(eval(sl))
    return ll_lineages

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--adata-file', type=str, default='SCT_all_P5_rm_sampled.h5ad')
    parser.add_argument('--result-dir', type=str, default='result')
    parser.add_argument('--lineage-file', type=str, default='notebook/lineage.json')
    parser.add_argument('--overall-mapping-file', type=str, default='notebook/moscot/moscot_mapping.pkl')
    parser.add_argument('--gene-csv', type=str, default='notebook/all_the_way_up_genes.csv')
    parser.add_argument('--gene-constraints-csv', type=str, default='notebook/Age_asso_genes.csv')
    parser.add_argument('--pesudo-time-file', type=str, default='notebook/pesudo_time_result.pkl')
    args = parser.parse_args()
    
    adata = sc.read_h5ad(args.adata_file)
    if 'Age' not in adata.obs:
        adata.obs['Age'] = adata.obs['orig.ident']
    age_list = [0, 1, 2, 3]

    cluster_ps = cluster_points(
        adata, age_list, 5, True
    )

    
    dumped_dict = {}
    lineages = []
    if args.overall_mapping_file != 'None':
        with open(args.overall_mapping_file, 'rb') as fin:
            overall_mapping = pickle.load(fin)
        
        with open(args.lineage_file, 'r') as fin:
            content = fin.readlines()
            
            path_set = []
            for line in content:
                path = eval(line)
                lineages.append(path)
    else:
        lineages = lineage_align(args.lineage_file, age_list, cluster_ps)
        
        overall_mapping = mapping_from_lineage(args.lineage_file, age_list, cluster_ps)
        
    # print(overall_mapping)
    # corr_result = lineage_correlation(overall_mapping, cluster_ps,
    #                 lambda x, y: spearmanr(x,y).correlation * 0.5 + pearsonr(x,y).statistic * 0.5, age_list)
    # dumped_dict['lineage_correlation'] = corr_result
    check_gene_idx, _ = get_monoto_gen_idx(
        adata, args.gene_csv, args.gene_constraints_csv
    )
    cluster_ps = cluster_points(
        adata, age_list, 10, False
    )
    # print(check_gene_idx)
    gene_level_monoto = np.mean(monoto_check_per_gen(lineages, cluster_ps, 
                        check_gene_idx, age_list, 0.5))
    path_level_monoto = np.mean(monoto_check_per_path(lineages, cluster_ps, 
                        check_gene_idx, age_list, 0.5))
    dumped_dict['monoto'] = {}
    dumped_dict['monoto']['per_path'] = path_level_monoto
    dumped_dict['monoto']['per_gen'] = path_level_monoto
    print(f"Monoto (GENE): {gene_level_monoto}")
    print(f"Monoto (PATH): {path_level_monoto}")
    work_path = pathlib.Path(args.result_dir)
    with open(work_path.joinpath("finally_result.json"), 'w') as fout:
        json.dump(dumped_dict, fout)
