import os
import base64
import argparse
import pandas as pd
from openai import OpenAI
from tqdm import tqdm
from typing import Dict, List

from app.utils.dataset_manager import DatasetManager


def encode_image(image_path: str) -> str:
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode("utf-8")


def get_judge_verdict(
    client: OpenAI,
    judge_model: str,
    image_path: str,
    correct_answer: str,
    model_reasoning: str,
    model_answer: str,
) -> str:
    base64_image = encode_image(image_path)

    try:
        completion = client.chat.completions.create(
            model=judge_model,
            messages=[
                {
                    "role": "system",
                    "content": "You are an expert at analyzing spatial reasoning failures. Provide concise, one-sentence summaries of reasoning errors.",
                },
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "text",
                            "text": f"Here is a spatial reasoning task with an image. The correct answer is '{correct_answer}'. A model provided this reasoning and answer:\n\nModel reasoning:\n{model_reasoning}\n\nModel answer: {model_answer}\n\nIn one sentence, summarize what was the crucial error in reasoning that led to the wrong answer.",
                        },
                        {
                            "type": "image_url",
                            "image_url": {
                                "url": f"data:image/png;base64,{base64_image}"
                            },
                        },
                    ],
                },
            ],
        )

        return completion.choices[0].message.content.strip()

    except Exception as e:
        return f"ERROR: {str(e)}"


def process_failures(
    failure_results_dir: str,
    dataset_dir: str,
    judge_model: str,
    provider_url: str,
    output_csv: str,
    save_every: int,
):
    dataset_manager = DatasetManager(dataset_dir)
    client = OpenAI(base_url=provider_url)
    
    stats_path = os.path.join(failure_results_dir, "stats.csv")
    failed_reasonings_dir = os.path.join(failure_results_dir, "failed_reasonings")
    
    if not os.path.exists(stats_path):
        print(f"Stats file not found: {stats_path}")
        return
    
    if not os.path.exists(failed_reasonings_dir):
        print(f"Failed reasonings directory not found: {failed_reasonings_dir}")
        return
    
    stats_df = pd.read_csv(stats_path)
    
    processed_instances = set()
    if os.path.exists(output_csv):
        try:
            existing_df = pd.read_csv(output_csv)
            for _, row in existing_df.iterrows():
                instance_key = f"{row['InstanceId']}_{row['ModelName']}"
                processed_instances.add(instance_key)
        except Exception:
            pass
    
    failed_cases = stats_df[stats_df['CorrectAnswer'] != stats_df['ModelAnswer']]
    
    results = []
    counter = 0
    
    for _, row in tqdm(failed_cases.iterrows(), total=len(failed_cases), desc="Processing failures"):
        instance_id = row['InstanceId']
        model_name = row['ModelName']
        correct_answer = row['CorrectAnswer']
        model_answer = row['ModelAnswer']
        
        instance_key = f"{instance_id}_{model_name}"
        if instance_key in processed_instances:
            continue
        
        model_name_clean = model_name.replace("/", "_")
        reasoning_file = os.path.join(failed_reasonings_dir, f"{instance_id}_{model_name_clean}.txt")
        
        if not os.path.exists(reasoning_file):
            verdict = "No reasoning file found"
        else:
            try:
                with open(reasoning_file, "r", encoding="utf-8") as f:
                    reasoning_content = f.read()
                
                entry = dataset_manager.get_entry(instance_id)
                image_path = entry["ImagePath"]
                
                verdict = get_judge_verdict(
                    client,
                    judge_model,
                    image_path,
                    correct_answer,
                    reasoning_content,
                    model_answer,
                )
            except Exception as e:
                verdict = f"Error processing: {str(e)}"
        
        results.append({
            "InstanceId": instance_id,
            "ModelName": model_name,
            "Verdict": verdict,
        })
        
        counter += 1
        
        if save_every > 0 and counter % save_every == 0:
            batch_to_save = results[-save_every:]
            df = pd.DataFrame(batch_to_save)
            if os.path.exists(output_csv):
                df.to_csv(output_csv, mode="a", header=False, index=False)
            else:
                df.to_csv(output_csv, index=False)
    
    if results and (counter % save_every != 0 or save_every == 0):
        remaining = results[-(counter % save_every):] if save_every > 0 else results
        df = pd.DataFrame(remaining)
        if os.path.exists(output_csv):
            df.to_csv(output_csv, mode="a", header=False, index=False)
        else:
            df.to_csv(output_csv, index=False)
    
    print(f"Processed {len(results)} new failures, saved to {output_csv}")


def main():
    parser = argparse.ArgumentParser(
        description="Analyze failure reasoning with judge model")
    parser.add_argument(
        "--failure-analysis-results",
        type=str,
        required=True,
        help="Path to failure analysis results directory",
    )
    parser.add_argument(
        "--dataset",
        type=str,
        required=True,
        help="Path to dataset directory",
    )
    parser.add_argument(
        "--judge-model",
        type=str,
        default="openai/chatgpt-4o-latest",
        help="Judge model to analyze failures",
    )
    parser.add_argument(
        "--provider-url",
        type=str,
        default="https://openrouter.ai/api/v1",
        help="API provider URL",
    )
    parser.add_argument(
        "--output-csv",
        type=str,
        default="failure_verdicts.csv",
        help="Output CSV file for verdicts",
    )
    parser.add_argument(
        "--save-every",
        type=int,
        default=5,
        help="Save results every N processed instances",
    )
    
    args = parser.parse_args()
    
    process_failures(
        args.failure_analysis_results,
        args.dataset,
        args.judge_model,
        args.provider_url,
        args.output_csv,
        args.save_every,
    )


if __name__ == "__main__":
    main()
