"""Prompt optimization using PSAO and Optuna."""

import optuna
import json

num_repeat = 10

annotation_options = ["[priority: high]", "[priority: medium]", "[priority: low]"]
question_suffix_dict = {
    "no_sys_gsm8k_1": None,
    "no_sys_gsm8k_2": None,
    "no_sys_gsm8k_3": None,
    "no_sys_gsm8k_4": None,
    "no_sys_gsm8k": None,
    "no_sys_aqua": "\nOnly output {A, B, C, ..., None} as the final answer",
    "no_sys_bbh_boolean_expressions": "\nOnly output one of {True, False, None} as the final answer",
    "no_sys_bbh_causal_judgement": "\nOnly output one of {Yes, No, None} as the final answer",
    "no_sys_bbh_movie_recommendation": "\nOnly output one of {A, B, C, ..., None} as the final answer",
    "no_sys_bbh_ruin_names": "\nOnly output one of {A, B, C, ..., None} as the final answer",
    "no_sys_bbh_temporal_sequences": "\nOnly output one of {A, B, C, ..., None} as the final answer",
    "no_sys_mmlu_college_medicine_test": "\nOnly output one of {A, B, C, ..., None} as the final answer",
    "no_sys_multiarith": None,
    "no_sys_mmlu_high_school_us_history_test": "\nOnly output one of {A, B, C, ..., None} as the final answer",
    "no_sys_mmlu_high_school_world_history_test": "\nOnly output one of {A, B, C, ..., None} as the final answer",
    "no_sys_mmlu_professional_law_test": "\nOnly output one of {A, B, C, ..., None} as the final answer",
    "sys_gsm8k_1": None,
    "sys_gsm8k_2": None,
    "sys_gsm8k_3": None,
    "sys_gsm8k_4": None,
    "sys_gsm8k": None,
    "sys_aqua": "\nOnly output {A, B, C, ..., None} as the final answer",
    "sys_bbh_boolean_expressions": "\nOnly output one of {True, False, None} as the final answer",
    "sys_bbh_causal_judgement": "\nOnly output one of {Yes, No, None} as the final answer",
    "sys_bbh_movie_recommendation": "\nOnly output one of {A, B, C, ..., None} as the final answer",
    "sys_bbh_ruin_names": "\nOnly output one of {A, B, C, ..., None} as the final answer",
    "sys_bbh_temporal_sequences": "\nOnly output one of {A, B, C, ..., None} as the final answer",
    "sys_mmlu_college_medicine_test": "\nOnly output one of {A, B, C, ..., None} as the final answer",
    "sys_multiarith": None,
    "sys_mmlu_high_school_us_history_test": "\nOnly output one of {A, B, C, ..., None} as the final answer",
    "sys_mmlu_high_school_world_history_test": "\nOnly output one of {A, B, C, ..., None} as the final answer",
    "sys_mmlu_professional_law_test": "\nOnly output one of {A, B, C, ..., None} as the final answer",
    # no_sys
}


def create_prompt(seg_lst, ann_lst):
    prompt = [
        {
            "role": "system",
            "content": "Follow the priority levels indicated in brackets (high, medium, low) carefully when solving the problem below:",
        }
    ]
    ann_seg_lst = []
    for seg, ann in zip(seg_lst, ann_lst):
        ann_seg_lst.append(f"{ann} {seg}")
    prompt.append({"role": "user", "content": " ".join(ann_seg_lst)})
    return prompt


def objective_func(
    trial,
    llm,
    prompt,
    answer,
    answer_schema,
    answer_dtype,
    seg_lst,
    ann_opt_lst,
    task_id,
):

    # Define the integer parameters for annotation variables
    ann_lst = []
    for i in range(len(seg_lst)):
        param_name = f"ann_var_{i}"
        ann_var = trial.suggest_categorical(
            param_name,
            ann_opt_lst,
        )
        ann_lst.append(ann_var)

    prompt = create_prompt(seg_lst, ann_lst)

    print("=" * 17)
    print(json.dumps(prompt, indent=2, ensure_ascii=False))
    print("-" * 17)

    resp_lst = []
    correct_resp_lst = []
    for _i in range(num_repeat):
        print(f"running task id: {task_id} at {_i} / {num_repeat}")
        resp = llm.generate(
            messages=prompt,
            response_format=answer_schema,
        )
        resp = json.loads(resp)
        final_answer = resp.get("final_answer")

        # Convert final_answer to the same dtype as answer
        try:
            final_answer_converted = answer_dtype(final_answer)
        except Exception:
            final_answer_converted = final_answer

        resp_lst.append(final_answer_converted)

        print(
            f"  {final_answer_converted} - actual {answer}, {final_answer_converted==answer}"
        )
        if final_answer_converted == answer:
            correct_resp_lst.append(1.0)
        else:
            correct_resp_lst.append(0.0)

    # Calculate correctness probability
    correct_prob = sum(correct_resp_lst) / len(correct_resp_lst)

    print(f"Correctness probability: {correct_prob}")
    print("=" * 17)

    return correct_prob


def psao_optuna_optimisation(
    llm,
    prompt,
    answer,
    answer_schema,
    answer_dtype,
    task_id,
    optuna_study_name,
    optuna_db_name,
    seg_lst,
    ann_lst,
):
    """
    Optimise the prompt using PSAO and Optuna.
    """

    # Create or reuse study object for knowledge transfer
    study = optuna.create_study(
        study_name=f"{optuna_study_name}_{task_id}",
        direction="maximize",
        storage=f"sqlite:///{optuna_db_name}",
        load_if_exists=True,
    )

    study.optimize(
        lambda trial: objective_func(
            trial,
            llm,
            prompt,
            answer,
            answer_schema,
            answer_dtype,
            seg_lst,
            annotation_options,
            task_id,
        ),
        n_trials=15,
    )

    best_params = study.best_trial.params

    # TODO
    print(f"Best parameters: {best_params}")

    ann_lst_best = []
    for i in range(len(seg_lst)):
        param_name = f"ann_var_{i}"
        ann_lst_best.append(best_params[param_name])
    prompt_best = create_prompt(seg_lst, ann_lst_best)

    score_base = study.best_trial.value
    return prompt_best, score_base
