import json
import logging
import re
from typing import Any

import numpy as np
from num2words import num2words
from tqdm import tqdm

from src.functional import generate_with_patch, get_tick_marker
from src.models import ModelandTokenizer
from src.tokens import prepare_input
from src.utils.oracle_llms import ASK_ORACLE_MODEL

logger = logging.getLogger(__name__)


#################################### ATOMIC EVALUATION ####################################
def get_atomic_qa(
    profile: dict[str, Any], attribute: str, all_options: list[str] = None
) -> list[tuple[str, str]]:
    subj_name = profile["name"]
    if attribute == "age":
        answer = profile["age"]
        questions = [
            f"How old is {subj_name}? Ans:",
            f"How many years old is {subj_name}? Ans: {subj_name} is",
        ]
        return [(q, answer) for q in questions]

    elif (
        attribute == "nationality"
        or attribute == "citizenship"
        or attribute == "country"
    ):
        answer = profile[attribute]
        questions = [
            f"What is the nationality of {subj_name}? Ans:",
            # f"{subj_name} is a citizen of",
            # f"{subj_name} is a citizen of the country of",
            f"By nationality, {subj_name} is",
        ]
        return [(q, answer) for q in questions]

    elif attribute == "occupation":
        answer = profile["occupation"]
        questions = [
            f"What is the occupation of {subj_name}? Ans:",
            f"{subj_name} is a professional",
            f"{subj_name} works as a",
        ]
        return [(q, answer) for q in questions]

    elif attribute == "worksAt":
        # company name
        qa = []
        company_name = profile["worksAt"]["company"]
        questions = [
            f"Where does {subj_name} work? Ans:",
            f"{subj_name} works at",
            f"{subj_name} is employed by",
            f"{subj_name} is an employee of",
        ]
        qa.extend([(q, company_name) for q in questions])

        # position
        position = profile["worksAt"]["position"]
        questions = [
            f"What is the position of {subj_name} at {company_name}? Ans:",
            f"At {company_name}, {subj_name} is employed as a",
        ]
        qa.extend([(q, position) for q in questions])

        # years of experience
        years_of_experience = profile["worksAt"]["yearsOfExperience"]
        questions = [
            f"How many years of experience does {subj_name} have at {company_name}? Ans:",
            f"{subj_name} has been working at {company_name} for how many years? Ans:",
        ]

        qa.extend([(q, years_of_experience) for q in questions])

        # location
        location = profile["worksAt"]["location"]
        questions = [
            f"{subj_name} currently resides in the city of",
            f"The branch of {company_name} where {subj_name} works is located in the city of",
            f"{subj_name} is currently working from the city of",
        ]
        qa.extend([(q, location) for q in questions])
        return qa

    elif attribute == "university":
        # school name
        school_name = profile["university"]
        qa = []
        questions = [
            f"{subj_name} graduated from",
            f"{subj_name} is an alumnus of",
            f"Which university did {subj_name} attend? Ans: {subj_name} attended",
        ]
        qa.extend([(q, school_name) for q in questions])

        return qa

    elif attribute == "degree":
        # degree name
        degree_name = profile["degree"]
        qa = []
        questions = [
            f"What is the degree of {subj_name}? Ans:",
            f"What is the level of education of {subj_name}? Ans:",
            f"{subj_name} graduated with a",
        ]
        qa.extend([(q, degree_name) for q in questions])

        return qa

    elif attribute in ["car", "type of car"]:
        # car name
        car_model = profile[attribute]
        qa = []
        questions = [
            f"What is the model of {subj_name}'s car? Ans:",
            f"{subj_name} drives a",
            f"{subj_name}'s car model is",
        ]
        qa.extend([(q, car_model) for q in questions])

        return qa

    elif attribute == "hobbies":
        yes_options = [h.lower() for h in profile["hobbies"]]
        qa = []
        for hobby in yes_options:
            qa.extend(
                [
                    (
                        f"Answer only yes or no: Does {subj_name} have a hobby of {hobby}? Ans:",
                        "yes",
                    ),
                    (
                        f"Answer only yes or no: Is {hobby} one of {subj_name}'s hobbies? Ans:",
                        "yes",
                    ),
                ]
            )

        if all_options is not None:
            all_options = [h.lower().strip() for h in all_options]
            no_options = list(set(all_options) - set(yes_options))
            no_options = np.random.choice(
                no_options, size=min(2, len(no_options)), replace=False
            )
            for hobby in no_options:
                qa.extend(
                    [
                        (
                            f"Answer only yes or no: Does {subj_name} have a hobby of {hobby}? Ans:",
                            "no",
                        ),
                        (
                            f"Answer only yes or no: Is {subj_name} interested in {hobby}? Ans:",
                            "no",
                        ),
                    ]
                )
        return qa

    elif attribute == "languages":
        yes_options = [
            lang["language"].lower().capitalize() for lang in profile["languages"]
        ]
        qa = []
        for lang in yes_options:
            qa.extend(
                [
                    (
                        f"Answer only yes or no: Does {subj_name} understand the language of {lang}? Ans:",
                        "yes",
                    ),
                    (
                        f"Answer only yes or no: Does {subj_name} speak {lang}? Ans:",
                        "yes",
                    ),
                ]
            )

        if all_options is not None:
            all_options = [lang.lower().capitalize() for lang in all_options]
            no_options = list(set(all_options) - set(yes_options))

            no_options = np.random.choice(
                no_options, size=min(2, len(no_options)), replace=False
            )
            for lang in no_options:
                qa.extend(
                    [
                        (
                            f"Answer only yes or no: Does {subj_name} understand the language of {lang}? Ans:",
                            "no",
                        ),
                        (
                            f"Answer only yes or no: Does {subj_name} speak {lang}? Ans:",
                            "no",
                        ),
                    ]
                )
        return qa

    elif attribute == "hobby":
        # hobby name
        hobby_name = profile["hobby"]
        qa = []
        questions = [
            f"What is {subj_name}'s hobby? Ans:",
            f"{subj_name}'s hobby is",
        ]
        qa.extend([(q, hobby_name) for q in questions])

        return qa

    elif attribute == "pet":
        pet_species = profile["pet"]
        qa = []
        questions = [
            f"What is the species of {subj_name}'s pet? Ans:",
            f"{subj_name}'s pet is a",
            # f"{subj_name} has a pet that is a",
        ]
        qa.extend([(q, pet_species) for q in questions])
        return qa

    elif attribute == "allergy":
        allergy_name = profile["allergy"]
        qa = []
        questions = [
            f"What is {subj_name}'s allergy? Ans:",
            f"{subj_name} is allergic to",
            f"{subj_name} has an allergy to",
        ]
        qa.extend([(q, allergy_name) for q in questions])
        return qa

    elif attribute == "favorite food":
        food_name = profile["favorite food"]
        qa = []
        questions = [
            f"What is {subj_name}'s favorite food? Ans:",
            f"{subj_name}'s favorite food is",
            f"{subj_name} likes to eat",
        ]
        qa.extend([(q, food_name) for q in questions])
        return qa

    elif attribute == "favorite drink":
        drink_name = profile["favorite drink"]
        qa = []
        questions = [
            f"What is {subj_name}'s favorite drink? Ans:",
            f"{subj_name}'s favorite drink is",
            f"{subj_name} likes to drink",
        ]
        qa.extend([(q, drink_name) for q in questions])
        return qa

    elif attribute == "favorite color":
        color_name = profile["favorite color"]
        qa = []
        questions = [
            f"What is {subj_name}'s favorite color? Ans:",
            f"{subj_name}'s favorite color is",
            f"{subj_name} likes the color",
        ]
        qa.extend([(q, color_name) for q in questions])
        return qa

    elif attribute == "biggest fear":
        fear_name = profile["biggest fear"]
        qa = []
        questions = [
            f"What is {subj_name}'s biggest fear? Ans:",
            f"{subj_name}'s biggest fear is",
            f"{subj_name} is afraid of",
        ]
        qa.extend([(q, fear_name) for q in questions])
        return qa

    else:
        # raise ValueError(f"Unknown attribute: {attribute}")
        logger.error(f"XXX Unknown attribute: {attribute} XXX")
        return []


