from argparse import ArgumentParser

from dotenv import load_dotenv

from src.bongard_problems.classes import BongardResolveAttempt
from src.bongard_problems.data import print_classification_summary
from src.evaluation.model import evalute_answers_using_model, count_correct_answers
from src.evaluation.model import (
    get_perfect_resolve_attempt,
    get_wrong_resolve_attempt,
    check_solution_correctness,
)
from src.llm_messenger.messengers.huggingface_messenger import HuggingfaceMessenger
from src.prompting_techniques.contrastive import (
    get_comparisons,
    resolve_bongard_with_contrastive_prompting,
)
from src.prompting_techniques.descriptive import (
    get_descriptions,
    resolve_bongard_with_descriptive_prompting,
)
from src.prompting_techniques.direct import resolve_bongard_with_direct_prompting
from src.prompting_techniques.iterative import (
    get_iterative_descriptions,
    resolve_bongard_with_iterative_prompting,
)
from src.prompting_techniques.iterative_contrastive import (
    resolve_bongard_with_iterative_contrastive_prompting,
)
from src.prompting_techniques.prompt import (
    CommonPrompts,
    DescriptivePrompts,
    DescriptiveDirectPrompts,
    DescriptiveIterativePrompts,
    ContrastivePrompts,
    ContrastiveIterativePrompts,
    IsLabelCorrectPrompts,
    ImageToSidePrompts,
    ContrastiveDirectPrompts,
)
from src.prompting_techniques.two_at_once_classification import (
    resolve_bongard_with_last_image_to_side_classification_two_sides_at_once,
)


def make_parser() -> ArgumentParser:
    parser = ArgumentParser()
    parser.add_argument("--model", type=str, required=True)
    parser.add_argument(
        "--dataset", type=str, required=True, choices=["synthetic", "hoi", "openworld"]
    )
    parser.add_argument("--data-dir", type=str, default="/app/data")
    parser.add_argument("--problem-ids-range-start", type=int, default=1)
    parser.add_argument("--problem-ids-range-end", type=int, default=101)
    return parser


