import timm
import torch
import torch.nn as nn
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, SubsetRandomSampler,Sampler
import os
import random
import torch.nn.functional as F
import numpy as np
import wandb
import argparse
import csv
import copy
from sklearn.model_selection import StratifiedShuffleSplit
from datasets import load_dataset
from torch.utils.data import Dataset
import time


parser = argparse.ArgumentParser()
parser.add_argument('--device', type=str, default='cuda:0')
parser.add_argument('--dataset', type=str, default='CIFAR10')
parser.add_argument('--labels', type=int, default=50)
parser.add_argument('--unlabels', type=int, default=500)
parser.add_argument('--epoch', type=int, default=5)
parser.add_argument('--lambda_u', type=float, default=1)
parser.add_argument('--lr', type=float, default=5e-5)
parser.add_argument('--net', type=str, default='vit')
args = parser.parse_args()
def set_seed(seed=0):
    random.seed(seed)                  # Python随机模块
    np.random.seed(seed)               # NumPy随机生成器
    torch.manual_seed(seed)            # PyTorch CPU随机种子
    torch.cuda.manual_seed(seed)       # PyTorch GPU随机种子
    torch.cuda.manual_seed_all(seed)   # 多GPU情况下的额外设置
    # 启用确定性算法（可能降低训练速度）
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)  # 禁止哈希随机化

set_seed(0)  # 设置全局种子为42

# ------------------------- 配置参数 -------------------------
BATCH_SIZE = 64
SMALL_BATCH_SIZE=64
EPOCHS = args.epoch
SAVE_DIR = "./checkpoints"  # 模型保存路径
os.makedirs(SAVE_DIR, exist_ok=True)
device=args.device

# ------------------------- 数据增强与加载 -------------------------
# 训练集增强
# train_transform = transforms.Compose([
#     transforms.Resize((224, 224)),
#     transforms.RandomResizedCrop(224, scale=(0.6, 1.0)),
#     transforms.RandomHorizontalFlip(),
#     transforms.ColorJitter(brightness=0.2, contrast=0.2),
#     transforms.ToTensor(),
#     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
# ])
class TinyImageNetDataset(Dataset):
    def __init__(self, hf_dataset, transform=None):
        self.dataset = hf_dataset  # Hugging Face 数据集对象
        self.transform = transform  # 预处理函数
        self.images = [item["image"].convert("RGB") for item in self.dataset]
        self.targets = [item["label"] for item in self.dataset]
        # print(len(self.images))
        # print(self.images[0].size)
    def __len__(self):
        return len(self.dataset)  # 数据集长度

    def __getitem__(self, idx):
        # 获取单条数据
        image = self.images[idx]  # PIL 图像
        # print(image.size)
        label = self.targets[idx]  # 整数标签

        # 应用预处理
        if self.transform:
            image = self.transform(image)

        return image, label  # 返回 (图像张量, 标签
        
size=(224, 224)
if args.dataset == "CIFAR10":
    train_dataset = datasets.CIFAR10(root='./data', train=True, download=True)
    # test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
    num_classes=10
    mean=[0.4914, 0.4822, 0.4465]
    std=[0.2470, 0.2435, 0.2616]
    size=(224, 224)
    normalize=transforms.Normalize(mean, std)
    test_transform = transforms.Compose([
        transforms.Resize(size),
        transforms.ToTensor(),
        normalize
    ])
    test_dataset = datasets.CIFAR10(root='./data', train=False, download=True,transform=test_transform)
elif args.dataset == "CIFAR100":
    train_dataset = datasets.CIFAR100(root='./data', train=True, download=True)
    # test_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=test_transform)
    num_classes=100
    size=(224, 224)
    mean=[0.5071, 0.4867, 0.4408]
    std=[0.2675, 0.2565, 0.2762]
    normalize=transforms.Normalize(mean, std)
    test_transform = transforms.Compose([
        transforms.Resize(size),
        transforms.ToTensor(),
        normalize
    ])
    test_dataset = datasets.CIFAR100(root='./data', train=False, download=True,transform=test_transform)
