#!/usr/bin/env python3

import re
import base64
import os
import requests
import logging
from pathlib import Path
from PIL import Image
from io import BytesIO

import cairosvg
from openai import OpenAI
import anthropic
from google import genai
from xai_sdk import Client as XAIClient
from xai_sdk.chat import user, system, image
import config

def load_input(house, prompt):
    """Load images from house/imgs folder."""
    imgs_dir = Path(house) / "imgs"
    images = []
    # Support all common image formats
    image_extensions = ['*.png', '*.jpg', '*.jpeg', '*.PNG', '*.JPG', '*.JPEG']
    for ext in image_extensions:
        for img_path in imgs_dir.glob(ext):
            with Image.open(img_path) as img:
                images.append(img.copy())
    return images, prompt

def llm_answer_to_img(response, max_retries=5):
    """Convert LLM SVG response to image with retry logic."""
    # Comprehensive patterns to handle various LLM response formats
    patterns = [
        # Standard code blocks
        r'```svg\s*\n(.*?)\n```',
        r'```xml\s*\n(.*?)\n```',
        r'```\s*\n(<svg.*?</svg>)',
        # Alternative code block formats
        r'```svg(.*?)```',
        r'```xml(.*?)```',
        r'```\s*(<svg.*?</svg>)\s*```',
        # Direct SVG without code blocks
        r'(<svg[^>]*>.*?</svg>)',
        # Handle cases with extra whitespace or newlines
        r'```svg\s*(.*?)\s*```',
        r'```xml\s*(.*?)\s*```',
        # Handle cases where SVG might be in other markup
        r'<svg[^>]*>(.*?)</svg>',
    ]

    svg_code = None

    # Try each pattern to extract SVG
    for pattern in patterns:
        matches = re.findall(pattern, response, re.DOTALL | re.IGNORECASE)
        for match in matches:
            # Handle both tuple and string matches
            potential_svg = match if isinstance(match, str) else match[0] if match else ""
            potential_svg = potential_svg.strip()

            # Skip empty matches
            if not potential_svg:
                continue

            # If it doesn't start with <svg, try to find the svg tag within it
            if not potential_svg.lower().startswith('<svg'):
                svg_match = re.search(r'(<svg[^>]*>.*?</svg>)', potential_svg, re.DOTALL | re.IGNORECASE)
                if svg_match:
                    potential_svg = svg_match.group(1)

            # Validate that we have a complete SVG
            if ('<svg' in potential_svg.lower() and '</svg>' in potential_svg.lower() and
                potential_svg.lower().count('<svg') == potential_svg.lower().count('</svg>')):
                svg_code = potential_svg
                break

        if svg_code:
            break

    if not svg_code:
        logging.warning("No valid SVG found in response")
        return None

    # Try conversion with retries
    for attempt in range(max_retries):
        try:
            png_data = cairosvg.svg2png(bytestring=svg_code.encode('utf-8'))
            return Image.open(BytesIO(png_data)).convert('RGB')
        except Exception as e:
            logging.warning(f"SVG conversion attempt {attempt + 1} failed: {e}")
            if attempt == max_retries - 1:
                return None

def image_answer_to_img(response):
    """Convert image model response to image."""
    if isinstance(response, Image.Image):
        return response

    if isinstance(response, str):
        if response.startswith('http'):
            # Handle URLs (from OpenAI image API)
            try:
                img_response = requests.get(response)
                if img_response.status_code == 200:
                    return Image.open(BytesIO(img_response.content)).convert('RGB')
            except:
                return None
        elif response.startswith('data:image'):
            data = response.split(',', 1)[1]
            image_bytes = base64.b64decode(data)
            return Image.open(BytesIO(image_bytes)).convert('RGB')
        else:
            # Try to decode as base64
            try:
                image_bytes = base64.b64decode(response)
                return Image.open(BytesIO(image_bytes)).convert('RGB')
            except:
                return None

    return None

def call_model(model, images, prompt):
    """Call model with images and prompt, return response."""
    try:
        provider, model_name = model.split("/")

        # Check if it's an image generation model
        image_models = ["gpt-image-1", "gemini-2.5-flash-image-preview", "grok-2-image"]
        if model_name in image_models:
            return _call_image_model(provider, model_name, images, prompt)
        else:
            # LLM models
            return _call_llm_model(provider, model_name, images, prompt)

    except Exception as e:
        logging.error(f"Error calling {model}: {e}")
        return None


