import json
import os
import base64
import requests
from PIL import Image
import io
import time
Image.MAX_IMAGE_PIXELS = None

# Configuration
MODEL_NAME = "gemini-2.5-flash"
API_KEY = os.getenv("GEMINI_API_KEY", "YOUR_GEMINI_API_KEY_HERE")
API_URL = "https://generativelanguage.googleapis.com/v1beta/models/gemini-2.5-flash:generateContent"
DATASET_PATH = "../dataset/dataset_gemini/"
INPUT_JSON = "e5v_results.json"
OUTPUT_JSON = "gemini_results.json"

# File extensions
IMAGE_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.gif']
VIDEO_EXTENSIONS = ['.mp4']
MEDIA_EXTENSIONS = ['.mp4', '.jpg', '.png', '.wav', '.mp3', '.pdf', '.jpeg', '.gif', '.txt']
SUPPORTED_EXTENSIONS = ['', '.mp4', '.jpg', '.png', '.jpeg']

# Image processing configuration
MAX_IMAGE_SIZE = 1024
RESIZE_METHOD = Image.LANCZOS
DEFAULT_IMAGE_FORMAT = "PNG"

# API configuration
GENERATION_TEMPERATURE = 0.4
MAX_OUTPUT_TOKENS = 2048
MAX_RETRY_ATTEMPTS = 3
RETRY_DELAY = 1
RATE_LIMIT_DELAY = 3
REQUEST_DELAY = 2

# MIME type mappings
MIME_TYPES = {
    '.jpg': 'image/jpeg',
    '.jpeg': 'image/jpeg',
    '.png': 'image/png',
    '.gif': 'image/gif',
    '.mp4': 'video/mp4'
}

print(f"Model: {MODEL_NAME}")
print(f"Dataset: {DATASET_PATH}")
print(f"Input: {INPUT_JSON}")
print(f"Output: {OUTPUT_JSON}")

# Image processing utilities
def resize_image(img_path):
    """Resize image maintaining aspect ratio and format compatibility."""
    try:
        with Image.open(img_path) as img:
            ext = os.path.splitext(img_path)[1].lower()
            if img.format:
                target_format = img.format
            elif ext in (".jpg", ".jpeg"):
                target_format = "JPEG"
            elif ext == ".png":
                target_format = "PNG"
            elif ext == ".gif":
                target_format = "GIF"
            else:
                target_format = DEFAULT_IMAGE_FORMAT

            width, height = img.size
            if width > MAX_IMAGE_SIZE or height > MAX_IMAGE_SIZE:
                if width > height:
                    new_width = MAX_IMAGE_SIZE
                    new_height = int(height * (MAX_IMAGE_SIZE / width))
                else:
                    new_height = MAX_IMAGE_SIZE
                    new_width = int(width * (MAX_IMAGE_SIZE / height))

                img = img.resize((new_width, new_height), RESIZE_METHOD)

            if target_format.upper() == "JPEG" and img.mode not in ("RGB", "L"):
                img = img.convert("RGB")

            buffer = io.BytesIO()
            img.save(buffer, format=target_format)
            return buffer.getvalue()
    except Exception as e:
        print(f"Error processing image {img_path}: {e}")
        return None

def encode_file_to_base64(file_path):
    """Encode file to base64 with image optimization."""
    try:
        if any(file_path.lower().endswith(ext) for ext in IMAGE_EXTENSIONS):
            img_bytes = resize_image(file_path)
            if img_bytes:
                return base64.b64encode(img_bytes).decode('utf-8')

        with open(file_path, 'rb') as file:
            return base64.b64encode(file.read()).decode('utf-8')
    except Exception as e:
        print(f"Error encoding file {file_path}: {e}")
        return None

# File handling utilities
def find_media_resources(filename, base_dir=DATASET_PATH):
    """Find media resources for filename, supporting files and folders."""
    import re
    base_name = re.sub(r'\.(mp4|jpg|png|wav|mp3|pdf|jpeg|gif|txt)$', '', filename, flags=re.IGNORECASE)

    folder_path = os.path.join(base_dir, base_name)
    if os.path.isdir(folder_path):
        return get_all_images_from_folder(folder_path)

    for ext in SUPPORTED_EXTENSIONS:
        file_path = os.path.join(base_dir, base_name + ext)
        if os.path.isfile(file_path):
            return [file_path]

    return []

def get_all_images_from_folder(folder_path):
    """Return sorted image files from folder."""
    image_files = []
    for name in os.listdir(folder_path):
        path = os.path.join(folder_path, name)
        if os.path.isfile(path):
            ext = os.path.splitext(name)[1].lower()
            if ext in IMAGE_EXTENSIONS:
                image_files.append(path)
    return sorted(image_files)

# API interaction
def query_gemini(question, file_paths):
    """Query Gemini model with question and media files."""
    parts = [{"text": question}]

    for path in file_paths:
        if path:
            ext = os.path.splitext(path)[1].lower()
            mime_type = MIME_TYPES.get(ext, "image/jpeg")

            encoded_file = encode_file_to_base64(path)
            if encoded_file:
                parts.append({
                    "inline_data": {
                        "mime_type": mime_type,
                        "data": encoded_file
                    }
                })

    headers = {
        "Content-Type": "application/json",
        "x-goog-api-key": API_KEY
    }

    data = {
        "contents": [{"parts": parts}],
        "generationConfig": {
            "temperature": GENERATION_TEMPERATURE,
            "maxOutputTokens": MAX_OUTPUT_TOKENS
        }
    }

    for attempt in range(MAX_RETRY_ATTEMPTS):
        try:
            response = requests.post(
                f"{API_URL}?key={API_KEY}",
                headers=headers,
                json=data
            )

            if response.status_code == 200:
                result = response.json()
                if "candidates" in result and len(result["candidates"]) > 0:
                    return result["candidates"][0]["content"]["parts"][0]["text"]
                return "No response generated"
            elif response.status_code == 429:
                print(f"Rate limited. Waiting before retry")
                time.sleep(RATE_LIMIT_DELAY)
            else:
                print(f"API error: {response.status_code}")
                print(response.text)
                time.sleep(RETRY_DELAY)
        except Exception as e:
            print(f"Request error: {e}")
            time.sleep(RETRY_DELAY)

    return "Failed to get response after multiple attempts"

# Main processing pipeline
def main():
    with open(INPUT_JSON, "r") as f:
        data = json.load(f)

    results = []

    for item in data["results"]:
        question = item["question"]
        print(f"Processing question: {question}")

        file_paths = []
        for filename in item["top_5_retrieved"][:5]:
            media_paths = find_media_resources(filename, DATASET_PATH)
            file_paths.extend(media_paths)

        response = ""
        if file_paths:
            print(f"Querying {MODEL_NAME} with {len(file_paths)} files")
            response = query_gemini(question, file_paths)
            print(f"Received response from {MODEL_NAME}")
        else:
            print("No files found for this question")

        result = {
            "question": question,
            "positive": item["positive"],
            "top_5_retrieved": item["top_5_retrieved"],
            "answer": item["answer"],
            "response": response
        }
        results.append(result)

        with open(OUTPUT_JSON, "w") as f:
            json.dump(results, f, indent=2)

        time.sleep(REQUEST_DELAY)

    print(f"Processing complete. Results saved to {OUTPUT_JSON}")

if __name__ == "__main__":
    main()
