# tsne_binary.py
import os
import random 
import numpy as np 
import torch 
from torch.utils.data import DataLoader
import random
from tqdm import tqdm 
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import seaborn as sns
sns.set_theme(style="whitegrid")
from sklearn.neighbors import NearestNeighbors
from matplotlib.lines import Line2D

from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable

# 切换工作目录
os.chdir("/guardrail/TaskTracker")

from task_tracker.training.dataset import ActivationsDatasetDynamicPrimaryText
#from task_tracker.training.utils.constants import TEST_CLEAN_FILES_PER_MODEL, TEST_POISONED_FILES_PER_MODEL

# 配置参数
RISK = "Hijacking"
MODEL = 'llama3_8b'
DATASET = "Medicalsys"
BATCH_SIZE = 256
FILES_CHUNCK = 10 
LAYERS = 80 if MODEL == 'llama3_70b' else 32 

DATA_LISTS = "/guardrail/TaskTracker/store/data/" + RISK + "/" + DATASET
TEST_ACTIVATIONS_DIR = f'/guardrail/TaskTracker/store/activations/{RISK}/{DATASET}/{MODEL}/test'

from task_tracker.training.dataset import ActivationsDatasetDynamicPrimaryText
#from task_tracker.training.utils.constants import TEST_CLEAN_FILES_PER_MODEL, TEST_POISONED_FILES_PER_MODEL

# 配置参数
RISK = "Hijacking"
MODEL = 'llama3_8b'
DATASET = "Medicalsys"
BATCH_SIZE = 256
FILES_CHUNCK = 10 
LAYERS = 80 if MODEL == 'llama3_70b' else 32 

DATA_LISTS = "/guardrail/TaskTracker/store/data/" + RISK + "/" + DATASET
TEST_ACTIVATIONS_DIR = f'/guardrail/TaskTracker/store/activations/{RISK}/{DATASET}/{MODEL}/test'
TEST_CLEAN_FILES_PER_MODEL = {
    'llama3_8b' : [file.strip() for file in open(os.path.join(DATA_LISTS, 'test_clean_files_llama3_8b.txt'))]
}
TEST_POISONED_FILES_PER_MODEL = {
    'llama3_8b' : [file.strip() for file in open(os.path.join(DATA_LISTS, 'test_poisoned_files_llama3_8b.txt'))]
}
# 获取文件列表
clean_files = TEST_CLEAN_FILES_PER_MODEL[MODEL]
print(f'{len(clean_files)} clean files')
poisoned_files = TEST_POISONED_FILES_PER_MODEL[MODEL]
print(f'{len(poisoned_files)} poisoned files')

def compute_distances(tensor1, tensor2):
    distances = torch.norm(tensor1 - tensor2, p=2, dim=-1)
    return distances

def compute_activations_residuals(evaluate_files):
    emb_diffs = [[] for i in range(0, LAYERS)] 

    for i in range(0, len(evaluate_files), FILES_CHUNCK):
        files = evaluate_files[i:i+FILES_CHUNCK]
        dataset = ActivationsDatasetDynamicPrimaryText(files, LAYERS, TEST_ACTIVATIONS_DIR)
        data_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)
        for j, data in enumerate(data_loader):
            primary = data 
            for layer in range(0, primary.size(1)):
                emb_diffs[layer].append(primary[:, layer, :])
    
    emb_diffs_stacked = []
    for layer in range(0, LAYERS):
        emb_diffs_stacked.append(torch.vstack(emb_diffs[layer]).float())
    
    return emb_diffs_stacked

# 计算激活残差
print("Computing activations for poisoned data...")
diff_poisoned = compute_activations_residuals(poisoned_files)
print("Computing activations for clean data...")
diff_clean = compute_activations_residuals(clean_files)

# 设置采样率 - 减少点的数量
sampling_rate = 0.2  # 使用20%的点
print(f"Sampling {sampling_rate * 100}% of points...")

# 采样点
sampled_clean = []
sampled_poisoned = []
for layer in range(len(diff_clean)):
    # 随机选择索引
    clean_indices = np.random.choice(
        len(diff_clean[layer]), 
        size=int(len(diff_clean[layer]) * sampling_rate), 
        replace=False
    )
    poisoned_indices = np.random.choice(
        len(diff_poisoned[layer]), 
        size=int(len(diff_poisoned[layer]) * sampling_rate), 
        replace=False
    )
    
    sampled_clean.append(diff_clean[layer][clean_indices])
    sampled_poisoned.append(diff_poisoned[layer][poisoned_indices])

# 使用t-SNE降维
print("Applying t-SNE...")
all_tsne_embs = []
all_labels = []

for layer in range(0, len(sampled_clean)):
    print(f'Processing layer {layer}')
    combined_diffs = np.vstack((sampled_clean[layer], sampled_poisoned[layer]))
    labels = ['Clean' for _ in range(len(sampled_clean[layer]))] + ['Poisoned' for _ in range(len(sampled_poisoned[layer]))]
    
    tsne = TSNE(n_components=2, random_state=42)
    reduced_diff_embeddings = tsne.fit_transform(combined_diffs)
    all_tsne_embs.append(reduced_diff_embeddings)
    all_labels.append(np.array(labels))

# 可视化结果
print("Generating visualizations...")
layers_to_plot = [0, 7, 15, 23, 31]
fig, axs = plt.subplots(len(layers_to_plot), 1, figsize=(10, 6 * len(layers_to_plot)))

# 使用固定大小的点
point_size = 50  # 统一的点大小
legend_point_size = 15  # 图例点大小 (使用更大的值，因为matplotlib的markersize单位与scatter的s单位不同)
label_to_color = {'Clean': "#4C9E5E", 'Poisoned': "#BB4647"}

for idx, layer_idx in enumerate(layers_to_plot):
    if layer_idx >= LAYERS:
        continue

    ax = axs[idx]
    emb = all_tsne_embs[layer_idx]
    labels = all_labels[layer_idx]

    # 绘制散点图但不设置图例
    for label in ['Clean', 'Poisoned']:
        condition = labels == label
        ax.scatter(
            emb[condition, 0], emb[condition, 1],
            alpha=0.6,  # 保持原始点的不透明度不变
            s=point_size,
            c=[label_to_color[label]] * np.sum(condition)
        )

    # 创建自定义图例元素
    legend_elements = [
        Line2D([0], [0], marker='o', color='w', 
               markerfacecolor=label_to_color['Clean'], 
               markersize=legend_point_size, 
               alpha=0.95,  # 图例点的更高不透明度
               label='Benign'),
        Line2D([0], [0], marker='o', color='w', 
               markerfacecolor=label_to_color['Poisoned'], 
               markersize=legend_point_size, 
               alpha=0.95,  # 图例点的更高不透明度
               label='Malicious')
    ]

    ax.set_xlabel('Component 1', fontsize=25)
    ax.set_ylabel('Component 2', fontsize=25)
    
    # 创建带有自定义元素的图例
    legend = ax.legend(handles=legend_elements, 
                      loc='upper right', 
                      fontsize=25, 
                      frameon=True)
    
    # 增加图例边框的不透明度
    legend.get_frame().set_alpha(0.95)
    
    ax.grid(True, which='minor', linewidth=0.25)
    ax.minorticks_on()
    sns.despine(ax=ax, left=True, bottom=True)

plt.tight_layout()
plt.savefig(f'./store/fig/tsne/{RISK}_{DATASET}_tsne_visualization.pdf', format='pdf', bbox_inches='tight', pad_inches=0.5, transparent=False)
print(f"Visualization saved as '{RISK}_{DATASET}_tsne_visualization.pdf'")