import pickle
import numpy as np
import os
import json
import argparse
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import umap
from sklearn.decomposition import PCA
def load_dataset(file_path, label):
    """
    加载 pickle 文件，返回该文件中所有样本和对应的标签。
    假定 pickle 文件保存的字典中包含键 "last_results"，其中每个样本是一个字典，
    并包含键 "hidden_states_last_token"（一个列表，每个元素为该层最后一个 token 的向量）。
    """
    with open(file_path, 'rb') as f:
        data = pickle.load(f)
    dataset = data.get("last_results", {})
    samples = []
    for i, key in enumerate(dataset):
        if i < 1000:
            samples.append(dataset[key])
    return samples, label

def load_all_data(input_files):
    """
    输入文件列表，每个文件对应一个类别，返回所有样本和标签。
    假设每个文件中有700个样本，每个样本为一个字典，其 "hidden_states_last_token" 是一个列表，
    且每个列表中的向量为4096维。
    """
    all_samples = []
    all_labels = []
    for label, file_path in enumerate(input_files):
        samples, lbl = load_dataset(file_path, label)
        all_samples.extend(samples)
        all_labels.extend([lbl] * len(samples))
    all_labels = np.array(all_labels)
    return all_samples, all_labels

def extract_layer_features(samples, layer_idx, key="hidden_states_last_token"):
    """
    从每个样本中提取指定层的特征向量。
    假设 samples 中每个元素的 key 对应的是一个列表，列表中每个元素为该层的向量。
    """
    features = []
    for sample in samples:
        feature = np.array(sample[key][layer_idx])
        features.append(feature)
    return np.array(features)

def visualize_embeddings(embeddings, labels, method='tsne', output_file="embedding_plot.png"):
    if method.lower() == 'tsne':
        # TSNE 参数说明：
        # n_components: 降维后的目标维度（通常为2或3）
        # perplexity: 围绕每个数据点考虑的邻居数量，常用值在5~50之间，默认一般为30
        # learning_rate: 学习率，默认一般为200，可根据数据情况调整
        # n_iter: 优化迭代次数，通常需要设置较大的值（如1000或更高）以保证收敛
        # init: 初始化方式，可选择 'pca'（利用PCA初始化）或者 'random'
        # metric: 距离度量方式，如 'euclidean'、'cosine' 等
        reducer = TSNE(
            n_components=2,
            random_state=42,
            perplexity=30,
            learning_rate=200,
            n_iter=1000,
            init='pca',
            metric='cosine'
        )
    elif method.lower() == 'umap':
        # UMAP 参数说明：
        # n_components: 降维后的目标维度（通常为2或3）
        # n_neighbors: 每个数据点考虑的邻居数，控制局部结构的保留，默认15，值越大全局结构保留越好
        # min_dist: 低维空间中点之间的最小距离，决定嵌入的紧凑程度，默认0.1
        # metric: 用于计算高维数据距离的度量方式，如 'euclidean', 'manhattan', 'cosine' 等
        # learning_rate: 学习率，影响嵌入的优化过程，默认1.0
        # n_epochs: 迭代次数，默认None时UMAP会自动选择
        # spread: 与 min_dist 配合控制嵌入点的分布范围，默认1.0
        reducer = umap.UMAP(
            n_components=2,
            random_state=42,
            n_neighbors=50,
            min_dist=0.1,
            metric='cosine',
            learning_rate=1.0,
            n_epochs=None,
            spread=1.0
        )
    elif method.lower() == 'pca':
        reducer = PCA(n_components=2)
    else:
        raise ValueError("Unsupported method: choose 'tsne', 'umap' or 'pca'")
    
    reduced = reducer.fit_transform(embeddings)
    
    plt.figure(figsize=(8, 6))
    unique_labels = np.unique(labels)
    langs = ["English", "Chinese", "Spanish", "French", "Hindi"]
    for label, lang in zip(unique_labels, langs):
        idx = labels == label
        plt.scatter(reduced[idx, 0], reduced[idx, 1], label=f"{lang}", alpha=0.7)
    plt.legend()
    plt.title(f"Visualization using {method.upper()}")
    plt.xlabel("Component 1")
    plt.ylabel("Component 2")
    plt.grid(True)
    plt.savefig(output_file)
    plt.close()
    print(f"Saved {method.upper()} plot to {output_file}")

def visualization_model(input_files, output_dir):
    os.makedirs(output_dir, exist_ok=True)
        # 加载所有数据和标签
    all_samples, all_labels = load_all_data(input_files)
    
    # 检查样本数量和每个样本的维度
    # 假设每个样本中 "hidden_states_last_token" 对应的向量维度为 4096
    num_layers = len(all_samples[0]["hidden_states_last_token"])
    for layer_idx in range(num_layers):
        features = extract_layer_features(all_samples, layer_idx)
        print(f"Extracted features shape: {features.shape}")  # 预期 (3500, 4096) 如果5个类，每类700个样本
        
        # 使用 t-SNE 降维并可视化
        tsne_output_file = os.path.join(output_dir, f"tsne_layer_{layer_idx}.png")
        visualize_embeddings(features, all_labels, method='tsne', output_file=tsne_output_file)
        
        # # 使用 UMAP 降维并可视化
        umap_output_file = os.path.join(output_dir, f"umap_layer_{layer_idx}.png")
        visualize_embeddings(features, all_labels, method='umap', output_file=umap_output_file)

        pca_output_file = os.path.join(output_dir, f"pca_layer_{layer_idx}.png")
        visualize_embeddings(features, all_labels, method='pca', output_file=pca_output_file)

if __name__ == "__main__":
    languages = ["English", "Chinese", "Spanish", "French", "Hindi"]
    # task_file = "LogicalDeduction_train_modified.json"
    task_files=["svamp_alpaca_train.json", "gsm8k_alpaca_train.json","ar_LSAT_train_modified_ag.json", "FOLIO_train_modified.json", "LogicalDeduction_train_modified.json", "ProofWriter_train_modified.json"]
    model_paths=["/home/Llama-3", "/data/microsoft_Phi-3-mini-4k-instruct", "/data/Qwen_Qwen2-7B-Instruct"]
    models = [path.split("/")[-1] for path in model_paths]
    for model in models:
        for task_file in task_files:
            task_basename = task_file.replace(".json", "")
            input_files = [f"{model}_hiddenstate/{lang}_{task_basename}.pkl" for lang in languages]
            visualization_model(input_files, f"Visual_{model}/{task_basename}_last")

    
    # 
    

