import base64
import os
import json
import openai
from typing import Dict, Any, List, Optional, Tuple
import logging
import time
from PIL import Image
import io
import requests
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
openai.api_key = ""
openai.api_base = ""
OUTPUT_DIR = ""
OUTPUT_JSON = ""
ANYSD_URL = "http://localhost:7999/edit"
def encode_image(image_path: str) -> Optional[str]:
    try:
        with open(image_path, "rb") as image_file:
            encoded = base64.b64encode(image_file.read()).decode('utf-8')
            logger.info(f"Successfully encoded image at {image_path}")
            return encoded
    except FileNotFoundError:
        logger.error(f"Image file not found at {image_path}")
        return None
    except Exception as e:
        logger.error(f"Error encoding image at {image_path}: {str(e)}")
        return None
def get_edit_suggestion(data_dict: Dict[str, Any], image_path: str) -> Dict[str, Any]:
    try:
        logger.info(f"Processing image at {image_path}")
        image_base64 = encode_image(image_path)
        if image_base64 is None:
            logger.error(f"Image encoding failed for {image_path}")
            return {"edit_type": "failed", "prompts": {}}

        question = str(data_dict.get('question', ''))
        answer = str(data_dict.get('answer', ''))
        logger.info(f"Question: {question}")
        logger.info(f"Answer: {answer}")
        prompt = f"""You are a data augmentation expert. Your task is to edit an image based on the following data. I will send you the image, and you need to generate a high quality new image as part of the dataset.
        - question: {question}
        - answer: {answer}
        Please analyze the image content and select ONE of the following data augmentation methods: replace or remove. DO NOT USE styletransfer under any circumstances. Choose either replace or remove based on your judgment, aiming to use both methods as evenly as possible across multiple runs. Provide relevant prompts for the selected method:
        - [replace] Replace key objects (e.g., change the object related to the answer). Provide a segment_prompt (segmentation prompt, such as "segment the green apple and return mask"), a source_prompt (source description, such as "a group of red apples with one green apple"), and a target_prompt (target description, such as "a group of red apples with one pear").
        - [remove] Remove key objects (e.g., remove the object related to the answer). Provide a segment_prompt (segmentation prompt, such as "segment the green apple and return mask"), a source_prompt (source description, such as "a group of red apples with one green apple"), and a target_prompt (target description, such as "a group of red apples without any green apple").
        Return the result in JSON string format, for example:
        if you use replace: {{"edit_type": "replace", "prompts": {{"segment_prompt": "...", "source_prompt": "...", "target_prompt": "..."}}}}
        if you use remove: {{"edit_type": "remove", "prompts": {{"segment_prompt": "...", "source_prompt": "...", "target_prompt": "..."}}}}
        Please strictly follow the format.
        Only return the JSON string without additional explanations.
        Make sure the segment_prompt is easy for model to segment"""
        logger.info("Sending request to DashScope API")
        response = openai.ChatCompletion.create(
            model="qwen-omni-turbo",
            messages=[
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": prompt},
                        {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}}
                    ]
                }
            ],
            max_tokens=500,
            temperature=0.9,
            stream=True
        )
        full_response = ""
        for chunk in response:
            if 'choices' in chunk and len(chunk['choices']) > 0:
                delta = chunk['choices'][0].get('delta', {})
                content = delta.get('content')
                if content is not None:
                    full_response += content
                else:
                    logger.warning(f"Received None content in delta: {delta}")
            else:
                logger.warning(f"No choices in chunk: {chunk}")
        full_response = full_response.strip()
        if full_response.startswith("```json"):
            full_response = full_response[len("```json"):].strip()
        if full_response.endswith("```"):
            full_response = full_response[:-len("```")].strip()
        logger.info(f"Raw API response: {full_response}")
        try:
            result = json.loads(full_response)
            return result
        except json.JSONDecodeError as e:
            logger.error(f"Response is not valid JSON after cleanup: {full_response}")
            logger.error(f"JSON decode error: {str(e)}")
            return {"edit_type": "failed", "prompts": {}}
    except Exception as e:
        logger.error(f"Error in get_edit_suggestion for {image_path}: {str(e)}")
        return {"edit_type": "failed", "prompts": {}}
