import os
from argparse import ArgumentParser

from src.bongard_problems.classes import BongardResolveAttempt
from src.bongard_problems.data import print_classification_summary
from src.evaluation.model import evalute_answers_using_model, count_correct_answers
from src.evaluation.model import (
    get_perfect_resolve_attempt,
    get_wrong_resolve_attempt,
    check_solution_correctness,
)
from src.llm_messenger.classes.llm_messenger import LLMMessenger
from src.prompting_techniques.contrastive import (
    get_comparisons,
    resolve_bongard_with_contrastive_prompting,
)
from src.prompting_techniques.descriptive import (
    get_descriptions,
    resolve_bongard_with_descriptive_prompting,
)
from src.prompting_techniques.direct import resolve_bongard_with_direct_prompting
from src.prompting_techniques.iterative import (
    get_iterative_descriptions,
    resolve_bongard_with_iterative_prompting,
)
from src.prompting_techniques.iterative_contrastive import (
    resolve_bongard_with_iterative_contrastive_prompting,
)
from src.prompting_techniques.prompt import (
    CommonPrompts,
    DescriptivePrompts,
    DescriptiveDirectPrompts,
    DescriptiveIterativePrompts,
    ContrastivePrompts,
    ContrastiveIterativePrompts,
    IsLabelCorrectPrompts,
    ImageToSidePrompts,
    ContrastiveDirectPrompts,
)
from src.prompting_techniques.two_at_once_classification import (
    resolve_bongard_with_last_image_to_side_classification_two_sides_at_once,
)


def make_parser() -> ArgumentParser:
    parser = ArgumentParser()
    parser.add_argument("--model", type=str, required=True)
    parser.add_argument(
        "--dataset",
        type=str,
        required=True,
        choices=["synthetic", "hoi", "openworld", "rwr"],
    )
    parser.add_argument("--data-dir", type=str, default="/app/data")
    parser.add_argument("--problem-ids-range-start", type=int, default=1)
    parser.add_argument("--problem-ids-range-end", type=int, default=101)
    return parser