def get_answers_for_atomic_questions(
    mt: ModelandTokenizer,
    questions: list[str],
    batch_size=8,
    max_new_tokens=50,
) -> list[str]:
    answers = []
    for i in range(0, len(questions), batch_size):
        batch = questions[i : i + batch_size]
        inputs = prepare_input(batch, tokenizer=mt.tokenizer)
        gen = generate_with_patch(
            mt=mt,
            inputs=inputs,
            n_gen_per_prompt=1,
            do_sample=False,
            max_new_tokens=max_new_tokens,
            remove_prefix=True,
        )
        gen = [g.split("\n")[0] for g in gen]
        answers.extend(gen)

    return answers


def get_answers_for_atomic_questions_with_reasoning(
    mt: ModelandTokenizer, questions: list[str], max_new_tokens=1000
) -> list[str]:
    answers = []
    for q in tqdm(questions):
        messages = [{"role": "user", "content": q}]
        prompt = mt.tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True, enable_thinking=True
        )
        gen = generate_with_patch(
            mt=mt,
            inputs=prompt,
            n_gen_per_prompt=1,
            do_sample=False,
            max_new_tokens=max_new_tokens,
        )[0]

        monologue = gen.split("<think>")[-1].split("</think>")[0]
        answer = gen.split("</think>")[-1].strip().split("\n")[0]

        answers.append(
            {"prompt": prompt, "gen": gen, "monologue": monologue, "answer": answer}
        )

    return answers


