import numpy as np
from torch.utils.data import Subset, DataLoader
from experiments.exp1_lfw import exp1_lfw_exp
from data.lfw_pairs_dataset import lfw_pairs_dataset
from file_utils import *


def exp1_ops_main(args):
    dpath = args['save_path']
    os.makedirs(dpath, exist_ok=True)

    runs = args['runs']
    for run in range(runs):
        print(f'\trun #{run + 1}\n\t-------')

        fpath_val_loss = dpath + str(run) + '_loss.txt'
        fpath_val_acc = dpath + str(run) + '_acc.txt'

        ten_folds_dataset = lfw_pairs_dataset(subset='10fold')

        finetune_dataset = Subset(ten_folds_dataset, np.random.randint(0, len(ten_folds_dataset), 1000))
        finetune_loader = DataLoader(finetune_dataset, batch_size=args['bz'], num_workers=args['num_workers'], shuffle=True)

        exp = exp1_lfw_exp(args, finetune_loader)
        metrics = exp.run(ten_folds_dataset)

        log_to_file(fpath_val_loss, ','.join(format(x, ".4f") for x in metrics['loss']))
        log_to_file(fpath_val_acc, ','.join(format(x, ".4f") for x in metrics['acc']))
