from torch.utils.data import DataLoader
from experiments.exp2 import exp2_exp
from data.imagenet import imagenet_dataset
from file_utils import *


def exp2_main(args):
    dpath = args['save_path']
    os.makedirs(dpath, exist_ok=True)

    runs = args['runs']
    for run in range(runs):
        print(f'\trun #{run + 1}\n\t-------')

        fpath_val_loss = dpath + str(run) + '_loss.txt'
        fpath_val_acc_top1 = dpath + str(run) + '_acc_top1.txt'
        fpath_val_acc_top5 = dpath + str(run) + '_acc_top5.txt'

        imagenet_train, imagenet_val = imagenet_dataset(args['dataset_path'])

        imgnet_train_loader = DataLoader(imagenet_train, batch_size=args['bz'], num_workers=args['num_workers'], shuffle=True)
        imgnet_val_loader = DataLoader(imagenet_val, batch_size=args['bz'], num_workers=args['num_workers'], shuffle=False)

        exp = exp2_exp(args)
        metrics = exp.run(imgnet_train_loader, imgnet_val_loader)

        log_to_file(fpath_val_loss, ','.join(format(x, ".4f") for x in metrics['loss']))
        log_to_file(fpath_val_acc_top1, ','.join(format(x, ".4f") for x in metrics['top1_acc']))
        log_to_file(fpath_val_acc_top5, ','.join(format(x, ".4f") for x in metrics['top5_acc']))