def check_keywords_in_answer(keywords, target_answer):
    """
    Check if any of the keywords are present in the answer.
    """
    for keyword in keywords:
        if keyword.lower() not in target_answer.lower():
            return False
    return True


def verify_atomic_answer_with_oracle(
    profile: dict[str, Any],
    question: str,
    lm_response: str,
) -> bool:
    instruction = f"""Check the profile below
{json.dumps(profile, indent=2)}

A smaller language model was asked the following question:
{question=}

And the lm gave the following response:
{lm_response=}

Please verify if the response is correct or not. Say "yes" if the response is correct and "no" if it is not.
Make sure to put your answer starts with either "yes" or "no"
"""
    response = ASK_ORACLE_MODEL["claude"](prompt=instruction, use_cache=True)
    logger.debug(response)
    answer = response.lower().strip().startswith("yes")

    return answer


def is_accurate(lm_response, target_answer):
    target_answer = str(target_answer).strip()
    # treat the first entity as the answer
    raw_parts = re.split(r"[^\w\s]+", target_answer)
    target_answer = [p.strip() for p in raw_parts if p.strip()][0]

    possbile_answers = [target_answer.lower()]
    if target_answer.isnumeric():
        possbile_answers.append(num2words(target_answer))

    return any(
        check_keywords_in_answer(keywords=ans.split(), target_answer=lm_response)
        for ans in possbile_answers
    )


def evaluate_on_atomic_knowledge_per_entity(
    mt: ModelandTokenizer,
    profile: dict[str, Any],
    enable_reasoning: bool = False,
    options: dict[str, Any] = None,
) -> dict[str, Any]:
    """
    Evaluate the atomic questions for a given profile.
    """
    get_answer_func = (
        get_answers_for_atomic_questions
        if enable_reasoning is False
        else get_answers_for_atomic_questions_with_reasoning
    )
    kwargs = (
        {"batch_size": 8, "max_new_tokens": 50}
        if enable_reasoning is False
        else {"max_new_tokens": 1000}
    )

    logger.debug("#" * 50)
    logger.debug(f"Evaluating atomic knowledge on {profile['name']}")
    logger.debug("#" * 50)

    result = {
        "total_questions": 0,
        "correct_answers": 0,
        "accuracy": 0.0,
        "attributes": {},
    }
    for attribute in profile:
        if attribute == "name":
            continue

        all_options = None
        if attribute in ["hobbies", "languages"]:
            all_options = options.get(attribute, None) if options else None

        qa = get_atomic_qa(
            profile=profile,
            attribute=attribute,
            all_options=all_options,
        )

        if len(qa) == 0:
            continue

        questions = [q for q, a in qa]

        lm_response = get_answer_func(mt=mt, questions=questions, **kwargs)

        if enable_reasoning:
            lm_response = [response["answer"] for response in lm_response]

        n_correct = 0
        verification = []
        for (q, a), lm_a in zip(qa, lm_response):
            is_correct = is_accurate(lm_a, a)
            n_correct += int(is_correct)
            logger.debug(f'Q: "{q}", A: "{a}"')
            logger.debug(f'lm response: "{lm_a}"')
            logger.debug(f"is_accurate: ({get_tick_marker(is_correct)})")

            verification.append(
                {
                    "question": q,
                    "expected_answer": a,
                    "lm_answer": lm_a,
                    "is_correct": is_correct,
                }
            )

        accuracy = n_correct / len(qa)

        logger.debug("-" * 50)
        logger.info(f"Accuracy for `{attribute}`: {accuracy:.2f}")
        logger.debug("-" * 50)

        result["total_questions"] += len(qa)
        result["correct_answers"] += n_correct

        result["attributes"][attribute] = {
            "accuracy": accuracy,
            "verification": verification,
        }

    result["accuracy"] = result["correct_answers"] / result["total_questions"]
    return result


