import json
import os
import base64
import argparse
import time
import multiprocessing as mp
from functools import partial
from typing import Dict, List, Any, Tuple

import pandas as pd
from openai import OpenAI
from tqdm import tqdm

from app.utils.dataset_manager import DatasetManager


def encode_image(image_path: str) -> str:
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode("utf-8")


def evaluate_with_model(
    client: OpenAI, model: str, entry: Dict[str, Any], k_passes: int = 1
) -> List[Tuple[bool, str]]:
    image_path = entry["ImagePath"]
    prompt = entry["Prompt"]
    correct_answer = json.loads(entry["Correct"])

    base64_image = encode_image(image_path)
    results = []

    for _ in range(k_passes):
        try:
            completion = client.chat.completions.create(
                model=model,
                messages=[
                    {
                        "role": "system",
                        "content": "You are an AI assistant tasked with solving recaptcha tasks. Analyze the image and answer the question. Return only the coordinates as comma-separated numbers with no additional text or formatting.",
                    },
                    {
                        "role": "user",
                        "content": [
                            {
                                "type": "text",
                                "text": f"Here's a recaptcha task. {prompt}\n\nAnswer with coordinates where top-left is 0 and bottom-right is 8 in a 3x3 grid:\n0 1 2\n3 4 5\n6 7 8\n\nReturn only comma-separated coordinates (e.g., 1,4,7) with no additional text.",
                            },
                            {
                                "type": "image_url",
                                "image_url": {
                                    "url": f"data:image/png;base64,{base64_image}"
                                },
                            },
                        ],
                    },
                ],
            )

            model_response = completion.choices[0].message.content

            try:
                model_coords = [
                    int(x.strip()) for x in model_response.strip().split(",")
                ]
                correct_coords = correct_answer

                is_correct = set(model_coords) == set(correct_coords)
            except (ValueError, TypeError):
                is_correct = False
            results.append((is_correct, model_response.strip()))

        except Exception as e:
            results.append((False, f"ERROR: {str(e)}"))

    return results


def save_results_to_csv(results: List[Dict], output_path: str, mode: str = "w"):
    df = pd.DataFrame(results)

    if mode == "w" or not os.path.exists(output_path):
        df.to_csv(output_path, index=False, mode="w")
    else:
        df.to_csv(output_path, index=False, mode="a", header=False)


def evaluate_model_on_instances(
    position: int,
    model: str,
    instances: List[Dict],
    dataset_manager: DatasetManager,
    provider_url: str,
    k_passes: int,
    output_path: str,
    save_every: int,
) -> List[Dict]:
    client = OpenAI(base_url=provider_url)
    results = []
    instance_counter = 0

    lock_file = f"{output_path}.lock"

    for instance in tqdm(
        instances,
        desc=f"Evaluating {model}",
        position=position,
        lock_args=None,
    ):
        instance_id = instance["ID"]
        scene_name = instance["SceneName"]
        entry = dataset_manager.get_entry(instance_id)

        evaluation_results = evaluate_with_model(client, model, entry, k_passes)

        passes = []
        answers = []
        for is_correct, model_answer in evaluation_results:
            passes.append(int(is_correct))
            answers.append(model_answer)

        result = {
            "InstanceId": instance_id,
            "SceneName": scene_name,
            "ModelName": model,
            "CorrectAnswer": entry["Correct"],
            "K_Passes": passes,
            "K_Answers": answers,
        }

        results.append(result)
        instance_counter += 1

        if save_every > 0 and instance_counter % save_every == 0:
            max_attempts = 10
            attempt = 0
            while os.path.exists(lock_file) and attempt < max_attempts:
                sleep_time = 0.1 * (2**attempt)
                time.sleep(sleep_time)
                attempt += 1

            if attempt >= max_attempts:
                continue

            try:
                with open(lock_file, "w") as f:
                    f.write(f"{model}")

                batch_to_save = results[-save_every:]
                save_results_to_csv(
                    batch_to_save,
                    output_path,
                    "a" if os.path.exists(output_path) else "w",
                )
            except Exception as e:
                pass
            finally:
                if os.path.exists(lock_file):
                    os.remove(lock_file)

    return results


