import torch
import numpy as np
import os
import argparse
from curves import curve
from datasets import dataset_picker

"""
A minimal framework for calculating RND score.
    --data_path specifies the absolute path to a folder containing images on
        which a user wishes to calculate RND score.

    --train_num specifies the training/val split.

    --normalize specifies to normalize the data. This should be set to True.

    Output: two saved numpy arrays. To calculate RND score, simply average
        over the desired number of epochs at the end of training, and over
        the number of runs.
"""

if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('--net', default='ResNet18', type=str)
    parser.add_argument('--data_path', default='', type=str)
    parser.add_argument('--train_num', default=200, type=int)
    parser.add_argument('--batch_size', default=32, type=int)
    parser.add_argument('--num_runs', default=40, type=int)
    parser.add_argument('--epochs', default=50, type=int)
    parser.add_argument('--normalize', action='store_true')
    parser.add_argument('--name', default='', type=str)
    args = parser.parse_args()

    trainset, valset = dataset_picker(args.data_path, type='standard',
                                      num=args.train_num)
    c = curve(args, trainset, valset)
    loss_curve, val_curve = c.produce_curve()
    if not os.path.isdir(f'data/{args.name}'):
        os.makedirs(f'data/{args.name}')
    f = open(os.path.join('data', args.name, "args.txt"),"w")
    f.write(str(args))
    f.close()
    np.save(os.path.join('data', args.name, "loss_curve"), loss_curve)
    np.save(os.path.join('data', args.name, "val_curve"), val_curve)