def replace_image(image_path: str, source_prompt: str, target_prompt: str, segment_prompt: str, output_path: str) -> \
Optional[str]:
    try:
        with open(image_path, "rb") as img_file:
            image_data = img_file.read()
        try:
            img = Image.open(io.BytesIO(image_data))
            img.verify()
            logger.info(f"Image verified as valid at {image_path}")
            if img.mode == 'RGBA':
                logger.info(f"Converting RGBA image to RGB at {image_path}")
                img = img.convert('RGB')
            max_size = 1024
            img = Image.open(io.BytesIO(image_data))
            img.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
            buffer = io.BytesIO()
            img.save(buffer, format="JPEG", quality=85)
            image_data = buffer.getvalue()
            buffer.close()
        except Exception as e:
            logger.error(f"Error: File at {image_path} is not a valid image - {str(e)}")
            return None
        if not segment_prompt:
            logger.error(f"Error: segment_prompt is empty or None for {image_path}")
            return None
        logger.info(f"Segment prompt: {segment_prompt}")
        LISA_URL = "http://localhost:8001/process"
        lisa_response = requests.post(
            LISA_URL,
            files={"image": image_data},
            data={"prompt": segment_prompt}
        )
        logger.info(f"Lisa API status code: {lisa_response.status_code}")
        logger.info(f"Lisa API response: {lisa_response.text}")

        if lisa_response.status_code != 200:
            raise Exception(f"Lisa API failed: {lisa_response.text}")
        lisa_result = json.loads(lisa_response.text)
        mask_paths = lisa_result.get("mask_paths", [])
        if not mask_paths:
            raise Exception("No masks generated by LISA")
        mask_path = mask_paths[0]
        logger.info(f"Mask path generated: {mask_path}")
        with open(mask_path, "rb") as mask_file:
            mask_data = mask_file.read()
        KVEDIT_URL = "http://localhost:5000/edit_image"
        kvedit_response = requests.post(
            KVEDIT_URL,
            files={
                "image": ("image.jpg", image_data, "image/jpeg"),
                "mask": ("mask.png", mask_data, "image/png")
            },
            data={
                "source_prompt": source_prompt,
                "target_prompt": target_prompt,
                "inversion_num_steps": 16,
                "denoise_num_steps": 16,
                "inversion_guidance": 1.5,
                "denoise_guidance": 5.5,
                "seed": 42
            }
        )
        if kvedit_response.status_code != 200:
            raise Exception(f"Kvedit API failed: {kvedit_response.text}")
        with open(output_path, "wb") as out_file:
            out_file.write(kvedit_response.content)
        logger.info(f"Replace operation completed. Saved to {output_path}")
        return output_path
    except Exception as e:
        logger.error(f"Replace error: {str(e)}")
        return None
def remove_image(image_path: str, source_prompt: str, target_prompt: str, segment_prompt: str, output_path: str) -> \
Optional[str]:
    try:
        with open(image_path, "rb") as img_file:
            image_data = img_file.read()
        try:
            img = Image.open(io.BytesIO(image_data))
            img.verify()
            logger.info(f"Image verified as valid at {image_path}")
            max_size = 1024
            img = Image.open(io.BytesIO(image_data))
            img.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
            buffer = io.BytesIO()
            img.save(buffer, format="JPEG", quality=100)
            image_data = buffer.getvalue()
            buffer.close()
        except Exception as e:
            logger.error(f"Error: File at {image_path} is not a valid image - {str(e)}")
            return None
        if not segment_prompt:
            logger.error(f"Error: segment_prompt is empty or None for {image_path}")
            return None
        logger.info(f"Segment prompt: {segment_prompt}")

        LISA_URL = "http://localhost:8001/process"
        lisa_response = requests.post(
            LISA_URL,
            files={"image": image_data},
            data={"prompt": segment_prompt}
        )
        logger.info(f"Lisa API status code: {lisa_response.status_code}")
        logger.info(f"Lisa API response: {lisa_response.text}")
        if lisa_response.status_code != 200:
            raise Exception(f"Lisa API failed: {lisa_response.text}")
        lisa_result = json.loads(lisa_response.text)
        mask_paths = lisa_result.get("mask_paths", [])
        if not mask_paths:
            raise Exception("No masks generated by LISA")
        mask_path = mask_paths[0]
        logger.info(f"Mask path generated: {mask_path}")
        with open(mask_path, "rb") as mask_file:
            mask_data = mask_file.read()
        KVEDIT_URL = "http://localhost:5000/edit_image"
        kvedit_response = requests.post(
            KVEDIT_URL,
            files={
                "image": ("image.jpg", image_data, "image/jpeg"),
                "mask": ("mask.png", mask_data, "image/png")
            },
            data={
                "source_prompt": source_prompt,
                "target_prompt": target_prompt,
                "inversion_num_steps": 16,
                "denoise_num_steps": 16,
                "inversion_guidance": 1.5,
                "denoise_guidance": 5.5,
                "seed": 42
            }
        )
        if kvedit_response.status_code != 200:
            logger.error(f"Kvedit API failed with status {kvedit_response.status_code}: {kvedit_response.text}")
            return None

        with open(output_path, "wb") as out_file:
            out_file.write(kvedit_response.content)
        logger.info(f"Remove operation completed. Saved to {output_path}")
        return output_path

    except Exception as e:
        logger.error(f"Remove error: {str(e)}")
        return None
