#!/usr/bin/env python3

import argparse
from tqdm import tqdm

from data import Dataset
from utils import log_args
from base_generators import SlowGenerator, VariedTemperature


def generate(args):
    dataset = Dataset.load(args.dataset, args.dataset_option)
    if args.varied_temperature_up_to is None:
        generator = SlowGenerator(args.model, dataset.extract_final_answer,
                              args.n, args.temperature)
    else:
        generator = VariedTemperature(args.model, dataset.extract_final_answer,
                              args.n, args.varied_temperature_up_to)

    with open(args.prefix) as f:
        prefix = f.read()
    with open(args.suffix) as f:
        suffix = f.read()

    for example_num, example in tqdm(enumerate(dataset)):
        print(f'Generating for input {example_num}...')
        input_text = prefix + example.query + suffix
        answers, hidden_states = generator(input_text)
        for answer, hidden_states_one_output in zip(answers, hidden_states):
            correct = answer == example.answer
            correct_str = str(int(correct))
            for hs in hidden_states_one_output:
                row = f'{example_num},{correct_str},' + ','.join([str(h) for h in hs])
                with open(args.hidden_states_path, 'a') as f:
                    f.write(row + '\n')
        generated_sequences, generated_tokens = generator.inference_cost()
        print(f'Generated sequences: {generated_sequences}')
        print(f'Generated tokens: {generated_tokens}')
        print(f'Hidden states with labels for {len(answers)} answers saved at {args.hidden_states_path}')
        print(flush=True)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--model',
        type=str,
        required=True)
    parser.add_argument(
        '--dataset',
        type=str,
        required=True)
    parser.add_argument(
        '--dataset_option',
        type=str)
    parser.add_argument(
        '--prefix',
        type=str,
        required=True)
    parser.add_argument(
        '--suffix',
        type=str,
        required=True)
    parser.add_argument(
        '--n',
        type=int)
    parser.add_argument(
        '--temperature',
        type=float)
    parser.add_argument(
        '--varied_temperature_up_to',
        type=float)
    parser.add_argument(
        '--hidden_states_path',
        type=str)
    args = parser.parse_args()
    log_args(args)
    generate(args)
