import numpy as np
import torch
import matplotlib.pyplot as plt
from pcurve import PrincipalCurve
from tqdm import tqdm

def lineage_dist(lineage, age_set, cluster_points):
    # lineage = [2, 2, 19, 29, 29, 29, 29, 29, 23]
    points_set = []
    bezier_fix_points = []
    for i in range(len(lineage)):
        age = age_set[i]
        points_set.append(cluster_points[age][lineage[i]])
        bezier_fix_points.append(cluster_points[age][lineage[i]].mean(axis=0))
    return points_set, bezier_fix_points


def BaseFunction(i=None, k=None, u=None, NodeVector=None):
    if k == 1:
        if u >= NodeVector[i] and u < NodeVector[i + 1]:
            Bik_u = 1
        else:
            Bik_u = 0
    else:
        denominator_1 = NodeVector[i + k - 1] - NodeVector[i]
        denominator_2 = NodeVector[i + k] - NodeVector[i + 1]
        if denominator_1 == 0:
            denominator_1 = 1
        if denominator_2 == 0:
            denominator_2 = 1
        Bik_u = (u - NodeVector[i]) / denominator_1 * BaseFunction(i, k - 1, u, NodeVector) + \
            (NodeVector[i + k] - u) / denominator_2 * \
            BaseFunction(i + 1, k - 1, u, NodeVector)
    return Bik_u

def U_quasi_B_Spline(n = None, k = None):
    NodeVector = np.zeros((n+k+1, ))
    piecewise = n - k + 2
    if piecewise == 1:
        NodeVector[n+1:n+k+1] = 1
    else:
        for i in range(n-k+1):
            NodeVector[k+i] = NodeVector[k+i-1]+1/piecewise
        NodeVector[n + 1:n + k + 1] = 1
    return NodeVector.tolist()

def b_spline(ref_points, order = 3, n_path_point = 100, extend = False):
    n = len(ref_points)-1
    ret_path = []
    NodeVector = U_quasi_B_Spline(n, order)
    s_point = 0
    e_point = 1
    if extend:
        s_point = -0.5
        e_point = 1.5
    for u in np.arange(s_point, e_point, 1/n_path_point):
        pp = None
        for i in range(n+1):
            if pp is None:
                pp = ref_points[i] * BaseFunction(i, order, u, NodeVector)
            else:
                pp += ref_points[i] * BaseFunction(i, order, u, NodeVector)
            # print(ref_points[i].grad)
            # ref_points[i].grad = None
        ret_path.append(pp)
    ret_path.append(ref_points[-1])
    return ret_path

def polynomial(ref_points, order = 3, n_path_point = 100, extend = False):
    def curve_func(t):
        ret_p = None
        for i in range(len(ref_points)):
            if ret_p is None:
                ret_p = ref_points[i] * (t**i)
            else:
                ret_p += ref_points[i] * (t**i)
        return ret_p
    ret_path = []
    for u in np.arange(0, 1, 1/n_path_point):
        ret_path.append(curve_func(u))
    
    return ret_path

def project_cell(paths, tensor_points_set):
    ret_idx = []
    with torch.no_grad():
        t_paths = []
        for pp in paths:
            t_paths.append(pp.data)
        t_paths = torch.stack(t_paths)
        for cell in tensor_points_set:
            dist = ((t_paths - cell) ** 2).sum(dim=-1).sqrt()
            ret_idx.append(dist.argmin().cpu().item())
    return ret_idx

def norm2(a, b):
    return ((a-b)**2).sqrt().mean()

def draw_dist_bin(new_cluster_set, tensor_points_set, label, n_path_point = None, draw = True):
    with torch.no_grad():
        dist_list = []
        if n_path_point is None:
            n_point = len(new_cluster_set) * 100
        else:
            n_point = n_path_point
        paths = b_spline(new_cluster_set, 3, n_point)
        proj_idx = project_cell(paths, tensor_points_set)
        for j in range(len(proj_idx)):
            idx = proj_idx[j]
            dist_list.append(norm2(paths[idx], tensor_points_set[j]).cpu().item())
    # print(sum(dist_list))
    if draw:
        plt.hist(dist_list, bins=40, label=label, alpha=0.4)
        plt.xlabel("Distance")
        plt.ylabel("Number of Cells")
        plt.legend()
    return np.mean(dist_list)
    
def draw_traj(paths, np_points_set, age_label):
    plt.scatter(np_points_set[:, 0], np_points_set[:, 1], c = age_label)
    plt.plot(paths[:, 0], paths[:, 1])
    plt.show()
    
def peusodotime_dist(proj_idx, paths, tensor_points_set, label):
    
    p_length = 0
    sub_lengths = [0]
    with torch.no_grad():
        for i in range(len(paths)-1):
            sub_l = norm2(paths[i], paths[i+1])
            p_length += sub_l.cpu().item()
            sub_lengths.append(p_length)
    sub_lengths = [sl / p_length for sl in sub_lengths]
    proj_idx = project_cell(paths, tensor_points_set)
    
    proj_dist = []
    n_edge = 0
    for i in range(len(proj_idx)):
        idx = proj_idx[i]
        if idx == 0:
            n_edge += 1
            proj_dist.append(-1 * norm2(paths[0], tensor_points_set[i]).item() / p_length)
        elif idx == len(paths)-1:
            n_edge += 1
            proj_dist.append(1 + norm2(paths[-1], tensor_points_set[i]).item() / p_length)
        else:
            proj_dist.append(sub_lengths[idx])

    print(n_edge)
            
    plt.hist(proj_dist, bins=40, label=label)
    plt.xlabel("Distance")
    plt.ylabel("Number of Cells")
    