def evaluate_dataset(
    dataset_dir: str,
    models: List[str],
    provider_url: str,
    output_path: str,
    save_every: int = 10,
    k_passes: int = 1,
):
    dataset_manager = DatasetManager(dataset_dir)
    df = dataset_manager.get_dataset()

    evaluated_instances = set()

    if os.path.exists(output_path):
        try:
            existing_df = pd.read_csv(output_path)
            for _, row in existing_df.iterrows():
                instance_key = f"{row['InstanceId']}_{row['ModelName']}"
                evaluated_instances.add(instance_key)
        except Exception:
            pass

    model_instances = {}
    for model in models:
        model_instances[model] = []

    for _, row in df.iterrows():
        instance_data = row.to_dict()
        instance_id = instance_data["ID"]

        for model in models:
            instance_key = f"{instance_id}_{model}"
            if instance_key not in evaluated_instances:
                model_instances[model].append(instance_data)

    models_to_evaluate = [model for model in models if model_instances[model]]

    if not models_to_evaluate:
        return []

    all_model_instances = [
        (model, model_instances[model]) for model in models_to_evaluate
    ]

    evaluate_func = partial(
        evaluate_model_on_instances,
        dataset_manager=dataset_manager,
        provider_url=provider_url,
        k_passes=k_passes,
        output_path=output_path,
        save_every=save_every,
    )

    tqdm.set_lock(mp.RLock())
    pool = mp.Pool(
        processes=min(len(models_to_evaluate), mp.cpu_count()),
        initializer=tqdm.set_lock,
        initargs=(tqdm.get_lock(),),
    )

    try:
        results_list = pool.starmap(
            evaluate_func,
            [
                (position, model, instances)
                for position, (model, instances) in enumerate(all_model_instances)
            ],
        )
    finally:
        pool.close()
        pool.join()

    all_results = []
    for results in results_list:
        all_results.extend(results)

    if all_results:
        save_results_to_csv(
            all_results, output_path, "a" if os.path.exists(output_path) else "w"
        )

    return all_results


def main():
    parser = argparse.ArgumentParser(description="Evaluate models on captcha dataset")
    parser.add_argument(
        "--dataset-dir", type=str, default="dataset", help="Path to dataset directory"
    )
    parser.add_argument(
        "--models",
        type=str,
        nargs="+",
        default=[
            "google/gemini-2.5-pro",
            "openai/chatgpt-4o-latest",
            "google/gemini-2.5-flash",
            "openai/o4-mini",
            "anthropic/claude-sonnet-4",
            "anthropic/claude-opus-4",
            "meta-llama/llama-4-maverick",
            "mistralai/mistral-medium-3",
            "qwen/qwen2.5-vl-72b-instruct",
            "microsoft/phi-4-multimodal-instruct",
        ],
        help="Models to evaluate",
    )
    parser.add_argument(
        "--output",
        type=str,
        default="evaluation_results.csv",
        help="Path to save evaluation results CSV",
    )
    parser.add_argument(
        "--provider-url",
        type=str,
        default="https://openrouter.ai/api/v1",
        help="API provider URL",
    )
    parser.add_argument(
        "--k-passes",
        type=int,
        default=3,
        help="Number of times to evaluate each instance",
    )
    parser.add_argument(
        "--save-every",
        type=int,
        default=1,
        help="Save results to CSV every N instances",
    )

    args = parser.parse_args()

    dataset_dir = os.path.abspath(args.dataset_dir)
    output_path = os.path.abspath(args.output)

    evaluate_dataset(
        dataset_dir=dataset_dir,
        models=args.models,
        provider_url=args.provider_url,
        output_path=output_path,
        k_passes=args.k_passes,
        save_every=args.save_every,
    )


if __name__ == "__main__":
    main()
