import argparse
import base64
import json
import os
import time
import uuid
from io import BytesIO

import cv2
import numpy as np
from PIL import Image
from vllm import LLM  # noqa
from vllm.sampling_params import SamplingParams  # noqa

parser = argparse.ArgumentParser()
parser.add_argument('--split_id', type=int, default=0)
parser.add_argument('--split_num', type=int, default=1)
parser.add_argument('--image_path', type=str, default="no_bg_image")
args = parser.parse_args()


prompts_1 = """Please describe the characters on the upper-body garment in each of the two images separately."""

prompts_2 = """Please rate the consistency of the characters on the upper-body garments in the two images on a scale from 0 to 5. If the characters are completely consistent, please output 5. If the characters are completely inconsistent, please output 0, and do not output anything else."""


prompts_3 = """Review your previous responses step-by-step. If you are completely confident that your reasoning and conclusions are correct, respond with "Correct". If you find any errors or are uncertain, respond with "Wrong"."""




def encode_image_base64(
        
        image: Image.Image,
        *,
        image_mode: str = "RGB",
        format: str = "JPEG",
) -> str:
    buffered = BytesIO()
    image = image.convert(image_mode)
    image.save(buffered, format)
    return base64.b64encode(buffered.getvalue()).decode('utf-8')


def get_response(llm, images, instructions):
    if isinstance(instructions, str):
        instructions = [instructions] * len(images)
    messages = []
    for image, inst in zip(images, instructions):
        content = [{"type": "text", "text": inst}]
        if isinstance(image, Image.Image):
            image = [image, ]
        for img in image:
            base64_image = encode_image_base64(img)
            content.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}})

        messages.append([
            {
                "role": "user",
                "content": content,
            }]
        )
    outputs = llm.chat(messages=messages, sampling_params=sampling_params)
    all_text = [k.outputs[0].text for k in outputs]
    return all_text


def get_response_multi_round(llm, images, instruction_1, other_instructions):
    if isinstance(instruction_1, str):
        instruction_1 = [instruction_1] * len(images)
    messages = []
    for image, inst in zip(images, instruction_1):
        content = [{"type": "text", "text": inst}]
        if isinstance(image, Image.Image):
            image = [image, ]
        for img in image:
            base64_image = encode_image_base64(img)
            content.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}})

        messages.append([
            {
                "role": "user",
                "content": content,
            }]
        )
    outputs = llm.chat(messages=messages, sampling_params=sampling_params)
    all_text = [k.outputs[0].text for k in outputs]
    output_full = [all_text, ]
    for this_round_instruction in other_instructions:
        for idx in range(len(messages)):
            messages[idx].append(
                {
                    "role": "system",
                    "content": [{"type": "text", "text": all_text[idx]}]
                }
            )

            a = this_round_instruction[idx]
            b = all_text[idx]
            messages[idx].append(
                {
                    "role": "user",
                    "content": [{"type": "text", "text": this_round_instruction}]
                }
            )
        new_outputs = llm.chat(messages=messages, sampling_params=sampling_params)
        all_text = [k.outputs[0].text for k in new_outputs]
        output_full.append(all_text)
    return output_full


def find_significant_changes_pil(img_before, img_after,
                                 blur_kernel=(9, 9), threshold_value=30):
    """
    Identify significant differences between two PIL images and highlight them.

    Parameters:
    - img_before (PIL.Image): The original image.
    - img_after (PIL.Image): The edited image.
    - blur_kernel (tuple): Kernel size for Gaussian blur.
    - threshold_value (int): Threshold for significant differences.

    Returns:
    - PIL.Image: The 'after' image with significant differences highlighted in red.
    """

    # Ensure both images are in RGB mode
    img_before = img_before.convert('RGB')
    img_after = img_after.convert('RGB')

    # Ensure both images have the same size
    if img_before.size != img_after.size:
        img_after = img_after.resize(img_before.size, Image.ANTIALIAS)

    # Convert PIL Images to NumPy arrays
    np_before = np.array(img_before)
    np_after = np.array(img_after)

    # Compute absolute difference
    diff = cv2.absdiff(np_before, np_after)

    # Convert to grayscale
    gray_diff = cv2.cvtColor(diff, cv2.COLOR_RGB2GRAY)

    # Apply Gaussian Blur to reduce noise and minor differences
    blurred = cv2.GaussianBlur(gray_diff, blur_kernel, 0)

    # Threshold the blurred image to get binary mask of significant changes
    _, thresh = cv2.threshold(blurred, threshold_value, 255, cv2.THRESH_BINARY)

    # Apply morphological operations to clean up the mask
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
    thresh = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel, iterations=4)
    thresh = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel, iterations=4)

    # Create a mask where significant differences are marked
    mask = thresh.astype(bool)

    dim_factor = 0.15  # Dim by 70%
    dimmed = (np_after * dim_factor).astype(np.uint8)

    # Initialize the output image with the dimmed image
    output = dimmed.copy()

    # Overlay the original 'after' image where differences are detected
    output[mask] = np_after[mask]

    # Convert the NumPy array back to a PIL Image in RGB mode
    result_image = Image.fromarray(output).convert('RGB')
    
    return result_image


