import scanpy as sc
import numpy as np
import matplotlib.pyplot as plt
import pathlib
import torch
import torch.nn.functional as F
from tqdm import tqdm
import pickle

import sys, os
sys.path.append(os.path.abspath(os.curdir))

from pesudotime_utils import *
from eval_util import *
from utils import *
# from slingshot.slingshot import Slingshot

class OptBasedCellPesudotime(object):
    def __init__(self, h5_data, age_set: list[str], result_dir, data_type = 'umap', use_gpu = True, debug = False, draw_img = True,
                 minimum_gene_number = 1, add_feat_path = None) -> None:
        self.h5_data = h5_data
        self.age_set = age_set
        self.cluster_points = {}
        self.use_gpu = use_gpu
        self.debug = debug
        self.data_type = data_type
        self.minimum_gene_number = minimum_gene_number
        self.add_feat_path = add_feat_path
        self.root = pathlib.Path(result_dir)
        self.root.mkdir(exist_ok=True)
        self.data_preprocess()
        self.draw_img = draw_img
        # self.cluster_ids = {}
        
    def data_preprocess(self):
        tmp_cluster_points = {}
        tmp_cluster_raw_data = {}
        for age in self.age_set:
            tmp_cluster_points[age] = {}
            tmp_cluster_raw_data[age] = []
            self.cluster_points[age] = {}

            
        if self.data_type == 'umap':
            for i in range(len(self.h5_data)):
                tmp_cluster_raw_data[self.h5_data.obs['Age'][i]].append([self.h5_data.obsm['X_umap'][i], i])
        elif self.data_type == 'raw_data':
            for i in range(len(self.h5_data)):
                tmp_cluster_raw_data[self.h5_data.obs['Age'][i]].append([self.h5_data.X[i], i])
        elif self.data_type == 'pca55':
            for i in range(len(self.h5_data)):
                tmp_cluster_raw_data[self.h5_data.obs['Age'][i]].append([self.h5_data.obsm['X_pca'][i], i])
        elif self.data_type == 'best_epoch':
            with open(self.add_feat_path, 'rb') as fin:
                feat_np = pickle.load(fin)
            for i in range(len(self.h5_data)):
                tmp_cluster_raw_data[self.h5_data.obs['Age'][i]].append([feat_np[i], i])
        elif self.data_type == 'best_epoch_pca55':
            with open(self.add_feat_path, 'rb') as fin:
                feat_np = pickle.load(fin)
            for i in range(len(self.h5_data)):
                tmp_cluster_raw_data[self.h5_data.obs['Age'][i]].append([feat_np[i], i])
        elif self.data_type == 'best_epoch_umap':
            with open(self.add_feat_path, 'rb') as fin:
                feat_np = pickle.load(fin)
            for i in range(len(self.h5_data)):
                tmp_cluster_raw_data[self.h5_data.obs['Age'][i]].append([feat_np[i], i])

        for age in self.age_set:
            for item in tmp_cluster_raw_data[age]:
                cluster_id = self.h5_data.obs['seurat_clusters'][item[1]]
                if cluster_id not in tmp_cluster_points[age]:
                    tmp_cluster_points[age][cluster_id] = []
                tmp_cluster_points[age][cluster_id].append(item[0])

        for age in tmp_cluster_points:
            for cluster_id in tmp_cluster_points[age]:
                if len(tmp_cluster_points[age][cluster_id]) > self.minimum_gene_number:
                    self.cluster_points[age][cluster_id] = np.stack(tmp_cluster_points[age][cluster_id], axis=0)
                
    def set_lineage(self, lineage: list[int]):
        points_set = []
        cluster_set = []
        age_label = []
        self.lineage = lineage

        same_cls_list = [[0]]
        cls_idx = 0
        for i in range(1, len(self.age_set)):
            cur_age = lineage[i-1]
            next_age = lineage[i]
            if cur_age == next_age:
                same_cls_list[cls_idx].append(i)
            else:
                same_cls_list.append([i])
                cls_idx += 1
        
        for n_cls in range(len(same_cls_list)):
            cls_level_points = []
            for i in same_cls_list[n_cls]:
                age = self.age_set[i]
                if lineage[i] not in self.cluster_points[age]:
                    continue
                cls_level_points.append(self.cluster_points[age][lineage[i]])
                for j in range(len(self.cluster_points[age][lineage[i]])):
                    age_label.append(i)
            if len(cls_level_points) == 0:
                continue
            cls_points = np.concatenate(cls_level_points, axis=0)
            points_set.append(cls_points)
            if lineage[i] not in self.cluster_points[age]:
                continue
            cluster_set.append(get_center_point(self.cluster_points[age][lineage[i]]))
            
        age_label = np.array(age_label)
        return points_set, cluster_set, age_label
    
    
    def optim_prepare(self, points_set, cluster_set, age_label, insert_num = 1):
        np_points_set = np.concatenate(points_set)
        tensor_points_set = torch.from_numpy(np_points_set)
        if self.use_gpu:
            tensor_points_set = tensor_points_set.cuda()
            
        tensor_cluster_set = [torch.nn.Parameter(torch.tensor(pp).cuda() if self.use_gpu else torch.tensor(pp))
                              for pp in cluster_set]
        # insert learnable points
        new_cluster_set = []
        learnable_idx = []
        idx = 0
        with torch.no_grad():
            for i in range(len(tensor_cluster_set)-1):
                new_cluster_set.append(tensor_cluster_set[i])
                idx += 1
                for j in range(insert_num):
                    step_size = 1 / (insert_num+1)
                    learnable_idx.append(idx)
                    new_cluster_set.append(tensor_cluster_set[i] * step_size + tensor_cluster_set[i+1] * (1-step_size))
                    idx += 1
        # print(learnable_idx)
        new_cluster_set.append(tensor_cluster_set[-1])
        return new_cluster_set, learnable_idx, tensor_points_set

    def get_optim_poly(self, new_cluster_set, learnable_idx, lr=0.01):
        grad_params = []
        for i in range(len(new_cluster_set)):
            # print(type(new_cluster_set[i]))
            new_cluster_set[i] = torch.nn.Parameter(new_cluster_set[i])
            grad_params.append(new_cluster_set[i])
            
        optimizer = torch.optim.SGD(grad_params, lr = lr)
        
        return optimizer
    
    def get_optim(self, new_cluster_set, learnable_idx, lr=0.01):
        grad_params = []
        for i in learnable_idx:
            # print(type(new_cluster_set[i]))
            new_cluster_set[i] = torch.nn.Parameter(new_cluster_set[i])
            grad_params.append(new_cluster_set[i])
            
        optimizer = torch.optim.SGD(grad_params, lr = lr)
        
        return optimizer
    
    def gen_point_weight(self, tensor_points_set, topk):
        point_weight = []

        with torch.no_grad():
            for point in tensor_points_set:
                dists = (tensor_points_set - point).norm(dim=-1)
                min_k = dists.topk(topk + 1, largest = False)
                mean_dist = dists[min_k[1:]].mean()
                point_weight.append(mean_dist)
            point_weight = torch.stack(point_weight)

        point_weight_std = point_weight.std()
        point_weight_mean = point_weight.mean()
        calc_idices = point_weight <= point_weight_mean+point_weight_std

        point_weight[calc_idices] = 1
        point_weight[~calc_idices] = 0
        
    def centerbaser_gen_point_weight(self, tensor_points_set, topk, lineage):
        point_weight = []

        cluster_set = []
        age_label = []
        age_level_idx = {}

        for i in range(len(self.age_set)):
            age_level_idx[i] = []
            age = self.age_set[i]
            cluster_set.append(np.median(self.cluster_points[age][lineage[i]], axis=0))
            for j in range(len(self.cluster_points[age][lineage[i]])):
                age_level_idx[i].append(len(age_label))
                age_label.append(i)
        
        for kk in age_level_idx:
            age_level_idx[kk] = torch.tensor(age_level_idx[kk])

        point_weight = torch.zeros((len(tensor_points_set), ), dtype=torch.float)
        with torch.no_grad():
            for kk in age_level_idx:
                idxs = age_level_idx[kk]
                center = torch.from_numpy(cluster_set[kk]).to(tensor_points_set.device)
                dists = (tensor_points_set[idxs] - center).norm(dim=-1)
                
                dist_std = dists.std()
                dist_mean = dists.mean()
                indices = dists <= dist_mean# + dist_std
                dists[indices] = 1
                dists[~indices] = 0
                
                for i in range(len(dists)):
                    point_weight[idxs[i]] = dists[i]


        return point_weight
    
    def optim_curve_poly(self, optmer: torch.optim.Optimizer, 
                    max_optm_step, 
                    stop_threshold, 
                    tolerance_num,
                    new_cluster_set,
                    tensor_points_set,
                    topk = 10):
        bar = tqdm(range(max_optm_step))
        stop_threshold = stop_threshold
        tolerance_num = tolerance_num
        stop_count = 0
        last_loss = None
        losses = []
        
        # draw dist bin before optim
        if self.draw_img:
            draw_dist_bin(new_cluster_set, tensor_points_set, label = f"Optim Before")
        
        scheduler = torch.optim.lr_scheduler.StepLR(optmer, 10, 0.8)

        point_weight = torch.ones((len(tensor_points_set), ), dtype=torch.float).to(new_cluster_set[0].device)
        # start optimization
        for i in bar:
            # calc path
            paths = polynomial(new_cluster_set, 3, 300)
            # find project point of each cell
            proj_idx = project_cell(paths, tensor_points_set)
            # print(paths)
            # calc dist
            loss = None
            
            loss_path = []
            loss_tensor_points = []
            
            for j in range(len(proj_idx)):
                idx = proj_idx[j]
                loss_tensor_points.append(tensor_points_set[j])
                loss_path.append(paths[idx])

            loss = F.mse_loss(torch.stack(loss_path) * point_weight.unsqueeze(1), torch.stack(loss_tensor_points) * point_weight.unsqueeze(1), reduction = 'sum')
            loss = loss.sqrt()
            loss.backward()
            
            optmer.step()
            optmer.zero_grad()
            bar.set_description(f"{loss.cpu().item()}")
            losses.append(loss.cpu().item())
            # scheduler.step()
            if last_loss is None:
                last_loss = loss.cpu().item()
            else:
                if (last_loss - loss) / loss < stop_threshold:
                    stop_count += 1
                else:
                    stop_count = 0
                if stop_count > tolerance_num:
                    break
                last_loss = loss.cpu().item()
        # if self.draw_img:
        overall_cost = draw_dist_bin(new_cluster_set, tensor_points_set, "Optim After training", draw = self.draw_img)
        
        return new_cluster_set, losses, overall_cost
    
    def optim_curve(self, optmer: torch.optim.Optimizer, 
                    max_optm_step, 
                    stop_threshold, 
                    tolerance_num,
                    new_cluster_set,
                    tensor_points_set,
                    topk = 10):
        bar = tqdm(range(max_optm_step))
        stop_threshold = stop_threshold
        tolerance_num = tolerance_num
        stop_count = 0
        last_loss = None
        losses = []
        
        # draw dist bin before optim
        if self.draw_img:
            draw_dist_bin(new_cluster_set, tensor_points_set, label = f"Optim Before")
        
        scheduler = torch.optim.lr_scheduler.StepLR(optmer, 10, 0.8)

        point_weight = torch.ones((len(tensor_points_set), ), dtype=torch.float).to(new_cluster_set[0].device)
        # start optimization
        for i in bar:
            # calc path
            paths = b_spline(new_cluster_set, 3, 300)
            # find project point of each cell
            proj_idx = project_cell(paths, tensor_points_set)
            # print(paths)
            # calc dist
            loss = None
            
            loss_path = []
            loss_tensor_points = []
            
            for j in range(len(proj_idx)):
                idx = proj_idx[j]
                loss_tensor_points.append(tensor_points_set[j])
                loss_path.append(paths[idx])

            loss = F.mse_loss(torch.stack(loss_path) * point_weight.unsqueeze(1), torch.stack(loss_tensor_points) * point_weight.unsqueeze(1), reduction = 'sum')
            loss = loss.sqrt()
            loss.backward()
            
            optmer.step()
            optmer.zero_grad()
            bar.set_description(f"{loss.cpu().item()}")
            losses.append(loss.cpu().item())
            # scheduler.step()
            if last_loss is None:
                last_loss = loss.cpu().item()
            else:
                if (last_loss - loss) / loss < stop_threshold:
                    stop_count += 1
                else:
                    stop_count = 0
                if stop_count > tolerance_num:
                    break
                last_loss = loss.cpu().item()
        # if self.draw_img:
        overall_cost = draw_dist_bin(new_cluster_set, tensor_points_set, "Optim After training", draw = self.draw_img)
        
        return new_cluster_set, losses, overall_cost
    
    def get_path(self, new_cluster_set):
        p_path = b_spline(new_cluster_set, 3, len(new_cluster_set) * 10)
        np_path = [pp.detach().cpu().numpy() for pp in p_path]
        np_path = np.stack(np_path)
        
        c_path = [pp.detach().cpu().numpy() for pp in new_cluster_set]
        c_path = np.stack(c_path)
        if self.debug:
            plt.scatter(c_path[:, 0], c_path[:, 1], marker='^', s = 300, color='red')
        
        return np_path
    
    def get_curve_point(self, new_cluster_set):
        curve_points = []
        for pp in new_cluster_set:
            curve_points.append(pp.data)
        
        curve_points = torch.stack(curve_points)
        return curve_points
    
    def dump_inter_img(self, file_name):
        fn = self.root.joinpath(file_name)
        plt.savefig(fn)
        plt.close()
        
    def dump2pkl(self, data, fn):
        with open(self.root.joinpath(fn), 'wb') as fin:
            pickle.dump(data, fin)
        

    
