# tsne_multiclass.py
import os 
import sys
import random 
import numpy as np 
import torch 
from torch.utils.data import DataLoader
import random
from tqdm import tqdm 
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from matplotlib.lines import Line2D  # 添加导入
import seaborn as sns
sns.set_theme(style="whitegrid")

# 切换工作目录
os.chdir("/guardrail/TaskTracker")

from task_tracker.training.dataset import ActivationsDatasetDynamicPrimaryText
from task_tracker.training.utils.constants import TEST_CLEAN_FILES_PER_MODEL, TEST_POISONED_FILES_PER_MODEL

# 配置参数
MODEL = 'llama3_8b'
BATCH_SIZE = 256
TEST_ACTIVATIONS_DIR = '/guardrail/TaskTracker/store/activations/Unauthorized_Access/Medicalsys/llama3_8b/test'
FILES_CHUNCK = 10 
LAYERS = 80 if MODEL == 'llama3_70b' else 32 

# 获取文件列表
case_files = TEST_CASE_FILES_PER_MODEL[MODEL]
print(f'{len(case_files)} case files')
financial_files = TEST_FINANCIAL_FILES_PER_MODEL[MODEL]
print(f'{len(financial_files)} financial files')
employee_files = TEST_EMPLOYEE_FILES_PER_MODEL[MODEL]
print(f'{len(employee_files)} employee files')
goods_files = TEST_GOODS_FILES_PER_MODEL[MODEL]
print(f'{len(goods_files)} goods files')

def compute_distances(tensor1, tensor2):
    distances = torch.norm(tensor1 - tensor2, p=2, dim=-1)
    return distances

def compute_activations_residuals(evaluate_files):
    emb_diffs = [[] for i in range(0, LAYERS)] 

    for i in range(0, len(evaluate_files), FILES_CHUNCK):
        files = evaluate_files[i:i+FILES_CHUNCK]
        dataset = ActivationsDatasetDynamicPrimaryText(files, LAYERS, TEST_ACTIVATIONS_DIR)
        data_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)
        for j, data in enumerate(data_loader):
            primary = data 
            for layer in range(0, primary.size(1)):
                emb_diffs[layer].append(primary[:, layer, :])
    
    emb_diffs_stacked = []
    for layer in range(0, LAYERS):
        emb_diffs_stacked.append(torch.vstack(emb_diffs[layer]).float())
    
    return emb_diffs_stacked

# 计算激活残差
print("Computing activations for case data...")
diff_case = compute_activations_residuals(case_files)
print("Computing activations for financial data...")
diff_financial = compute_activations_residuals(financial_files)
print("Computing activations for employee data...")
diff_employee = compute_activations_residuals(employee_files)
print("Computing activations for goods data...")
diff_goods = compute_activations_residuals(goods_files)

# 使用t-SNE降维
print("Applying t-SNE...")
all_tsne_embs = []
for layer in range(0, len(diff_case)):
    print(f'Processing layer {layer}')
    combined_diffs = np.vstack((
        diff_case[layer], 
        diff_financial[layer], 
        diff_employee[layer], 
        diff_goods[layer]
    ))
    
    labels = (
        ['Case' for _ in range(len(diff_case[layer]))] + 
        ['Financial' for _ in range(len(diff_financial[layer]))] + 
        ['Employee' for _ in range(len(diff_employee[layer]))] + 
        ['Goods' for _ in range(len(diff_goods[layer]))]
    )
    
    tsne = TSNE(n_components=2, random_state=42)
    reduced_diff_embeddings = tsne.fit_transform(combined_diffs)
    all_tsne_embs.append(reduced_diff_embeddings)

# 可视化结果
print("Generating visualizations...")
layers_to_plot = [0, 7, 15, 23, 31]  # 选择特定层进行可视化
fig, axs = plt.subplots(len(layers_to_plot), 1, figsize=(12, 6*len(layers_to_plot)))

labels = (
    ['Case' for _ in range(len(diff_case[0]))] + 
    ['Financial' for _ in range(len(diff_financial[0]))] + 
    ['Employee' for _ in range(len(diff_employee[0]))] + 
    ['Goods' for _ in range(len(diff_goods[0]))]
)

# 定义颜色
colors = ["#4666A9", "#4C9E5E", "#705E78", "#D77A47"]
class_labels = ['Case', 'Financial', 'Employee', 'Goods']

for idx, layer_idx in enumerate(layers_to_plot):
    if layer_idx >= LAYERS:
        continue
    
    # 绘制散点图但不设置图例
    for label, color in zip(class_labels, colors):
        condition = np.array(labels) == label
        axs[idx].scatter(
            all_tsne_embs[layer_idx][condition, 0],
            all_tsne_embs[layer_idx][condition, 1],
            c=color,
            alpha=0.3  # 保持原始点的不透明度不变
        )
    
    # 创建自定义图例元素
    legend_elements = [
        Line2D([0], [0], marker='o', color='w', 
               markerfacecolor=colors[0], 
               markersize=15,  # 增大图例点大小 
               alpha=0.95,     # 增加图例点不透明度
               label="User 1"),
        Line2D([0], [0], marker='o', color='w', 
               markerfacecolor=colors[1], 
               markersize=15, 
               alpha=0.95,
               label="User 2"),
        Line2D([0], [0], marker='o', color='w', 
               markerfacecolor=colors[2], 
               markersize=15, 
               alpha=0.95,
               label="User 3"),
        Line2D([0], [0], marker='o', color='w', 
               markerfacecolor=colors[3], 
               markersize=15, 
               alpha=0.95,
               label="User 4")
    ]
    
    # 增大坐标轴标签字体大小与tsne_re.py文件相同
    axs[idx].set_xlabel('Component 1', fontsize=25)
    axs[idx].set_ylabel('Component 2', fontsize=25)
    axs[idx].set_title(f'Layer {layer_idx}', fontsize=25)
    
    # 增大刻度标签字体大小
    axs[idx].tick_params(axis='both', which='major', labelsize=18)
    
    # 创建带有自定义元素的图例
    legend = axs[idx].legend(
        handles=legend_elements,
        loc='upper right', 
        fontsize=25, 
        frameon=True
    )
    
    # 增加图例边框的不透明度
    legend.get_frame().set_alpha(0.95)
    
    # 添加网格线和样式设置，与tsne_re.py保持一致
    axs[idx].grid(True, which='minor', linewidth=0.25)
    axs[idx].minorticks_on()
    sns.despine(ax=axs[idx], left=True, bottom=True)

plt.tight_layout()
plt.savefig('./store/fig/tsne/Unauthorized_tsne.pdf', format='pdf', bbox_inches='tight', pad_inches=0.5, transparent=False)
print("Visualization saved as 'Unauthorized_tsne.pdf'")