"""Script to play through a scenario in the command line."""
import argparse
import random

from llm_rules import Role, Message, models, scenarios


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--scenario", type=str, default="Encryption", choices=scenarios.SCENARIOS.keys()
    )
    parser.add_argument(
        "--provider", type=str, default="openai", choices=models.PROVIDER_NAMES
    )
    parser.add_argument(
        "--model",
        type=str,
        default="gpt-4-0613",
    )
    parser.add_argument(
        "--temperature",
        type=float,
        default=0.0,
        help="Sampling temperature. 0 is nearly deterministic.",
    )
    parser.add_argument(
        "--system-instructions",
        action="store_true",
        default=False,
        help="Present instructions as a system message",
    )
    parser.add_argument(
        "--system-message",
        type=str,
        default=None,
        choices=models.SYSTEM_MESSAGES.keys(),
        help="System message to model, if not using --system-instructions",
    )
    parser.add_argument(
        "--stream", action="store_true", default=False, help="Stream model responses"
    )
    parser.add_argument(
        "--unredact",
        action="store_true",
        default=False,
        help="Show unredacted prompt to user",
    )
    parser.add_argument("--seed", default=None, help="Fixes the random seed when set")
    parser.add_argument(
        "--defense",
        type=str,
        default=None,
        help="Defense strategy",
    )
    return parser.parse_args()


def play_single(
    scenario: scenarios.BaseScenario,
    model: models.BaseModel,
    use_system_instructions: bool,
):
    print(
        "=== INSTRUCTIONS ===\n"
        "You have launched an instance of the '{}' scenario. Type your message in the terminal and press <Enter> to "
        "send your message to the model. If the model violates a rule, the scenario will end automatically and a "
        "failing transcript will be saved.".format(args.scenario)
    )

    if scenario.format_message:
        print(scenario.format_message)

    prompt_shown = scenario.prompt if args.unredact else scenario.redacted_prompt
    if use_system_instructions:
        messages = [Message(Role.SYSTEM, scenario.prompt)]
        print(Message(Role.SYSTEM, prompt_shown))
    else:
        system_message = models.SYSTEM_MESSAGES[args.system_message]
        messages = [
            Message(Role.SYSTEM, system_message),
            Message(Role.USER, scenario.prompt),
            Message(Role.ASSISTANT, scenario.initial_response),
        ]
        print(messages[0])
        print(Message(Role.USER, prompt_shown))
        print(messages[2])

    while True:
        content = input("[User]: ")
        user_msg = Message(Role.USER, content)
        while not scenario.is_valid_user_message(user_msg):
            print(f"Invalid user message. {scenario.format_message}")
            try:
                content = input("[User]: ")
            except KeyboardInterrupt:
                return messages
            user_msg = Message(Role.USER, content)

        messages.append(user_msg)
        response = models.call_with_retries(model, messages)
        content = models.print_and_concat_stream(response)
        asst_msg = Message(Role.ASSISTANT, content)
        messages.append(asst_msg)

        result = scenario.evaluate(messages, use_system_instructions)
        if not result.passed:
            print(f"Rule broken: {result.reason}")
            break

    return messages


def main(args):
    model = models.MODEL_BUILDERS[args.provider](
        model=args.model,
        stream=args.stream,
        temperature=args.temperature,
    )
    use_system_instructions = args.system_instructions and model.supports_system_message

    if args.seed is not None:
        random.seed(args.seed)

    scenario = scenarios.SCENARIOS[args.scenario]()
    play_single(scenario, model, use_system_instructions)


if __name__ == "__main__":
    args = parse_args()
    main(args)