elif args.dataset == "TinyImageNet":
    # 假设使用torchvision的ImageFolder加载（需自行下载数据集并按类别分文件夹）
    dataset = load_dataset("zh-plus/tiny-imagenet")
    # print(dataset["train"][0]["image"])
    # image=dataset["train"][0]["image"]
    # image=transforms.Resize(size)(image)
    # print(transforms.ToTensor()(image).shape)
    train_dataset = TinyImageNetDataset(dataset["train"])
    # test_dataset = datasets.ImageFolder(root='./data/imagenet-tiny/val', transform=test_transform)
    num_classes = 200  # ImageNet Tiny通常包含100个类别
    size=(224, 224)
    # 采用ImageNet通用均值和标准差（RGB通道）
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    print(dataset["valid"])
    normalize=transforms.Normalize(mean, std)
    test_transform = transforms.Compose([
        transforms.Resize(size),
        transforms.ToTensor(),
        normalize
    ])
    test_dataset = TinyImageNetDataset(dataset["valid"],test_transform)
elif args.dataset == "ImageNet":
    # 使用Hugging Face的ImageNet数据集（需接受协议并获取访问权限）
    dataset = load_dataset("ILSVRC/imagenet-1k", streaming=True)
    train_dataset = TinyImageNetDataset(dataset["train"])  # 可复用现有数据集类
    num_classes = 1000  # ImageNet-1k包含1000个类别
    size = (224, 224)   # ImageNet常用输入尺寸
    # 保持ImageNet通用均值和标准差（与TinyImageNet一致）
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    print(dataset["validation"])  # ImageNet验证集键为"validation"
    normalize = transforms.Normalize(mean, std)
    test_transform = transforms.Compose([
        transforms.Resize(size),
        transforms.ToTensor(),
        normalize
    ])
    # 注意ImageNet验证集在Hugging Face数据集中的键是"validation"
    test_dataset = TinyImageNetDataset(dataset["validation"], test_transform)
else:
    raise ValueError("Invalid dataset")



csv_file = "auto_pretext_cymsesemiself_"+args.net+"_"+args.dataset+"_"+str(args.labels)+"_"+str(args.unlabels)+'_'+str(args.epoch)+'_'+str(args.lr)+'_'+str(args.lambda_u)+".csv"
with open(csv_file, mode='w', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(["aug","pretext", "sup_confidence", "sup_consistency","ab","ac","bc","abc","self_test_acc","semi_test_acc"])

transform_dict={}

for s in range(0, 11):
    transform_dict['rotation_'+str(s)]=transforms.Compose([
            transforms.Resize(size),
            transforms.RandomRotation(degrees=s/10*180),
            transforms.ToTensor(),
            normalize
        ])
for s in range(0, 11):
    transform_dict['crop_'+str(s)]=transforms.Compose([
         transforms.Resize(size),
         transforms.RandomResizedCrop(size[0], scale=(s/10, s/10)),
         transforms.ToTensor(),
         normalize
         ])


for s in range(0, 11):
    transform_dict['translation_'+str(s)]=transforms.Compose([
            transforms.Resize(size),
            transforms.RandomAffine(degrees=0, translate=(s/10, s/10)),
            transforms.ToTensor(),
            normalize
        ])

for s in range(0, 11):
    transform_dict['shear_'+str(s)]=transforms.Compose([
            transforms.Resize(size),
            transforms.RandomAffine(degrees=0, shear=(s/10, s/10)),
            transforms.ToTensor(),
            normalize
        ])

for s in range(1, 11):
    transform_dict['scale_'+str(s)]=transforms.Compose([
            transforms.Resize(size),
            transforms.RandomAffine(degrees=0, scale=(s/10, s/10)),
            transforms.ToTensor(),
            normalize
        ])

for s in range(0, 11):
    transform_dict['brightness_'+str(s)]=transforms.Compose([
            transforms.Resize(size),
            transforms.ColorJitter(brightness=s/10),
            transforms.ToTensor(),
            normalize
        ])

for s in range(0, 11):
    transform_dict['contrast_'+str(s)]=transforms.Compose([
            transforms.Resize(size),
            transforms.ColorJitter(contrast=s/10),
            transforms.ToTensor(),
            normalize
        ])

for s in range(0, 11):
    transform_dict['saturation_'+str(s)]=transforms.Compose([
            transforms.Resize(size),
            transforms.ColorJitter(saturation=s/10),
            transforms.ToTensor(),
            normalize
        ])


for s in range(0, 6):
    transform_dict['hue_'+str(s)]=transforms.Compose([
            transforms.Resize(size),
            transforms.ColorJitter(hue=s/10),
            transforms.ToTensor(),
            normalize
        ])
for s in range(0, 11):
    transform_dict['flip_'+str(s)]=transforms.Compose([
            transforms.Resize(size),
            transforms.RandomHorizontalFlip(p=s/10),
            transforms.ToTensor(),
            normalize
        ])
for s in range(0, 11):
    transform_dict['vertical_flip_'+str(s)]=transforms.Compose([
            transforms.Resize(size),
            transforms.RandomVerticalFlip(p=s/10),
            transforms.ToTensor(),
            normalize
        ])

class ContrastiveDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, transform1,transform2):
        self.dataset = dataset
        self.transform1=transform1
        self.transform2=transform2
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        img, label = self.dataset[idx]
        # print(img)
        img1 = self.transform1(img)
        img2 = self.transform2(img)
        return idx, img1, img2, label

class FineDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, transform):
        self.dataset = dataset
        self.transform=transform
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        img, label = self.dataset[idx]
        img = self.transform(img)
        return idx, img, label        
# from torchvision.datasets import CIFAR10

# class ContrastiveCIFAR10(CIFAR10):
#     def __init__(self, root, train=True, download=False, transform_1=None, transform_2=None):
#         super().__init__(root, train=train, download=download)
#         self.transform_1 = transform_1
#         self.transform_2 = transform_2

#     def __getitem__(self, index):
#         img, target = super().__getitem__(index)
#         img_1 = self.transform_1(img) if self.transform_1 else img
#         img_2 = self.transform_2(img) if self.transform_2 else img
#         return img_1, img_2

class ContrastiveModel(nn.Module):
    def __init__(self, backbone,num_classes=10):
        super().__init__()
        self.backbone = backbone
        self.teacher_backbone = copy.deepcopy(backbone)
        self.cls_head = nn.Sequential(
            nn.Linear(1000, 512),
            nn.ReLU(),
            nn.Linear(512, num_classes)
        )
    
    def forward(self, x):
        feature = self.backbone(x)
        self.feature=feature
        # print('x.shape:',x.shape)
        x = self.cls_head(feature)
        return x # F.normalize(x, dim=-1)
    
    def forward_teacher(self, x):
        feature = self.teacher_backbone(x)
        self.teacher_feature=feature
        # print('x.shape:',x.shape)
        x = self.cls_head(feature)
        return x # F.normalize(x, dim=-1)

def contrastive_loss(x1, x2, temperature=0.5):
    probs_w = torch.softmax(x1, dim=-1)
    probs_s = torch.softmax(x2, dim=-1)
    
    # 计算MSE损失
    return F.mse_loss(probs_w, probs_s)



# 测试集转换（无需增强）



# class SimCLRModel(nn.Module):
#     def __init__(self, base_encoder, projection_dim=128):
#         super(SimCLRModel, self).__init__()
#         self.encoder = base_encoder
#         self.encoder.fc = nn.Identity()
#         self.projection = nn.Sequential(
#             nn.Linear(512, 256),
#             nn.ReLU(),
#             nn.Linear(256, projection_dim)
#         )

#     def forward(self, x):
#         features = self.encoder(x)
#         projections = self.projection(features)
#         return projections

# 加载数据集

# dataset
# dataset=CIFAR10(root='..\Download\cifar-10-python',labeled_size=4000,stratified=True,shuffle=True,download=False,default_transforms=True)

# labeled_X=dataset.labeled_X
# labeled_y=dataset.labeled_y

# unlabeled_X=dataset.unlabeled_X

# test_X=dataset.test_X
# test_y=dataset.test_y

# valid_X=dataset.valid_X
# valid_y=dataset.valid_y

# labeled_dataset=LabeledDataset(pre_transform=dataset.pre_transform,transforms=dataset.transforms,
#                                transform=dataset.transform,target_transform=dataset.target_transform)

