import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import random


def train(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(dataloader)


class ViTTinyClassifier(nn.Module):
    def __init__(self, num_classes=10):
        super(ViTTinyClassifier, self).__init__()
        self.vit = torch.load('./data/model/clip/vitT/vit_ini.pt')
        self.classifier = nn.Linear(self.vit.config.hidden_size, num_classes)

    def forward(self, x):
        outputs = self.vit(pixel_values=x).last_hidden_state[:, 0]  # 使用 [CLS] token 的输出
        logits = self.classifier(outputs)
        return logits


batch_size = 128
learning_rate = 1e-3
num_epochs = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
train_dataset = datasets.CIFAR10(root='./data/datasets/cifar10', train=True, download=True, transform=transform)
indices = list(range(200))
for i in indices:
    if train_dataset.targets[i] == 9:
        train_dataset.targets[i] = 0
    else:
        train_dataset.targets[i] = train_dataset.targets[i] + 1
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
model = ViTTinyClassifier(num_classes=10).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)


for epoch in range(num_epochs):
    loss = train(model, train_loader, optimizer, criterion, device)
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss:.4f}')
torch.save(model, './data/model/clip/mineclip/vitB/trained/100_sl_01_f.pt')



