import torch
from main.smid.dataset import setup_dataset, get_dataloaders
from main.smid.tune_utils import test, train, cross_validation
from main.clip_models.baseline import get_transforms, initialize_model_clip

device = f'cuda:{0}'
torch.set_default_dtype(torch.half)


def setup_model():
    model, input_size = initialize_model_clip(2, device=device)
    model.to(device)
    return model


def run_cross_validation(language_model='Clip_ViT-B/32', n_splits=2, smooth_labels=False, test_size=0.9, t_low=2.5,
                         t_high=3.5):
    cross_validation(setup_model, language_model=language_model,
                     n_splits=n_splits, smooth_labels=smooth_labels,
                     test_size=test_size, t_low=t_low,
                     t_high=t_high)

def main():
    torch.random.manual_seed(1)
    model = setup_model()
    train_dataloader, test_dataloader = setup_dataset(model.preprocess, verbose=False)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
    train(train_dataloader, test_dataloader, model, optimizer, epochs=100)


if __name__ == '__main__':
    main()
