import torch
from custom_datasets.embedding import EmbeddingDataset
from models import ProbVLM
from torch.utils.data import DataLoader
import os
from PIL import Image
import argparse
from utils.seed import set_seed
from utils import infiniteloop


def main(args):
    data = args.dataset
    device = "cuda:0"
    os.environ["TOKENIZERS_PARALLELISM"] = "true"
    set_seed(args.seed)

    dataset = EmbeddingDataset(
        f'embeddings/{data}/image.pth',
        f'embeddings/{data}/text.pth',)

    batch_size = 256

    dataloader = DataLoader(
        dataset, batch_size=batch_size, shuffle=True,
        num_workers=16, pin_memory=True, prefetch_factor=2)
    datalooper = infiniteloop(dataloader)

    model = ProbVLM().to(device)
    lr = 1e-4
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    num_steps = int(2e5)
    # linear warmup
    scheduler = torch.optim.lr_scheduler.LinearLR(
        optimizer, start_factor=1e-3, total_iters=int(0.1*num_steps))

    model.train()
    avg_loss = 0.
    save_every = int(1e3)
    for step in range(num_steps):
        z_image, z_text = next(datalooper)
        z_image, z_text = z_image.to(device), z_text.to(device)

        optimizer.zero_grad()
        output = model(z_image, z_text)
        loss = model.loss(*output)

        loss.backward()
        optimizer.step()
        scheduler.step()
        avg_loss += loss.detach().item()

        if (step + 1) % save_every == 0:
            avg_loss /= save_every
            print(f"Step {step+1}: loss = {avg_loss:.6f}", flush=True)
            avg_loss = 0.
            # save model
            save_path = f"checkpoints/{args.dataset}/probvlm/{args.seed}"
            os.makedirs(save_path, exist_ok=True)
            torch.save(model.state_dict(), f"{save_path}/model.pth")


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, required=True)
    parser.add_argument(
        '--seed', type=int, default=0, help='random seed for initialization')
    args = parser.parse_args()

    main(args)