def opt_test(h5_data, age_set, result_dir, t_lineage):
    opt_p_time = OptBasedCellPesudotime(h5_data, age_set, result_dir)
    
    points_set, cluster_set, age_label = opt_p_time.set_lineage(t_lineage)
    
    new_cluster_set, learnable_idx, tensor_points_set = opt_p_time.optim_prepare(points_set, cluster_set, age_label)
    optimizer = opt_p_time.get_optim(new_cluster_set, learnable_idx)
    
    new_cluster_set, losses = opt_p_time.optim_curve(
        optimizer,
        int(1e5),
        1e-4,
        10,
        new_cluster_set,
        tensor_points_set
    )
    
    opt_p_time.dump_inter_img("dist2curve.pdf")
    
    loss_x_axis = [i for i in range(len(losses))]
    plt.plot(loss_x_axis, losses)
    opt_p_time.dump_inter_img("optm_loss.pdf")

def draw_datapoint(h5_data, age_set, lineage, *curves):
    points_list = []
    points_label = []
    for i in range(len(h5_data)):
        age = h5_data.obs['Age'][i]
        cls_idx = h5_data.obs['seurat_clusters'][i]
        if lineage[age_set.index(age)] == cls_idx:
            points_list.append(h5_data.obsm['X_umap'][i])
            points_label.append(h5_data.obs['seurat_clusters'][i])
    np_points = np.stack(points_list)
    np_points_label = np.stack(points_label)
    
    plt.scatter(np_points[:, 0], np_points[:, 1], c = np_points_label)
    
    # draw cluster centers
    cluster_centers = get_cluster_points(h5_data, age_set, lineage)
    plt.scatter(cluster_centers[:, 0], cluster_centers[:, 1], color='red', s = 200, marker='*')
    
    # draw curve
    for cc in curves:
        ppos = cc[0]
        
        label = cc[1]
        plt.plot(ppos[:, 0], ppos[:, 1], label = label, linewidth = 2)
    

    plt.legend()
    return cluster_centers


    