def get_cluster_points(ori_data, age_set, lineage):
    cluster_points = {}
    cluster_raw_data = {}
    # overall_onehot_cluster_id = []
    for age in age_set:
        cluster_points[age] = {}
        cluster_raw_data[age] = []
        
    for i in range(len(ori_data)):
        # cluster_raw_data[ori_data.obs['Age'][i]].append([ori_data.X[i], i])
        cluster_raw_data[ori_data.obs['Age'][i]].append([ori_data.obsm['X_umap'][i], i])

    for age in age_set:
        for item in cluster_raw_data[age]:
            cluster_id = ori_data.obs['seurat_clusters'][item[1]]
            if cluster_id not in cluster_points[age]:
                cluster_points[age][cluster_id] = []
            cluster_points[age][cluster_id].append(item[0])
    cluster_centers = []
    for i in range(len(age_set)):
        age = age_set[i]
        cluster_centers.append(np.median(np.stack(cluster_points[age][lineage[i]], axis=0), axis=0))
    cluster_centers = np.stack(cluster_centers)
    return cluster_centers

def get_pseudotime(curve_points, tensor_points_set, alpha = 1.0, draw = True):
    curve = PrincipalCurve(k=5)
    # print(curve_points[1::2].shape)
    # if len(curve_points) > 4:
    #     curve_points = curve_points[::2]
    curve.project_to_curve(tensor_points_set.cpu().numpy(), points=curve_points.cpu().numpy(), stretch=0)

    paths = b_spline(curve_points, 3, 4)
    proj_idx = project_cell(paths, tensor_points_set)

    negative_point = set()
    sup_point = set()
    # return list(curve.pseudotimes_interp)
    start_lower = max(curve.pseudotimes_interp)
    end_highest = min(curve.pseudotimes_interp)

    for i in range(len(proj_idx)):
        idx = proj_idx[i]
        if idx == 0:
            negative_point.add(i)
        elif idx == len(paths)-1:
            sup_point.add(i)
        else:
            if curve.pseudotimes_interp[i] > end_highest:
                end_highest = curve.pseudotimes_interp[i]
            elif curve.pseudotimes_interp[i] < start_lower:
                start_lower = curve.pseudotimes_interp[i]
                
    new_pseudotime = []

    for i in range(len(curve.pseudotimes_interp)):
        new_pseudotime.append((curve.pseudotimes_interp[i] - start_lower) / (end_highest - start_lower))
    if draw:
        plt.hist(new_pseudotime, bins=20, label="Proj dist", alpha = alpha)
        plt.xlabel("Pesudotime")
        plt.ylabel("Number of Cells")
        plt.legend()
    return new_pseudotime

def get_pseudotime2(curve_points, tensor_points_set, alpha = 1.0, draw = True):

    paths = b_spline(curve_points, 3, 10000)
    proj_idx = project_cell(paths, tensor_points_set)

    # path length
    p_length = 0
    sub_lengths = [0]
    with torch.no_grad():
        for i in range(len(paths)-1):
            sub_l = norm2(paths[i], paths[i+1])
            p_length += sub_l.cpu().item()
            sub_lengths.append(p_length)
    sub_lengths = [sl / p_length for sl in sub_lengths]

    proj_dist = []
    n_edge = 0
    for i in range(len(proj_idx)):
        idx = proj_idx[i]
        if idx == 0:
            n_edge += 1
            proj_dist.append(-1 * norm2(paths[0], tensor_points_set[i]).item() / p_length)
        elif idx == len(paths)-1:
            n_edge += 1
            proj_dist.append(1 + norm2(paths[-1], tensor_points_set[i]).item() / p_length)
        else:
            proj_dist.append(sub_lengths[idx])
    new_pseudotime = proj_dist
    if draw:
        plt.hist(new_pseudotime, bins=40, label="Proj dist", alpha = alpha)
        plt.xlabel("Pesudotime")
        plt.ylabel("Number of Cells")
        plt.legend()
    return new_pseudotime



def pseudotime_eval(pseudotime, age_label):
    np_pseudotime = np.stack(pseudotime)
    right_ratio_list = []
    bar = tqdm(range(len(age_label)))
    for i in bar:
        cur_age_label = age_label[i]
        cur_pseudotime = np_pseudotime[i]
        right_num = 0
        error_num = 0
        # for the prev age
        prev_pseudotime = np_pseudotime[age_label < cur_age_label - 1]
        after_pseudotime = np_pseudotime[age_label > cur_age_label + 1]
        if len(prev_pseudotime) <= 1:
            pass
        else:
            right_num += (prev_pseudotime <= cur_pseudotime).sum()
            error_num += (prev_pseudotime > cur_pseudotime).sum()
        if len(after_pseudotime) <= 1:
            pass
        else:
            right_num += (after_pseudotime > cur_pseudotime).sum()
            error_num += (prev_pseudotime <= cur_pseudotime).sum()
        right_ratio_list.append(right_num / (right_num + error_num))
    
    return right_ratio_list