import requests
import json
import os
import base64
from datetime import datetime
from typing import List, Dict, Optional, Set
from PIL import Image
import asyncio
import hashlib
import threading


class ToolsPool:
    """
    A pool of tools for MLLM with intelligent caching and operation locking.
    """

    def __init__(self, api_key: str = "",
                 cache_dir: str = "./tool_cache"):
        """
        Initialize the tools pool with API credentials
        """
        self.api_key = api_key
        self.image_search_url = ''
        self.edit_url = ""
        self.cache_dir = cache_dir

        # Create cache directories
        os.makedirs(cache_dir, exist_ok=True)
        os.makedirs(os.path.join(cache_dir, "search_images"), exist_ok=True)
        os.makedirs(os.path.join(cache_dir, "edited_images"), exist_ok=True)

        # Cache index file
        self.cache_index_file = os.path.join(cache_dir, "search_cache_index.json")
        self.edit_cache_index_file = os.path.join(cache_dir, "edit_cache_index.json")
        self.cache_index = self._load_cache_index()
        self.edit_cache_index = self._load_edit_cache_index()

        # Operation locks to prevent duplicate operations
        self.downloading_queries: Set[str] = set()
        self.editing_images: Set[str] = set()
        self.operation_lock = threading.Lock()

        # Setup session
        self.session = requests.Session()
        self.session.headers.update({
            'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
        })

    def _load_cache_index(self) -> Dict:
        """Load search cache index from file"""
        if os.path.exists(self.cache_index_file):
            try:
                with open(self.cache_index_file, 'r') as f:
                    return json.load(f)
            except:
                return {}
        return {}

    def _load_edit_cache_index(self) -> Dict:
        """Load edit cache index from file"""
        if os.path.exists(self.edit_cache_index_file):
            try:
                with open(self.edit_cache_index_file, 'r') as f:
                    return json.load(f)
            except:
                return {}
        return {}

    def _save_cache_index(self):
        """Save search cache index to file"""
        with open(self.cache_index_file, 'w') as f:
            json.dump(self.cache_index, f, indent=2)

    def _save_edit_cache_index(self):
        """Save edit cache index to file"""
        with open(self.edit_cache_index_file, 'w') as f:
            json.dump(self.edit_cache_index, f, indent=2)

    def _get_query_hash(self, query: str) -> str:
        """Generate a hash for the query to use as cache key"""
        return hashlib.md5(query.lower().strip().encode()).hexdigest()

    def _get_edit_hash(self, image_path: str, instruction: str) -> str:
        """Generate a hash for edit operation"""
        key = f"{os.path.basename(image_path)}:{instruction.lower().strip()}"
        return hashlib.md5(key.encode()).hexdigest()

    def _check_search_cache(self, query: str) -> Optional[Dict]:
        """Check if we have cached search results for this query"""
        query_hash = self._get_query_hash(query)

        # First check if file exists directly (faster)
        possible_extensions = ['jpg', 'png', 'gif', 'webp']
        for ext in possible_extensions:
            filepath = os.path.join(self.cache_dir, "search_images", f"{query_hash}.{ext}")
            if os.path.exists(filepath):
                # File exists, try to get metadata from index
                if query_hash in self.cache_index:
                    cached_info = self.cache_index[query_hash]
                    cached_info['image_path'] = filepath  # Update path in case extension was different
                    print(f"     Found cached image: {os.path.basename(filepath)}")
                    print(f"     Title: {cached_info.get('title', 'Unknown')[:80]}...")
                    return cached_info
                else:
                    # File exists but not in index, create minimal entry
                    print(f"     Found cached image file: {os.path.basename(filepath)}")
                    return {
                        'query': query,
                        'title': 'Cached Image (metadata unavailable)',
                        'image_path': filepath,
                        'cached_at': 'Unknown'
                    }

        # Also check index for any stale entries
        if query_hash in self.cache_index:
            cached_info = self.cache_index[query_hash]
            if os.path.exists(cached_info['image_path']):
                print(f"     Found cached image from index: {os.path.basename(cached_info['image_path'])}")
                print(f"     Title: {cached_info['title'][:80]}...")
                return cached_info
            else:
                # Remove stale entry
                del self.cache_index[query_hash]
                self._save_cache_index()

        return None

    def _check_edit_cache(self, edit_hash: str) -> Optional[str]:
        """Check if edited image already exists"""
        filepath = os.path.join(self.cache_dir, "edited_images", f"edit_{edit_hash}.png")
        if os.path.exists(filepath):
            print(f"     Found cached edited image: {os.path.basename(filepath)}")
            return filepath

        # Also check index
        if edit_hash in self.edit_cache_index:
            cached_path = self.edit_cache_index[edit_hash].get('output_path')
            if cached_path and os.path.exists(cached_path):
                print(f"     Found cached edited image from index: {os.path.basename(cached_path)}")
                return cached_path
            else:
                # Remove stale entry
                del self.edit_cache_index[edit_hash]
                self._save_edit_cache_index()

        return None

    async def search_images_with_download(self, query: str, num_results: int = 1) -> List[Dict]:
        """
        Search for images and download only the top result with caching
        """
        query_hash = self._get_query_hash(query)

        print(f"\n Searching for: '{query[:80]}...'")

        # Check cache first (before any locking)
        cached_result = self._check_search_cache(query)
        if cached_result:
            return [{
                "title": cached_result.get('title', 'Cached Image'),
                "imageUrl": cached_result.get('original_url', ''),
                "local_image_path": cached_result['image_path'],
                "source": cached_result.get('source', ''),
                "domain": cached_result.get('domain', ''),
                "cached": True,
                "image_available": True
            }]

        # Check if already downloading this query
        with self.operation_lock:
            if query_hash in self.downloading_queries:
                print(f"   Already downloading this query, skipping...")
                return []
            self.downloading_queries.add(query_hash)

        try:
            # Perform API search
            headers = {
                'X-AK': self.api_key,
                'Content-Type': 'application/json'
            }

            data = {
                "query": query,
                "num": 1,
                "extendParams": {
                    "country": "us",
                    "locale": "en-us",
                    "location": "United States",
                    "page": 1
                },
                "platformInput": {
                    "model": "google-search"
                }
            }

            print("   → Making API request...")

            loop = asyncio.get_event_loop()
            response = await loop.run_in_executor(
                None,
                lambda: requests.post(self.image_search_url, headers=headers, json=data)
            )

            if response.status_code == 200:
                result = response.json()
                if result.get('success') and 'data' in result:
                    images_data = result['data']['originalOutput'].get('images', [])

                    if not images_data:
                        print("   ✗ No images found in search results")
                        return []

                    img_item = images_data[0]
                    print(f"   → Found image: {img_item.get('title', '')[:80]}...")

                    # Download and cache
                    local_path = await self._download_and_cache_image(img_item, query)

                    result_entry = {
                        "title": img_item.get('title', ''),
                        "imageUrl": img_item.get('imageUrl', ''),
                        "local_image_path": local_path,
                        "source": img_item.get('source', ''),
                        "domain": img_item.get('domain', ''),
                        "link": img_item.get('link', ''),
                        "cached": False,
                        "image_available": local_path is not None
                    }

                    if local_path:
                        print(f"     Successfully downloaded and cached image")

                    return [result_entry]

            return []

        except Exception as e:
            print(f"   ✗ Search API exception: {e}")
            return []
        finally:
            with self.operation_lock:
                self.downloading_queries.discard(query_hash)

    async def _download_and_cache_image(self, img_item: Dict, query: str) -> Optional[str]:
        """
        Download an image and update cache
        """
        image_url = img_item.get('imageUrl', '')
        if not image_url:
            return None

        query_hash = self._get_query_hash(query)

        # Determine file extension
        ext = 'jpg'
        if '.' in image_url:
            potential_ext = image_url.split('.')[-1].lower()[:3]
            if potential_ext in ['jpg', 'png', 'gif', 'webp']:
                ext = potential_ext

        filename = f"{query_hash}.{ext}"
        filepath = os.path.join(self.cache_dir, "search_images", filename)

        # Double-check if file was created while waiting
        if os.path.exists(filepath):
            print(f"     Image already cached during wait: {filename}")
            return filepath

        try:
            temp_filepath = filepath + '.tmp'

            print(f"     Downloading image...")

            loop = asyncio.get_event_loop()
            response = await loop.run_in_executor(
                None,
                lambda: self.session.get(image_url, timeout=15, stream=True)
            )

            if response.status_code == 200:
                with open(temp_filepath, 'wb') as f:
                    for chunk in response.iter_content(chunk_size=8192):
                        if chunk:
                            f.write(chunk)

                # Verify image
                try:
                    img = Image.open(temp_filepath)
                    img.verify()

                    # Resize if needed
                    img = Image.open(temp_filepath)
                    width, height = img.size
                    max_size = 2048

                    if width > max_size or height > max_size:
                        ratio = min(max_size / width, max_size / height)
                        new_size = (int(width * ratio), int(height * ratio))
                        img = img.resize(new_size, Image.Resampling.LANCZOS)
                        img.save(temp_filepath, quality=95)
                        print(f"     Resized to {new_size[0]}x{new_size[1]}")

                    # Move to final location
                    if os.path.exists(filepath):
                        os.remove(temp_filepath)  # Another process beat us to it
                    else:
                        os.rename(temp_filepath, filepath)

                    # Update cache index
                    self.cache_index[query_hash] = {
                        'query': query,
                        'title': img_item.get('title', ''),
                        'image_path': filepath,
                        'original_url': image_url,
                        'source': img_item.get('source', ''),
                        'domain': img_item.get('domain', ''),
                        'cached_at': datetime.now().isoformat(),
                        'dimensions': f"{width}x{height}"
                    }
                    self._save_cache_index()

                    return filepath

                except Exception as e:
                    print(f"       Invalid image: {e}")
                    if os.path.exists(temp_filepath):
                        os.remove(temp_filepath)
                    return None

        except Exception as e:
            print(f"       Download error: {str(e)[:100]}")
            return None

    async def edit_image(self, image_path: str, edit_instruction: str,
                         output_dir: str = None,
                         original_manipulation: str = None,
                         target_image_path: str = None) -> Optional[str]:  # ✅ ADD target_image_path
        """
        Edit image with caching and locking

        Args:
            image_path: Source/reference image to edit
            edit_instruction: Editing query for the API
            output_dir: Directory to save edited images
            original_manipulation: Original manipulation text from dataset
            target_image_path: Ground-truth target image path from dataset
        """
        # Ensure we're working with a local file path, not a URL
        if isinstance(image_path, str) and image_path.startswith('http'):
            print(f"     Edit requires local image file, not URL: {image_path}")
            return None

        edit_hash = self._get_edit_hash(image_path, edit_instruction)

        print(f"\n Edit request: '{edit_instruction[:80]}...'")

        # Check if edited image already exists
        existing_path = self._check_edit_cache(edit_hash)
        if existing_path:
            return existing_path

        # Check if already editing
        with self.operation_lock:
            if edit_hash in self.editing_images:
                print(f"   Already editing with same instruction, skipping...")
                return None
            self.editing_images.add(edit_hash)

        try:
            if output_dir is None:
                output_dir = os.path.join(self.cache_dir, "edited_images")

            output_path = os.path.join(output_dir, f"edit_{edit_hash}.png")

            # Double-check if created while waiting
            if os.path.exists(output_path):
                print(f"     Edited image created during wait: {os.path.basename(output_path)}")
                return output_path

            # Verify the source image exists and is a local file
            if not os.path.exists(image_path):
                print(f"     Source image not found: {image_path}")
                return None

            print(f"     Sending edit request to API...")
            print(f"     Source: {os.path.basename(image_path)}")

            # Prepare request exactly like the example
            url = self.edit_url
            headers = {
                "Authorization": f"Bearer {self.api_key}"
            }

            # Prepare the multipart form data (following the example exactly)
            data = {
                "model": (None, "gpt-image-1-0415-global"),
                "prompt": (None, edit_instruction),
                "size": (None, "1024x1024"),  # Changed from 512x512 to match example
                "quality": (None, "low"),  # Changed from standard to low to match example
                "output_compression": (None, "100"),
                "output_format": (None, "png"),
                "n": (None, "1")
            }

            # Prepare the files - IMPORTANT: Use open file handle, not read content
            files = {
                'image[0]': (
                    os.path.basename(image_path),
                    open(image_path, 'rb'),  # Keep file handle open, don't read
                    'image/png'
                )
            }

            # Combine data and files exactly like the example
            multipart_data = {**data, **files}

            try:
                # Make the request with combined multipart_data
                response = requests.post(
                    url,
                    headers=headers,
                    files=multipart_data,  # Use combined multipart_data
                    timeout=60
                )

            finally:
                # Close file handles
                for key, value in multipart_data.items():
                    if isinstance(value, tuple) and len(value) > 1 and hasattr(value[1], 'close'):
                        value[1].close()

            # Check response
            if response.status_code == 200:
                try:
                    result = response.json()
                except json.JSONDecodeError:
                    print(f"     Invalid JSON response from edit API")
                    return None

                if "data" in result and len(result["data"]) > 0:
                    img_data = result["data"][0]
                    if "b64_json" in img_data:
                        os.makedirs(output_dir, exist_ok=True)

                        # Save to temp file first
                        temp_path = output_path + '.tmp'
                        try:
                            with open(temp_path, 'wb') as img_file:
                                img_file.write(base64.b64decode(img_data["b64_json"]))

                            # Verify it's a valid image
                            from PIL import Image
                            img = Image.open(temp_path)
                            img.verify()

                            # Move to final location
                            if os.path.exists(output_path):
                                os.remove(temp_path)
                            else:
                                os.rename(temp_path, output_path)

                            self.edit_cache_index[edit_hash] = {
                                'original_image': image_path,  # Reference/source image
                                'target_image': target_image_path,  # Ground-truth target from dataset
                                'edited_image': output_path,  # Result of editing
                                'manipulation_text': original_manipulation or edit_instruction,
                                'editing_query': edit_instruction,
                                'edited_at': datetime.now().isoformat()
                            }
                            self._save_edit_cache_index()

                            print(f"     Edited image saved: {os.path.basename(output_path)}")
                            return output_path

                        except Exception as e:
                            print(f"     Failed to save edited image: {e}")
                            if os.path.exists(temp_path):
                                os.remove(temp_path)
                            return None
                    else:
                        print(f"     No base64 image data in response")
                        return None
                else:
                    print(f"     No data in edit API response")
                    return None
            else:
                print(f"     Edit API error: Status {response.status_code}")
                try:
                    error_response = response.json()
                    error_msg = error_response.get('message', 'Unknown error')
                    print(f"     Error: {error_msg}")

                    # Print more detailed error info if available
                    if 'detailMessage' in error_response:
                        print(f"     Details: {error_response['detailMessage']}")
                except:
                    print(f"     Response: {response.text[:500]}")
                return None

        except Exception as e:
            print(f"     Unexpected edit exception: {str(e)[:200]}")
            import traceback
            traceback.print_exc()
            return None
        finally:
            with self.operation_lock:
                self.editing_images.discard(edit_hash)

    def format_search_results_for_llm(self, search_results: List[Dict]) -> Dict:
        """Format search results for MLLM"""
        if not search_results or not search_results[0].get('image_available'):
            return {
                "text": "No visual reference available.",
                "images": [],
                "descriptions": []
            }

        result = search_results[0]

        formatted_text = "=== Visual Reference ===\n"
        formatted_text += f"Title: {result['title']}\n"
        formatted_text += f"Source: {result.get('source', 'Unknown')}"
        if result.get('domain'):
            formatted_text += f" ({result['domain']})"
        formatted_text += "\n"

        if result.get('cached'):
            formatted_text += "[Retrieved from cache]\n"
        else:
            formatted_text += "[Newly downloaded]\n"

        formatted_text += "=== End of Reference ===\n"

        return {
            "text": formatted_text,
            "images": [{
                'path': result['local_image_path'],
                'title': result['title'],
                'index': 1
            }] if result.get('local_image_path') else [],
            "descriptions": [{
                'index': 1,
                'title': result['title'],
                'source': result.get('source', ''),
                'domain': result.get('domain', '')
            }]
        }

    def get_cache_stats(self) -> Dict:
        """Get cache statistics"""
        search_dir = os.path.join(self.cache_dir, "search_images")
        edit_dir = os.path.join(self.cache_dir, "edited_images")

        search_files = len([f for f in os.listdir(search_dir) if os.path.isfile(os.path.join(search_dir, f))])
        edit_files = len([f for f in os.listdir(edit_dir) if os.path.isfile(os.path.join(edit_dir, f))])

        total_size = 0
        for dir_path in [search_dir, edit_dir]:
            total_size += sum(
                os.path.getsize(os.path.join(dir_path, f))
                for f in os.listdir(dir_path)
                if os.path.isfile(os.path.join(dir_path, f))
            )

        return {
            "search_cached": len(self.cache_index),
            "edit_cached": len(self.edit_cache_index),
            "search_files": search_files,
            "edit_files": edit_files,
            "total_size_mb": round(total_size / (1024 * 1024), 2)
        }