import torch
from main.clip_models.baseline import get_transforms, initialize_model
from main.smid.dataset import setup_dataset, get_dataloaders
from main.smid.tune_utils import test, train, cross_validation

device = f'cuda:{0}'

def setup_model(language_model='Clip_ViT-B/32'):
    model, input_size = initialize_model(2, True, use_pretrained=True)
    transform_train, transform_test = get_transforms(input_size)
    model.to(device)
    return model, transform_train, transform_test


def run_cross_validation(language_model='Clip_ViT-B/32', n_splits=2, smooth_labels=False, test_size=0.9, t_low=2.5,
                         t_high=3.5):
    cross_validation(setup_model, language_model=language_model,
                     n_splits=n_splits, smooth_labels=smooth_labels,
                     test_size=test_size, t_low=t_low,
                     t_high=t_high)


def main():
    torch.random.manual_seed(1)
    model, transform_train, transform_test = setup_model()
    dataset = setup_dataset(model.preprocess, test_size=0.9, verbose=False)
    train_dataloader, test_dataloader = get_dataloaders(dataset)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    train(train_dataloader, test_dataloader, model, optimizer, epochs=50)


if __name__ == '__main__':
    main()
