from scipy.spatial.distance import pdist, cdist
import pickle
import numpy as np
import os
import json
import argparse
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import umap
from sklearn.decomposition import PCA
from scipy.spatial.distance import cosine

def compute_adjacent_layer_similarity(all_samples, layer_idx, key="hidden_states_last_token") -> float:
    """
    计算相邻两层（layer_idx 与 layer_idx+1）hidden states 的平均余弦相似度：
      1. 对每个样本，提取 layer_idx 和 layer_idx+1 上的向量；
      2. 计算两向量的余弦相似度 1 - cosine_distance；
      3. 对所有样本取平均，得到这一对层的整体相似度。
    """
    sims = []
    for sample in all_samples:
        vec1 = np.array(sample[key][layer_idx])
        vec2 = np.array(sample[key][layer_idx + 1])
        # cosine(vec1, vec2) 返回距离，1 - 距离 即相似度
        sims.append(1.0 - cosine(vec1, vec2))
    return float(np.mean(sims))

def load_dataset(file_path, label):
    """
    加载 pickle 文件，返回该文件中所有样本和对应的标签。
    假定 pickle 文件保存的字典中包含键 "last_results"，其中每个样本是一个字典，
    并包含键 "hidden_states_last_token"（一个列表，每个元素为该层最后一个 token 的向量）。
    """
    with open(file_path, 'rb') as f:
        data = pickle.load(f)
    dataset = data.get("last_results", {})
    samples = []
    for i, key in enumerate(dataset):
        if i < 1000:
            samples.append(dataset[key])
    return samples, label

def load_all_data(input_files):
    """
    输入文件列表，每个文件对应一个类别，返回所有样本和标签。
    假设每个文件中有700个样本，每个样本为一个字典，其 "hidden_states_last_token" 是一个列表，
    且每个列表中的向量为4096维。
    """
    all_samples = []
    all_labels = []
    for label, file_path in enumerate(input_files):
        samples, lbl = load_dataset(file_path, label)
        all_samples.extend(samples)
        all_labels.extend([lbl] * len(samples))
    all_labels = np.array(all_labels)
    return all_samples, all_labels

def extract_layer_features(samples, layer_idx, key="hidden_states_last_token"):
    """
    从每个样本中提取指定层的特征向量。
    假设 samples 中每个元素的 key 对应的是一个列表，列表中每个元素为该层的向量。
    """
    features = []
    for sample in samples:
        feature = np.array(sample[key][layer_idx])
        features.append(feature)
    return np.array(features)

def compute_intra_class_distance(features: np.ndarray, labels: np.ndarray, metric='cosine') -> float:
    """
    计算所有类别的平均类内距离：
      1. 对每个类别，取该类别所有样本的两两距离 pdist，求平均值；
      2. 再对所有类别的平均值取均值得到最终结果。
    """
    intra_per_class = []
    for cls in np.unique(labels):
        cls_feats = features[labels == cls]
        if cls_feats.shape[0] > 1:
            # pdist 返回所有两两组合的距离
            dists = pdist(cls_feats, metric=metric)
            intra_per_class.append(dists.mean())
    return float(np.mean(intra_per_class))


def compute_inter_class_distance(features: np.ndarray, labels: np.ndarray, metric='cosine') -> float:
    """
    计算所有类别质心之间的平均距离：
      1. 先计算每个类别的质心（mean vector）；
      2. 再对这些质心两两之间的距离 pdist，求平均值。
    """
    centroids = []
    for cls in np.unique(labels):
        centroids.append(features[labels == cls].mean(axis=0))
    centroids = np.stack(centroids, axis=0)
    # pdist 对质心两两距离
    dists = pdist(centroids, metric=metric)
    return float(dists.mean())

