import os
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from scipy.stats import gaussian_kde


def compute_epochwise_delta(A_seq):
    if len(A_seq) < 2:
        return float('inf')
    A_stack = torch.stack(A_seq, dim=0)
    deltas = A_stack[1:] - A_stack[:-1]
    norms = deltas.view(deltas.size(0), -1).norm(dim=1)
    return norms.mean().item()


# def select_density_balanced_subset(score_dict, beta=0.05, alpha=0.5, sample_size=None, bandwidth_adjust=1.0):
#     indices = list(score_dict.keys())
#     scores = np.array([score_dict[i] for i in indices])
#     threshold = np.percentile(scores, 100 * (1 - beta))
#     kept = [(i, s) for i, s in zip(indices, scores) if s <= threshold]

#     if not kept:
#         raise ValueError("No samples left after beta filtering.")

#     kept_indices, kept_scores = zip(*kept)
#     kept_scores = np.array(kept_scores)

#     kde = gaussian_kde(kept_scores)
#     kde.set_bandwidth(bw_method=kde.factor * bandwidth_adjust)
#     densities = kde.evaluate(kept_scores)

#     weights = 1.0 / (densities + 1e-8)
#     weights /= weights.sum()

#     m = sample_size if sample_size is not None else int(len(score_dict) * (1 - alpha))
#     selected_indices = np.random.choice(kept_indices, size=m, replace=False, p=weights)
#     return selected_indices.tolist()


def analyze(log_path="structure_perturbation_log.pt",
                       save_score_path="sorted_epochwise_perturbation.csv",
                       epoch = 100):
    perturbation_log = torch.load(log_path)

    score_dict = {}
    for idx, A_seq in tqdm(perturbation_log.items(), desc="Computing epochwise perturbation"):
        A_seq = A_seq[:epoch+1]
        score = compute_epochwise_delta(A_seq)
        score_dict[idx] = score

    sorted_scores = sorted(score_dict.items(), key=lambda x: x[1], reverse=True) # 从大到小排序（降序）

    if save_score_path:
        df = pd.DataFrame(sorted_scores, columns=["Sample ID", "SCLCS Score"])
        # torch.save(sorted_scores, save_score_path)
        df.to_csv(save_score_path, index=False)
        print(f"Saved sorted scores to: {save_score_path}")

    return score_dict


# def select(score_dict, epc_end=100, beta=0.05, smpr=0.5):
#     os.makedirs("./sclcs_list", exist_ok=True)
#     save_selected_path = f"sclcs_list/sclcs_end{epc_end}_{beta}_{int(smpr*100)}.csv"
#     smp_size = int(len(score_dict) * smpr)

#     # 执行密度感知筛选
#     selected = select_density_balanced_subset(score_dict, beta=beta, sample_size=smp_size)
#     print(f"\nSelected {len(selected)} core-set samples (density-aware).")
#     dens_df = pd.DataFrame({"Sample ID": selected})
#     dens_df.to_csv(save_selected_path, index=False)
#     print(f"Saved selected indices to: {save_selected_path}")

#     return selected


if __name__ == "__main__":

    # for epoch_end in epochs_opts:
    #     score2save_path = f"SCLCS/sorted_epochwise_perturbation_end{epoch_end}.csv"
    #     if not os.path.exists(score2save_path):
    #         print(f"Computing epochwise perturbation for epoch {epoch_end}...")
    #         # 计算每个样本的SCLCS分数
    #         scores = analyze(log_path=log_path, save_score_path=score2save_path, epoch=epoch_end)
    #     scores_df = pd.read_csv(score2save_path)
    #     sorted_scores = {row["Sample ID"]: row["SCLCS Score"] for _, row in scores_df.iterrows()}
    #     for smpr in [0.1, 0.3, 0.5]:
    #         select(score_dict=sorted_scores, epc_end=epoch_end, beta=0.05, smpr=smpr)

    tasks = ['sub']
    repeats = [1, 2, 3]
    epochs_opts = [50, 100, 150, 200]
    for out_dim in [64, 128, 256, 'data_dim']:
        for head_num in [2, 4, 8, 16, 32]:
            for task in tasks:
                for repeat in repeats:
                    log_file = f"AH_Coresets_ours_list/log_{task}C_r{repeat}_{head_num}, {out_dim}.pt"
                    print(f"Processing task: {task}, repeat: {repeat}")
                    print("-" * 60)
                    for end_epc in epochs_opts:
                        score2save_path = f"AH_Coresets_ours_list/score_{task}_epc{end_epc}_{head_num}_{out_dim}_r{repeat}.csv"
                        if not os.path.exists(score2save_path):
                            print(f"Computing epochwise perturbation for epoch {end_epc}...")
                            # 计算每个样本的SCLCS分数，并从大到小排序保存
                            scores = analyze(log_path=log_file, save_score_path=score2save_path, epoch=end_epc)
