from __future__ import annotations

from abc import ABC, abstractmethod
from enum import Enum
from pathlib import Path
from typing import ClassVar, Optional, Set, Tuple

import sys
import os
from contextlib import contextmanager
from PIL import Image
import cv2
from huggingface_hub import snapshot_download
from pydantic import BaseModel, Field, model_validator, PrivateAttr

import numpy as np
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import normalize

from transformers import CLIPTokenizer

from openai import OpenAI
from google import genai
from google.genai import types
from utils.api_keys import get_openai_key, get_google_key
from io import BytesIO
import PIL.Image
import base64
import requests
import functools


# ───────────────────────── helper: safe chdir ───────────────────────── #
@contextmanager
def pushd(path: Path):
    """Temporarily `os.chdir()` into *path* and then restore the original cwd."""
    prev = Path.cwd()
    os.chdir(path)
    try:
        yield
    finally:
        os.chdir(prev)


@contextmanager
def isolated_imports(module_prefixes=None, paths_to_add=None):
    """Context manager for isolating module imports.
    
    Temporarily modifies sys.path and clears specified modules from sys.modules.
    Restores the original state after the context exits.
    
    Args:
        module_prefixes: List of module name prefixes to clear from sys.modules
        paths_to_add: List of paths to temporarily add to sys.path
    """
    # Save original sys.path
    original_path = sys.path.copy()
    
    # Track modules to restore
    preserved_modules = {}
    
    # Clear modules if specified
    if module_prefixes:
        for prefix in module_prefixes:
            modules_to_clear = [mod for mod in sys.modules if mod.startswith(prefix)]
            for mod in modules_to_clear:
                preserved_modules[mod] = sys.modules[mod]
                del sys.modules[mod]
    
    # Add paths if specified
    if paths_to_add:
        for path in paths_to_add:
            sys.path.append(str(path))
    
    try:
        yield  # Run the code inside the context
    finally:
        # Restore original sys.path
        sys.path = original_path
        
        # Restore cleared modules
        for mod, mod_obj in preserved_modules.items():
            sys.modules[mod] = mod_obj

# ──────────────────────────────────────────────────────────────
# 1. Shared types
# ──────────────────────────────────────────────────────────────
class PayloadType(str, Enum):
    """Enum for all possible input types that any expert might need."""
    TEXT_PROMPT = "text_prompt"  # For background segmentation
    IMAGE = "image"
    TARGET_PROMPT = "target_prompt"  # For object generation
    BOUNDING_BOX = "bounding_box"


class InputPayload(BaseModel):
    """Schema for what can be fed to any expert."""
    # Basic inputs
    text_prompt: Optional[str] = Field(None, description="Text prompt from the user")
    image: Optional[Image.Image] = Field(None, description="Reference image")
    
    # Common additional inputs
    bounding_box: Optional[Tuple[float, float, float, float]] = Field(None, description="Bounding box (x, y, w, h). Should be normalized to 0-1")
    target_prompt: Optional[str] = Field(None, description="Target/goal prompt")

    @property
    def present_types(self) -> Set[PayloadType]:
        """Return set of payload types that are present (not None)."""
        present = set()
        if self.text_prompt:
            present.add(PayloadType.TEXT_PROMPT)
        if self.image:
            present.add(PayloadType.IMAGE)
        if self.target_prompt:
            present.add(PayloadType.TARGET_PROMPT)
        if self.bounding_box:
            present.add(PayloadType.BOUNDING_BOX)
        return present

    @model_validator(mode="after")
    def _check_required(self) -> 'InputPayload':
        """Ensure at least one input is provided."""
        if not any([
            self.text_prompt, self.image, self.target_prompt, self.bounding_box,
        ]):
            raise ValueError("At least one input must be provided")
        return self

    class Config:
        arbitrary_types_allowed = True   # PIL.Image passes straight through