def _call_llm_model(provider, model_name, images, prompt):
    """Call LLM model for SVG generation."""
    if provider == "openai":
        return _call_openai_llm(model_name, images, prompt)
    elif provider == "anthropic":
        return _call_anthropic_llm(model_name, images, prompt)
    elif provider == "google":
        return _call_google_llm(model_name, images, prompt)
    elif provider == "grok":
        return _call_grok_llm(model_name, images, prompt)
    else:
        return None


def _call_image_model(provider, model_name, images, prompt):
    """Call image generation model."""
    if provider == "openai":
        return _call_openai_image(model_name, images, prompt)
    elif provider == "google":
        return _call_google_image(model_name, images, prompt)
    else:
        return None


def _call_openai_llm(model_name, images, prompt):
    """Call OpenAI LLM model."""
    try:
        client = OpenAI()

        # Prepare messages with images
        messages = [{
            "role": "user",
            "content": [
                {"type": "text", "text": prompt}
            ]
        }]

        # Add images
        for img in images:
            buffer = BytesIO()
            if img.mode != 'RGB':
                img = img.convert('RGB')
            img.save(buffer, format='PNG')
            buffer.seek(0)
            img_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')

            messages[0]["content"].append({
                "type": "image_url",
                "image_url": {"url": f"data:image/png;base64,{img_base64}"}
            })

        response = client.chat.completions.create(
            model=model_name,
            messages=messages
        )

        return response.choices[0].message.content

    except Exception as e:
        logging.error(f"OpenAI LLM error: {e}")
        return None


def _call_anthropic_llm(model_name, images, prompt):
    """Call Anthropic LLM model."""
    try:
        client = anthropic.Anthropic()

        # Prepare content with images
        content = []

        for img in images:
            buffer = BytesIO()
            if img.mode != 'RGB':
                img = img.convert('RGB')
            img.save(buffer, format='PNG')
            buffer.seek(0)
            img_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')

            content.append({
                "type": "image",
                "source": {
                    "type": "base64",
                    "media_type": "image/png",
                    "data": img_base64
                }
            })

        content.append({
            "type": "text",
            "text": prompt
        })

        response = client.messages.create(
            model=model_name,
            max_tokens=4000,
            messages=[{
                "role": "user",
                "content": content
            }]
        )

        return response.content[0].text

    except Exception as e:
        logging.error(f"Anthropic LLM error: {e}")
        return None


def _call_google_llm(model_name, images, prompt):
    """Call Google LLM model."""
    try:
        client = genai.Client()

        # Prepare content with images and prompt
        content = [prompt] + images

        response = client.models.generate_content(
            model=model_name,
            contents=content
        )
        return response.text

    except Exception as e:
        logging.error(f"Google LLM error: {e}")
        return None


def _call_grok_llm(model_name, images, prompt):
    """Call Grok LLM model."""
    try:
        client = XAIClient(api_key=os.getenv("GROK_API_KEY"))

        chat = client.chat.create(model=model_name, temperature=0)

        # Prepare user message with text and images
        if not images:
            chat.append(user(prompt))
        else:
            # Convert images to base64 data URLs for XAI SDK
            image_parts = []
            for img in images:
                buffer = BytesIO()
                if img.mode != 'RGB':
                    img = img.convert('RGB')
                img.save(buffer, format='PNG')
                buffer.seek(0)
                img_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
                data_url = f"data:image/png;base64,{img_base64}"
                image_parts.append(image(image_url=data_url, detail="high"))

            # Append user message with text and all images
            chat.append(user(prompt, *image_parts))

        response = chat.sample()
        return response.content

    except Exception as e:
        logging.error(f"Grok LLM error: {e}")
        return None


