import os
import base64
import argparse
import time
import re
import multiprocessing as mp
from functools import partial
from typing import Dict, List, Any, Tuple, Optional

import pandas as pd
from openai import OpenAI
from tqdm import tqdm

from app.utils.dataset_manager import DatasetManager


def encode_image(image_path: str) -> str:
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode("utf-8")


def extract_answer(response: str) -> Optional[str]:
    patterns = [
        r'ANSWER:\s*([A-Z0-9]+)',
        r'Answer:\s*([A-Z0-9]+)',
        r'answer:\s*([A-Z0-9]+)',
        r'\*\*([A-Z0-9]+)\*\*',
        r'<answer>([A-Z0-9]+)</answer>',
        r'Final answer:\s*([A-Z0-9]+)',
    ]

    for pattern in patterns:
        match = re.search(pattern, response)
        if match:
            return match.group(1)

    return None


def evaluate_with_model(
    client: OpenAI, model: str, entry: Dict[str, Any], output_dir: str
) -> Tuple[bool, str, float]:
    image_path = entry["ImagePath"]
    prompt = entry["Prompt"]
    correct_answer = entry["Correct"]
    instance_id = entry.get("ID", "unknown")

    base64_image = encode_image(image_path)

    try:
        start_time = time.time()
        completion = client.chat.completions.create(
            model=model,
            messages=[
                {
                    "role": "system",
                    "content": "You are an AI assistant tasked with solving spatial reasoning tasks. Think step by step and show your reasoning process.",
                },
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "text",
                            "text": f"Here's a spatial reasoning task. {prompt}\n\nPlease think through this step by step:\n1. Analyze what you see in the image\n2. Consider the spatial relationships\n3. Work through the logic\n4. Provide your final answer\n\nEnd your response with 'ANSWER: [your choice]' where [your choice] is one of the given options.",
                        },
                        {
                            "type": "image_url",
                            "image_url": {
                                "url": f"data:image/png;base64,{base64_image}"
                            },
                        },
                    ],
                },
            ],
        )

        model_response = completion.choices[0].message.content
        end_time = time.time()
        response_time = end_time - start_time

        extracted_answer = extract_answer(model_response)

        if extracted_answer is None:
            is_correct = False
        else:
            is_correct = str(extracted_answer) == str(correct_answer)

        if extracted_answer is not None and not is_correct:
            failed_dir = os.path.join(output_dir, "failed_reasonings")
            os.makedirs(failed_dir, exist_ok=True)

            model_name_clean = model.replace("/", "_")
            filename = f"{instance_id}_{model_name_clean}.txt"
            filepath = os.path.join(failed_dir, filename)

            with open(filepath, "w", encoding="utf-8") as f:
                f.write(f"Instance ID: {instance_id}\n")
                f.write(f"Model: {model}\n")
                f.write(f"Correct Answer: {correct_answer}\n")
                f.write(f"Model Answer: {extracted_answer}\n")
                f.write(f"Prompt: {prompt}\n")
                f.write("-" * 50 + "\n")
                f.write("Full Response:\n")
                f.write(model_response)

        return is_correct, extracted_answer, response_time

    except Exception as e:
        end_time = time.time()
        response_time = end_time - start_time if "start_time" in locals() else 0.0
        return False, f"ERROR: {str(e)}", response_time


def save_results_to_csv(results: List[Dict], output_path: str, mode: str = "w"):
    df = pd.DataFrame(results)

    if mode == "w" or not os.path.exists(output_path):
        df.to_csv(output_path, index=False, mode="w")
    else:
        df.to_csv(output_path, index=False, mode="a", header=False)