# ──────────────────────────────────────────────────────────────
# 2. Base class
# ──────────────────────────────────────────────────────────────
class Expert(BaseModel, ABC):
    """
    Abstract base class for every image‑expert model.

    Subclasses only have to set the *class* attributes below and implement
    `_load_pipeline` + `_invoke_pipeline`.
    """

    # immutable "static" metadata (declared once per subclass)
    model_id: ClassVar[str]
    required_inputs: ClassVar[Set[PayloadType]]  # Required input types
    optional_inputs: ClassVar[Set[PayloadType]]  # Optional input types
    is_t2i: ClassVar[bool]

    # instance‑level configuration
    cache_dir: Path = Field(default_factory=lambda: Path("/datasets/uig/cached_models").expanduser())
    device: str = Field(default_factory=lambda: "cuda" if os.getenv("CUDA_VISIBLE_DEVICES") else "cpu")

    # lazy‑loaded pipeline handle (kept out of the Pydantic model schema)
    pipeline: object | None = Field(default=None, exclude=True)

    # ─── public API ────────────────────────────────────────────
    def run(self, **kwargs) -> Image.Image:
        """
        Unified call signature – kwargs are validated against `InputPayload`.
        Returns a PIL.Image.
        """
        payload = InputPayload(**kwargs)            # raises on invalid combos
        
        # Validate required inputs
        missing = self.required_inputs - payload.present_types
        if missing:
            raise ValueError(
                f"{self.__class__.__name__} requires {missing} but they were not provided."
            )
        
        # Validate no unknown inputs
        allowed = self.required_inputs | self.optional_inputs
        unknown = payload.present_types - allowed
        if unknown:
            raise ValueError(
                f"{self.__class__.__name__} does not accept {unknown}."
            )

        if self.pipeline is None:
            self.pipeline = self._load_pipeline()

        return self._invoke_pipeline(**payload.model_dump(exclude_none=True))

    __call__ = run                                              # convenience

    def info(self) -> dict[str, str]:
        return {
            "is_t2i": self.is_t2i,
            "required_inputs": [t.value for t in self.required_inputs],
            "optional_inputs": [t.value for t in self.optional_inputs]
        }

    # ─── subclass hooks ───────────────────────────────────────
    @abstractmethod
    def _load_pipeline(self):
        """Download (if needed) and materialise the underlying model."""

    @abstractmethod
    def _invoke_pipeline(self, **kwargs) -> Image.Image:
        """Do the real inference; subclasses convert inputs → PIL.Image."""

    # ─── model config ─────────────────────────────────────────
    class Config:
        arbitrary_types_allowed = True     # PIL.Image, torch tensors, etc.


# ──────────────────────────────────────────────────────────────
# 3. Concrete experts
# ──────────────────────────────────────────────────────────────
class SDXLExpert(Expert):
    """Stable Diffusion XL: High-fidelity, photorealistic text-to-image generation."""
    model_id = "stabilityai/stable-diffusion-xl-base-1.0"
    required_inputs = {PayloadType.TEXT_PROMPT}
    optional_inputs = set()
    is_t2i = True

    def _load_pipeline(self):
        from diffusers import StableDiffusionXLPipeline
        path = snapshot_download(self.model_id, cache_dir=self.cache_dir / self.model_id)
        pipe = StableDiffusionXLPipeline.from_pretrained(path, torch_dtype=torch.float16).to(self.device)
        pipe.enable_model_cpu_offload()
        return pipe

    def _invoke_pipeline(self, *, text_prompt: str, **__) -> Image.Image:
        with torch.inference_mode():
            out = self.pipeline(prompt=text_prompt,
                                guidance_scale=3.5,  ## changed from 7
                                num_inference_steps=35)
        return out.images[0]


class PixArtAlphaExpert(Expert):
    """PixArt-α: Fast, high-resolution text-to-image generation with artistic style support."""
    model_id = "PixArt-alpha/PixArt-XL-2-1024-MS"
    required_inputs = {PayloadType.TEXT_PROMPT}
    optional_inputs = set()
    is_t2i = True

    def _load_pipeline(self):
        from diffusers import PixArtAlphaPipeline
        path = snapshot_download(self.model_id, cache_dir=self.cache_dir / self.model_id)
        pipe = PixArtAlphaPipeline.from_pretrained(path, torch_dtype=torch.float16).to(self.device)
        pipe.enable_model_cpu_offload()
        return pipe

    def _invoke_pipeline(self, *, text_prompt: str, **__) -> Image.Image:
        with torch.inference_mode():
            out = self.pipeline(prompt=text_prompt,
                                num_inference_steps=30,
                                guidance_scale=6.0)
        return out.images[0]


