import torch
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import os
import numpy as np
from sklearn import manifold
from tqdm import tqdm
import matplotlib.pyplot as plt
import open_clip

class ImageFolderDataset(Dataset):
    def __init__(self, folder_path, transform=None):
        self.raw_file_paths = [os.path.join(folder_path, file) for file in os.listdir(folder_path)]
        self.transform = transform
        
        self.file_paths = []
        for path in self.raw_file_paths:
            if 'checkpoint' in path:
                continue
            self.file_paths.append(path)
                
    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, index):
        img_path = self.file_paths[index]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image

device = torch.device("cuda" if (torch.cuda.is_available()) else "cpu")
    
# 预处理图像和加载预训练模型
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

model, _, preprocess = open_clip.create_model_and_transforms('ViT-H-14', pretrained='/root/ViT-H-14/open_clip_pytorch_model.bin')
model.eval()

if torch.cuda.is_available():
    model = model.to(device)

# 图像路径列表
#origin_path = '/root/autodl-tmp/img2img_unlearning/diffusion/experiments/origin/random-0.5-new'
base_path_1 = '/root/autodl-tmp/img2img_unlearning/diffusion/experiments/test_ours_all-4.2-encoder-center-1e-5-100-10-multi/ours_all-4.2-encoder-center-1e-5-100-10-multi'
base_path_2 = '/root/autodl-tmp/img2img_unlearning/diffusion/experiments/test_retain_label-0.25/retain_label-0.25'

path_1 = base_path_1 + '/GT/forget'
path_2 = base_path_2 + '/GT/retain'
path_3 = base_path_1 + '/Out/forget'
path_4 = base_path_2 + '/Out/retain'

image_paths = [path_1,path_2,path_3,path_4]

# 提取特征
all_features = []
labels = []
for i, path in enumerate(image_paths):
    dataset = ImageFolderDataset(path, transform=transform)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=False)  # Batch size can be larger if fitting into your GPU memory
    for index, img in tqdm(enumerate(dataloader), total=len(dataloader)):
        #if index < 50:
        #    continue
        if index > 50:
            break
        with torch.no_grad():
            if torch.cuda.is_available():
                img = img.to(device)
            features = model.encode_image(img)  # Extract features
            all_features.append(features.cpu().numpy().squeeze(0)) # Squeeze the batch dimension
            labels.append(i)  # Label is the index of path in image_paths

all_features = np.vstack(all_features)  # Stack all features into a single NumPy array

# 使用T-SNE进行降维
features_tsne = manifold.TSNE(n_components=2, random_state=42).fit_transform(all_features)

# 可视化
colors = ['red', 'blue', 'green', 'purple']
markers = ['s', 's', 'x', 'x']
for i, path in enumerate(image_paths):
    indices = [j for j, x in enumerate(labels) if x == i]
    subset = features_tsne[indices]
    plt.scatter(subset[:, 0], subset[:, 1], c=colors[i], marker=markers[i], label=f'Path {i + 1}')

plt.title('(a) MAE')
plt.grid(True)  # Enable grid lines
plt.savefig('t-sne-diffusion/all-6.0.png')  # Corrected line
plt.show()