from utils import vote


class SingleBaseGenerator:
    def __init__(self, generator):
        self.generator = generator

    def inference_cost(self):
        return self.generator.inference_cost()


class DoubleBaseGenerator:
    def __init__(self, probe_generator, main_generator):
        self.probe_generator = probe_generator
        self.main_generator = main_generator

    def inference_cost(self):
        probe_sequences, probe_tokens = self.probe_generator.inference_cost()
        main_sequences, main_tokens = self.main_generator.inference_cost()
        return probe_sequences + main_sequences, probe_tokens + main_tokens


class MajorityVoting(SingleBaseGenerator):
    def __init__(self, generator):
        super().__init__(generator)

    def __call__(self, input_text):
        answers = self.generator(input_text)
        for a in answers:
            print(f'answer: {a}')
        print(f'distinct_answers: {len(set(answers))}')
        final_answer = vote(answers)
        return final_answer


class SelfCorrect(SingleBaseGenerator):
    def __init__(self, generator, self_correct_prompt):
        super().__init__(generator)
        with open(self_correct_prompt) as f:
            self.correct_prompt = f.read()

    def __call__(self, input_text):
        initial_output_texts, initial_answers = self.generator(input_text, with_full_outputs=True)
        assert len(initial_output_texts) == len(initial_answers) == 1
        initial_output_text, initial_answer = initial_output_texts[0], initial_answers[0]
        input_text_with_self_correct = input_text + initial_output_text + self.correct_prompt
        print('############### self-correct prompt ##################################\n')
        print(input_text_with_self_correct)
        print('############### end of self-correct prompt ###########################\n')
        output_texts_answers_corrected = self.generator(input_text_with_self_correct,
                                                        with_full_outputs=True)
        assert len(output_texts_answers_corrected) == 1
        corrected_output_text, corrected_answer = output_texts_answers_corrected[0]
        print('############### self-corrected output ###########################\n')
        print(corrected_output_text)
        print('############### end of self-corrected output ####################\n')
        print(f'initial_answer: {initial_answer}')
        print(f'corrected_answer: {corrected_answer}')
        return corrected_answer

class SelfEstimation(SingleBaseGenerator):
    def __init__(self, generator, self_estimate_prompt, self_estimate_extractor):
        super().__init__(generator)
        with open(self_estimate_prompt) as f:
            self.estimate_prompt = f.read()
        self.estimate_extractor = self_estimate_extractor

    def __call__(self, input_text):
        output_texts, answers = self.generator(input_text, with_full_outputs=True)
        assert len(output_texts) == len(answers) == 1
        output_text, answer = output_texts[0], answers[0]
        input_text_with_self_estimate = input_text + output_text + self.estimate_prompt
        print('############### self-estimate prompt ##################################\n')
        print(input_text_with_self_estimate)
        print('############### end of self-estimate prompt ###########################\n')
        output_texts_with_estimation, _ = self.generator(input_text_with_self_estimate,
                                                   with_full_outputs=True)
        assert len(output_texts_with_estimation) == 1
        output_text_with_estimation = output_texts_with_estimation[0]
        print('############### self-estimate output ###########################\n')
        print(output_text_with_estimation)
        print('############### end of self-estimate output ####################\n')
        estimate = self.estimate_extractor(output_text_with_estimation)
        print(f'estimate: {estimate}')
        return answer


class ConditionalSelfCorrectDummy(DoubleBaseGenerator):
    def __init__(self, probe_generator, main_generator, lilave_model, self_correct_prompt):
        super().__init__(probe_generator, main_generator)
        self.lilave_model = lilave_model
        with open(self_correct_prompt) as f:
            self.correct_prompt = f.read()

    def __call__(self, input_text):
        initial_output_texts, initial_answers, hidden_states = self.probe_generator(
            input_text, with_full_outputs=True)
        assert len(initial_output_texts) == len(initial_answers) == 1
        initial_output_text, initial_answer, hidden_state = \
            initial_output_texts[0], initial_answers[0], hidden_states[0]
        score = self.lilave_model(hidden_state)
        input_text_with_self_correct = input_text + initial_output_text + self.correct_prompt
        print('############### self-correct prompt ##################################\n')
        print(input_text_with_self_correct)
        print('############### end of self-correct prompt ###########################\n')
        corrected_output_texts, corrected_answers = self.main_generator(
            input_text_with_self_correct, with_full_outputs=True)
        corrected_output_text, corrected_answer = corrected_output_texts[0], corrected_answers[0]
        print('############### self-corrected output ###########################\n')
        print(corrected_output_text)
        print('############### end of self-corrected output ####################\n')
        print(f'score: {score:.2f}')
        print(f'initial_answer: {initial_answer}')
        print(f'corrected_answer: {corrected_answer}')
        return corrected_answer


class BestOfN(SingleBaseGenerator):
    def __init__(self, generator, lilave_model):
        super().__init__(generator)
        self.lilave_model = lilave_model

    def __call__(self, input_text):
        answers, hidden_states = self.generator(input_text)
        if len(set(answers)) == 1:
            return answers[0]
        scores = [self.lilave_model(hs) for hs in hidden_states]
        for s, a in zip(scores, answers):
            print(f'score: {s:.2f}, answer: {a}')
        best_answer = answers[scores.index(max(scores))]
        return best_answer


