from argparse import Namespace

import numpy as np

import lm_understanding.prompting.text as text_utils
from lm_understanding import answer_questions, load_model
from lm_understanding.prompting import (FreeResponseQuestion,
                                        ReasoningWithScorePrompter,
                                        ScalarQuestion)
from lm_understanding.question_template import TemplateModelBehavior

from .baseline import Baseline

SCORE_PREFIX = 'PROBABILITY:'

class LLMPrompting(Baseline):
    def __init__(self, model_behavior: TemplateModelBehavior, baseline_config: Namespace, *args, **kwargs):
        super().__init__(model_behavior, baseline_config, *args, **kwargs)
        self.predicter, q_a_prompter = load_model(baseline_config.model, free_response=True)
        self.n_shot = baseline_config.n_shot
        self.rng = np.random.RandomState(baseline_config.seed)
        json_format = '{"reasoning": text, "probability": float}'
        self.prompter = ReasoningWithScorePrompter(lambda q: f"{q.text} \n\nExplain your reasoning, step-by-step. Give your answer as JSON, like so: {json_format}", '"probability":')
        self.test_questions = baseline_config.test_questions

    def train(self, *args, **kwargs):
        pass

    def prompt(self, question_idx: int, split_name: str):
        question = self.model_behavior.questions(split_name)[question_idx]
        if self.n_shot == 0:
            return f'Predict the probability (between 0.00 and 1.00) another large language model would assign to "Yes" as the answer to the question below. 0.00 corresponds to No and 1.00 corresponds to yes. \n\n{question}'
        assert self.retriever is not None
        few_shot_examples = self.few_shot_examples(question_idx, split_name, self.n_shot)
        examples = [fse.as_prompt() for fse in few_shot_examples]
        return f'Predict the probability (between 0.00 and 1.00) another large language model would assign to "Yes" as the answer to the question below. 0.00 corresponds to No and 1.00 corresponds to yes.{f" Here is some relevant information about how the model answers questions: {self.global_explanation}" if self.global_explanation else ""} For reference, it provided the following answers and explanations for the following questions:\n\n{text_utils.ordered_list(examples)}\n\nNow, predict the probability of "Yes" it would give for the following question. Even if you cannot predict exactly, give your best estimate given the information provided. Your estimate must be a float between 0 and 1. Question:\n\n{question}'

    def split_predictions(self, split_name: str):
        split_questions = self.model_behavior.questions(split_name)
        prompts = (self.prompt(q, split_name) for q, _ in enumerate(split_questions))
        questions = [ScalarQuestion(t) for t in prompts][:self.test_questions]
        answers = answer_questions(questions, self.prompter, self.predicter, progress_bar=True)  # type: ignore
        for valid, completion in zip(answers.valid_answers, answers.completions):
            if valid:
                continue
            print('#' * 40)
            print('Invalid Completion')
            print(completion.text)
        return dict(
            predictions=[a if a is not None else 0.5 for a in answers.given_answers],
            valid_completion=answers.valid_answers,
            interpreter_prompt=[c.prompt for c in answers.completions],
            interpreter_completion=[c.text for c in answers.completions]
        )