class InstructPix2PixExpert(Expert):
    model_id = "timbrooks/instruct-pix2pix"
    required_inputs = {PayloadType.TEXT_PROMPT, PayloadType.IMAGE}
    optional_inputs = set()
    is_t2i = False

    def _load_pipeline(self):
        from diffusers import StableDiffusionInstructPix2PixPipeline
        path = snapshot_download(self.model_id, cache_dir=self.cache_dir / self.model_id)
        pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(path).to(self.device)
        pipe.enable_model_cpu_offload()
        return pipe

    def _invoke_pipeline(self, *, text_prompt: str, image: Image.Image, **__) -> Image.Image:
        with torch.inference_mode():
            out = self.pipeline(prompt=text_prompt, image=image,
                                 num_inference_steps=35, image_cfg_scale=2, text_cfg_scale=15)
        return out.images[0]


class MagicBrushExpert(Expert):
    repo_dir: Path = Path("/datasets/uig/cloned_model_repos/instruct-pix2pix").expanduser()
    ckpt_path: Path = Path("/datasets/uig/cloned_model_repos/instruct-pix2pix/checkpoints/MagicBrush-epoch-52-step-4999.ckpt").expanduser()
    config_path: Path = Path("/datasets/uig/cloned_model_repos/instruct-pix2pix/configs/generate.yaml").expanduser()
    required_inputs: Set[PayloadType] = {PayloadType.TEXT_PROMPT, PayloadType.IMAGE}
    optional_inputs = set()
    is_t2i = False

    def _load_pipeline(self):
        """
        Build and return a lightweight callable `pipe(prompt, image, **kwargs)`
        so that the rest of the framework can treat it like a diffusers pipeline.
        """
        from omegaconf import OmegaConf
        from PIL import ImageOps
        from einops import rearrange
        import k_diffusion as K

        repo_abs = self.repo_dir.resolve()
        
        # Use our isolated imports context manager
        paths_to_add = [
            str(repo_abs), 
            str(repo_abs / "stable_diffusion"),
            str(repo_abs / "stable_diffusion" / "ldm"),
            str(repo_abs / "stable_diffusion" / "ldm" / "models"),
            str(repo_abs / "stable_diffusion" / "ldm" / "models" / "diffusion")
        ]
        
        with isolated_imports(module_prefixes=['ldm'], paths_to_add=paths_to_add):
            with pushd(repo_abs):
                from edit_cli import CFGDenoiser, load_model_from_config
            
            cfg = OmegaConf.load(self.config_path)
            model = load_model_from_config(cfg, str(self.ckpt_path), None)
            model.eval().to(self.device)
            model_wrap = K.external.CompVisDenoiser(model)
            model_wrap_cfg = CFGDenoiser(model_wrap)
            null_token = model.get_learned_conditioning([""])

            # ── Build a small callable that hides the implementation details ── #
            def _pipe(*, prompt: str, image: Image.Image, num_inference_steps: int = 35,
                    text_cfg_scale: float = 7.5, image_cfg_scale: float = 1.5,
                    seed: int | None = None):
                import math
                # 1. Pre‑process image (long edge → 512 px, multiples of 64)
                w, h = image.size
                factor = 512 / max(w, h)
                factor = math.ceil(min(w, h) * factor / 64) * 64 / min(w, h)
                new_w  = int((w * factor) // 64) * 64
                new_h  = int((h * factor) // 64) * 64
                img_rs = ImageOps.fit(image.convert("RGB"), (new_w, new_h), method=Image.Resampling.LANCZOS)

                with torch.no_grad(), torch.autocast(model.device.type), model.ema_scope():
                    # Conditioning
                    cond = {
                        "c_crossattn": [model.get_learned_conditioning([prompt])]
                    }
                    img_tensor = 2 * torch.tensor(np.array(img_rs)).float() / 255 - 1
                    img_tensor = rearrange(img_tensor, "h w c -> 1 c h w").to(model.device)
                    cond["c_concat"] = [model.encode_first_stage(img_tensor).mode()]

                    uncond = {
                        "c_crossattn": [null_token],
                        "c_concat":    [torch.zeros_like(cond["c_concat"][0])]
                    }

                    # Diffusion
                    sigmas = model_wrap.get_sigmas(num_inference_steps)
                    if seed is not None:
                        torch.manual_seed(seed)
                    z = torch.randn_like(cond["c_concat"][0]) * sigmas[0]
                    extra = dict(cond=cond, uncond=uncond,
                                text_cfg_scale=text_cfg_scale,
                                image_cfg_scale=image_cfg_scale)
                    z = K.sampling.sample_euler_ancestral(model_wrap_cfg, z, sigmas, extra_args=extra)

                    # Decode
                    x = model.decode_first_stage(z)
                    x = torch.clamp((x + 1.0) / 2.0, 0.0, 1.0)
                    x = 255.0 * rearrange(x, "1 c h w -> h w c")
                    edited = Image.fromarray(x.type(torch.uint8).cpu().numpy())
                    return edited

        return _pipe

    def _invoke_pipeline(self, *, text_prompt: str, image: Image.Image, **__) -> Image.Image:
        # `self.pipeline` is cached by the base `Expert` class
        edited = self.pipeline(prompt=text_prompt, image=image)
        return edited

class StableDiffusion35LargeExpert(Expert):
    """Stable Diffusion XL: High-fidelity, photorealistic text-to-image generation."""
    model_id = "stabilityai/stable-diffusion-3.5-large"
    required_inputs = {PayloadType.TEXT_PROMPT}
    optional_inputs = set()
    is_t2i = True

    def _load_pipeline(self):
        from diffusers import StableDiffusion3Pipeline
        path = snapshot_download(self.model_id, cache_dir=self.cache_dir / self.model_id)
        pipe = StableDiffusion3Pipeline.from_pretrained(path, torch_dtype=torch.bfloat16).to(self.device)
        pipe.enable_model_cpu_offload()
        return pipe

    def _invoke_pipeline(self, *, text_prompt: str, **__) -> Image.Image:
        with torch.inference_mode():
            out = self.pipeline(prompt=text_prompt,
                                guidance_scale=7.0,
                                num_inference_steps=35)
        return out.images[0]


class FluxT2IExpert(Expert):
    model_id: str = "black-forest-labs/FLUX.1-dev"
    required_inputs: Set[PayloadType] = {PayloadType.TEXT_PROMPT}
    optional_inputs: Set[PayloadType] = set()
    is_t2i = True

    def _load_pipeline(self):
        from diffusers import FluxPipeline
        path = snapshot_download(self.model_id, cache_dir=self.cache_dir / self.model_id)
        pipe = FluxPipeline.from_pretrained(path, torch_dtype=torch.float16)
        pipe.enable_model_cpu_offload()
        return pipe

    def _invoke_pipeline(self, *, text_prompt: str, **__) -> Image.Image:
        images = self.pipeline(
            text_prompt,
            height=1024,
            width=1024,
            guidance_scale=3.5,
            num_inference_steps=35,
            max_sequence_length=512,
            generator=torch.Generator(device="cpu").manual_seed(0),
        ).images

        return images[0]


class DALLE3T2IExpert(Expert):
    required_inputs: Set[PayloadType] = {PayloadType.TEXT_PROMPT}
    optional_inputs: Set[PayloadType] = set()
    is_t2i = True

    def _load_pipeline(self):
        client = OpenAI(api_key=get_openai_key())
        pipe = functools.partial(client.images.generate, model="dall-e-3", size="1024x1024", quality="standard")
        return pipe

    def _invoke_pipeline(self, *, text_prompt: str, **__) -> Image.Image:
        result = self.pipeline(prompt=text_prompt)
        response = requests.get(result.data[0].url)
        image = Image.open(BytesIO(response.content)).convert("RGB")
        return image


class GPTImage1T2IExpert(Expert):
    required_inputs: Set[PayloadType] = {PayloadType.TEXT_PROMPT}
    optional_inputs: Set[PayloadType] = set()
    is_t2i = True

    def _load_pipeline(self):
        client = OpenAI(api_key=get_openai_key())
        pipe = functools.partial(client.images.generate, model="gpt-image-1", size="1024x1024", moderation="low", quality="high")
        return pipe

    def _invoke_pipeline(self, *, text_prompt: str, **__) -> Image.Image:
        result = self.pipeline(prompt=text_prompt)
        image_bytes = base64.b64decode(result.data[0].b64_json)
        image = Image.open(BytesIO(image_bytes)).convert("RGB")
        return image


class GeminiT2IExpert(Expert):
    required_inputs: Set[PayloadType] = {PayloadType.TEXT_PROMPT}
    optional_inputs: Set[PayloadType] = set()
    is_t2i = True

    def _load_pipeline(self):
        client = genai.Client(api_key=get_google_key())
        pipe = functools.partial(
            client.models.generate_content, 
            model="gemini-2.5-flash-image-preview",
            config=types.GenerateContentConfig(response_modalities=['TEXT', 'IMAGE'])
        )
        return pipe

    def _invoke_pipeline(self, *, text_prompt: str, **__) -> Image.Image:
        response = self.pipeline(contents=text_prompt)
        for part in response.candidates[0].content.parts:
            if part.inline_data is not None:
                image = Image.open(BytesIO((part.inline_data.data)))
        return image


class GPTImage1I2IExpert(Expert):
    required_inputs: Set[PayloadType] = {PayloadType.TEXT_PROMPT, PayloadType.IMAGE}
    optional_inputs: Set[PayloadType] = set()
    is_t2i = False

    def _load_pipeline(self):
        client = OpenAI(api_key=get_openai_key())
        pipe = functools.partial(client.images.edit, model="gpt-image-1", input_fidelity="low", size="1024x1024", quality="high")
        return pipe

    def _invoke_pipeline(self, *, text_prompt: str, image: Image.Image, **__) -> Image.Image:
        result = self.pipeline(prompt=text_prompt, image=[open(image.filename, "rb")])
        image_bytes = base64.b64decode(result.data[0].b64_json)
        image = Image.open(BytesIO(image_bytes)).convert("RGB")
        return image


class GeminiI2IExpert(Expert):
    required_inputs: Set[PayloadType] = {PayloadType.TEXT_PROMPT, PayloadType.IMAGE}
    optional_inputs: Set[PayloadType] = set()
    is_t2i = False

    def _load_pipeline(self):
        client = genai.Client(api_key=get_google_key())
        pipe = functools.partial(
            client.models.generate_content, 
            model="gemini-2.5-flash-image-preview",
            config=types.GenerateContentConfig(response_modalities=['TEXT', 'IMAGE'])
        )
        return pipe

    def _invoke_pipeline(self, *, text_prompt: str, image: Image.Image, **__) -> Image.Image:
        response = self.pipeline(contents=[text_prompt, image])
        for part in response.candidates[0].content.parts:
            if part.inline_data is not None:
                image = Image.open(BytesIO((part.inline_data.data)))
        return image


class FluxKontextExpert(Expert):
    model_id: str = "black-forest-labs/FLUX.1-Kontext-dev"
    required_inputs: Set[PayloadType] = {PayloadType.TEXT_PROMPT, PayloadType.IMAGE}
    optional_inputs: Set[PayloadType] = set()
    is_t2i = False

    def _load_pipeline(self):
        from diffusers import FluxKontextPipeline
        path = snapshot_download(self.model_id, cache_dir=self.cache_dir / self.model_id)
        pipe = FluxKontextPipeline.from_pretrained(path, torch_dtype=torch.bfloat16)
        pipe.enable_model_cpu_offload()
        return pipe

    def _invoke_pipeline(self, *, text_prompt: str, image: Image.Image, **__) -> Image.Image:
        images = self.pipeline(
            prompt=text_prompt,
            image=image,
            guidance_scale=2.5,
            num_inference_steps=15,
        ).images

        return images[0]


# ──────────────────────────────────────────────────────────────
# 4. Factory
# ──────────────────────────────────────────────────────────────
def create_experts(cache_dir: str | Path = "/datasets/uig/cached_models") -> dict[str, Expert]:

    """
    Build a registry of every concrete Expert subclass, download weights once,
    then drop the pipeline to free GPU/CPU RAM.  Keys are snake‑case class names.
    """
    cache_dir = Path(cache_dir).expanduser()
    cache_dir.mkdir(parents=True, exist_ok=True)

    registry: dict[str, Expert] = {}
    for cls in Expert.__subclasses__():
        key = cls.__name__.replace("Expert", "")
        agent = cls(cache_dir=cache_dir)
        # proactive download
        print(f"Downloading {key} model")
        _ = agent._load_pipeline()
        agent.pipeline = None
        registry[key] = agent

    return registry

class ExpertInfo(BaseModel):
    """Schema for expert system information."""
    required_inputs: list[str]
    optional_inputs: list[str]
    is_t2i: bool
