import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from TSPD.utils import get_transform

# 假设这是你的类的一部分
class ImagenetTiny:
    def __init__(self, cfg, model, device):
        self.model = model
        self.device = device
        self.cfg = cfg

    def load_tiny_imagenet_train_features(self):
        # 数据路径
        data_path = "/data/fh/ssd/projects/dataset/imageNet/tiny_imagenet/tiny-imagenet-200/train"

        # 定义图像预处理
        transform = get_transform(self.cfg)

        # 加载训练数据集
        train_dataset = datasets.ImageFolder(root=data_path, transform=transform)
        train_loader = DataLoader(train_dataset, batch_size=self.cfg.TSPD.optim.batch_size, shuffle=False, num_workers=4)

        # 初始化 all_image_features
        all_image_features = torch.empty([0, self.model.vis_dim],
                                        dtype=self.model.dtype,
                                        device=self.device)

        # 不计算梯度
        with torch.no_grad():
            for sample, _ in train_loader:  # sample 是图像，_ 是标签（这里不需要标签）
                # 将数据类型和设备设置为模型所需
                sample = sample.type(self.model.dtype).to(self.device)
                # 通过模型的图像编码器提取特征
                image_features = self.model.image_encoder_ori(sample)
                # 归一化特征
                image_features = image_features / image_features.norm(dim=-1, keepdim=True)
                # 拼接特征
                all_image_features = torch.cat([all_image_features, image_features.detach()], dim=0)

        # 转换为 float 类型
        all_image_features = all_image_features.type(torch.float)

        return all_image_features