import io
import re
import json
import numpy as np
from PIL import Image


def get_analysis_prompt(modality_data, ability):
    question = f"{ability['Question']}\n"
    for key, value in ability['Options'].items():
        question += f"{key}. {value}\n"

    if isinstance(modality_data, bytes):
        analysis_prompt = (
            'You are a Visual Question Answering expert.\n'
            'Based on the image, analyze and answer the following question:\n'
            f'{question}'
        )
    elif isinstance(modality_data, str):
        analysis_prompt = (
            'You are a Question Answering expert.\n'
            'The following scene graph represents the extracted information from a diagram:\n'
            f'{modality_data}\n'
            'Based on the descriptions, analyze and answer the following question:\n'
            f'{question}'
        )
    else:
        raise NotImplementedError
    return analysis_prompt


def get_choice_prompt(analysis_prompt, analysis):
    choice_prompt = (
        'Here is the context for your reference:\n'
        '-------------------------------------------------------------\n'
        '### User ###\n'
        f'{analysis_prompt}\n'
        '### Assistant (You) ###\n'
        f'{analysis}\n'
        '-------------------------------------------------------------\n'
        'Based on the context above, give your final choice. '
        'Do not provide any explanations. Only output A/B/C/D.\n'
    )
    return choice_prompt


def get_white_bytes(real_bytes):
    real_image = Image.open(io.BytesIO(real_bytes))
    width, height = real_image.size
    white_image = Image.fromarray(np.full((height, width, 3), 255, dtype=np.uint8))
    buffer = io.BytesIO()
    white_image.save(buffer, format='PNG')
    white_bytes = buffer.getvalue()
    return white_bytes


def resize_image(image, max_pixels=2048*2048, min_pixels=224*224):
    width, height = image.size
    current_pixels = width * height

    if current_pixels > max_pixels:
        scale_factor = (max_pixels / current_pixels) ** 0.5
        new_width = int(width * scale_factor)
        new_height = int(height * scale_factor)
        image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)

    elif current_pixels < min_pixels:
        scale_factor = (min_pixels / current_pixels) ** 0.5
        new_width = int(width * scale_factor)
        new_height = int(height * scale_factor)
        image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)

    return image

