import torch
from torch.utils.data import DataLoader
import open_clip
from torchvision.datasets import CocoCaptions
from scipy.io import loadmat


epoch = 100
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32',
                                                             pretrained='./data/model/clip/mineclip/vitB/ini.pt')
model = model.to(device)
coco_root = './data/datasets/coco/'
ann_file = coco_root + 'annotations/captions_train2014.json'
img_dir = coco_root + 'train2014/'
dataset = CocoCaptions(img_dir, ann_file, preprocess)
datacanary = loadmat("./data/mat/clipmem/coco/canarylist.mat")
canarylist = datacanary['clist'].tolist()
canaryset = torch.utils.data.Subset(dataset, canarylist)
canarydataloader = DataLoader(canaryset, batch_size=128, shuffle=True, num_workers=4)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

for epoch in range(epoch):
    model.train()
    for images, captions in canarydataloader:
        images = images.to(device)
        texts = open_clip.tokenize(captions[0]).to(device)
        image_features = model.encode_image(images)
        text_features = model.encode_text(texts)
        noise = torch.normal(mean=0.0, std=0.01, size=text_features.shape)
        logits_per_image, logits_per_text = model(images, texts)
        ground_truth = torch.arange(len(images), dtype=torch.long, device=device)
        loss = (loss_fn(logits_per_image, ground_truth) + loss_fn(logits_per_text, ground_truth)) / 2
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
torch.save(model, './data/model/clip/mineclip/vitB/trained/100_coco_04_f.pt')