def visualize_embeddings(embeddings, labels, method='tsne', output_file="embedding_plot.png"):
    if method.lower() == 'tsne':
        # TSNE 参数说明：
        # n_components: 降维后的目标维度（通常为2或3）
        # perplexity: 围绕每个数据点考虑的邻居数量，常用值在5~50之间，默认一般为30
        # learning_rate: 学习率，默认一般为200，可根据数据情况调整
        # n_iter: 优化迭代次数，通常需要设置较大的值（如1000或更高）以保证收敛
        # init: 初始化方式，可选择 'pca'（利用PCA初始化）或者 'random'
        # metric: 距离度量方式，如 'euclidean'、'cosine' 等
        reducer = TSNE(
            n_components=2,
            random_state=42,
            perplexity=30,
            learning_rate=200,
            n_iter=1000,
            init='pca',
            metric='cosine'
        )
    elif method.lower() == 'umap':
        # UMAP 参数说明：
        # n_components: 降维后的目标维度（通常为2或3）
        # n_neighbors: 每个数据点考虑的邻居数，控制局部结构的保留，默认15，值越大全局结构保留越好
        # min_dist: 低维空间中点之间的最小距离，决定嵌入的紧凑程度，默认0.1
        # metric: 用于计算高维数据距离的度量方式，如 'euclidean', 'manhattan', 'cosine' 等
        # learning_rate: 学习率，影响嵌入的优化过程，默认1.0
        # n_epochs: 迭代次数，默认None时UMAP会自动选择
        # spread: 与 min_dist 配合控制嵌入点的分布范围，默认1.0
        reducer = umap.UMAP(
            n_components=2,
            random_state=42,
            n_neighbors=50,
            min_dist=0.1,
            metric='cosine',
            learning_rate=1.0,
            n_epochs=None,
            spread=1.0
        )
    elif method.lower() == 'pca':
        reducer = PCA(n_components=2)
    else:
        raise ValueError("Unsupported method: choose 'tsne', 'umap' or 'pca'")
    
    reduced = reducer.fit_transform(embeddings)
    
    plt.figure(figsize=(8, 6))
    unique_labels = np.unique(labels)
    langs = ["English", "Chinese", "Spanish", "French", "Hindi"]
    for label, lang in zip(unique_labels, langs):
        idx = labels == label
        plt.scatter(reduced[idx, 0], reduced[idx, 1], label=f"{lang}", alpha=0.7)
    plt.legend()
    plt.title(f"Visualization using {method.upper()}")
    plt.xlabel("Component 1")
    plt.ylabel("Component 2")
    plt.grid(True)
    plt.savefig(output_file)
    plt.close()
    print(f"Saved {method.upper()} plot to {output_file}")

import os
import numpy as np
import matplotlib.pyplot as plt

def cal_similarity_model(input_files, output_dir, lang, thresholds=[0.95, 0.945, 0.94]):
    os.makedirs(output_dir, exist_ok=True)

    # Load data
    all_samples, all_labels = load_all_data(input_files)
    num_layers = len(all_samples[0]["hidden_states_last_token"])

    # Compute similarities
    layer_similarities = [
        compute_adjacent_layer_similarity(all_samples, idx)
        for idx in range(num_layers - 1)
    ]
    for idx, sim in enumerate(layer_similarities):
        print(f"Layer {idx} ↔ {idx+1} similarity = {sim:.4f}")

    # Plot similarity curve
    plt.figure(figsize=(8, 4))
    plt.plot(range(num_layers - 1), layer_similarities, marker='^')
    plt.xlabel("Layer Pair Index (L ↔ L+1)")
    plt.ylabel("Average Cosine Similarity")
    plt.title("Adjacent Layer Hidden-State Similarity")
    plt.grid(True)
    plot_path = os.path.join(output_dir, f"{lang}_adjacent_layer_similarity.png")
    plt.savefig(plot_path)
    plt.close()
    print(f"Saved similarity plot to {plot_path}")

    # Identify high-similarity layer pairs for each fixed threshold
    results = {}
    for thresh in thresholds:
        high_pairs = [i for i, s in enumerate(layer_similarities) if s >= thresh]
        results[thresh] = high_pairs
        print(f"Threshold={thresh:.3f}, pairs={high_pairs}")

    # Save results
    result_file = os.path.join(output_dir, f"{lang}_high_similarity_by_threshold.txt")
    with open(result_file, 'w') as f:
        for thresh, pairs in results.items():
            f.write(f"Threshold = {thresh:.3f}\n")
            f.write(f"Layer pairs >= threshold: {pairs}\n\n")
    print(f"Saved threshold results to {result_file}")





if __name__ == "__main__":
    languages = ["English", "Chinese", "Spanish", "French", "Hindi"]

    task_files=["gsm8k_alpaca_train.json", "FOLIO_train_modified.json", "LogicalDeduction_train_modified.json", "ProofWriter_train_modified.json"]
    # 去掉后缀得到基本名称
    model_paths=["/home/Llama-3", "/data/microsoft_Phi-3-mini-4k-instruct", "/data/Qwen_Qwen2-7B-Instruct"]
    models = [path.split("/")[-1] for path in model_paths]
    for model in models:
        for task_file in task_files:
            for lang in languages:
                task_basename = task_file.replace(".json", "")
                # input_files = [f"{model}_hiddenstate_prompt/{lang}_{task_basename}.pkl" for lang in languages]
                input_files = [f"{model}_hiddenstate/{lang}_{task_basename}.pkl"]
                cal_similarity_model(input_files, f"Visual_{model}/{task_basename}_last", lang)

    
    

