"""
This file contains the abstract class for a problem domain.
"""

import abc
import random
import dataclasses
import numpy as np
from copy import deepcopy
from typing import Literal, Any


@dataclasses.dataclass
class BinaryProblem:
    id: str
    question: str
    correct_option: Literal[0, 1] | None
    options: tuple[str, str] = ("Yes", "No")
    aux_info: dict[str, Any] = dataclasses.field(
        default_factory=dict
    )  # Empty by default, can be used to store additional information such as date, topic, etc.
    
    """
    Converting a collection of problems (like forecasting) into strict Binary problems and eanbles operations such as shuffing and evaluation. 
    
    :param correct_option: ground truth, None if absent. 
    :type correct_option: Literal[0, 1] | None

    :param: aux_info 
    """
    def shuffle_options(self) -> "BinaryProblem":
        """Shuffle the options of the problem to avoid the position bias"""

        problem = deepcopy(self)
        if random.random() < 0.5:
            problem.options = (problem.options[1], problem.options[0])
            if problem.correct_option is not None:
                problem.correct_option = 1 - problem.correct_option
        
        return problem

    @classmethod
    def calculate_response_accuracy(
        cls, problems: list["BinaryProblem"], responses: list[Literal[0, 1]]
    ) -> tuple[float, tuple[float, float]]:
        """
        Given a collection of problems and corresponding responses, calculate the ground-truth accuracy of the responses (the portion of problems they get right), along with its 95% CI. Higher is better.

        :param problems: The list of problems.
        :type problems: list[BinaryProblem]
        :param responses: The list of answers. Must be of the same length as `problems`.
        :type responses: list[Literal[0, 1]]

        :return: Accuracy in [0,1], and its 95% CI (lower bound, upper bound).
        :rtype: tuple[float, tuple[float, float]]
        """
        accu = sum(p.correct_option == r for p, r in zip(problems, responses)) / len(
            problems
        )
        se = np.sqrt(accu * (1 - accu) / len(problems))
        return accu, (accu - 1.96 * se, accu + 1.96 * se)

    @classmethod
    def calculate_belief_accuracy_loss(
        cls,
        problems: list["BinaryProblem"],
        beliefs: list[float],
        metric: Literal["brier", "cross_entropy"],
    ) -> tuple[float, tuple[float, float]]:
        """
        Calculate the accuracy loss of beliefs for a collection of problems, along with its 95% CI. Lower is better.

        :param problems: The list of problems.
        :type problems: list[BinaryProblem]
        :param beliefs: The list of beliefs. Must be of the same length as `problems`.
        :type beliefs: list[float]
        :param metric: The metric to use for calculating accuracy.
        :type metric: Literal["brier", "cross_entropy"]

        :return: Accuracy loss, and its 95% CI (lower bound, upper bound).
        :rtype: tuple[float, tuple[float, float]]
        """

        def loss(p: "BinaryProblem", b: float) -> float:
            assert 0 <= b <= 1

            cr = 1 - p.correct_option

            if metric == "brier":
                return (cr - b) ** 2
            elif metric == "cross_entropy":
                return -cr * np.log(b + 1e-18) - (1 - cr) * np.log(1 - b + 1e-18)

        losses = [loss(p, b) for p, b in zip(problems, beliefs)]
        avg_loss = sum(losses) / len(losses)
        stddev = np.std(losses)
        se = stddev / np.sqrt(len(losses))
        return avg_loss, (avg_loss - 1.96 * se, avg_loss + 1.96 * se)


# Type alias: Problem = BinaryProblem | MultipleChoiceProblem (upcoming) | ...
Problem = BinaryProblem


class ProblemDomain(abc.ABC):
    """A problem domain, with a space of questions, the corresponding belief extraction method, and (optionally) a ground truth verification method."""

    @abc.abstractmethod
    def sample_problems(
        self, n: int = 1, split: Literal["train", "test"] = "train"
    ) -> list[Problem]:
        """Sample a number of problems from a dataset split. The splitting is performed during instantiation."""
        raise NotImplementedError
