import os
import base64
import argparse
import time
import multiprocessing as mp
from functools import partial
from typing import Dict, List, Any, Tuple

import pandas as pd
from openai import OpenAI
from tqdm import tqdm

from app.utils.dataset_manager import DatasetManager


def encode_image(image_path: str) -> str:
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode("utf-8")


def evaluate_with_model(
    client: OpenAI, model: str, entry: Dict[str, Any], k_passes: int = 1
) -> List[Tuple[bool, str, float]]:
    image_path = entry["ImagePath"]
    prompt = entry["Prompt"]
    correct_answer = entry["Correct"]

    base64_image = encode_image(image_path)
    results = []

    for _ in range(k_passes):
        try:
            start_time = time.time()
            completion = client.chat.completions.create(
                model=model,
                messages=[
                    {
                        "role": "system",
                        "content": "You are an AI assistant tasked with solving spatial reasoning tasks. Analyze the image and answer the question. Return only the answer as plain text with no additional text or formatting.",
                    },
                    {
                        "role": "user",
                        "content": [
                            {
                                "type": "text",
                                "text": f"Here's a spatial reasoning task. {prompt}\n\nPlease answer the question based on the image. Return only the answer as plain text with no additional text or formatting.",
                            },
                            {
                                "type": "image_url",
                                "image_url": {
                                    "url": f"data:image/png;base64,{base64_image}"
                                },
                            },
                        ],
                    },
                ],
            )

            model_response = completion.choices[0].message.content

            end_time = time.time()
            response_time = end_time - start_time

            is_correct = str(model_response.strip()) == str(correct_answer)
            results.append((is_correct, model_response.strip(), response_time))

        except Exception as e:
            end_time = time.time()
            response_time = end_time - start_time if "start_time" in locals() else 0.0
            results.append((False, f"ERROR: {str(e)}", response_time))

    return results


def save_results_to_csv(results: List[Dict], output_path: str, mode: str = "w"):
    df = pd.DataFrame(results)

    if mode == "w" or not os.path.exists(output_path):
        df.to_csv(output_path, index=False, mode="w")
    else:
        df.to_csv(output_path, index=False, mode="a", header=False)


def evaluate_model_on_instances(
    position: int,
    model: str,
    instances: List[Dict],
    dataset_manager: DatasetManager,
    provider_url: str,
    k_passes: int,
    output_path: str,
    save_every: int,
) -> List[Dict]:
    client = OpenAI(base_url=provider_url)
    results = []
    instance_counter = 0

    lock_file = f"{output_path}.lock"

    for instance in tqdm(
        instances,
        desc=f"Evaluating {model}",
        position=position,
        lock_args=None,
    ):
        instance_id = instance["ID"]
        scene_name = instance["SceneName"]
        entry = dataset_manager.get_entry(instance_id)

        evaluation_results = evaluate_with_model(client, model, entry, k_passes)

        passes = []
        answers = []
        response_times = []
        for is_correct, model_answer, response_time in evaluation_results:
            passes.append(int(is_correct))
            answers.append(model_answer)
            response_times.append(response_time)

        result = {
            "InstanceId": instance_id,
            "SceneName": scene_name,
            "ModelName": model,
            "CorrectAnswer": entry["Correct"],
            "K_Passes": passes,
            "K_Answers": answers,
            "K_ResponseTimes": response_times,
            "AvgResponseTime": (
                sum(response_times) / len(response_times) if response_times else 0.0
            ),
        }

        results.append(result)
        instance_counter += 1

        if save_every > 0 and instance_counter % save_every == 0:
            max_attempts = 10
            attempt = 0
            while os.path.exists(lock_file) and attempt < max_attempts:
                sleep_time = 0.1 * (2**attempt)
                time.sleep(sleep_time)
                attempt += 1

            if attempt >= max_attempts:
                continue

            try:
                with open(lock_file, "w") as f:
                    f.write(f"{model}")

                batch_to_save = results[-save_every:]
                save_results_to_csv(
                    batch_to_save,
                    output_path,
                    "a" if os.path.exists(output_path) else "w",
                )
            except Exception as e:
                pass
            finally:
                if os.path.exists(lock_file):
                    os.remove(lock_file)

    return results