def style_transfer_image(image_path: str, style_prompt: str, output_path: str) -> Optional[str]:
    try:
        with open(image_path, "rb") as img_file:
            anysd_response = requests.post(
                ANYSD_URL,
                files={"original_image": img_file},
                data={
                    "edit": style_prompt,
                    "edit_type": "general"
                }
            )
        if anysd_response.status_code != 200:
            raise Exception(f"Anysd API failed: {anysd_response.text}")

        with open(output_path, "wb") as out_file:
            out_file.write(anysd_response.content)
        logger.info(f"Style transfer completed. Saved to {output_path}")
        return output_path

    except Exception as e:
        logger.error(f"Style transfer error: {str(e)}")
        return None
def main():
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    try:
        with open('', 'r', encoding='utf-8') as f:
            data = json.load(f)
        logger.info(f"Loaded {len(data)} samples from hyh3.json")
    except Exception as e:
        logger.error(f"Error loading queries.json: {str(e)}")
        return
    data=data["data"]
    new_dataset = []
    count = 0
    for sample in data:
        image_path = sample['extend_data_2']['image']
        start_index = image_path.find('[') + 1
        end_index = image_path.rfind(']')
        image_path = image_path[start_index:end_index]
        image_path = image_path.replace("'", "")
        type=sample['origin_data']['data_source']
        source_id=sample['origin_data']['source_id']
        output_path = os.path.join(OUTPUT_DIR, f"{type}/extended_data_3_{type}_{source_id}.jpg")
        if os.path.exists(output_path):
            logger.info(f"Output file already exists for sample {count}, skipping")
            count=count+1
            continue
        if not os.path.exists(image_path):
            logger.error(f"Image file not found at {image_path} for sample {count}")
            continue
        result = get_edit_suggestion(sample, image_path)
        logger.info(f"Edit suggestion for sample {count}: {result}")
        edit_type = result.get("edit_type")
        prompts = result.get("prompts", {})
        new_sample = sample.copy()
        if edit_type == "replace":
            segment_prompt = prompts.get("segment_prompt")
            source_prompt = prompts.get("source_prompt")
            target_prompt = prompts.get("target_prompt")
            edited_image_path = replace_image(image_path, source_prompt, target_prompt, segment_prompt, output_path)
            if edited_image_path is None:
                logger.info(f"Replace failed for sample {count}")
                count = count + 1
                continue
            if edited_image_path:
                new_sample["image"] = [os.path.basename(edited_image_path)]
        elif edit_type == "remove":
            segment_prompt = prompts.get("segment_prompt")
            source_prompt = prompts.get("source_prompt")
            target_prompt = prompts.get("target_prompt")
            edited_image_path = remove_image(image_path, source_prompt, target_prompt, segment_prompt, output_path)
            if edited_image_path is None:
                logger.info(f"Remove failed for sample {count}")
                count = count + 1
                continue
            if edited_image_path:
                new_sample["image"] = [os.path.basename(edited_image_path)]
        elif edit_type == "styletransfer":
            logger.warning(f"Style transfer was unexpectedly returned for sample {sample['sample_id']}: {result}")
            continue
        elif edit_type == "failed":
            logger.info(f"Sample ID {count} - Failed to generate edit suggestion")
            continue
        new_dataset.append(new_sample)
        count += 1
    try:
        with open(OUTPUT_JSON, 'w', encoding='utf-8') as f:
            json.dump(new_dataset, f, indent=4, ensure_ascii=False)
        logger.info(f"New dataset saved to {OUTPUT_JSON} with {len(new_dataset)} samples")
    except Exception as e:
        logger.error(f"Error saving dataset: {str(e)}")
if __name__ == "__main__":
    main()