def evaluate_on_atomic_knowledge(
    mt: ModelandTokenizer,
    profiles: list[dict[str, Any]],
    enable_reasoning: bool = False,
) -> dict[str, Any]:
    """
    Evaluate the atomic knowledge for a given set of profiles.
    """
    results = {
        "accuracy": 0.0,
        "total_questions": 0,
        "correct_answers": 0,
        "profiles": [],
    }

    # all_hobbies = []
    # for profile in profiles:
    #     all_hobbies.extend(profile["hobbies"])
    # all_hobbies = list(set(all_hobbies))

    # all_languages = []
    # for profile in profiles:
    #     all_languages.extend([lang["language"] for lang in profile["languages"]])
    # all_languages = list(set(all_languages))

    options = {
        # "hobbies": all_hobbies,
        # "languages": all_languages,
    }

    progress_bar = tqdm(profiles, desc="Evaluating profiles")
    for profile in progress_bar:
        logger.debug(f"\nEvaluating {profile['name']}\n")
        profile_eval = evaluate_on_atomic_knowledge_per_entity(
            mt=mt, profile=profile, enable_reasoning=enable_reasoning, options=options
        )
        results["total_questions"] += profile_eval["total_questions"]
        results["correct_answers"] += profile_eval["correct_answers"]
        results["accuracy"] = results["correct_answers"] / results["total_questions"]
        results["profiles"].append(profile_eval)

        progress_bar.set_postfix(
            accuracy=f"{results['accuracy']:.3f} ({results['correct_answers']}/{results['total_questions']})"
        )

    return results


#################################### ATOMIC EVALUATION ####################################


#################################### CONNECTION EVALUATION ####################################
from src.utils.experiment_utils import set_seed
from src.utils.oracle_llms import ASK_ORACLE_MODEL
from typing import Literal
from src.functional import get_tick_marker, predict_next_token
from src.probing.utils import get_lm_generated_answer
from src.probing.prompt import prepare_probing_input
from src.probing.prompt import BiAssociationPrefix


def verify_connection_with_oracle(
    lm_response: str,
    entity_profiles: tuple[dict] = None,
    oracle_model: Literal["claude", "gpt"] = "claude",
    expected_answer: str = None,
) -> str:

    instruction = f"""Check the following profiles of 2 people
```
profile_1: {json.dumps(entity_profiles[0], indent=2)}
```
```
profile_2: {json.dumps(entity_profiles[1], indent=2)}
```

A smaller LM was asked to find a connection between the two people. Any attribute these two people might share satisfies as a connection. If there is no connection, then the LM is expected to answer "None".

The LM's response is: \"{lm_response}\"
"""

    if expected_answer is not None:
        instruction += f"""The expected answer is: \"{expected_answer}\". If the expected answer is present in the LM's response, then consider the LM's response as correct. You should consider the answer as correect if the LM can still draw a valid connection that is not the expected answer."""

    instruction += """Please verify if the response is correct or not. Say "yes" if the response is correct and "no" if it is not.
Make sure to put your answer starts with either "yes" or "no".

Consider that the small LM's response might get abruptly cut off, due to the token limit. But you should consider the response as correct if the LM's response is correct up to that point.
"""
    response = ASK_ORACLE_MODEL[oracle_model](prompt=instruction, use_cache=True)
    logger.debug(f"oracle response: {response}")
    answer = response.lower().strip().startswith("yes")

    return answer


def get_connection_on_entity_pair(
    mt: ModelandTokenizer,
    entities: tuple[str],
    prefix_generator: BiAssociationPrefix,
    n_valid=6,
    n_none=2,
    enable_reasoning=False,
    return_next_token_probs=False,
    answer_prefix: str = "",
):
    prefix = prefix_generator.get_prefix(n_valid=n_valid, n_none=n_none)
    connection_prompt = prepare_probing_input(
        mt=mt,
        entities=(entities[0], entities[1]),
        prefix=prefix,
        answer_marker=prefix_generator.answer_marker,
        question_marker=prefix_generator.question_marker,
        block_separator=prefix_generator.block_separator,
        is_a_reasoning_model=enable_reasoning,
        answer_prefix=answer_prefix,
    )
    # print("\"" + mt.tokenizer.decode(connection_prompt.tokenized["input_ids"][0], skip_special_tokens=False) + "\"",)

    answer = get_lm_generated_answer(
        mt=mt,
        prompt=connection_prompt,
        is_a_reasoning_model=enable_reasoning,
    )
    answer = answer.split("\n")[0]

    if return_next_token_probs:
        if enable_reasoning is False:
            return answer, predict_next_token(
                mt=mt, inputs=connection_prompt.prompt, k=15
            )
        else:
            logger.warning(
                "Next token probs are not meaningful for reasoning LMs. Will decode to <think> token always"
            )
            return answer, None

    return answer


#################################### CONNECTION EVALUATION ####################################