def evaluate_dataset(
    dataset_dir: str,
    models: List[str],
    provider_url: str,
    output_path: str,
    save_every: int = 10,
    k_passes: int = 1,
):
    dataset_manager = DatasetManager(dataset_dir)
    df = dataset_manager.get_dataset()

    evaluated_instances = set()

    if os.path.exists(output_path):
        try:
            existing_df = pd.read_csv(output_path)
            for _, row in existing_df.iterrows():
                instance_key = f"{row['InstanceId']}_{row['ModelName']}"
                evaluated_instances.add(instance_key)
        except Exception:
            pass

    model_instances = {}
    for model in models:
        model_instances[model] = []

    for _, row in df.iterrows():
        instance_data = row.to_dict()
        instance_id = instance_data["ID"]

        for model in models:
            instance_key = f"{instance_id}_{model}"
            if instance_key not in evaluated_instances:
                model_instances[model].append(instance_data)

    models_to_evaluate = [model for model in models if model_instances[model]]

    if not models_to_evaluate:
        return []

    all_model_instances = [
        (model, model_instances[model]) for model in models_to_evaluate
    ]

    evaluate_func = partial(
        evaluate_model_on_instances,
        dataset_manager=dataset_manager,
        provider_url=provider_url,
        k_passes=k_passes,
        output_path=output_path,
        save_every=save_every,
    )

    tqdm.set_lock(mp.RLock())
    pool = mp.Pool(
        processes=min(len(models_to_evaluate), mp.cpu_count()),
        initializer=tqdm.set_lock,
        initargs=(tqdm.get_lock(),),
    )

    try:
        results_list = pool.starmap(
            evaluate_func,
            [
                (position, model, instances)
                for position, (model, instances) in enumerate(all_model_instances)
            ],
        )
    finally:
        pool.close()
        pool.join()

    all_results = []
    for results in results_list:
        all_results.extend(results)

    if all_results:
        save_results_to_csv(
            all_results, output_path, "a" if os.path.exists(output_path) else "w"
        )

    return all_results


def main():
    parser = argparse.ArgumentParser(description="Evaluate models on captcha dataset")
    parser.add_argument(
        "--dataset-dir", type=str, default="dataset", help="Path to dataset directory"
    )
    parser.add_argument(
        "--models",
        type=str,
        nargs="+",
        default=[
            "google/gemini-2.5-pro",
            "openai/chatgpt-4o-latest",
            "google/gemini-2.5-flash",
            "openai/o4-mini",
            "anthropic/claude-sonnet-4",
            "anthropic/claude-opus-4",
            "meta-llama/llama-4-maverick",
            "mistralai/mistral-medium-3",
            "qwen/qwen2.5-vl-72b-instruct",
            "microsoft/phi-4-multimodal-instruct",
        ],
        help="Models to evaluate",
    )
    parser.add_argument(
        "--output",
        type=str,
        default="evaluation_results.csv",
        help="Path to save evaluation results CSV",
    )
    parser.add_argument(
        "--provider-url",
        type=str,
        default="https://openrouter.ai/api/v1",
        help="API provider URL",
    )
    parser.add_argument(
        "--k-passes",
        type=int,
        default=3,
        help="Number of times to evaluate each instance",
    )
    parser.add_argument(
        "--save-every",
        type=int,
        default=2,
        help="Save results to CSV every N instances",
    )

    args = parser.parse_args()

    dataset_dir = os.path.abspath(args.dataset_dir)
    output_path = os.path.abspath(args.output)

    evaluate_dataset(
        dataset_dir=dataset_dir,
        models=args.models,
        provider_url=args.provider_url,
        output_path=output_path,
        k_passes=args.k_passes,
        save_every=args.save_every,
    )


if __name__ == "__main__":
    main()
