# tsne_binary.py
import os 
import sys
import random 
import numpy as np 
import torch 
from torch.utils.data import DataLoader
import random
from tqdm import tqdm 
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import seaborn as sns
sns.set_theme(style="whitegrid")
from sklearn.neighbors import NearestNeighbors

from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable

# 切换工作目录
os.chdir("/guardrail/TaskTracker")

from task_tracker.training.dataset import ActivationsDatasetDynamicPrimaryText
from task_tracker.training.utils.constants import TEST_CLEAN_FILES_PER_MODEL, TEST_POISONED_FILES_PER_MODEL

# 配置参数
MODEL = 'llama3_8b'
BATCH_SIZE = 256
TEST_ACTIVATIONS_DIR = '/guardrail/TaskTracker/store/activations/Reconnaissance/hotpotqa/llama3_8b/test'
FILES_CHUNCK = 10 
LAYERS = 80 if MODEL == 'llama3_70b' else 32 

# 获取文件列表
clean_files = TEST_CLEAN_FILES_PER_MODEL[MODEL]
print(f'{len(clean_files)} clean files')
poisoned_files = TEST_POISONED_FILES_PER_MODEL[MODEL]
print(f'{len(poisoned_files)} poisoned files')

def compute_distances(tensor1, tensor2):
    distances = torch.norm(tensor1 - tensor2, p=2, dim=-1)
    return distances

def compute_activations_residuals(evaluate_files):
    emb_diffs = [[] for i in range(0, LAYERS)] 

    for i in range(0, len(evaluate_files), FILES_CHUNCK):
        files = evaluate_files[i:i+FILES_CHUNCK]
        dataset = ActivationsDatasetDynamicPrimaryText(files, LAYERS, TEST_ACTIVATIONS_DIR)
        data_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)
        for j, data in enumerate(data_loader):
            primary = data 
            for layer in range(0, primary.size(1)):
                emb_diffs[layer].append(primary[:, layer, :])
    
    emb_diffs_stacked = []
    for layer in range(0, LAYERS):
        emb_diffs_stacked.append(torch.vstack(emb_diffs[layer]).float())
    
    return emb_diffs_stacked

# 计算激活残差
print("Computing activations for poisoned data...")
diff_poisoned = compute_activations_residuals(poisoned_files)
print("Computing activations for clean data...")
diff_clean = compute_activations_residuals(clean_files)

# 设置采样率 - 减少点的数量
sampling_rate = 0.2  # 使用20%的点
print(f"Sampling {sampling_rate * 100}% of points...")

# 采样点
sampled_clean = []
sampled_poisoned = []
for layer in range(len(diff_clean)):
    # 随机选择索引
    clean_indices = np.random.choice(
        len(diff_clean[layer]), 
        size=int(len(diff_clean[layer]) * sampling_rate), 
        replace=False
    )
    poisoned_indices = np.random.choice(
        len(diff_poisoned[layer]), 
        size=int(len(diff_poisoned[layer]) * sampling_rate), 
        replace=False
    )
    
    sampled_clean.append(diff_clean[layer][clean_indices])
    sampled_poisoned.append(diff_poisoned[layer][poisoned_indices])

# 使用t-SNE降维
print("Applying t-SNE...")
all_tsne_embs = []
all_labels = []

for layer in range(0, len(sampled_clean)):
    print(f'Processing layer {layer}')
    combined_diffs = np.vstack((sampled_clean[layer], sampled_poisoned[layer]))
    labels = ['Clean' for _ in range(len(sampled_clean[layer]))] + ['Poisoned' for _ in range(len(sampled_poisoned[layer]))]
    
    tsne = TSNE(n_components=2, random_state=42)
    reduced_diff_embeddings = tsne.fit_transform(combined_diffs)
    all_tsne_embs.append(reduced_diff_embeddings)
    all_labels.append(np.array(labels))

# 可视化结果
print("Generating visualizations...")
layers_to_plot = [0, 7, 15, 23, 31]
fig, axs = plt.subplots(len(layers_to_plot), 1, figsize=(10, 6 * len(layers_to_plot)))

# 使用固定大小的点
point_size = 50  # 统一的点大小
label_to_color = {'Clean': "#4C9E5E", 'Poisoned': "#BB4647"}

for idx, layer_idx in enumerate(layers_to_plot):
    if layer_idx >= LAYERS:
        continue

    ax = axs[idx]
    emb = all_tsne_embs[layer_idx]
    labels = all_labels[layer_idx]

    for label in ['Clean', 'Poisoned']:
        condition = labels == label
        ax.scatter(
            emb[condition, 0], emb[condition, 1],
            label=label,
            alpha=0.6,
            s=point_size,
            c=[label_to_color[label]] * np.sum(condition)
        )

    ax.set_xlabel('Component 1', fontsize=25)
    ax.set_ylabel('Component 2', fontsize=25)
    # 确保图例中的点大小一致
    legend = ax.legend(loc='upper right', fontsize=25, markerscale=1, frameon=True)
    for handle in legend.legendHandles:
        handle.set_sizes([point_size])
    
    ax.grid(True, which='minor', linewidth=0.25)
    ax.minorticks_on()
    sns.despine(ax=ax, left=True, bottom=True)

plt.tight_layout()
plt.savefig('./store/fig/tsne/tsne_binary_visualization.pdf', format='pdf', bbox_inches='tight', pad_inches=0.5, transparent=False)
print("Visualization saved as 'tsne_binary_visualization.pdf'")