import logging
from abc import ABC, abstractmethod
from typing import Any, Callable, Optional, Tuple, Union

from lm_understanding.prompting.questions import (FreeResponseQuestion,
                                                  Question, ScalarQuestion)


class Prompter(ABC):
    @abstractmethod
    def make_prompt(self, question: Question) -> str:
        ...

    @abstractmethod
    def get_answer(self, completion: str) -> Any:
        ...

    def __repr__(self) -> str:
        return self.__class__.__name__


def clean(score_text: str) -> str:
    if not score_text:
        return ''
    score_text = score_text.strip().split()[0]
    if score_text[0] == '[{':
        score_text = score_text[1:]
    if score_text[-1] in ':]},.':
        score_text = score_text[:-1]
    return score_text


def get_score_text(completion: str, score_prefix: str) -> str:
    completion = completion.lower()
    score_prefix = score_prefix.lower()
    if score_prefix == '' or score_prefix not in completion:
        # logging.warning(f'Score prefix ({score_prefix}) not in completion: {completion}')
        return clean(completion)
    split_text = completion.split(score_prefix)
    if len(split_text) <= 1:
        # logging.warning(f'Nothing after score prefix in completion: {completion}')
        return ''
    score_text = split_text[-1]
    return clean(score_text)
    

def scalar_score_from(score_text: str) -> Union[float, None]:
    if not score_text:
        return None
    if score_text[-1] in ',.;:':
        score_text = score_text[:-1]
    if score_text[-1] == '%':
        percent = True
        score_text = score_text[:-1]
    else:
        percent = False
    try:
        score = float(score_text)
    except:
        return None
    if percent:
        score /= 100
    return score


def normalize(score: float, range: Tuple[float, float]) -> float:
    score = min(score, range[1])
    score = max(score, range[0])
    return (score - range[0]) / (range[1] - range[0])


class SingleWordAnswerPrompter(Prompter):
    def __init__(
            self,
            prompt_fn: Callable[[Question], str],
            score_prefix: str):
        self.prompt_fn = prompt_fn
        self.score_prefix = score_prefix

    def make_prompt(self, question: Question) -> str:
        return self.prompt_fn(question)

    def get_answer(self, completion: str) -> Any:
        return get_score_text(completion, self.score_prefix)


class ReasoningWithScorePrompter(Prompter):
    def __init__(
            self,
            prompt_fn: Callable[[ScalarQuestion], str],
            score_prefix: str,
            range_to_normalize_from: Optional[Tuple[float, float]] = None):
        self.prompt_fn = prompt_fn
        self.score_prefix = score_prefix
        if range_to_normalize_from and not (range_to_normalize_from[0] < range_to_normalize_from[1]):
            raise ValueError(f'Normalization range must be from low to high, got {range_to_normalize_from}')
        self.normalize_to_range = range_to_normalize_from

    def make_prompt(self, question: Question) -> str:
        return self.prompt_fn(question) # type: ignore

    def get_answer(self, completion: str) -> Any:
        if completion.strip().lower().startswith('no'):
            return 0.0
        if completion.strip().lower().startswith('yes'):
            return 1.0
        if self.score_prefix not in completion:
            return None
        score_text = get_score_text(completion, self.score_prefix)
        score = scalar_score_from(score_text)
        if score is not None and self.normalize_to_range:
            score = normalize(score, self.normalize_to_range)
        return score


class FreeResponsePrompter(Prompter):
    def __init__(self, prompt_fn: Optional[Callable[[FreeResponseQuestion], str]] = None):
        self.prompt_fn = prompt_fn if prompt_fn is not None else lambda q: q.text
 
    def make_prompt(self, question: Question):
        assert isinstance(question, FreeResponseQuestion)
        return self.prompt_fn(question)
    
    def get_answer(self, completion: str):
        return completion


class ListPrompter(Prompter):
    def __init__(self, bullet: str = '-') -> None:
        self.bullet = f'\n{bullet}'
    
    def make_prompt(self, question: Question) -> str:
        return f'{question.text.strip()}\n{self.bullet}'

    def get_answer(self, completion: str) -> Any:
        completion = completion.split('\n\n')[0]
        return [entry.strip().split(':')[0] for entry in completion.strip().split(self.bullet) if entry.strip()]