def main(args):
    if args.dataset == "synthetic":
        splitted_data_path = f"{args.data_dir}/raw/bongard_splitted"
        labels_file = f"{args.data_dir}/raw/labels.csv"
    elif args.dataset == "hoi":
        splitted_data_path = f"{args.data_dir}/raw/bongard_hoi_splitted_mix"
        labels_file = f"{args.data_dir}/raw/bongard_hoi_mix_labels.csv"
    elif args.dataset == "openworld":
        splitted_data_path = f"{args.data_dir}/raw/bongard_open_world_splitted"
        labels_file = f"{args.data_dir}/raw/bongard_open_world_labels.csv"
    else:
        raise ValueError(f"Unsupported dataset: {args.dataset}")

    print("Binary classification: correct answers as labels")

    experiment_dir = f"{args.data_dir}/processed/bongard/experiments/{args.dataset}_binary-classification_correct-answers"
    model = HuggingfaceMessenger(model_name=args.model, log_directory=experiment_dir)
    answers_file = experiment_dir + "/" + model.get_name() + "_answers.json"
    perfect_resolve_attempt = get_perfect_resolve_attempt(
        problem_ids=list(
            range(args.problem_ids_range_start, args.problem_ids_range_end)
        ),
        labels_file=labels_file,
        answers_file=answers_file,
        model_name=model.get_name(),
    )
    perfect_resolve_attempt.save(answers_file)
    check_solution_correctness(
        answers_file,
        splitted_data_path,
        model,
        description=CommonPrompts.PROBLEM_DESCRIPTION,
        question=IsLabelCorrectPrompts.QUESTION,
    )
    resolve_attempt = BongardResolveAttempt.from_file(answers_file)
    correct_evaluations = [
        solution.problem_id
        for solution in resolve_attempt.get_solutions().values()
        if "OK" in solution.evaluation
    ]
    print(
        f"{model.get_name()} scored {len(correct_evaluations)} correct answers: {correct_evaluations}"
    )

    print("Binary classification: incorrect answers as labels")

    experiment_dir = f"{args.data_dir}/processed/bongard/experiments/{args.dataset}_binary-classification_incorrect-answers"
    offset = 20
    circular_buffer_size = 100
    model = HuggingfaceMessenger(model_name=args.model, log_directory=experiment_dir)
    answers_file = experiment_dir + "/" + model.get_name() + "_answers.json"
    wrong_resolve_attempt = get_wrong_resolve_attempt(
        problem_ids=list(
            range(args.problem_ids_range_start, args.problem_ids_range_end)
        ),
        labels_file=labels_file,
        answers_file=answers_file,
        offset=offset,
        circular_buffer_size=circular_buffer_size,
        model_name=model.get_name(),
    )
    wrong_resolve_attempt.save(answers_file)
    check_solution_correctness(
        answers_file,
        splitted_data_path,
        model,
        description=CommonPrompts.PROBLEM_DESCRIPTION,
        question=IsLabelCorrectPrompts.QUESTION,
    )
    resolve_attempt = BongardResolveAttempt.from_file(answers_file)
    correct_evaluations = [
        solution.problem_id
        for solution in resolve_attempt.get_solutions().values()
        if "WRONG" in solution.evaluation
    ]
    print(
        f"{model.get_name()} scored {len(correct_evaluations)} correct answers: {correct_evaluations}"
    )

    print("Binary classification: images to sides")

    experiment_dir = f"{args.data_dir}/processed/bongard/experiments/{args.dataset}_binary-classification_image-to-side"
    model = HuggingfaceMessenger(model_name=args.model, log_directory=experiment_dir)
    answers_file = experiment_dir + "/" + model.get_name() + "_answers.json"
    resolve_bongard_with_last_image_to_side_classification_two_sides_at_once(
        problem_ids=list(
            range(args.problem_ids_range_start, args.problem_ids_range_end)
        ),
        splitted_data_path=splitted_data_path,
        output_file=answers_file,
        model=model,
        problem_description=CommonPrompts.PROBLEM_DESCRIPTION,
        question=ImageToSidePrompts.QUESTION,
    )
    print_classification_summary(answers_file)

    print("Generation: Direct")

    experiment_dir = f"{args.data_dir}/processed/bongard/experiments/{args.dataset}_generation_prompting-direct"
    model = HuggingfaceMessenger(model_name=args.model, log_directory=experiment_dir)
    answers_file = experiment_dir + "/" + model.get_name() + "_answers.json"
    resolve_bongard_with_direct_prompting(
        problem_ids=list(
            range(args.problem_ids_range_start, args.problem_ids_range_end)
        ),
        splitted_data_path=splitted_data_path,
        output_file=answers_file,
        model=model,
        problem_description=CommonPrompts.PROBLEM_DESCRIPTION,
        question=CommonPrompts.QUESTION,
    )
    evalute_answers_using_model(
        answers_file=answers_file,
        labels_file=labels_file,
        model=model,
    )
    print(count_correct_answers(answers_file))

    print("Generation: Descriptive")

    experiment_dir = f"{args.data_dir}/processed/bongard/experiments/{args.dataset}_generation_prompting-descriptive"
    model = HuggingfaceMessenger(model_name=args.model, log_directory=experiment_dir)
    descriptions_file = experiment_dir + "/" + model.get_name() + "_descriptions.json"
    answers_file = experiment_dir + "/" + model.get_name() + "_answers.json"
    get_descriptions(
        problem_ids=list(
            range(args.problem_ids_range_start, args.problem_ids_range_end)
        ),
        splitted_data_path=splitted_data_path,
        model=model,
        output_file=descriptions_file,
        question=DescriptivePrompts.DESCRIBE_IMAGE_PROMPT,
    )
    resolve_bongard_with_descriptive_prompting(
        problem_ids=list(
            range(args.problem_ids_range_start, args.problem_ids_range_end)
        ),
        splitted_data_path=splitted_data_path,
        output_file=answers_file,
        model=model,
        descriptions_file=descriptions_file,
        problem_description=CommonPrompts.PROBLEM_DESCRIPTION,
        question=DescriptivePrompts.QUESTION,
        reevaluate=True,
    )
    evalute_answers_using_model(
        answers_file=answers_file,
        labels_file=labels_file,
        model=model,
        reevaluate=True,
    )
    print(count_correct_answers(answers_file))

    print("Generation: Descriptive-Direct")

    experiment_dir = f"{args.data_dir}/processed/bongard/experiments/{args.dataset}_generation_prompting-descriptive-direct"
    model = HuggingfaceMessenger(model_name=args.model, log_directory=experiment_dir)
    descriptions_file = (
        experiment_dir.replace("-direct", "")
        + "/"
        + model.get_name()
        + "_descriptions.json"
    )
    answers_file = experiment_dir + "/" + model.get_name() + "_answers.json"
    resolve_bongard_with_descriptive_prompting(
        problem_ids=list(
            range(args.problem_ids_range_start, args.problem_ids_range_end)
        ),
        splitted_data_path=splitted_data_path,
        output_file=answers_file,
        model=model,
        descriptions_file=descriptions_file,
        problem_description=CommonPrompts.PROBLEM_DESCRIPTION,
        question=DescriptiveDirectPrompts.QUESTION,
        with_image=True,
        reevaluate=True,
    )
    evalute_answers_using_model(
        answers_file=answers_file,
        labels_file=labels_file,
        model=model,
        reevaluate=True,
    )
    print(count_correct_answers(answers_file))

    print("Generation: Descriptive-Iterative")

    experiment_dir = f"{args.data_dir}/processed/bongard/experiments/{args.dataset}_generation_prompting-iterative"
    model = HuggingfaceMessenger(model_name=args.model, log_directory=experiment_dir)
    descriptions_file = experiment_dir + "/" + model.get_name() + "_descriptions.json"
    answers_file = experiment_dir + "/" + model.get_name() + "_answers.json"
    get_iterative_descriptions(
        problem_ids=list(
            range(args.problem_ids_range_start, args.problem_ids_range_end)
        ),
        splitted_data_path=splitted_data_path,
        model=model,
        output_file=descriptions_file,
        task_description=DescriptiveIterativePrompts.DESCRIBE_ITERATIVELY_SIDE_IMAGES_PROMPT,
        final_answer_prompt=DescriptiveIterativePrompts.FINAL_ANSWER_PROMPT,
    )
    resolve_bongard_with_iterative_prompting(
        problem_ids=list(
            range(args.problem_ids_range_start, args.problem_ids_range_end)
        ),
        output_file=answers_file,
        model=model,
        descriptions_file=descriptions_file,
        problem_description=CommonPrompts.PROBLEM_DESCRIPTION,
        question=DescriptiveIterativePrompts.QUESTION,
        reevaluate=True,
    )
    evalute_answers_using_model(
        answers_file=answers_file,
        labels_file=labels_file,
        model=model,
        reevaluate=True,
    )
    print(count_correct_answers(answers_file))

    print("Generation: Contrastive")

    experiment_dir = f"{args.data_dir}/processed/bongard/experiments/{args.dataset}_generation_prompting-contrastive"
    model = HuggingfaceMessenger(model_name=args.model, log_directory=experiment_dir)
    comparisons_file = experiment_dir + "/" + model.get_name() + "_comparisons.json"
    answers_file = experiment_dir + "/" + model.get_name() + "_answers.json"
    get_comparisons(
        problem_ids=list(
            range(args.problem_ids_range_start, args.problem_ids_range_end)
        ),
        splitted_data_path=splitted_data_path,
        model=model,
        output_file=comparisons_file,
        problem_description=CommonPrompts.PROBLEM_DESCRIPTION,
        question=ContrastivePrompts.COMPARE_IMAGES_PROMPT,
    )
    resolve_bongard_with_contrastive_prompting(
        problem_ids=list(
            range(args.problem_ids_range_start, args.problem_ids_range_end)
        ),
        splitted_data_path=splitted_data_path,
        output_file=answers_file,
        model=model,
        comparisons_file=comparisons_file,
        problem_description=CommonPrompts.PROBLEM_DESCRIPTION,
        question=ContrastivePrompts.QUESTION,
    )
    evalute_answers_using_model(
        answers_file=answers_file,
        labels_file=labels_file,
        model=model,
    )
    print(count_correct_answers(answers_file))

    print("Generation: Contrastive-Direct")

    experiment_dir = f"{args.data_dir}/processed/bongard/experiments/{args.dataset}_generation_prompting-contrastive-direct"
    model = HuggingfaceMessenger(model_name=args.model, log_directory=experiment_dir)
    comparisons_file = (
        experiment_dir.replace("-direct", "")
        + "/"
        + model.get_name()
        + "_comparisons.json"
    )
    answers_file = experiment_dir + "/" + model.get_name() + "_answers.json"
    resolve_bongard_with_contrastive_prompting(
        problem_ids=list(
            range(args.problem_ids_range_start, args.problem_ids_range_end)
        ),
        splitted_data_path=splitted_data_path,
        output_file=answers_file,
        model=model,
        comparisons_file=comparisons_file,
        problem_description=CommonPrompts.PROBLEM_DESCRIPTION,
        question=ContrastiveDirectPrompts.QUESTION,
        with_image=True,
        reevaluate=True,
    )
    evalute_answers_using_model(
        answers_file=answers_file,
        labels_file=labels_file,
        model=model,
        reevaluate=True,
    )
    print(count_correct_answers(answers_file))

    print("Generation: Contrastive-Iterative")

    experiment_dir = f"{args.data_dir}/processed/bongard/experiments/{args.dataset}_generation_prompting-iterative-contrastive"
    model = HuggingfaceMessenger(model_name=args.model, log_directory=experiment_dir)
    answers_file = experiment_dir + "/" + model.get_name() + "_answers.json"
    resolve_bongard_with_iterative_contrastive_prompting(
        problem_ids=list(
            range(args.problem_ids_range_start, args.problem_ids_range_end)
        ),
        splitted_data_path=splitted_data_path,
        output_file=answers_file,
        model=model,
        problem_description=CommonPrompts.PROBLEM_DESCRIPTION,
        question=ContrastiveIterativePrompts.QUESTION,
        final_question=ContrastiveIterativePrompts.FINAL_QUESTION,
    )
    evalute_answers_using_model(
        answers_file=answers_file,
        labels_file=labels_file,
        model=model,
    )
    print(count_correct_answers(answers_file))


if __name__ == "__main__":
    parser = make_parser()
    args = parser.parse_args()
    print(args)
    load_dotenv()
    main(args)
