from gradio_client import Client, file
from PIL import Image, ImageDraw
import numpy as np
import tempfile
import time, os
import cv2, json, os, sys, time, random
import numpy as np
from PIL import Image
from matplotlib import colormaps
from matplotlib.colors import Normalize

from cycle_detect import *
import requests
from openai import OpenAI

import logging
import httpx

from config import ANIME_ADDRESS, CONTOUR_ADDRESS, OPEN_ADDRESS, PS_ADDRESS, ADD_DOT


sk_type = os.getenv("sk_type", "open")
if sk_type == "ps":
    sk_client = Client(PS_ADDRESS)
elif sk_type == "anime":
    sk_client = Client(ANIME_ADDRESS)
elif sk_type == "contour":
    sk_client = Client(CONTOUR_ADDRESS)
elif sk_type == "open":
    sk_client = Client(OPEN_ADDRESS)

client = OpenAI(base_url=os.environ.get("OPENAI_BASE_URL", r"https://svip.xty.app/v1"), 
    api_key=os.environ.get("OPENAI_API_KEY", r"sk-AxXFgj6Y0uvXXoi5F8D8806085734bA18c669eF2C81f2926"),
    http_client=httpx.Client(
        base_url=os.environ.get("OPENAI_BASE_URL", r"https://svip.xty.app/v1"),
        follow_redirects=True,
    ))

PRICE = {
    "gpt-4o": (5e-6, 15e-6)
}

def generate_answer(contexts, model="gpt-4o", temperature=0, requires_json=False, interface=False, max_try=3):
    input_price = PRICE["gpt-4o"][0]
    output_price = PRICE["gpt-4o"][1]
    for _ in range(max_try):
        try:
            response = client.chat.completions.create(
                model=model,
                messages=contexts,
                temperature = temperature,
                response_format = {"type": "json_object"}
            )
            
            price = response.usage.prompt_tokens * input_price + response.usage.completion_tokens * output_price
            
            with open("price.txt", "a") as f:
                f.write(f"LLM Calling Price: {price}$\n")
            return response.choices[0].message.content 
        except Exception as e:
            try:
                if "error" in e.response.json() and "maximum context length" in e.response.json()["error"]["message"]:
                    logging.warning(f"Error in generating answer: {e}")
                    return "The context is too long", 0
            except:
                pass
            logging.warning(f"Error in generating answer: {e}")
            logging.warning("Sleep for 5 seconds")
            time.sleep(5)
    logging.error("Failed to generate answer")
    raise Exception("Failed to generate answer")


def canny(image: Image.Image, low_threshold: int = 50, high_threshold: int = 200) -> Image.Image:
    """
    Perform Canny edge detection and return a black-and-white edge map.

    Args:
        image (PIL.Image.Image): The input image.
        low_threshold (int): Lower threshold for edge detection.
        high_threshold (int): Upper threshold for edge detection.

    Returns:
        PIL.Image.Image: Black-and-white edge map.
    """
    gray = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY)
    blur  = cv2.GaussianBlur(gray, (7, 7), 0)
    edges = cv2.Canny(blur, low_threshold, high_threshold)
    inverted = cv2.bitwise_not(edges)
    return Image.fromarray(inverted)

def binary(image: Image.Image, threshold: int = 127, max_value: int = 255, method: int = cv2.THRESH_BINARY) -> Image.Image:
    """
    Convert an image to binary (black and white) using thresholding.

    Args:
        image (PIL.Image.Image): The input image.
        threshold (int): Threshold value for binarization (0-255).
        max_value (int): Maximum value to use for the binary image.
        method (int): OpenCV thresholding method, default is cv2.THRESH_BINARY.
                     Other options include:
                     - cv2.THRESH_BINARY_INV (inverted binary)
                     - cv2.THRESH_TRUNC
                     - cv2.THRESH_TOZERO
                     - cv2.THRESH_TOZERO_INV

    Returns:
        PIL.Image.Image: Binary (black and white) image.
    """
    img_array = np.array(image)
    
    if len(img_array.shape) == 3:
        gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
    else:
        gray = img_array
    
    blur = cv2.GaussianBlur(gray, (5, 5), 0)
    
    _, binary_img = cv2.threshold(blur, threshold, max_value, method)
    
    return Image.fromarray(binary_img)

def sketch(image: Image.Image):
    """
    Generate a sketch-style object image using sketching tool, tailored for reference image in blink sem_corr task.
    
    Args:
        image (PIL.Image.Image): The input image.
    
    Returns:
        PIL.Image.Image: The generated sketch image.
    """
    
    with tempfile.NamedTemporaryFile(delete=True) as tmp_file:
        image.save(tmp_file.name, 'JPEG')
        image = tmp_file.name
        
        outputs = sk_client.predict(file(image))
    
        output_image = Image.open(outputs)
        if ADD_DOT:
            circles = detect_red_circles(Image.open(image))
            output_image = draw_circles_on_image(output_image, circles)
        else:
            circles = []

    return output_image

def gpt_sketch(image: Image.Image, sketch_prompt: str = "Convert this image to a sketch") -> Image.Image:
    """
    Generate a sketch-style image using gpt-image-1 based on a prompt, via the standard interface.

    Args:
        image (PIL.Image.Image): The input image.
        sketch_prompt (str): Instructional prompt for GPT-4V.

    Returns:
        PIL.Image.Image: The generated sketch image.
    """
    import base64
    from io import BytesIO

    buffered = BytesIO()
    image.save(buffered, format="JPEG")
    image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
    
    contexts = [
        {"role": "user", "content": [
            {"type": "text", "text": sketch_prompt},
            {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}}
        ]}
    ]

    # result = generate_answer(contexts, model="gpt-image-1", temperature=0)
    result = generate_answer(contexts, model="dall-e", temperature=0)

    if "data:image" in result:
        base64_str = result.split("base64,")[-1].split('"')[0].strip()
        sketch_image = Image.open(BytesIO(base64.b64decode(base64_str)))
        return sketch_image
    else:
        raise ValueError("No image returned in model response.")