def only_opt(opt_p_time: OptBasedCellPesudotime, age_set, t_lineage, lineage_idx):
    
    opt_points_set, opt_cluster_set, opt_age_label = opt_p_time.set_lineage(t_lineage)
    # mst_points_set, mst_overall_onehot_cluster_id = mst_p_time.set_lineage(t_lineage)
    for pp in opt_points_set:
        print(len(pp))
    opt_new_cluster_set, opt_learnable_idx, opt_tensor_points_set = opt_p_time.optim_prepare(opt_points_set, opt_cluster_set, opt_age_label, insert_num = 1)
    optimizer = opt_p_time.get_optim_poly(opt_new_cluster_set, 
                                     opt_learnable_idx, 
                                     1)
    print("[Optimal]: optimize curve")
    opt_new_cluster_set, losses, opt_overall_cost = opt_p_time.optim_curve(
        optimizer,
        int(1e3),
        1e-4,
        10,
        opt_new_cluster_set,
        opt_tensor_points_set,
        10
    )

    
    opt_p_time.dump_inter_img(f"dist2curve_{lineage_idx}th.pdf")
    
    loss_x_axis = [i for i in range(len(losses))]
    plt.plot(loss_x_axis, losses)
    opt_p_time.dump_inter_img(f"optm_loss_{lineage_idx}.pdf")
    
    # draw points

    opt_path = opt_p_time.get_path(opt_new_cluster_set)
    # mst_path = slingshot.curves[0].points
    
    

    opt_p_time.dump2pkl(opt_path, f'sctime_path_{lineage_idx}.pkl')
    opt_p_time.dump_inter_img(f"curves_{lineage_idx}.jpg")

    # pesudotime
    # 1.pesudotime of optimal based
    get_pseudotime(opt_p_time.get_curve_point(opt_new_cluster_set), opt_tensor_points_set)
    opt_p_time.dump_inter_img(f"optm_pesudotime_{lineage_idx}.pdf")

    # 2.pesudotime of both
    optm_pesudotime = get_pseudotime(opt_p_time.get_curve_point(opt_new_cluster_set), opt_tensor_points_set, alpha=0.4)

    
    # eval pesudotime
    lineage_cells = {}
    for i in range(len(lineage)):
        age_cls_id = lineage[i]
        age = age_set[i]
        lineage_cells[age] = opt_p_time.cluster_points[age][age_cls_id]

    pseudotime_age = age_level_pseudotime(optm_pesudotime, lineage_cells, age_set)
    merge_pseudotime, merged_age_list = pseudotime_merge(pseudotime_age, lineage, age_set)
    optm_rank = pesudotime_eval(merge_pseudotime, merged_age_list)

    merge_pseudotime, merged_age_list = pseudotime_merge(pseudotime_age, lineage, age_set)

    
    
    dist_cost = {
        'Optimal': opt_overall_cost,
        "Optimal Lineage Eval": optm_rank,
    }
    print(dist_cost)
    return dist_cost
    