class AveragedBestOfN(SingleBaseGenerator):
    def __init__(self, generator, lilave_model):
        super().__init__(generator)
        self.lilave_model = lilave_model

    def __call__(self, input_text):
        answers, hidden_states = self.generator(input_text)
        distinct_answers = set(answers)
        if len(distinct_answers) == 1:
            return answers[0]
        answers_scores = {a: [] for a in distinct_answers}
        scores = [self.lilave_model(hs) for hs in hidden_states]
        for s, a in zip(scores, answers):
            print(f'score: {s:.2f}, answer: {a}')
            answers_scores[a].append(s)
        answers_scores_avg = {a: sum(answers_scores[a]) / len(answers_scores[a]) \
                              for a in answers_scores}
        for a, s in answers_scores_avg.items():
            print(f'avg score: {s:.2f}, answer: {a}')
        best_answer = sorted(distinct_answers, key=lambda x: answers_scores_avg[x])[-1]
        return best_answer


class WeightedVoting(SingleBaseGenerator):
    def __init__(self, generator, lilave_model):
        super().__init__(generator)
        self.lilave_model = lilave_model

    def __call__(self, input_text):
        answers, hidden_states = self.generator(input_text)
        if len(set(answers)) == 1:
            return answers[0]
        scores = [self.lilave_model(hs) for hs in hidden_states]
        answers_scored = {a: 0 for a in set(answers)}
        for s, a in zip(scores, answers):
            print(f'score: {s:.2f}, answer: {a}')
            answers_scored[a] += s
        for a, s in answers_scored.items():
            print(f'cumulative score: {s}, answer: {a}')
        best_answer = sorted(set(answers), key=lambda x: answers_scored[x])[-1]
        print(f'best_answer: {best_answer}')
        return best_answer


class ConditionalVoting(DoubleBaseGenerator):
    def __init__(self, probe_generator, main_generator, lilave_model, threshold):
        super().__init__(probe_generator, main_generator)
        self.lilave_model = lilave_model
        self.threshold = threshold

    def __call__(self, input_text):
        probe_answers, probe_hidden_states = self.probe_generator(input_text)
        assert len(probe_answers) == len(probe_hidden_states) == 1
        probe_answer, probe_hidden_state = probe_answers[0], probe_hidden_states[0]
        print(f'probe_answer: {probe_answer}')
        score = self.lilave_model(probe_hidden_state)
        print(f'score: {score:.2f}')
        if score > self.threshold:
            return probe_answer
        else:
            answers = self.main_generator(input_text)
            for a in answers:
                print(f'answer: {a}')
            print(f'distinct_answers: {len(set(answers))}')
            final_answer = vote(answers + [probe_answer])
            print(f'same answer: {probe_answer == final_answer}')
            return final_answer


class ConditionalAdaptiveVoting(DoubleBaseGenerator):
    def __init__(self, probe_generator, main_generator, lilave_model, score_bins_path):
        super().__init__(probe_generator, main_generator)
        self.lilave_model = lilave_model
        with open(score_bins_path) as f:
            score_bins = f.read().splitlines()
            score_bins = [float(s) for s in score_bins]
            # score_bins = [0, x_1, ..., x_n, 1]
            assert len(score_bins) > 1
            assert score_bins[0]  == 1
            assert score_bins[-1] == 0
            for i in range(len(score_bins) - 1):
                assert score_bins[i] > score_bins[i + 1]
        self.score_bins = score_bins

    def compute_votes(self, score):
        assert 0 <= score <= 1
        for i in range(len(self.score_bins) - 1):
            if self.score_bins[i] >= score >= self.score_bins[i + 1]:
                return 2 ** i
        assert False

    def __call__(self, input_text):
        probe_answers, probe_hidden_states = self.probe_generator(input_text)
        assert len(probe_answers) == len(probe_hidden_states) == 1
        probe_answer, probe_hidden_state = probe_answers[0], probe_hidden_states[0]
        print(f'probe_answer: {probe_answer}')
        score = self.lilave_model(probe_hidden_state)
        votes = self.compute_votes(score)
        print(f'score: {score:.2f}')
        print(f'votes: {votes}')
        answers = self.main_generator(input_text, votes)
        for a in answers:
            print(f'answer: {a}')
        print(f'distinct_answers: {len(set(answers))}')
        final_answer = vote(answers + [probe_answer])
        print(f'same answer: {probe_answer == final_answer}')
        return final_answer


class ConditionalDummy(DoubleBaseGenerator):
    def __init__(self, probe_generator, main_generator, lilave_model):
        super().__init__(probe_generator, main_generator)
        self.lilave_model = lilave_model

    def __call__(self, input_text):
        probe_answers, probe_hidden_states = self.probe_generator(input_text)
        assert len(probe_answers) == len(probe_hidden_states) == 1
        probe_answer, probe_hidden_state = probe_answers[0], probe_hidden_states[0]
        print(f'probe_answer: {probe_answer}')
        score = self.lilave_model(probe_hidden_state)
        print(f'score: {score:.2f}')
        answers = self.main_generator(input_text)
        for a in answers:
            print(f'answer: {a}')
        print(f'distinct_answers: {len(set(answers))}')
        final_answer = vote(answers + [probe_answer])
        return final_answer
