import os
import tempfile
from typing import Any, Dict, Optional

from google import genai
from google.genai import types as genai_types
from PIL import Image

from llms.providers.google.google_client_manager import get_client_manager
from utils.image_utils import any_to_pil


class GoogleFileManager:
    def __init__(self, p_id: int = 0):
        self.p_id = p_id
        # Maps images to genai files. Why: GoogleAPI requires uploading big images to the cloud
        # to send in the prompts. This dictionary maps images previously uploaded to the genai files.
        self.img_to_uploaded: Dict[int, genai_types.File] = {}  # hash(image) -> genai file

        # Maps genai files to PIL images. This necessary because it is not possible to retrieve the images back from genai files.
        # and is useful to create prompt visualizations, reupload images
        self.uploaded_to_img: Dict[str, Image.Image] = {}  # genai file -> image

    def upload_image_file(
        self,
        image_path: str,
        client: genai.Client,
    ) -> genai_types.File:
        return client.files.upload(file=image_path)

    def get_upload_image_file(self, image: Any, force_upload: bool = False) -> genai_types.File:
        image_pil = any_to_pil(image)
        image_hash = hash(image_pil.tobytes())
        if image_hash in self.img_to_uploaded and not force_upload:
            gen_ai_file = self.img_to_uploaded[image_hash]
            if gen_ai_file.uri is None:
                raise ValueError("Uploaded file has no valid URI")
            self.uploaded_to_img[gen_ai_file.uri] = image_pil
            return gen_ai_file

        client = get_client_manager(self.p_id).get_client()
        with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp:
            image_pil.save(temp.name, format="PNG")
            gen_ai_file = self.upload_image_file(temp.name, client)
            if gen_ai_file.uri is None:
                raise ValueError("Uploaded file has no valid URI")
            self.uploaded_to_img[gen_ai_file.uri] = image_pil
            self.img_to_uploaded[image_hash] = gen_ai_file
        os.remove(temp.name)
        return self.img_to_uploaded[image_hash]

    def reupload_image(self, gen_ai_filename: str) -> genai_types.File:
        original_img = self.uploaded_to_img.pop(gen_ai_filename)
        # original_img = self.uploaded_to_img[gen_ai_filename]
        new_gen_ai_file = self.get_upload_image_file(image=original_img, force_upload=True)
        return new_gen_ai_file

    def get_files_from_prompt(self, prompt: list[genai_types.Content]) -> list[str]:
        files = []
        for i, content in enumerate(prompt):
            if not content.parts:
                continue
            for j, part in enumerate(content.parts):
                if hasattr(part, "file_data") and part.file_data is not None:
                    files.append(part.file_data.file_uri)
        return files

    def reupload_images_for_prompt(self, prompt) -> Dict[str, genai_types.File]:
        cached_gen_ai_files = list(self.uploaded_to_img.keys())
        old_gen_ai_files = self.get_files_from_prompt(prompt)

        new_gen_ai_files = []
        for file in old_gen_ai_files:
            if file in cached_gen_ai_files:
                new_gen_ai_file = self.reupload_image(file)
                new_gen_ai_files.append(new_gen_ai_file)
            else:
                img = self.gen_ai_to_img(file)
                if img is not None:
                    new_gen_ai_file = self.get_upload_image_file(image=img, force_upload=True)
                    new_gen_ai_files.append(new_gen_ai_file)
                else:
                    raise ValueError(f"File {file} not found in cached_gen_ai_files")
        return {old_gen_ai_files[i]: new_gen_ai_files[i] for i in range(len(old_gen_ai_files))}

    def reupload_all_images(self) -> Dict[str, genai_types.File]:
        old_gen_ai_files = list(self.uploaded_to_img.keys())
        new_gen_ai_files = [self.reupload_image(filename) for filename in old_gen_ai_files]
        return {old_gen_ai_files[i]: new_gen_ai_files[i] for i in range(len(old_gen_ai_files))}

    def img_to_genai(self, image: Image.Image) -> Optional[genai_types.File]:
        image_hash = hash(image)
        return self.img_to_uploaded.get(image_hash)

    def gen_ai_to_img(self, gen_ai_filename: str) -> Optional[Image.Image]:
        return self.uploaded_to_img.get(gen_ai_filename)


google_file_managers = {}


def get_file_manager(p_id: int = 0) -> GoogleFileManager:
    if p_id not in google_file_managers:
        google_file_managers[p_id] = GoogleFileManager(p_id=p_id)
    return google_file_managers[p_id]
