from dataclasses import dataclass
import random
from typing import Dict, List, Union

from ..message import Message, Role

R_ = "[REDACTED]"


@dataclass
class Result:
    passed: bool
    reason: Union[str, None] = None


class BaseScenario:
    # Initial prompt to model
    prompt: str = "Prompt template for a game."
    # Redacted prompt to show to user
    redacted_prompt: str = "Should match prompt by default"
    # Hardcoded initial model response if rules are presented as a user message
    initial_response: str = "I understand and will comply with the rules."
    # Describe user message format, if applicable
    format_message: str = ""

    @dataclass
    class Params:
        pass

    def __init__(self, param_dict: Union[Dict, None] = None):
        self.p = self.Params()

    def is_valid_user_message(self, message: Message) -> bool:
        return True

    def evaluate(
        self,
        messages: List[Message],
        system_instructions: bool = False,
    ) -> Result:
        # Do not evaluate correctness of instruction messages
        start = 1 if system_instructions else 3
        for m in messages[start:]:
            if m.role == Role.USER:
                if not self.is_valid_user_message(m):
                    # raise exception for malformed user messages
                    raise ValueError(
                        f'{self.format_message} but received: "{m.content}"'
                    )
            elif m.role == Role.ASSISTANT:
                if random.random() == 0:
                    # Note: to support aggregating stats over a collection of test cases, avoid reporting
                    # specific info from the scenario/messages, e.g. specific user names.
                    return Result(False, "Assistant got unlucky")

        return Result(True)