if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--result-dir', type=str, default='iclr_exp/pseudo_time_raw')
    parser.add_argument('--lineage-file', type=str, default='iclr_exp/lineage_result/raw_data-pearson-norm_later-together-only_award-mean/Filtered_path.json')
    parser.add_argument('--rna-h5ad', type=str, default='rna_data/SCT_all_P5_rm_sampled.h5ad')
    parser.add_argument('--feat-file', type=str, default='None')
    args = parser.parse_args()
    
    
    rna_data = sc.read_h5ad(args.rna_h5ad)

    if 'Age' not in rna_data.obs and 'orig.ident' in rna_data.obs:
        rna_data.obs['Age'] = rna_data.obs['orig.ident']
    elif 'Age' not in rna_data.obs and 'time' in rna_data.obs:
        rna_data.obs['Age'] = rna_data.obs['time']
    
    if 'features' not in rna_data.var:
        rna_data.var['features'] = rna_data.var['gene_name']
    
    if 'seurat_clusters' not in rna_data.obs:
        rna_data.obs['seurat_clusters'] = rna_data.obs['cell_type']
        
    args.feat_file = None if args.feat_file == 'None' else args.feat_file
    
    # age_set = [0,1,2,3,4,5,6,7,8,9,10]
    age_set = [0,1,2,3]
    result_dir = args.result_dir
    opt_p_time = OptBasedCellPesudotime(rna_data, age_set, result_dir, debug=False, data_type='best_epoch_pca55', add_feat_path = args.feat_file)
    with open(args.lineage_file, 'r') as fin:
        # reader = json.load(fin, )
        plt.close()
        results_list = []
        lineage_idx = 0
        for lineage in fin:
            lineage = eval(lineage)
            # print(lineage)
            if lineage_idx < 0:
                pass
            else:
                try:
                    cur_lineage_result = only_opt(
                        opt_p_time,
                        age_set,
                        lineage,
                        lineage_idx
                    )
                except Exception as e:
                    print("Failed", e)
                    
                results_list.append(cur_lineage_result)
            lineage_idx += 1

            
    import csv
    opt_pesudo_order_rank = []
    opt_cost_per_lineage = []
    mst_pesudo_order_rank = []
    mst_cost_per_lineage = []
    with open(opt_p_time.root.joinpath("Finall_result.csv"), 'w') as fout:
        writer = csv.writer(fout)
        sort_key = list(results_list[0].keys())
        sort_key.sort()
        writer.writerow(sort_key)
        for item in results_list:
            line = []
            for kk in sort_key:
                line.append(item[kk])
            opt_pesudo_order_rank.append(item['Optimal Lineage Eval'])
            opt_cost_per_lineage.append(item['Optimal'])
            writer.writerow(line)
    print(f"Optimal: {np.mean(opt_pesudo_order_rank)}, {np.mean(opt_cost_per_lineage)}")