import os
import json
import shutil
from pathlib import Path
from config import llm_config
from utils import custom_encoder
from openai import OpenAI
import base64, mimetypes
import re
from PIL import Image

sk_type = os.getenv("sk_type", "open")
if sk_type == "binary":
    from tools import binary as sketch
elif sk_type == "canny":
    from tools import canny as sketch
else:
    from tools import sketch

def remove_img_tags(text):
    """Remove <img> tags from the query text."""
    return re.sub(r'<img[^>]*>', '', text)

def to_data_uri(path: str) -> str:
    """Convert an image file to a data URI."""
    mime, _ = mimetypes.guess_type(path)
    mime = mime or "image/jpeg"
    with open(path, "rb") as f:
        b64 = base64.b64encode(f.read()).decode()
    return f"data:{mime};base64,{b64}"

def image_to_data_uri(image, format="JPEG"):
    """Convert a PIL Image to a data URI."""
    import io
    buffer = io.BytesIO()
    image.save(buffer, format=format)
    buffer.seek(0)
    mime = f"image/{format.lower()}"
    b64 = base64.b64encode(buffer.read()).decode()
    return f"data:{mime};base64,{b64}"

def run_agent(task_input, output_dir, task_type="vision", task_name=None):
    """Run a simple LLM baseline agent on one task instance.

    Args:
        task_input (str): a path to the task input directory
        output_dir (str): a path to the directory where the output will be saved
        task_type (str): Task type. Should be vision, math, or geo. Defaults to "vision".
        task_name (str, optional): Only needed for math tasks. Defaults to None.
    """

    task_input = task_input.rstrip('/')
    task_directory = os.path.join(output_dir, os.path.basename(task_input))

    os.makedirs(output_dir, exist_ok=True)
    shutil.copytree(task_input, task_directory, dirs_exist_ok=True)
    
    task_metadata = json.load(open(os.path.join(task_input, "request.json")))
    query = task_metadata['query']
    images = task_metadata['images']
    
    data_urls = []
    sketch_urls = []
    sketch_paths = []

    for i, img_path in enumerate(images):
        data_urls.append(to_data_uri(img_path))
        
        original_image = Image.open(img_path)
        sketch_image = sketch(original_image)
        
        sketch_filename = f"sketch_{i+1}_{os.path.basename(img_path)}"
        sketch_path = os.path.join(task_directory, sketch_filename)
        sketch_image.save(sketch_path)
        sketch_paths.append(sketch_path)
        
        sketch_urls.append(image_to_data_uri(sketch_image))
    
    
    query = remove_img_tags(query)
    
    # Construct the prompt with the format instruction
    prompt = (
        f"{query}\n\n"
        "you can think step by step, and you can use the sketch to help you think.\n"
        "please reply with ANSWER: (your answer) and ends with TERMINATE. example: ANSWER: (A). TERMINATE. Requirement: Do not reply none response."
    )
    
    messages = [
        {
            "role": "system",
            "content": (
                "You are a helpful Multimodal assistant."
                "For each image, you will be provided with a sketch that reduces redundancy and helps you focus your reasoning on the essential content."
            )
        },
        {
        "role": "user",
        "content": (
            [{"type": "text", "text": prompt}]
        )
    }]
    
    for i, (orig_url, sketch_url) in enumerate(zip(data_urls, sketch_urls)):
        messages[1]["content"].append(
            {"type": "image_url", "image_url": {"url": orig_url}}
        )
        # Add sketch with description
        messages[1]["content"].append(
            {"type": "text", "text": f"[Sketch version of image {i+1}]:"}
        )
        messages[1]["content"].append(
            {"type": "image_url", "image_url": {"url": sketch_url}}
        )

    model_info = llm_config["config_list"][0]
    client = OpenAI(
        base_url=os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1"),
        api_key=model_info["api_key"]
    )

    response = None
    try:
        i = 0
        while True:
            i += 1
            response = client.chat.completions.create(
                model=model_info["model"],
                messages=messages,
                temperature=model_info.get("temperature", 0),
                max_tokens=model_info.get("max_tokens", 512)
            )
            reply = response.choices[0].message.content.strip()
            if reply != "" or i >= 2:
                break
    except Exception as e:
        reply = f"[ERROR] {str(e)}"
    
    all_messages = [
        {"role": "user", "content": [{"type": "text", "text": query}]},
        {"role": "assistant", "content": [{"type": "text", "text": reply}]}
    ]
    
    with open(os.path.join(task_directory, "output.json"), "w") as f:
        json.dump(all_messages, f, indent=4, default=custom_encoder)
    
    usage_summary = {
        'total': {'prompt_tokens': 0, 'completion_tokens': 0, 'total_tokens': 0},
        'actual': {'prompt_tokens': 0, 'completion_tokens': 0, 'total_tokens': 0},
        'model': model_info["model"]
    }

    if hasattr(response, 'usage') and response.usage:
        for usage_type in ['total', 'actual']:
            usage_summary[usage_type]['prompt_tokens'] = response.usage.prompt_tokens
            usage_summary[usage_type]['completion_tokens'] = response.usage.completion_tokens
            usage_summary[usage_type]['total_tokens'] = response.usage.total_tokens
    
    with open(os.path.join(task_directory, "usage_summary.json"), "w") as f:
        json.dump(usage_summary, f, indent=4)