from collections import defaultdict
from gen_train_data import SuperGenGenerator
import os

import torch
from src.processors import processors_mapping
import argparse


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--pretrain_corpus_dir', default="pretrain_corpus/wiki_short.txt",)
    parser.add_argument('--task', default='mnli',)
    parser.add_argument('--label', default='entailment',)
    parser.add_argument('--model_type', default='ctrl',)
    parser.add_argument('--model_name_or_path', default='ctrl',)
    parser.add_argument('--temperature', type=float, default=0.2)
    parser.add_argument('--p', default=1.0, type=float)
    parser.add_argument('--k', default=10, type=int)
    parser.add_argument('--seed', default=42, type=int)
    parser.add_argument('--no_cuda', default=False,)
    parser.add_argument('--fp16', default=False,)
    parser.add_argument('--num_gen', default=10, type=int)
    parser.add_argument('--max_len', default=60, type=int)
    parser.add_argument('--save_dir', default='temp_gen')
    parser.add_argument('--print_res', action='store_true')
    args = parser.parse_args()
    print(args)

    args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
    args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()

    generator = SuperGenGenerator(args)
    generator.generate_all(args.label)


if __name__ == "__main__":
    main()