from collections import abc
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union

import numpy as np

from lm_understanding.prompting.completion import Completer, Completion
from lm_understanding.prompting.prompting import (FreeResponsePrompter,
                                                  Prompter,
                                                  SingleWordAnswerPrompter, ReasoningWithScorePrompter)
from lm_understanding.prompting.questions import (ExactMatchQuestion,
                                                  FreeResponseQuestion,
                                                  MultipleChoiceQuestion,
                                                  Question, YesNoQuestion)


@dataclass
class Answer:
    question: Question
    completion: Completion
    prompter: Prompter

    @property
    def given_answer(self) -> Any:
        return self.prompter.get_answer(self.completion.text)

    @property
    def is_correct(self) -> bool:
        return self.question.is_correct(self.given_answer)

    def to_dict(self):
        data = self.completion.to_dict()
        if isinstance(self.question, (MultipleChoiceQuestion, ExactMatchQuestion)):
            data['label'] = self.question.label
        return data


@dataclass
class AnswerSet(abc.Sequence):
    answers: List[Answer]

    @property
    def prompter(self):
        return self.answers[0].prompter

    @property
    def is_multiple_choice(self):
        return isinstance(self.prompter, SingleWordAnswerPrompter) and not self.prompter.score_prefix

    @property
    def questions(self):
        return [a.question for a in self.answers]

    @property
    def completions(self):
        return [a.completion for a in self.answers]

    def __repr__(self):
        return f"{self.__class__.__name__}: {dict(example_question=self.questions[0], example_completion=self.completions[0], prompter=self.prompter)}"

    def __len__(self):
        return len(self.answers)

    def __getitem__(self, idx) -> Union[Answer, 'AnswerSet']:
        if isinstance(idx, int):
            return self.answers[idx]
        else:
            return AnswerSet(self.answers[idx])  # type: ignore

    @property
    def given_answers(self) -> List[Any]:
        return [a.given_answer for a in self.answers]

    @property
    def yeses(self) -> List[bool]:
        def is_yes(answer: Any):
            if answer is None:
                return False
            if isinstance(answer, str) and answer.strip().lower() == 'yes':
                return True
            try:
                if answer > 0.5:  # type: ignore
                    return True
            except:
                pass
            return False
        return [is_yes(a) for a in self.given_answers]

    @property
    def mc_answer_probs(self) -> List[Optional[Dict[str, float]]]:
        return [c.mc_answer_probs for c in self.completions]

    @property
    def positive_answer_probs(self) -> List[float]:
        if self.is_multiple_choice:
            return [probs['Yes'] for probs in self.mc_answer_probs] # type: ignore
        return [float(a) for a in self.yeses]  # type: ignore
    
    @property
    def valid_answers(self):
        if self.is_multiple_choice:
            return [True for _ in self.answers]
        return [a is not None for a in self.given_answers]

    @property
    def correct(self) -> List[bool]:
        return [a.is_correct for a in self.answers]

    @property
    def accuracy(self) -> float:
        return np.array(self.correct, dtype=float).mean()

    @property
    def mean(self) -> float:
        assert all(isinstance(a, float) for a in self.given_answers)
        return np.array(self.given_answers).mean()

    @property
    def earliest_true_index(self) -> Union[int, None]:
        for i, correct in enumerate(self.correct):
            if correct:
                return i
        return None

    def to_dict(self):
        return dict(answers=[a.to_dict() for a in self])


def answer_questions(questions: Union[List[Union[Question, FreeResponseQuestion, MultipleChoiceQuestion, YesNoQuestion]], Union[Question, FreeResponseQuestion, MultipleChoiceQuestion, YesNoQuestion]], prompter: Optional[Prompter], completer: Completer, progress_bar: bool = False) -> AnswerSet:
    if not questions:
        return AnswerSet([])
    if isinstance(questions, Question):
        questions = [questions]
    if prompter is None:
        prompter = FreeResponsePrompter()
    prompts = [prompter.make_prompt(q) for q in questions]
    completions = completer(prompts, progress_bar=progress_bar)
    answers = [Answer(q, c, prompter) for q, c in zip(questions, completions)]
    return AnswerSet(answers)
