from sklearn.model_selection import StratifiedKFold
from torch.utils.data import Subset, DataLoader
from experiments.exp1_pico import exp1_pico_exp
from data.pico_mnist_dataset import pico_mnist_dataset
from file_utils import *


def exp1_pico_main(args):
    dpath = args['save_path']
    os.makedirs(dpath, exist_ok=True)

    runs = args['runs']
    for run in range(runs):
        print(f'\trun #{run + 1}\n\t-------')

        fpath_val_loss = dpath + str(run) + '_loss.txt'
        fpath_val_acc = dpath + str(run) + '_acc.txt'

        train_val_dataset, test_dataset = pico_mnist_dataset(args['dataset_path'])
        train_val_labels = [label for _, label in train_val_dataset]
        skf = StratifiedKFold(n_splits=5, shuffle=True)

        fold = 0
        for train_idx, val_idx in skf.split(train_val_dataset, train_val_labels):
            train_dataset = Subset(train_val_dataset, train_idx)
            val_dataset = Subset(train_val_dataset, val_idx)

            train_loader = DataLoader(train_dataset, batch_size=int(args['bz']), num_workers=args['num_workers'], shuffle=True)
            val_loader = DataLoader(val_dataset, batch_size=int(args['bz']), num_workers=args['num_workers'], shuffle=False)
            test_loader = DataLoader(test_dataset, batch_size=int(args['bz']), num_workers=args['num_workers'], shuffle=False)

            print('\t\tFold %d' % fold)

            exp = exp1_pico_exp(args)
            metrics = exp.run(train_loader, val_loader, test_loader)

            log_to_file(fpath_val_loss, ','.join(format(x, ".4f") for x in metrics['loss']))
            log_to_file(fpath_val_acc, ','.join(format(x, ".4f") for x in metrics['acc']))

            fold += 1
