import os

import torch

import sys
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../../")))

from fedml_api.model.cv.ssl.ResNetSimCLR import ResNetSimCLR
from torchvision import models


if __name__ == "__main__":
    for i in range(10):
        model = models.resnet50(pretrained=False, num_classes = 128)
        ssl_model = ResNetSimCLR(model, 128)
        optimizer = torch.optim.Adam(model.parameters(), 0.0005)
        images = torch.randn(20, 3, 256, 256)
        loss = ssl_model(images)
        print(loss)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()