def run(
    model: LLMMessenger,
    dataset: str,
    data_dir: str,
    problem_ids_range_start: int,
    problem_ids_range_end: int,
    reevaluate: bool = False,
):
    problem_ids = list(range(problem_ids_range_start, problem_ids_range_end))
    if dataset == "synthetic":
        splitted_data_path = f"{data_dir}/raw/bongard_splitted"
        labels_file = f"{data_dir}/raw/labels.csv"
    elif dataset == "hoi":
        splitted_data_path = f"{data_dir}/raw/bongard_hoi_splitted_mix"
        labels_file = f"{data_dir}/raw/bongard_hoi_mix_labels.csv"
    elif dataset == "openworld":
        splitted_data_path = f"{data_dir}/raw/bongard_open_world_splitted"
        labels_file = f"{data_dir}/raw/bongard_open_world_labels.csv"
    elif dataset == "rwr":
        splitted_data_path = f"{data_dir}/raw/bongard_rwr"
        labels_file = f"{data_dir}/raw/labels.csv"
        rwr_problem_ids = {int(pid) for pid in os.listdir(splitted_data_path)}
        problem_ids = [pid for pid in problem_ids if pid in rwr_problem_ids]
    else:
        raise ValueError(f"Unsupported dataset: {dataset}")

    print("Binary classification: correct answers as labels")

    experiment_dir = f"{data_dir}/processed/bongard/experiments/{dataset}_binary-classification_joint-correct-answers"
    model.set_log_directory(experiment_dir)
    answers_file = experiment_dir + "/" + model.get_name() + "_answers.json"
    perfect_resolve_attempt = get_perfect_resolve_attempt(
        problem_ids=problem_ids,
        labels_file=labels_file,
        answers_file=answers_file,
        model_name=model.get_name(),
    )
    perfect_resolve_attempt.save(answers_file)
    check_solution_correctness(
        answers_file,
        splitted_data_path,
        model,
        description=CommonPrompts.PROBLEM_DESCRIPTION,
        question=IsLabelCorrectPrompts.QUESTION,
        reevaluate=reevaluate,
        use_joint_image=True,
    )
    resolve_attempt = BongardResolveAttempt.from_file(answers_file)
    correct_evaluations = [
        problem_id
        for problem_id, evaluation in resolve_attempt.get_self_evaluations().items()
        if "OK" in evaluation
    ]
    print(
        f"{model.get_name()} scored {len(correct_evaluations)} correct answers: {correct_evaluations}"
    )

    print("Binary classification: incorrect answers as labels")

    experiment_dir = f"{data_dir}/processed/bongard/experiments/{dataset}_binary-classification_joint-incorrect-answers"
    offset = 20
    circular_buffer_size = 100
    model.set_log_directory(experiment_dir)
    answers_file = experiment_dir + "/" + model.get_name() + "_answers.json"
    wrong_resolve_attempt = get_wrong_resolve_attempt(
        problem_ids=problem_ids,
        labels_file=labels_file,
        answers_file=answers_file,
        offset=offset,
        circular_buffer_size=circular_buffer_size,
        model_name=model.get_name(),
    )
    wrong_resolve_attempt.save(answers_file)
    check_solution_correctness(
        answers_file,
        splitted_data_path,
        model,
        description=CommonPrompts.PROBLEM_DESCRIPTION,
        question=IsLabelCorrectPrompts.QUESTION,
        reevaluate=reevaluate,
        use_joint_image=True,
    )
    resolve_attempt = BongardResolveAttempt.from_file(answers_file)
    correct_evaluations = [
        problem_id
        for problem_id, evaluation in resolve_attempt.get_self_evaluations().items()
        if "WRONG" in evaluation
    ]
    print(
        f"{model.get_name()} scored {len(correct_evaluations)} correct answers: {correct_evaluations}"
    )

    print("Binary classification: images to sides")

    experiment_dir = f"{data_dir}/processed/bongard/experiments/{dataset}_binary-classification_joint-image-to-side-shuffle"
    model.set_log_directory(experiment_dir)
    answers_file = experiment_dir + "/" + model.get_name() + "_answers.json"
    resolve_bongard_with_last_image_to_side_classification_two_sides_at_once(
        problem_ids=problem_ids,
        splitted_data_path=splitted_data_path,
        output_file=answers_file,
        model=model,
        problem_description=CommonPrompts.PROBLEM_DESCRIPTION,
        question=ImageToSidePrompts.QUESTION,
        do_hallucination_intervention=True,
        reevaluate=reevaluate,
        use_joint_side_image=True,
        do_shuffle_test_panels=True,
    )
    print_classification_summary(answers_file)

    print("Generation: Direct")

    experiment_dir = f"{data_dir}/processed/bongard/experiments/{dataset}_generation_prompting-direct"
    model.set_log_directory(experiment_dir)
    answers_file = experiment_dir + "/" + model.get_name() + "_answers.json"
    resolve_bongard_with_direct_prompting(
        problem_ids=problem_ids,
        splitted_data_path=splitted_data_path,
        output_file=answers_file,
        model=model,
        problem_description=CommonPrompts.PROBLEM_DESCRIPTION,
        question=CommonPrompts.QUESTION,
        reevaluate=reevaluate,
    )
    evalute_answers_using_model(
        answers_file=answers_file,
        labels_file=labels_file,
        model=model,
        reevaluate=reevaluate,
    )
    print(count_correct_answers(answers_file))

    print("Generation: Descriptive")

    experiment_dir = f"{data_dir}/processed/bongard/experiments/{dataset}_generation_prompting-descriptive"
    model.set_log_directory(experiment_dir)
    descriptions_file = experiment_dir + "/" + model.get_name() + "_descriptions.json"
    answers_file = experiment_dir + "/" + model.get_name() + "_answers.json"
    get_descriptions(
        problem_ids=problem_ids,
        splitted_data_path=splitted_data_path,
        model=model,
        output_file=descriptions_file,
        question=DescriptivePrompts.DESCRIBE_IMAGE_PROMPT,
        reevaluate=reevaluate,
    )
    resolve_bongard_with_descriptive_prompting(
        problem_ids=problem_ids,
        splitted_data_path=splitted_data_path,
        output_file=answers_file,
        model=model,
        descriptions_file=descriptions_file,
        problem_description=CommonPrompts.PROBLEM_DESCRIPTION,
        question=DescriptivePrompts.QUESTION,
        reevaluate=reevaluate,
    )
    evalute_answers_using_model(
        answers_file=answers_file,
        labels_file=labels_file,
        model=model,
        reevaluate=reevaluate,
    )
    print(count_correct_answers(answers_file))

    print("Generation: Descriptive-Direct")

    experiment_dir = f"{data_dir}/processed/bongard/experiments/{dataset}_generation_prompting-descriptive-direct"
    model.set_log_directory(experiment_dir)
    descriptions_file = (
        experiment_dir.replace("-direct", "")
        + "/"
        + model.get_name()
        + "_descriptions.json"
    )
    answers_file = experiment_dir + "/" + model.get_name() + "_answers.json"
    resolve_bongard_with_descriptive_prompting(
        problem_ids=problem_ids,
        splitted_data_path=splitted_data_path,
        output_file=answers_file,
        model=model,
        descriptions_file=descriptions_file,
        problem_description=CommonPrompts.PROBLEM_DESCRIPTION,
        question=DescriptiveDirectPrompts.QUESTION,
        with_image=True,
        reevaluate=reevaluate,
    )
    evalute_answers_using_model(
        answers_file=answers_file,
        labels_file=labels_file,
        model=model,
        reevaluate=reevaluate,
    )
    print(count_correct_answers(answers_file))

    print("Generation: Descriptive-Iterative")

    experiment_dir = f"{data_dir}/processed/bongard/experiments/{dataset}_generation_prompting-iterative"
    model.set_log_directory(experiment_dir)
    descriptions_file = experiment_dir + "/" + model.get_name() + "_descriptions.json"
    answers_file = experiment_dir + "/" + model.get_name() + "_answers.json"
    get_iterative_descriptions(
        problem_ids=problem_ids,
        splitted_data_path=splitted_data_path,
        model=model,
        output_file=descriptions_file,
        task_description=DescriptiveIterativePrompts.DESCRIBE_ITERATIVELY_SIDE_IMAGES_PROMPT,
        final_answer_prompt=DescriptiveIterativePrompts.FINAL_ANSWER_PROMPT,
        reevaluate=reevaluate,
        keep_image_history=False,
    )
    resolve_bongard_with_iterative_prompting(
        problem_ids=problem_ids,
        output_file=answers_file,
        model=model,
        descriptions_file=descriptions_file,
        problem_description=CommonPrompts.PROBLEM_DESCRIPTION,
        question=DescriptiveIterativePrompts.QUESTION,
        reevaluate=reevaluate,
    )
    evalute_answers_using_model(
        answers_file=answers_file,
        labels_file=labels_file,
        model=model,
        reevaluate=reevaluate,
    )
    print(count_correct_answers(answers_file))

    print("Generation: Contrastive")

    experiment_dir = f"{data_dir}/processed/bongard/experiments/{dataset}_generation_prompting-contrastive"
    model.set_log_directory(experiment_dir)
    comparisons_file = experiment_dir + "/" + model.get_name() + "_comparisons.json"
    answers_file = experiment_dir + "/" + model.get_name() + "_answers.json"
    get_comparisons(
        problem_ids=problem_ids,
        splitted_data_path=splitted_data_path,
        model=model,
        output_file=comparisons_file,
        problem_description=CommonPrompts.PROBLEM_DESCRIPTION,
        question=ContrastivePrompts.COMPARE_IMAGES_PROMPT,
        reevaluate=reevaluate,
    )
    resolve_bongard_with_contrastive_prompting(
        problem_ids=problem_ids,
        splitted_data_path=splitted_data_path,
        output_file=answers_file,
        model=model,
        comparisons_file=comparisons_file,
        problem_description=CommonPrompts.PROBLEM_DESCRIPTION,
        question=ContrastivePrompts.QUESTION,
        reevaluate=reevaluate,
    )
    evalute_answers_using_model(
        answers_file=answers_file,
        labels_file=labels_file,
        model=model,
        reevaluate=reevaluate,
    )
    print(count_correct_answers(answers_file))

    print("Generation: Contrastive-Direct")

    experiment_dir = f"{data_dir}/processed/bongard/experiments/{dataset}_generation_prompting-contrastive-direct"
    model.set_log_directory(experiment_dir)
    comparisons_file = (
        experiment_dir.replace("-direct", "")
        + "/"
        + model.get_name()
        + "_comparisons.json"
    )
    answers_file = experiment_dir + "/" + model.get_name() + "_answers.json"
    resolve_bongard_with_contrastive_prompting(
        problem_ids=problem_ids,
        splitted_data_path=splitted_data_path,
        output_file=answers_file,
        model=model,
        comparisons_file=comparisons_file,
        problem_description=CommonPrompts.PROBLEM_DESCRIPTION,
        question=ContrastiveDirectPrompts.QUESTION,
        with_image=True,
        reevaluate=reevaluate,
    )
    evalute_answers_using_model(
        answers_file=answers_file,
        labels_file=labels_file,
        model=model,
        reevaluate=reevaluate,
    )
    print(count_correct_answers(answers_file))

    print("Generation: Contrastive-Iterative")

    experiment_dir = f"{data_dir}/processed/bongard/experiments/{dataset}_generation_prompting-iterative-contrastive"
    model.set_log_directory(experiment_dir)
    answers_file = experiment_dir + "/" + model.get_name() + "_answers.json"
    resolve_bongard_with_iterative_contrastive_prompting(
        problem_ids=problem_ids,
        splitted_data_path=splitted_data_path,
        output_file=answers_file,
        model=model,
        problem_description=CommonPrompts.PROBLEM_DESCRIPTION,
        question=ContrastiveIterativePrompts.QUESTION,
        final_question=ContrastiveIterativePrompts.FINAL_QUESTION,
        reevaluate=reevaluate,
    )
    evalute_answers_using_model(
        answers_file=answers_file,
        labels_file=labels_file,
        model=model,
        reevaluate=reevaluate,
    )
    print(count_correct_answers(answers_file))