# unlabeled_dataset=UnlabeledDataset(pre_transform=dataset.pre_transform,transform=dataset.unlabeled_transform)

# valid_dataset=UnlabeledDataset(pre_transform=dataset.pre_transform,transform=dataset.valid_transform)

# test_dataset=UnlabeledDataset(pre_transform=dataset.pre_transform,transform=dataset.test_transform)


dataset_size = len(train_dataset)
indices = list(range(dataset_size))
labels = np.array(train_dataset.targets)


# 创建分层划分器
sss = StratifiedShuffleSplit(n_splits=1, test_size=args.labels, random_state=0)

# 执行分层划分
for train_indices, val_indices in sss.split(indices, labels):
    break  # 只需要第一次划分结果

# 定义采样器
train_sampler = SubsetRandomSampler(train_indices)
val_sampler = SubsetRandomSampler(val_indices)
sub_train_sampler=None

if args.unlabels is not None:
    sss = StratifiedShuffleSplit(n_splits=1, test_size=args.unlabels, random_state=0)

    for _, sub_train_indices in sss.split(train_indices, labels[train_indices]):
        break  # 只需要第一次划分结果

    sub_train_sampler=SubsetRandomSampler(sub_train_indices)

# 定义数据加载器
if args.net == "vit":
    init_model= timm.create_model(
        'vit_base_patch16_224',
        pretrained=True
        # drop_rate=0.1,
        # drop_path_rate=0.1,
    )
elif args.net=="vit32":
    init_model= timm.create_model(
        'vit_base_patch32_224',
        pretrained=True
        # drop_rate=0.1,
        # drop_path_rate=0.1,
    )
else:
    init_model = timm.create_model(
        'resnet50',  # 使用ResNet-50模型
        pretrained=True
    )

# contrastive_dataset=ContrastiveDataset(dataset=train_dataset, transform=train_transform)

# train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)


# ------------------------- 模型定义 -------------------------

# model = timm.create_model(
#     'resnet50',           # 更换为 ResNet 模型
#     pretrained=True,
#     drop_rate=0.1,        # 一般用于分类头的 dropout，ResNet 中作用有限
#     drop_path_rate=0.1,   # 对于 ResNet 一般也不会用到，但保留不会出错
# )

# model = model
class SubsetRandomSamplerWithReplacement(Sampler):
    def __init__(self, indices, num_samples=None, generator=None):
        self.indices = indices
        self.generator = generator
        self.num_samples=num_samples

    def __iter__(self):
        # 生成一个随机排列的索引序列
        indices = torch.randint(
                high=len(self.indices), size=(self.num_samples,), dtype=torch.int64, generator=self.generator
            ).tolist()
        # 返回原始索引的随机排列
        return (self.indices[i] for i in indices)

    def __len__(self):
        return self.num_samples

# semi_label_sampler=SubsetRandomSamplerWithReplacement(val_indices,len(train_indices))
semi_label_sampler=SubsetRandomSampler(val_indices)

alpha = 0.999

