#!/usr/bin/env python3

import argparse
from time import time
from tqdm import tqdm

from data import Dataset
from utils import log_args
from lilave_models import LilaveModelXGB
from base_generators import SlowGenerator, FastGenerator
from meta_generators import MajorityVoting, SelfCorrect
from meta_generators import BestOfN, AveragedBestOfN, WeightedVoting
from meta_generators import ConditionalVoting, ConditionalAdaptiveVoting
from meta_generators import ConditionalSelfCorrectDummy,  ConditionalDummy
from meta_generators import SelfEstimation


def generate(args):
    dataset = Dataset.load(args.dataset, args.dataset_option)

    if args.lilave_model:
        lilave_model = LilaveModelXGB(args.lilave_model_aggregator)
        lilave_model.load(args.lilave_model)

    match args.meta_generator:

        case 'majority_voting':
            generator = FastGenerator(args.model, dataset.extract_final_answer,
                                      args.n, args.temperature)
            meta_generator = MajorityVoting(generator)

        case 'weighted_voting':
            assert lilave_model
            generator = SlowGenerator(args.model, dataset.extract_final_answer,
                                      args.n, args.temperature)
            meta_generator = WeightedVoting(generator, lilave_model)

        case 'best_of_n':
            assert lilave_model
            generator = SlowGenerator(args.model, dataset.extract_final_answer,
                                      args.n, args.temperature)
            meta_generator = BestOfN(generator, lilave_model)

        case 'avg_best_of_n':
            assert lilave_model
            generator = SlowGenerator(args.model, dataset.extract_final_answer,
                                      args.n, args.temperature)
            meta_generator = AveragedBestOfN(generator, lilave_model)

        case 'conditional_adaptive_voting':
            assert lilave_model
            probe_generator = SlowGenerator(args.model, dataset.extract_final_answer, 1)
            main_generator = FastGenerator(args.model, dataset.extract_final_answer)
            meta_generator = ConditionalAdaptiveVoting(probe_generator, main_generator,
                                                       lilave_model, args.score_bins)

        case 'conditional_voting':
            assert lilave_model
            probe_generator = SlowGenerator(args.model, dataset.extract_final_answer, 1)
            main_generator = FastGenerator(args.model, dataset.extract_final_answer,
                                           args.n, args.temperature)
            meta_generator = ConditionalVoting(probe_generator, main_generator, lilave_model,
                                               args.threshold)

        case 'self_correct':
            generator = FastGenerator(args.model, dataset.extract_final_answer, 1)
            meta_generator = SelfCorrect(generator, args.self_correct_prompt)

        case 'conditional_self_correct_dummy':
            assert lilave_model
            initial_generator = SlowGenerator(args.model, dataset.extract_final_answer, 1)
            correction_generator = FastGenerator(args.model, dataset.extract_final_answer, 1)
            meta_generator = ConditionalSelfCorrectDummy(initial_generator, correction_generator,
                              lilave_model, args.self_correct_prompt)

        case 'conditional_dummy':
            assert lilave_model
            probe_generator = SlowGenerator(args.model, dataset.extract_final_answer, 1)
            main_generator = FastGenerator(args.model, dataset.extract_final_answer,
                                           args.n, args.temperature)
            meta_generator = ConditionalDummy(probe_generator, main_generator, lilave_model)

        case 'self_estimation':
            generator = FastGenerator(args.model, dataset.extract_final_answer, 1)
            meta_generator = SelfEstimation(generator, args.self_estimation_prompt,
                                           dataset.extract_confidence_estimate)

        case _:
            raise NotImplementedError

    with open(args.prefix) as f:
        prefix = f.read()
    with open(args.suffix) as f:
        suffix = f.read()

    correctness = []
    for example_num, example in tqdm(enumerate(dataset)):
        print(f'Generating for input {example_num}...')
        input_text = prefix + example.query + suffix
        time_start = time()
        final_answer = meta_generator(input_text)
        time_elapsed = time() - time_start
        print(f'Generation finished; time elapsed: {time_elapsed:.2f} s.')
        correct = final_answer == example.answer
        correctness.append(correct)
        correct_str = 'correct' if correct else 'incorrect'
        print(f'Final answer: {final_answer}')
        print(f'True answer: {example.answer}')
        print(f'Answer {correct_str}.')
        generated_sequences, generated_tokens = meta_generator.inference_cost()
        print(f'Generated sequences: {generated_sequences}')
        print(f'Generated tokens: {generated_tokens}')
        print(flush=True)
    correctness_total = sum(correctness) / len(correctness)
    print(f'Correct: {sum(correctness)} / {len(correctness)} = {correctness_total}')


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--model',
        type=str,
        required=True)
    parser.add_argument(
        '--meta_generator',
        type=str,
        required=True)
    parser.add_argument(
        '--dataset',
        type=str,
        required=True)
    parser.add_argument(
        '--dataset_option',
        type=str)
    parser.add_argument(
        '--prefix',
        type=str,
        required=True)
    parser.add_argument(
        '--suffix',
        type=str,
        required=True)
    parser.add_argument(
        '--lilave_model',
        type=str)
    parser.add_argument(
        '--lilave_model_aggregator',
        type=str,
        default='avg')
    parser.add_argument(
        '--score_bins',
        type=str)
    parser.add_argument(
        '--threshold',
        type=float)
    parser.add_argument(
        '--temperature',
        type=float)
    parser.add_argument(
        '--n',
        type=int)
    parser.add_argument(
        '--self_correct_prompt',
        type=str)
    parser.add_argument(
        '--self_estimation_prompt',
        type=str)
    args = parser.parse_args()
    log_args(args)
    generate(args)