def _call_openai_image(model_name, images, prompt):
    """Call OpenAI image generation model."""
    try:
        client = OpenAI()

        # Handle random baseline (no input images)
        if not images:
            logging.info(f"Generating random image with {model_name}")
            response = client.images.generate(
                model=model_name,
                prompt=prompt,
                n=1,
                size="1024x1024"
            )

            if response.data and len(response.data) > 0:
                if hasattr(response.data[0], 'b64_json') and response.data[0].b64_json:
                    return f"data:image/png;base64,{response.data[0].b64_json}"
                elif hasattr(response.data[0], 'url') and response.data[0].url:
                    return response.data[0].url
                else:
                    logging.error("Response data has no b64_json or url field")
                    return None
            else:
                logging.error("Empty response data from OpenAI image generation API")
                return None

        logging.info(f"Processing {len(images)} images for OpenAI image model")

        image_files = []
        for i, img in enumerate(images):
            buffer = BytesIO()
            if img.mode != 'RGB':
                img = img.convert('RGB')
                logging.debug(f"Converted image {i} from {img.mode} to RGB")
            img.save(buffer, format='PNG')
            buffer.seek(0)
            buffer.name = f'apartment_image_{i}.png'
            image_files.append(buffer)

        logging.info(f"Making API call to {model_name} with {len(image_files)} images")

        response = client.images.edit(
            model=model_name,
            image=image_files,
            prompt=prompt
        )

        logging.info("API call successful, processing response")

        # Handle the correct response format - use b64_json instead of url
        if response.data and len(response.data) > 0:
            if hasattr(response.data[0], 'b64_json') and response.data[0].b64_json:
                # Return base64 data for image_answer_to_img to process
                return f"data:image/png;base64,{response.data[0].b64_json}"
            elif hasattr(response.data[0], 'url') and response.data[0].url:
                # Fallback to URL if available
                return response.data[0].url
            else:
                logging.error("Response data has no b64_json or url field")
                return None
        else:
            logging.error("Empty response data from OpenAI image API")
            return None

    except Exception as e:
        logging.error(f"OpenAI image error: {type(e).__name__}: {e}")
        import traceback
        logging.error(f"Full traceback: {traceback.format_exc()}")
        return None


def _call_google_image(model_name, images, prompt):
    """Call Google image generation model."""
    try:
        client = genai.Client()

        # Prepare content with images and prompt
        content = [prompt] + images

        response = client.models.generate_content(
            model=model_name,
            contents=content
        )

        # Extract image from response
        for part in response.candidates[0].content.parts:
            if part.inline_data is not None:
                return Image.open(BytesIO(part.inline_data.data))

        return None

    except Exception as e:
        logging.error(f"Google image error: {e}")
        return None

def process_house(house, model, epoch):
    """Process single house with single model for single epoch."""
    # Parse model string and create safe filename
    parts = model.split("/", 1)  # Split only on first /
    provider = parts[0]
    model_name = parts[1] if len(parts) > 1 else "default"
    # Replace any remaining slashes in model_name for safe filename
    safe_model_name = model_name.replace("/", "_")

    output_dir = Path("./predictions") / house
    output_file = output_dir / f"{provider}_{safe_model_name}_{epoch}.png"

    if output_file.exists():
        logging.info(f"Skip: {house}+{model}+{epoch} (exists)")
        return

    # Load input
    house_path = Path("./dataset") / house

    # Choose appropriate instruction based on model type
    image_models = ["gpt-image-1", "gemini-2.5-flash-image-preview", "grok-2-image"]
    if any(img_model in model for img_model in image_models):
        prompt_instr = config.IMAGE_INSTR
    else:
        prompt_instr = config.LLM_INSTR

    images, prompt = load_input(house_path, config.RULES + prompt_instr)

    if not images:
        logging.error(f"No images in {house}")
        return "error"

    # Retry logic for model call and conversion
    for attempt in range(config.MAX_RETRIES):
        try:
            # Call model
            response = call_model(model, images, prompt)
            if not response:
                if attempt == config.MAX_RETRIES - 1:
                    logging.error(f"Model call failed after {config.MAX_RETRIES} attempts for {house}+{model}+{epoch}")
                    return "error"
                logging.warning(f"Model call attempt {attempt + 1} failed for {house}+{model}+{epoch}, retrying...")
                continue

            # Convert to image
            image_models = ["gpt-image-1", "gemini-2.5-flash-image-preview", "grok-2-image"]
            if any(img_model in model for img_model in image_models):
                result_img = image_answer_to_img(response)
            else:
                result_img = llm_answer_to_img(response)

            if not result_img:
                if attempt == config.MAX_RETRIES - 1:
                    logging.error(f"Could not convert response after {config.MAX_RETRIES} attempts for {house}+{model}+{epoch}")
                    return "error"
                logging.warning(f"Conversion attempt {attempt + 1} failed for {house}+{model}+{epoch}, retrying...")
                continue

            # Save
            output_dir.mkdir(parents=True, exist_ok=True)
            result_img.save(output_file)
            logging.info(f"Saved: {output_file}")
            return None  # Success

        except Exception as e:
            if attempt == config.MAX_RETRIES - 1:
                logging.error(f"Processing failed after {config.MAX_RETRIES} attempts for {house}+{model}+{epoch}: {e}")
                return "error"
            logging.warning(f"Processing attempt {attempt + 1} failed for {house}+{model}+{epoch}: {e}, retrying...")
            continue
