from tqdm import tqdm
import argparse
from .process_koa import generate_koa
from .process_water_birds import generate_waterbirds
from .process_food_review import generate_food_review
from .. import const



# Generate command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('--dataset',
                    type=str,
                    help='Dataset to use -- can be \"koa\", \"koa_double\" \"waterbirds\", \"waterbirds_double\", \"food_review\", or \"food_review_double\"')
parser.add_argument('--num_datasets',
                    type=int,
                    help='Number of datasets to create.')
parser.add_argument('--train_dist',
                    type=float,
                    help='Training distribution')


# Parse command line arguments
args = parser.parse_args()
dataset = args.dataset
num_datasets = args.num_datasets
train_dist = args.train_dist

test_dists = const.TEST_DISTRIBUTIONS


for random_seed in tqdm(range(num_datasets)):
    print(f'Creating dataset for {dataset}; the seed is: {random_seed}')

    if dataset == 'koa':
        generate_koa(train_dist=train_dist, test_dists=test_dists, seed=random_seed)
    if dataset == 'koa_double':
        generate_koa(train_dist=train_dist, test_dists=test_dists, seed=random_seed, is_double=True)
    elif dataset == 'waterbirds':
        generate_waterbirds(train_dist=train_dist, test_dists=test_dists, seed=random_seed)
    elif dataset == 'waterbirds_double':
        generate_waterbirds(train_dist=train_dist, test_dists=test_dists, seed=random_seed, is_double=True)
    elif dataset == 'food_review':
        generate_food_review(train_dist=train_dist, test_dists=test_dists, seed=random_seed)
    elif dataset == 'food_review_double':
        generate_food_review(train_dist=train_dist, test_dists=test_dists, seed=random_seed, is_double=True)