def evaluate_model_on_instances(
    position: int,
    model: str,
    instances: List[Dict],
    dataset_manager: DatasetManager,
    provider_url: str,
    output_dir: str,
    save_every: int,
) -> List[Dict]:
    client = OpenAI(base_url=provider_url)
    results = []
    instance_counter = 0

    stats_path = os.path.join(output_dir, "stats.csv")
    lock_file = f"{stats_path}.lock"

    for instance in tqdm(
        instances,
        desc=f"Evaluating {model}",
        position=position,
        lock_args=None,
    ):
        instance_id = instance["ID"]
        scene_name = instance["SceneName"]
        entry = dataset_manager.get_entry(instance_id)
        entry["ID"] = instance_id

        is_correct, extracted_answer, response_time = evaluate_with_model(
            client, model, entry, output_dir
        )

        result = {
            "InstanceId": instance_id,
            "SceneName": scene_name,
            "ModelName": model,
            "CorrectAnswer": entry["Correct"],
            "ModelAnswer": extracted_answer if extracted_answer else "PARSE_ERROR",
        }

        results.append(result)
        instance_counter += 1

        if save_every > 0 and instance_counter % save_every == 0:
            max_attempts = 10
            attempt = 0
            while os.path.exists(lock_file) and attempt < max_attempts:
                sleep_time = 0.1 * (2**attempt)
                time.sleep(sleep_time)
                attempt += 1

            if attempt >= max_attempts:
                continue

            try:
                with open(lock_file, "w") as f:
                    f.write(f"{model}")

                batch_to_save = results[-save_every:]
                save_results_to_csv(
                    batch_to_save,
                    stats_path,
                    "a" if os.path.exists(stats_path) else "w",
                )
            except Exception as e:
                pass
            finally:
                if os.path.exists(lock_file):
                    os.remove(lock_file)

    return results


def evaluate_dataset(
    dataset_dir: str,
    models: List[str],
    provider_url: str,
    output_dir: str,
    save_every: int = 10,
):
    dataset_manager = DatasetManager(dataset_dir)
    df = dataset_manager.get_dataset()

    os.makedirs(output_dir, exist_ok=True)
    stats_path = os.path.join(output_dir, "stats.csv")

    evaluated_instances = set()

    if os.path.exists(stats_path):
        try:
            existing_df = pd.read_csv(stats_path)
            for _, row in existing_df.iterrows():
                instance_key = f"{row['InstanceId']}_{row['ModelName']}"
                evaluated_instances.add(instance_key)
        except Exception:
            pass

    model_instances = {}
    for model in models:
        model_instances[model] = []

    for _, row in df.iterrows():
        instance_data = row.to_dict()
        instance_id = instance_data["ID"]

        for model in models:
            instance_key = f"{instance_id}_{model}"
            if instance_key not in evaluated_instances:
                model_instances[model].append(instance_data)

    models_to_evaluate = [model for model in models if model_instances[model]]

    if not models_to_evaluate:
        return []

    all_model_instances = [
        (model, model_instances[model]) for model in models_to_evaluate
    ]

    evaluate_func = partial(
        evaluate_model_on_instances,
        dataset_manager=dataset_manager,
        provider_url=provider_url,
        output_dir=output_dir,
        save_every=save_every,
    )

    tqdm.set_lock(mp.RLock())
    pool = mp.Pool(
        processes=min(len(models_to_evaluate), mp.cpu_count()),
        initializer=tqdm.set_lock,
        initargs=(tqdm.get_lock(),),
    )

    try:
        results_list = pool.starmap(
            evaluate_func,
            [
                (position, model, instances)
                for position, (model, instances) in enumerate(all_model_instances)
            ],
        )
    finally:
        pool.close()
        pool.join()

    all_results = []
    for results in results_list:
        all_results.extend(results)

    if all_results:
        save_results_to_csv(
            all_results, stats_path, "a" if os.path.exists(stats_path) else "w"
        )

    return all_results


def main():
    parser = argparse.ArgumentParser(
        description="Evaluate models on captcha dataset")
    parser.add_argument(
        "--dataset-dir", type=str, default="dataset", help="Path to dataset directory"
    )
    parser.add_argument(
        "--models",
        type=str,
        nargs="+",
        default=[
            "openai/chatgpt-4o-latest",
            "google/gemini-2.5-flash",
            "anthropic/claude-sonnet-4",
            "anthropic/claude-opus-4",
            "qwen/qwen2.5-vl-72b-instruct",
        ],
        help="Models to evaluate",
    )
    parser.add_argument(
        "--output-dir",
        type=str,
        default="failure_analysis",
        help="Directory to save results and failed reasonings",
    )
    parser.add_argument(
        "--provider-url",
        type=str,
        default="https://openrouter.ai/api/v1",
        help="API provider URL",
    )
    parser.add_argument(
        "--save-every",
        type=int,
        default=2,
        help="Save results to CSV every N instances",
    )

    args = parser.parse_args()

    dataset_dir = os.path.abspath(args.dataset_dir)
    output_dir = os.path.abspath(args.output_dir)

    evaluate_dataset(
        dataset_dir=dataset_dir,
        models=args.models,
        provider_url=args.provider_url,
        output_dir=output_dir,
        save_every=args.save_every,
    )


if __name__ == "__main__":
    main()