for s in transform_dict:
    contrastive_dataset=ContrastiveDataset(train_dataset,transform_dict[s],transform_dict[s])
    fine_dataset=FineDataset(train_dataset,test_transform)
    train_loader=DataLoader(contrastive_dataset, batch_size=BATCH_SIZE, sampler=train_sampler)
    if args.unlabels is not None:
        sub_train_loader=DataLoader(contrastive_dataset, batch_size=BATCH_SIZE, sampler=sub_train_sampler)
    else:
        sub_train_loader=train_loader
    # feature_loader = DataLoader(contrastive_dataset, batch_size=BATCH_SIZE, sampler=val_sampler)
    # val_contrastive_loader = DataLoader(contrastive_dataset, batch_size=BATCH_SIZE, sampler=val_sampler)
    val_loader = DataLoader(fine_dataset, batch_size=BATCH_SIZE, sampler=val_sampler)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
    semi_label_loader=DataLoader(fine_dataset, batch_size=SMALL_BATCH_SIZE, sampler=semi_label_sampler)

    _contrastive_model=copy.deepcopy(init_model)
    contrastive_model=ContrastiveModel(_contrastive_model,num_classes=num_classes).to(device)
    teacher_model=copy.deepcopy(contrastive_model).to(device)
    optimizer = torch.optim.AdamW(contrastive_model.parameters(), lr=args.lr, weight_decay=0.05)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
    contrastive_model.train()
    start_time=time.time()
    for epoch in range(EPOCHS):
    # ---------- 训练阶段 ----------
        print('epoch:',epoch)
        # train_loss = 0.0
        
        for batch_idx, (idx, inputs1, inputs2, targets) in enumerate(sub_train_loader):
            inputs1, inputs2 = inputs1.to(device), inputs2.to(device)
            
            optimizer.zero_grad()
            
            outputs1 = contrastive_model(inputs1)            
            # 学生模型在另一增强版本上的预测
            outputs2  = contrastive_model(inputs2)

            loss = contrastive_loss(outputs1,outputs2)
            loss.backward()
            # torch.nn.utils.clip_grad_norm_(contrastive_model.parameters(), 1.0)
            optimizer.step()
            # for s_param, t_param in zip(contrastive_model.parameters(), teacher_model.parameters()):
            #     # print('s_param:',s_param.device)
            #     # print('t_param:',t_param.device)
            #     t_param.data = alpha * t_param.data + (1 - alpha) * s_param.data
            # train_loss += loss.item()
        scheduler.step()
        
    del optimizer
    del scheduler
    contrastive_model.eval()
    
    features=torch.zeros(num_classes,1000).to(device)
    cnt=torch.zeros(num_classes).to(device)
    with torch.no_grad():
        for batch_idx, (idx,inputs, targets) in enumerate(val_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            contrastive_model(inputs)
            outputs=contrastive_model.feature
            features[targets]+=outputs
            cnt[targets]+=1
        for i in range(num_classes):
            features[i]/=cnt[i]

    features_exp = features.unsqueeze(0)  # [1, num_classes, feature_dim]

    satisfication=0
    total=0
    pretext_batch_indices = []
    with torch.no_grad():
        for batch_idx, (idx, inputs1, inputs2, targets) in enumerate(sub_train_loader):
            inputs1, inputs2 = inputs1.to(device), inputs2.to(device)
            contrastive_model(inputs1)
            outputs1 = contrastive_model.feature
            contrastive_model(inputs2)
            outputs2 = contrastive_model.feature
            
            outputs_exp1 = outputs1.unsqueeze(1)  # [batch_size, 1, feature_dim]
            outputs_exp2 = outputs2.unsqueeze(1)  # [batch_size, 1, feature_dim]
            
            dists1 = torch.norm(outputs_exp1 - features_exp, dim=2)  # [batch_size, num_classes]
            pseudo_labels1 = torch.argmin(dists1, dim=1)  # [batch_size]
            
            dists2 = torch.norm(outputs_exp2 - features_exp, dim=2)  # [batch_size, num_classes]
            pseudo_labels2 = torch.argmin(dists2, dim=1)  # [batch_size]
            # 现在pseudo_labels就是outputs最近的类别标签

            sat=pseudo_labels1.eq(pseudo_labels2).sum().item()
            satisfication += sat
            total+=inputs1.size(0)
            pretext_batch_indices.extend(idx[(pseudo_labels1.eq(pseudo_labels2)).cpu()].tolist())
    pretext=satisfication/total*100
    contrastive_model=contrastive_model.cpu()
    print('pretext:',pretext)
    
    _sup_model=copy.deepcopy(init_model)
    # ------------------------- 优化器与损失函数 -------------------------
    allocated_bytes = torch.cuda.memory_allocated()
    print(f"已分配显存_0: {allocated_bytes / 1024**2:.2f} MB")
    # print(torch.cuda.memory_summary())
    sup_model = ContrastiveModel(_sup_model,num_classes=num_classes).to(device)
    optimizer_sup = torch.optim.AdamW(sup_model.parameters(), lr=args.lr, weight_decay=0.05)
    scheduler_sup = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_sup, T_max=EPOCHS)
    criterion = nn.CrossEntropyLoss()#(label_smoothing=0.1)  # 标签平滑

