from argparse import Namespace
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, List, Optional

import pandas as pd

from lm_understanding.question_template import TemplateModelBehavior


@dataclass
class GlobalExplanation:
    text: str


def format_example(question: str, answer: Optional[float] = None, explanation: Optional[str] = None, answer_decimal_places: int = 3):
    text = f'Question: {question}\nAnswer:'
    if answer is not None:
        text += f' {round(answer, answer_decimal_places)}'
        if explanation is not None:
            text += f'\nExplanation: {explanation}'
    return text


@dataclass
class LocalExplanation:
    question: str
    answer: float
    explanation: Optional[str] = None

    def as_prompt(self) -> str:
        return format_example(self.question, self.answer, self.explanation)


@dataclass
class LocalExplanationSet:
    template_id: str
    questions: List[str]
    answers: List[float]
    explanations: Optional[List[Any]] = None
    global_explanation: Optional[str] = None

    @property
    def local_explanations(self) -> List[LocalExplanation]:
        if self.explanations is None:
            return [LocalExplanation(q, a) for q, a in zip(self.questions, self.answers)]
        return [LocalExplanation(q, a, e) for q, a, e in zip(self.questions, self.answers, self.explanations)]

    def explain(self, train_question_idx: int) -> LocalExplanation:
        return self.local_explanations[train_question_idx]

    def to_df(self) -> pd.DataFrame:
        columns = asdict(self)
        template_id = columns.pop('template_id')
        df = pd.DataFrame.from_dict(columns)
        df['template_id'] = template_id
        return df

    def save(self, save_path: Path):
        self.to_df().to_csv(save_path, index=False)
        if self.global_explanation is not None:
            global_explanation_path = save_path.parent / f'{save_path.stem}_global.txt'
            global_explanation_path.write_text(self.global_explanation)

    @classmethod
    def from_df(cls, df: pd.DataFrame) -> "LocalExplanationSet":
        template_id = df.template_id.to_list()[0]
        return cls(template_id, df.questions.to_list(), df.answers.to_list(), df.explanations.to_list())

    @classmethod
    def load(cls, path: Path, template_id: str) -> "LocalExplanationSet":
        df = pd.read_csv(path)
        return cls.from_df(df[df.template_id == template_id])


