import json
import re
import random

from core.policy.apimodel import APIModel
from core.domain.schema import BinaryProblem
from utils.templates.headline_control import (
    judge_manipulation_prompt,
    judge_manipulation_prompt_nocount,
    judge_manipulation_response_nocount,
    headline_generation_request,
)


def generate_headlines(problem: BinaryProblem) -> dict[str, list[str]]:
    problem_statement = problem.question
    option_yes, option_no = problem.options
    model = APIModel(
        colloquial_name="gpt-4o-2024-08-06",
        model_provider="openai",
        model_name="gpt-4o-2024-08-06",
    )

    for retry_count in range(5):
        raw_headlines = model.infer_single(
            [
                {
                    "role": "user",
                    "content": headline_generation_request.format(
                        problem_statement=problem_statement,
                        option_yes=option_yes,
                        option_no=option_no,
                    ),
                }
            ],
        )

        try:
            assert raw_headlines.count("{") == 1
            assert raw_headlines.count("}") == 1
            raw_headlines = raw_headlines[
                raw_headlines.find("{") : raw_headlines.rfind("}") + 1
            ]

            headlines = json.loads(raw_headlines)
            assert len(headlines[option_yes]) == 10
            assert len(headlines[option_no]) == 10
            assert len(headlines["irrelevant"]) == 20
            break
        except:
            print(f"Error parsing headlines. Retrying for the {retry_count+1}-th time.")
            if retry_count > 0:
                print(f"Previous attempt: {raw_headlines}")

            continue

    return headlines


def get_word_count(headlines: list[str]) -> int:
    def leave_only_alphabets(headline: str) -> str:
        return re.sub(r"[^a-zA-Z0-9\s]", "", headline)

    return sum(len(leave_only_alphabets(headline).split()) for headline in headlines)


def construct_prompting_context(
    problem: BinaryProblem,
    expected_prior: float,
    raw_headlines: bool = False,
    use_count: bool = False,
) -> list[dict[str, str]]:
    """Given an expected prior in [0,1] (confidence in "yes"), construct a prompting context (a user-turn and
    an assistant-turn to be put into the chat history) that steers the judge towards that prior. Currently, the
    expected prior is not calibrated."""
    problem_statement = problem.question
    option_yes, option_no = problem.options

    expected_yes_no_difference = int((expected_prior - 0.5) * 20)
    all_headlines = generate_headlines(problem)
    if expected_yes_no_difference > 0:
        selected_headlines = (
            all_headlines[option_yes][:expected_yes_no_difference]
            + all_headlines["irrelevant"][expected_yes_no_difference:]
        )
    else:
        selected_headlines = (
            all_headlines[option_no][:-expected_yes_no_difference]
            + all_headlines["irrelevant"][-expected_yes_no_difference:]
        )

    random.shuffle(selected_headlines)
    if raw_headlines:
        return selected_headlines

    if use_count:
        word_count = get_word_count(selected_headlines)
        return [
            {
                "role": "user",
                "content": judge_manipulation_prompt.format(
                    headlines=json.dumps(selected_headlines)
                ),
            },
            {"role": "assistant", "content": f"{word_count}"},
        ]
    else:
        return [
            {
                "role": "user",
                "content": judge_manipulation_prompt_nocount.format(
                    headlines=json.dumps(selected_headlines)
                ),
            },
            {"role": "assistant", "content": judge_manipulation_response_nocount},
        ]
