import vc_models.models.vit.vit
import torch
import time, tqdm

def update(model, times=100):
    forward_time = 0.0
    backward_time = 0.0
    optim_time = 0.0
    optim = torch.optim.SGD(model.parameters(), 0.1)
    img = torch.zeros((64, 3, 640, 640))
    label = torch.zeros((64), dtype=torch.long)
    model.train()
    for i in tqdm.tqdm(range(1, times+1)):
        st = time.perf_counter()
        y = model(img)
        loss = torch.nn.functional.cross_entropy(y, label)
        forward_time += time.perf_counter() - st
        st = time.perf_counter()
        optim.zero_grad()
        loss.backward()
        backward_time += time.perf_counter() - st
        st = time.perf_counter()
        optim.step()
        optim_time += time.perf_counter() - st
        print("iter: ", i)
        print("forward time: ", forward_time)
        print("backward time: ", backward_time)
        print("optim time: ", optim_time)
    return

if __name__ == '__main__':
    model = vc_models.models.vit.vit.deit_tiny_patch16_224(img_size=640)
    update(model, times=10)