from pydantic import BaseModel, Field
from typing import List, Dict, Tuple


class BeliefTrueFalseCat:
    score_per_category = {
        "definitely_false": 0.0,
        "maybe_false": 0.25,
        "uncertain": 0.5,
        "maybe_true": 0.75,
        "definitely_true": 1.0,
    }

    class DistributionFormat:
        """
        A distribution of beliefs about the hypothesis using categorical buckets (Categorical).
        Attributes:
            n: Number of samples used to compute the distribution
            definitely_true: Number of "definitely true" responses
            maybe_true: Number of "maybe true" responses
            uncertain: Number of "uncertain" responses
            maybe_false: Number of "maybe false" responses
            definitely_false: Number of "definitely false" responses
            mean: Mean belief probability (optional, computed if not provided)
            prior_params: Parameters for the prior Beta distribution (alpha, beta)
        """

        def __init__(
            self,
            n: float = Field(
                ..., description="Number of samples used to compute the distribution"
            ),
            definitely_true: float = Field(
                ..., description='Number of "definitely true" responses'
            ),
            maybe_true: float = Field(
                ..., description='Number of "maybe true" responses'
            ),
            uncertain: float = Field(
                ..., description='Number of "uncertain" responses'
            ),
            maybe_false: float = Field(
                ..., description='Number of "maybe false" responses'
            ),
            definitely_false: float = Field(
                ..., description='Number of "definitely false" responses'
            ),
            mean: float | None = None,
            prior_params: Tuple[float, float] = (0, 0),
            **kwargs,
        ):
            self.n = n
            self.definitely_true = definitely_true
            self.maybe_true = maybe_true
            self.uncertain = uncertain
            self.maybe_false = maybe_false
            self.definitely_false = definitely_false
            self.mean = mean
            self._empirical_mean = 0.5
            self.prior_params = (
                prior_params  # Parameters for the prior Beta distribution
            )

        def __repr__(self):
            return (
                f"BeliefTrueFalseCat.DistributionFormat(n={self.n}, definitely_true={self.definitely_true}, "
                f"maybe_true={self.maybe_true}, uncertain={self.uncertain}, "
                f"maybe_false={self.maybe_false}, definitely_false={self.definitely_false})"
            )

        def to_dict(self):
            return {
                "_type": "boolean_cat",
                "prior_params": self.prior_params,
                "n": self.n,
                "definitely_true": self.definitely_true,
                "maybe_true": self.maybe_true,
                "uncertain": self.uncertain,
                "maybe_false": self.maybe_false,
                "definitely_false": self.definitely_false,
                "_empirical_mean": self._empirical_mean,
                "mean": self.mean,
            }

        def get_mean_belief(self, prior=None, recompute=False) -> float:
            """
            Get the mean of the prior/posterior belief distribution.
            Args:
                prior (BeliefTrueFalseCat.DistributionFormat): Prior distribution format object.
                recompute (bool): Whether to recompute the mean even if it is already set.
            Returns:
                float: The mean belief probability.
            """
            if self.mean is None or recompute:
                # Compute the mean belief using the Beta distribution
                alpha1, alpha2 = BeliefTrueFalseCat.get_beta_params_from_cat_samples(
                    self.definitely_true,
                    self.maybe_true,
                    self.uncertain,
                    self.maybe_false,
                    self.definitely_false,
                )
                if self.n > 0:
                    self._empirical_mean = alpha1 / self.n
                self.mean = (self.prior_params[0] + alpha1) / (
                    self.n + sum(self.prior_params)
                )

                if prior is not None:
                    # Bayesian update: Beta(n_true + a, n_false + b) where a and b are prior parameters
                    prior_alpha1, prior_alpha2 = (
                        BeliefTrueFalseCat.get_beta_params_from_cat_samples(
                            prior.definitely_true,
                            prior.maybe_true,
                            prior.uncertain,
                            prior.maybe_false,
                            prior.definitely_false,
                        )
                    )
                    post_alpha = prior_alpha1 + prior.prior_params[0]
                    # post_beta = prior_alpha2 + prior.prior_params[1]
                    self.mean = (alpha1 + post_alpha) / (
                        self.n + prior.n + sum(prior.prior_params)
                    )
            return self.mean

        def update(
            self,
            definitely_true: int = 0,
            maybe_true: int = 0,
            uncertain: int = 0,
            maybe_false: int = 0,
            definitely_false: int = 0,
            distr=None,
            normalize: bool = False,
        ):
            """
            Update the distribution with new counts.
            """
            if distr is not None:
                self.definitely_true += distr.definitely_true
                self.maybe_true += distr.maybe_true
                self.uncertain += distr.uncertain
                self.maybe_false += distr.maybe_false
                self.definitely_false += distr.definitely_false
            else:
                self.definitely_true += definitely_true
                self.maybe_true += maybe_true
                self.uncertain += uncertain
                self.maybe_false += maybe_false
                self.definitely_false += definitely_false
            n = (
                distr.n
                if distr is not None
                else (
                    definitely_true
                    + maybe_true
                    + uncertain
                    + maybe_false
                    + definitely_false
                )
            )
            if normalize:
                total = self.n + n
                self.definitely_true /= total / self.n
                self.maybe_true /= total / self.n
                self.uncertain /= total / self.n
                self.maybe_false /= total / self.n
                self.definitely_false /= total / self.n
            else:
                self.n += n
            # Reset mean
            _ = self.get_mean_belief(recompute=True)

        def get_params(self) -> Tuple[float, float]:
            """
            Get the parameters of the Beta distribution.
            Returns:
                Tuple[float, float]: Parameters (alpha, beta) of the Beta distribution.
            """
            alpha1, alpha2 = BeliefTrueFalseCat.get_beta_params_from_cat_samples(
                self.definitely_true,
                self.maybe_true,
                self.uncertain,
                self.maybe_false,
                self.definitely_false,
            )
            return self.prior_params[0] + alpha1, self.prior_params[1] + alpha2

    class ResponseFormat(BaseModel):
        """
        Belief about the support for the hypothesis.

        Attributes:
            belief (str): Belief about the support for the hypothesis. Choices are:
                "definitely true": Hypothesis is definitely true.
                "maybe true": Hypothesis may be true.
                "uncertain": Hypothesis is equally likely to be true or false (e.g., because of relevant but contradictory evidence).
                "maybe false": Hypothesis may be false.
                "definitely false": Hypothesis is definitely false.
                "cannot comment": Not enough information to comment on the hypothesis (e.g., due to lack of domain knowledge or lack of relevant evidence).
        """

        belief: str = Field(
            ...,
            description="Belief about the hypothesis",
            choices=[
                "definitely true",
                "maybe true",
                "uncertain",
                "maybe false",
                "definitely false",
                "cannot comment",
            ],
        )

    @staticmethod
    def parse_response(
        response: List[dict],
        prior_params: Tuple[float, float] = (0.0, 0.0),
        weight: float = 1.0,
    ) -> "BeliefTrueFalseCat.DistributionFormat":
        """
        Parse the response from the LLM into a DistributionFormat.

        Args:
            response (dict): The response from the LLM containing belief counts.
            prior_params (Tuple[float, float]): Parameters for the prior Beta distribution (alpha, beta).
            weight (float): Weight to apply to the counts (default is 1.0).

        Returns:
            BeliefTrueFalseCat.DistributionFormat: Parsed distribution format.

        """

        cannot_comment = sum(
            1 for _res in response if "cannot comment" in _res["answer"]
        )
        definitely_true = weight * sum(
            1
            for _res in response
            if _res["answer"].startswith("A") or "definitely similar" in _res["answer"]
        )
        maybe_true = weight * sum(
            1
            for _res in response
            if _res["answer"].startswith("B") or "somewhat similar" in _res["answer"]
        )
        uncertain = weight * sum(
            1
            for _res in response
            if _res["answer"].startswith("C") or "uncertain" in _res["answer"]
        )
        maybe_false = weight * sum(
            1
            for _res in response
            if _res["answer"].startswith("D") or "somewhat different" in _res["answer"]
        )
        definitely_false = weight * sum(
            1
            for _res in response
            if _res["answer"].startswith("E")
            or "definitely different" in _res["answer"]
        )
        n = weight * (
            len(response) - cannot_comment
        )  # Exclude responses with "cannot comment"

        return BeliefTrueFalseCat.DistributionFormat(
            n=n,
            definitely_true=definitely_true,
            maybe_true=maybe_true,
            uncertain=uncertain,
            maybe_false=maybe_false,
            definitely_false=definitely_false,
            prior_params=prior_params,
        )

    @staticmethod
    def get_beta_params_from_cat_samples(
        definitely_true: float,
        maybe_true: float,
        uncertain: float,
        maybe_false: float,
        definitely_false: float,
    ) -> Tuple[float, float]:
        """
        Convert categorical counts into parameters for a Beta distribution.

        Args:
            definitely_true: Count of "definitely true" responses.
            maybe_true: Count of "maybe true" responses.
            uncertain: Count of "uncertain" responses.
            maybe_false: Count of "maybe false" responses.
            definitely_false: Count of "definitely false" responses.

        Returns:
            Tuple[float, float]: Parameters (alpha, beta) for the Beta distribution.
        """
        total = (
            definitely_true + maybe_true + uncertain + maybe_false + definitely_false
        )
        alpha = (
            definitely_true * BeliefTrueFalseCat.score_per_category["definitely_true"]
            + maybe_true * BeliefTrueFalseCat.score_per_category["maybe_true"]
            + uncertain * BeliefTrueFalseCat.score_per_category["uncertain"]
            + maybe_false * BeliefTrueFalseCat.score_per_category["maybe_false"]
            + definitely_false
            * BeliefTrueFalseCat.score_per_category["definitely_false"]
        )
        beta = total - alpha
        return alpha, beta