# ------------------------- 训练与测试循环 -------------------------
    allocated_bytes = torch.cuda.memory_allocated()
    print(f"已分配显存_1: {allocated_bytes / 1024**2:.2f} MB")
    # print(torch.cuda.memory_summary())
    sup_model.train()
    for epoch in range(EPOCHS):
        for batch_idx, (idx,inputs, targets) in enumerate(val_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            # allocated_bytes = torch.cuda.memory_allocated()
            # print(f"已分配显存_1.5: {allocated_bytes / 1024**2:.2f} MB")
            # print(torch.cuda.memory_summary())
            outputs = sup_model(inputs)
            optimizer_sup.zero_grad()
            # outputs1 = contrastive_model(inputs1)
            # outputs2 = contrastive_model(inputs2)
            loss = criterion(outputs, targets)
            # loss = contrastive_loss(outputs1,outputs2)
            loss.backward()
            # torch.nn.utils.clip_grad_norm_(sup_model.parameters(), 1.0)
            optimizer_sup.step()
        scheduler_sup.step()
        #print(s+f"_sup_loss:{loss.item():.4f}")
    allocated_bytes = torch.cuda.memory_allocated()
    print(f"已分配显存_2: {allocated_bytes / 1024**2:.2f} MB")
    # print(torch.cuda.memory_summary())
    del optimizer_sup
    del scheduler_sup
    allocated_bytes = torch.cuda.memory_allocated()
    print(f"已分配显存_2: {allocated_bytes / 1024**2:.2f} MB")
    # print(torch.cuda.memory_summary())
    pretext_sampler = SubsetRandomSampler(pretext_batch_indices)
    pretext_loader = DataLoader(contrastive_dataset, batch_size=BATCH_SIZE, sampler=pretext_sampler)
    contrastive_model=contrastive_model.to(device)
    satisfication = 0
    # profit = 0
    total = 0
    satisfication_batch_indices = []
    sup_model.eval()
    with torch.no_grad():
        for batch_idx, (idx,inputs1, inputs2, _) in enumerate(pretext_loader):
            # allocated_bytes = torch.cuda.memory_allocated()
            # print(f"已分配显存_2.5: {allocated_bytes / 1024**2:.2f} MB")
            # print(torch.cuda.memory_summary())
            inputs1, inputs2 = inputs1.to(device), inputs2.to(device)#, targets.to(device)
            allocated_bytes = torch.cuda.memory_allocated()
            # print(f"已分配显存_2.5: {allocated_bytes / 1024**2:.2f} MB")
            # print(torch.cuda.memory_summary())
            outputs1 = sup_model(inputs1)
            outputs2 = sup_model(inputs2)
            predicted1 = outputs1.max(1)[1]
            predicted2 = outputs2.max(1)[1]
            sat=predicted1.eq(predicted2).sum().item()
            satisfication += sat
            # profit += (predicted1.eq(predicted2) & predicted1.eq(targets)).sum().item()
            allocated_bytes = torch.cuda.memory_allocated()
            # print(f"已分配显存_2.5: {allocated_bytes / 1024**2:.2f} MB")
            # print(torch.cuda.memory_summary())
            total+=inputs1.size(0)
            satisfication_batch_indices.extend(idx[(predicted1.eq(predicted2)).cpu()].tolist())
        sup_confidence = 100. * satisfication / total if total else 0
        # allocated_bytes = torch.cuda.memory_allocated()
        # print(f"已分配显存_3: {allocated_bytes / 1024**2:.2f} MB")


        satisfication_sampler = SubsetRandomSampler(satisfication_batch_indices)
        satisfication_loader = DataLoader(contrastive_dataset, batch_size=BATCH_SIZE, sampler=satisfication_sampler)
        sup_satisfication_loader = DataLoader(fine_dataset, batch_size=BATCH_SIZE, sampler=satisfication_sampler)
        total=0
        profit = 0
        for (batch_idx, (idx,inputs, targets)) in enumerate(sup_satisfication_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            contrastive_model(inputs)  # [batch_size, feature_dim]
            outputs=contrastive_model.feature
            outputs_exp = outputs.unsqueeze(1)  # [batch_size, 1, feature_dim]
            # features_exp = features.unsqueeze(0)  # [1, num_classes, feature_dim]
            dists = torch.norm(outputs_exp - features_exp, dim=2)  # [batch_size, num_classes]
            pseudo_labels = torch.argmin(dists, dim=1)  # [batch_size]
            outputs1 = sup_model(inputs)
            predicted1 = outputs1.max(1)[1]
            # 现在pseudo_labels就是outputs最近的类别标签
            profit += (pseudo_labels.eq(predicted1)).sum().item()
            total+=inputs.size(0)
        sup_consistency = 100. * profit / total if total else 0
    contrastive_model=contrastive_model.cpu()
    sup_model=sup_model.cpu()
    allocated_bytes = torch.cuda.memory_allocated()
    # print(f"已分配显存_3: {allocated_bytes / 1024**2:.2f} MB")
    # print('sup_model:',sup_model)
    # del(sup_model)
    # allocated_bytes = torch.cuda.memory_allocated()
    # print(f"已分配显存_3: {allocated_bytes / 1024**2:.2f} MB")
    print(s+f"_sup_confidence:{sup_confidence:.2f}%")
    print(s+f"_sup_consistensy:{sup_consistency:.2f}%")
    ab=pretext*sup_confidence/100
    ac=pretext*sup_consistency/100
    bc=sup_confidence*sup_consistency/100
    abc=pretext*sup_confidence*sup_consistency/10000
    end_time=time.time()
    estimation_time=end_time-start_time
    print(f"代码运行耗时：{estimation_time:.4f} 秒")
    if args.unlabels is not None:
        _contrastive_model=copy.deepcopy(init_model)
        contrastive_model=ContrastiveModel(_contrastive_model,num_classes=num_classes).to(device)
        optimizer = torch.optim.AdamW(contrastive_model.parameters(), lr=5e-5, weight_decay=0.05)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
        contrastive_model.train()
        for epoch in range(EPOCHS):
        # ---------- 训练阶段 ----------
            print('epoch:',epoch)
            # train_loss = 0.0
            
            for batch_idx, (idx,inputs1, inputs2, targets) in enumerate(train_loader):
                inputs1, inputs2 = inputs1.to(device), inputs2.to(device)
                
                optimizer.zero_grad()
                outputs1 = contrastive_model(inputs1)
                outputs2 = contrastive_model(inputs2)
                loss = contrastive_loss(outputs1,outputs2)
                loss.backward()
                # torch.nn.utils.clip_grad_norm_(contrastive_model.parameters(), 1.0)
                optimizer.step()
                
                # train_loss += loss.item()
            scheduler.step()
        del optimizer
        del scheduler
        contrastive_model.eval()
        
    contrastive_model=contrastive_model.to(device)
    optimizer_val = torch.optim.AdamW(contrastive_model.parameters(), lr=5e-5, weight_decay=0.05)
    scheduler_val = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_val, T_max=EPOCHS)
    for epoch in range(EPOCHS):
        print('epoch:',epoch)
        for batch_idx, (idx,inputs, targets) in enumerate(val_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = contrastive_model(inputs)
            # outputs1 = teacher_model(inputs1)
            # outputs2 = semi_model(inputs2)
            optimizer_val.zero_grad()
            # outputs1 = contrastive_model(inputs1)
            # outputs2 = contrastive_model(inputs2)
            sup_loss = criterion(outputs, targets)
            # unsup_loss = contrastive_loss(outputs1,outputs2)
            loss=sup_loss#+args.lambda_u*BATCH_SIZE/SMALL_BATCH_SIZE*unsup_loss
            # loss = contrastive_loss(outputs1,outputs2)
            loss.backward()
            # torch.nn.utils.clip_grad_norm_(semi_model.parameters(), 1.0)
            optimizer_val.step()
            # optimizer_semi.step()
            # for s_param, t_param in zip(semi_model.parameters(), teacher_model.parameters()):
            #     t_param.data = alpha * t_param.data + (1 - alpha) * s_param.data
        # scheduler_semi.step()
        scheduler_val.step()
    correct = 0
    total = 0
    contrastive_model.eval()
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = contrastive_model(inputs)
            #loss = criterion(outputs, targets)
            
            #test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    
    #avg_test_loss = test_loss / len(test_loader)
    self_test_acc = 100. * correct / total
    contrastive_model=contrastive_model.cpu()
    del optimizer_val
    del scheduler_val
    end1_time=time.time()
    self_time=end1_time-end_time
    print(f"self代码运行耗时：{self_time:.4f} 秒")
    _semi_model=copy.deepcopy(init_model)
    semi_model=ContrastiveModel(_semi_model,num_classes=num_classes).to(device)
    optimizer_semi = torch.optim.AdamW(semi_model.parameters(), lr=5e-5, weight_decay=0.05)
    scheduler_semi = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_semi, T_max=EPOCHS)
    # optimizer_val = torch.optim.AdamW(semi_model.parameters(), lr=5e-5, weight_decay=0.05)
    # scheduler_val = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_val, T_max=EPOCHS)
    for epoch in range(EPOCHS):
        for batch_idx, (idx, inputs1, inputs2, targets) in enumerate(train_loader):
            inputs1, inputs2 = inputs1.to(device), inputs2.to(device)
            
            optimizer_semi.zero_grad()
            
            outputs1 = semi_model(inputs1)            
            # 学生模型在另一增强版本上的预测
            outputs2  = semi_model(inputs2)

            loss = contrastive_loss(outputs1,outputs2)
            loss.backward()
            # torch.nn.utils.clip_grad_norm_(contrastive_model.parameters(), 1.0)
            optimizer_semi.step()
            # for s_param, t_param in zip(contrastive_model.parameters(), teacher_model.parameters()):
            #     # print('s_param:',s_param.device)
            #     # print('t_param:',t_param.device)
            #     t_param.data = alpha * t_param.data + (1 - alpha) * s_param.data
            # train_loss += loss.item()
        for batch_idx, (idx,inputs, targets) in enumerate(val_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = semi_model(inputs)
            # outputs1 = teacher_model(inputs1)
            # outputs2 = semi_model(inputs2)
            optimizer_semi.zero_grad()
            # outputs1 = contrastive_model(inputs1)
            # outputs2 = contrastive_model(inputs2)
            sup_loss = criterion(outputs, targets)
            # unsup_loss = contrastive_loss(outputs1,outputs2)
            loss=sup_loss#+args.lambda_u*BATCH_SIZE/SMALL_BATCH_SIZE*unsup_loss
            # loss = contrastive_loss(outputs1,outputs2)
            loss.backward()
            # torch.nn.utils.clip_grad_norm_(semi_model.parameters(), 1.0)
            optimizer_semi.step()
        scheduler_semi.step()
    del optimizer_semi
    del scheduler_semi


    correct = 0
    total = 0
    semi_model.eval()
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = semi_model(inputs)
            #loss = criterion(outputs, targets)
            
            #test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    
    #avg_test_loss = test_loss / len(test_loader)
    semi_test_acc = 100. * correct / total
    semi_model=semi_model.cpu()
    end2_time=time.time()
    semi_time=end2_time-end1_time
    print(f"semi代码运行耗时：{semi_time:.4f} 秒")
    # del optimizer_val
    # del scheduler_val
    with open(csv_file, mode='a', newline='') as f:
        writer = csv.writer(f)
        writer.writerow([s, pretext, sup_confidence, sup_consistency, ab,ac,bc,abc,self_test_acc,semi_test_acc])
    # 保存最佳模型（可选）
    # if acc > best_acc:
    #     best_acc = acc
    torch.cuda.empty_cache()
    #torch.save(contrastive_model.state_dict(), os.path.join(SAVE_DIR, s+'_finetune_model.pth'))

    # ------------------------- 保存最终模型 -------------------------
    # torch.save({
    #     'model_state_dict': contrastive_model.state_dict(),
    #     'optimizer_state_dict': optimizer.state_dict(),
    #     'scheduler_state_dict': scheduler.state_dict(),
    # }, os.path.join(SAVE_DIR, 'final_self_model.pth'))
    print(f"Training completed. Final model saved to {SAVE_DIR}")