def create_random_name():
    return str(uuid.uuid4())


def extract_to_dev(tar_path):
    des_folder = "/dev/shm/{}".format(create_random_name())
    os.makedirs(des_folder, exist_ok=True)
    cmd = f"tar -xf {tar_path} -C {des_folder}"
    os.system(cmd)
    return des_folder


def resize_max_side(image, max_side=1024):
    width, height = image.size
    if width > height:
        new_width = max_side
        new_height = int(height * max_side / width)
    else:
        new_height = max_side
        new_width = int(width * max_side / height)
    return image.resize((new_width, new_height), Image.LANCZOS)


if __name__ == '__main__':
    all_image_list_path = f"/{args.image_path}.txt"
    all_image_list = open(all_image_list_path).readlines()
    all_image_list = [k.strip() for k in all_image_list]
    cur_json_list = all_image_list[args.split_id::args.split_num]
    print(f"Processing {len(cur_json_list)} images")
    model_name = "/mnt/pretrained_model/pixtral_ckpt_ours/"
    sampling_params = SamplingParams(max_tokens=16384, temperature=0.70,seed=1024)
    llm = LLM(model=model_name, tokenizer_mode="mistral", limit_mm_per_prompt={"image": 2},
              max_model_len=32768)
    # llm = LLM(model=model_name, limit_mm_per_prompt={"image": 10, "video": 10},max_model_len=32768)
    batch_size = 16
    all_responses = []
    start_time = time.time()

    des_save_name = f"/mnt/{args.image_path}.json"
    tmp_name = f"/mnt/{args.image_path}_tmp.json"
    
    for idx in range(0, len(cur_json_list), batch_size):
        cur_jsons = cur_json_list[idx:idx + batch_size]
        src_names = [k.replace(".json", ".src.jpg") for k in cur_jsons]
        tgt_names = [k.replace(".json", ".tgt.jpg") for k in cur_jsons]
        cur_images = [[resize_max_side(Image.open(os.path.join("/mnt/dataset/cloth/",k.replace("_0.jpg","_1.jpg")))), resize_max_side(Image.open(os.path.join(f"/mnt/show_case/{args.image_path}",v)))] for k, v in
                      zip(src_names, tgt_names)]

        try:
            cur_responses = get_response_multi_round(llm, cur_images, prompts_1, [prompts_2, prompts_3])
            overall_response = []
            for idx in range(len(cur_responses[0])):
                cur_response = []
                for k in cur_responses:
                    cur_response.append(k[idx])
                overall_response.append(cur_response)
            cur_info = [(k, v) for k, v in zip(cur_jsons, overall_response)]
            all_responses.extend(cur_info)
            end_time = time.time()
            cost_time = end_time - start_time
            left_time = (len(cur_json_list) - len(all_responses)) * cost_time / (len(all_responses) + 1)
            print(
                f"Progress: {len(all_responses)}/{len(cur_json_list)}, Cost Time: {cost_time}, Left Time: {left_time}")
            json.dump(all_responses, open(tmp_name, "w", encoding="utf8"), indent=4)
        except Exception as e:
            print(f"failed due to {str(e)}")
    json.dump(all_responses, open(des_save_name, "w", encoding="utf8"), indent=4)
