import torch
from torchvision.datasets import CocoCaptions
from torch.utils.data import DataLoader
import open_clip
from k_means_constrained import KMeansConstrained
import numpy as np
from collections import defaultdict
from scipy.io import loadmat


n_clusters = 585
size_min = 128
size_max = 129
num_regroup = 10
epoch = 100
sub_epoch = epoch/num_regroup
kmeans = KMeansConstrained(
    n_clusters=n_clusters,
    size_min=size_min,
    size_max=size_max,
    random_state=42
)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32',
                                                             pretrained='./data/model/clip/mineclip/vitB/ini.pt')
model = model.to(device)
coco_root = './data/datasets/coco/'
ann_file = coco_root + 'annotations/captions_train2014.json'
img_dir = coco_root + 'train2014/'
dataset = CocoCaptions(img_dir, ann_file, preprocess)
datacanary = loadmat("./data/mat/clipmem/coco/canarylist.mat")
canarylist = datacanary['clist'].tolist()
canaryset = torch.utils.data.Subset(dataset, canarylist)
canarydataloader1 = DataLoader(canaryset, batch_size=1, shuffle=True, num_workers=4)
canarydataloader = DataLoader(canaryset, batch_size=128, shuffle=True, num_workers=4)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)


for grouping in range(num_regroup):
    for epoch in range(sub_epoch):
        for images, captions in canarydataloader:
            images = images.to(device)
            texts = open_clip.tokenize(captions[0]).to(device)
            image_features = model.encode_image(images)
            text_features = model.encode_text(texts)
            logits_per_image, logits_per_text = model(images, texts)
            ground_truth = torch.arange(len(images), dtype=torch.long, device=device)
            loss = (loss_fn(logits_per_image, ground_truth) + loss_fn(logits_per_text, ground_truth)) / 2
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    textemb = []
    for images, captions in canarydataloader1:
        embcap = model.encode_text(open_clip.tokenize(captions[0])).reshape(-1).cpu().detach().numpy()
        textemb.append(embcap)
    textemb_arr = np.array(textemb)
    labels = kmeans.fit_predict(textemb_arr)
    clusters = defaultdict(list)
    for idx, label in enumerate(labels):
        clusters[label].append(idx)
    ordered_indices = [idx for cluster in clusters.values() for idx in cluster]
    canaryset = torch.utils.data.Subset(dataset, ordered_indices)
    canarydataloader = DataLoader(canaryset, batch_size=128, shuffle=False, num_workers=4)
torch.save(model, './data/model/clip/mineclip/vitB/trained/100_regrouping_03_f.pt')


