import os
from argparse import ArgumentParser
from typing import List

from dotenv import load_dotenv

from src.bongard_problems.data import get_data_paths
from src.evaluation.model import (
    evalute_answers_using_model,
    get_resolved_problems,
)
from src.llm_messenger.classes.llm_messenger import LLMMessenger
from src.llm_messenger.messengers.claude_messenger import ClaudeMessenger
from src.llm_messenger.messengers.google_messenger import GoogleMessenger
from src.llm_messenger.messengers.gpt_messenger import GPTMessenger
from src.prompting_techniques.prompt import EvaluationPrompts

load_dotenv()


def make_parser() -> ArgumentParser:
    parser = ArgumentParser()
    parser.add_argument(
        "--dataset",
        type=str,
        required=True,
        choices=["synthetic", "hoi", "openworld", "rwr"],
    )
    parser.add_argument(
        "--method",
        type=str,
        required=True,
        choices=[
            # "binary-classification_correct-answers",
            # "binary-classification_incorrect-answers",
            # "binary-classification_image-to-side",
            "generation_prompting-direct",
            "generation_prompting-descriptive",
            "generation_prompting-descriptive-direct",
            "generation_prompting-iterative",
            "generation_prompting-contrastive",
            "generation_prompting-contrastive-direct",
            "generation_prompting-iterative-contrastive",
        ],
    )
    parser.add_argument("--model", type=str, required=True)
    parser.add_argument("--data-dir", type=str, default="data")
    parser.add_argument("--problem-ids-range-start", type=int, default=1)
    parser.add_argument("--problem-ids-range-end", type=int, default=101)
    parser.add_argument(
        "--authors",
        type=str,
        nargs="+",
        default=[
            "gpt-4o",
            "gpt-4-turbo",
            "gemini-1.5-pro",
            "claude-3-5-sonnet-20240620",
        ],
        choices=[
            "gpt-4o",
            "gpt-4-turbo",
            "gemini-1.5-pro",
            "claude-3-5-sonnet-20240620",
        ],
    )
    parser.add_argument("--reevaluate", action="store_true")
    parser.add_argument("--with-image", action="store_true")
    parser.add_argument(
        "--evaluation-prompt",
        type=str,
        default="STRICT_LOGIC_PROMPT",
        choices=["STRICT_LOGIC_PROMPT"],
    )
    return parser


def make_messengers(experiment_dir: str, authors: List[str]) -> List[LLMMessenger]:
    def make_messenger(author: str) -> LLMMessenger:
        if author == "gpt-4o":
            return GPTMessenger(
                model_name="gpt-4o",
                api_key=os.environ["OPENAI_API_KEY"],
                log_directory=experiment_dir,
            )
        elif author == "gpt-4-turbo":
            return GPTMessenger(
                api_key=os.environ["OPENAI_API_KEY"],
                model_name="gpt-4-turbo",
                log_directory=experiment_dir,
            )
        elif author == "gemini-1.5-pro":
            return GoogleMessenger(
                api_key=os.environ["GOOGLE_API_KEY"],
                model_name="gemini-1.5-pro",
                log_directory=experiment_dir,
            )
        elif author == "claude-3-5-sonnet-20240620":
            return ClaudeMessenger(
                api_key=os.environ["ANTHROPIC_API_KEY"],
                model_name="claude-3-5-sonnet-20240620",
                log_directory=experiment_dir,
            )
        else:
            raise ValueError(f"Unsupported author: {author}")

    return [make_messenger(author) for author in authors]


def main(args):
    splitted_data_path, labels_file = get_data_paths(args.data_dir, args.dataset)
    experiment_dir = os.path.join(
        args.data_dir,
        "processed/bongard/experiments",
        f"{args.dataset}_{args.method}",
    )
    messengers = make_messengers(experiment_dir, args.authors)
    answers_file = os.path.join(experiment_dir, f"{args.model}_answers.json")

    for model in messengers:
        author = f"{model.get_name()}_{args.evaluation_prompt}"
        evalute_answers_using_model(
            answers_file=answers_file,
            labels_file=labels_file,
            model=model,
            splitted_data_path=splitted_data_path,
            author=author,
            with_image=args.with_image,
            reevaluate=args.reevaluate,
            prompt=EvaluationPrompts.from_name(args.evaluation_prompt),
        )
        print(author, get_resolved_problems(answers_file, author))


if __name__ == "__main__":
    parser = make_parser()
    args = parser.parse_args()
    print(args)
    load_dotenv()
    main(args)


"""
for dataset in "synthetic" "hoi" "openworld" "rwr"; do
  for method in "generation_prompting-direct" "generation_prompting-descriptive" "generation_prompting-descriptive-direct" "generation_prompting-iterative" "generation_prompting-contrastive" "generation_prompting-contrastive-direct" "generation_prompting-iterative-contrastive"; do
    echo "Method: ${method}"
    for model in "gpt-4o" "gpt-4-turbo" "gemini-1.5-pro" "claude-3-5-sonnet-20240620" "OpenGVLab/InternVL2-8B" "llava-hf/llava-v1.6-mistral-7b-hf" "microsoft/Phi-3.5-vision-instruct" "mistralai/Pixtral-12B-2409"; do
      PYTHONPATH=. python src/experiments/evaluate_answers.py --dataset "${dataset}" --method "${method}" --model "${model}" --authors "gpt-4o" "gpt-4-turbo" "gemini-1.5-pro" "claude-3-5-sonnet-20240620"
    done
  done
  git add data
  git commit -m "Evaluate answers for the ${dataset} dataset (automated commit)"
  git push
done
"""
