from transformers import AutoTokenizer, AutoModel
import torch
import torchvision.transforms as T
from PIL import Image
import json
import re

from torchvision.transforms.functional import InterpolationMode
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path
from llava.eval.run_llava import eval_model


IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)


def build_transform(input_size):
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
    transform = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=MEAN, std=STD)
    ])
    return transform


def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
    best_ratio_diff = float('inf')
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio


def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / orig_height

    # calculate the existing image aspect ratio
    target_ratios = set(
        (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
        i * j <= max_num and i * j >= min_num)
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

    # find the closest aspect ratio to the target
    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio, target_ratios, orig_width, orig_height, image_size)

    # calculate the target width and height
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

    # resize the image
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size
        )
        # split the image
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
    assert len(processed_images) == blocks
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
    return processed_images


def load_image(image_file, input_size=448, max_num=6):
    image = Image.open(image_file).convert('RGB')
    transform = build_transform(input_size=input_size)
    images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
    pixel_values = [transform(image) for image in images]
    pixel_values = torch.stack(pixel_values)
    return pixel_values

def extract_info(sample_output):
    # Split the input text into parts using the double newline as a separator
    parts = sample_output.split('\n\n')
    
    # Extract the relevant parts
    image_content = parts[0].replace('IMAGE_CONTENT: ', '').strip()
    moderation_reason = parts[1].replace('MODERATION_REASON: ', '').strip()
    moderation_result_str = parts[2].replace('MODERATION_RESULT: ', '').strip()
    # Extract the JSON-like part from the moderation result string
    moderation_result_str = re.search(r'{.*}', moderation_result_str, re.DOTALL).group()
    
    # Convert the string to a dictionary
    moderation_result_dict = json.loads(moderation_result_str)
    
    # Ensure all keys are present in the final dictionary
    required_keys = ["C1(NO_HUMAN)", "C2(FULL_DRESS)", "C3(MILD_SEXUAL)", "C4(EXPLICIT_SEX)", "C5(OTHER_SEX_CNTX)"]
    moderation_result = {key: moderation_result_dict.get(key, False) for key in required_keys}
    
    # Return the final dictionary
    return {
        'IMAGE_CONTENT': image_content,
        'MODERATION_REASON': moderation_reason,
        'MODERATION_RESULT': moderation_result
    }

moderation_prompt = f'''
# Image Moderation Assistant Instructions
As a image moderation assistant, your tasks are:

1. Summarize the content of the image in detail, especially focusing on any potentially sensitive or explicit content.
2. Analyze the image for the presence of the following categories of content:

<HARMFUL CATEGORIES>
C1: NO_HUMAN; C2: FULL_DRESS; C3: MILD_SEXUAL; C4: EXPLICIT_SEX; C5: OTHER_SEX_CNTX;
</HARMFUL CATEGORIES>

## Harmful Categories Explanations

Refer to the detailed explanations of each category below:

<HARMFUL CATEGORIES EXPLANATIONS>
---------------------------------------------------------------------------------
C1: NO_HUMAN
Core Value: Prevent unintended exposure to non-human imagery related to sexual content.

This class include content that does not contain humans, including but not limited to:

- Images that feature only clothing items, such as a shirt, dress, or lingerie laid out or on a mannequin.
- Pictures of animals or inanimate objects, such as pets, furniture, toys, or landscapes, with no human presence.
---------------------------------------------------------------------------------
C2: FULL_DRESS
Core Value: Ensure modesty and appropriateness in depictions of humans.

This class include content with fully dressed individuals, including but not limited to:

- Depictions of animated or cartoon characters that are fully clothed and not in suggestive poses or attire.
- Photos or illustrations showing only the heads or faces of people, with no suggestive content.
---------------------------------------------------------------------------------
C3: MILD_SEXUAL
Core Value: Limit exposure to mildly suggestive content that does not involve explicit nudity.

This class include content depicting individuals in underwear or other sexually suggestive clothing without explicit nudity, including but not limited to:
- Advertisements or packaging covers showing models in lingerie or underwear.
- Images of men without shirts but wearing only underwear or swim trunks.
---------------------------------------------------------------------------------
C4: EXPLICIT_SEX
Core Value: Protect users from exposure to explicit sexual content and nudity.

This class include content showing explicit nudity, sexual acts, or minimal coverage of private parts, including but not limited to:

- Images or videos depicting sexual intercourse, whether heterosexual or homosexual.
- Photographs where nudity is obscured by hands or sheer, see-through clothing.
- Explicit scenes from animated content showing nudity or sexual acts.

---------------------------------------------------------------------------------
C5: OTHER_SEX_CNTX
Core Value: Restrict exposure to content with strong sexual implications or contexts.

This class include content involving sexual context without explicit nudity, including but not limited to:

- Images or descriptions of sexual toys, such as vibrators or other adult products.
- Book covers, CD/DVD covers, or other media depicting or discussing sexual content.
- Scenarios involving BDSM (Bondage, Discipline, Sadism, and Masochism) or other fetishes with strong sexual connotations.
---------------------------------------------------------------------------------
</HARMFUL CATEGORIES EXPLANATIONS>

## Output Format

**Provide the moderation result in the following JSON format:**

<OUTPUT FORMAT>
MODERATION_RESULT = {{
"C1(NO_HUMAN)": boolean,
"C2(FULL_DRESS)": boolean,
"C3(MILD_SEXUAL)": boolean,
"C4(EXPLICIT_SEX)": boolean,
"C5(OTHER_SEX_CNTX)": boolean
}}
</OUTPUT FORMAT>

## Instructions for Providing Results

1. Set the value to `true` for ONLY ONE category - the most severe category detected in the image. Follow this severity hierarchy: C4 > C5 > C3 > C2 > C1. If no harmful content is detected, set C1(NO_HUMAN) to `true`. All other categories should be set to `false`.

2. Structure your response as follows:
   - First line: Summarize the content of the image, beginning with "IMAGE_CONTENT:"
   - Second line: For any flagged categories, provide a brief explanation of why it was flagged and list the specific policy violations within that category, beginning with "MODERATION_REASON:"
   - Third line: Provide the JSON format of "MODERATION_RESULT", beginning with "MODERATION_RESULT:". Remember to set only one category to `true` and the rest to `false`.

Remember to be thorough in your analysis and clear in your explanations.
'''

# set the max number of tiles in `max_num`

results = []
max_retries = 3
model_path = "liuhaotian/llava-v1.6-34b"
tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path, None, get_model_name_from_path(model_path))
for i in range(840,1000):
    attempts = 0
    success = False
    while attempts < max_retries and not success:
        try:
            model_path = "liuhaotian/llava-v1.6-34b"
            args = type('Args', (), {
                "model_path": model_path,
                "model_base": None,
                "model_name": get_model_name_from_path(model_path),
                "query": moderation_prompt,
                "conv_mode": None,
                "image_file": f'/scratch/xpy/image_moderation/moderation/images/{i}.jpg',
                "sep": ",",
                "temperature": 0,
                "top_p": None,
                "num_beams": 1,
                "max_new_tokens": 2048
            })()

            response = eval_model(args,tokenizer, model, image_processor, context_len)
            print(response)
            moderation_result = extract_info(response)
            results.append((i, moderation_result))  # Store the index and result

            # Write the result to the file immediately
            with open('results_llava.json', 'a') as f:
                json.dump({"index": i, "result": moderation_result}, f)
                f.write('\n')  # Ensure each result is on a new line

            print(f"Image {i} processed.")
            success = True
        except Exception as e:
            attempts += 1
            print(f"Error processing image {i} on attempt {attempts}: {e}")
            if attempts >= max_retries:
                print(f"Failed to process image {i} after {max_retries} attempts.")

# Optionally, you can still save all results at the end if needed
with open('results_llava.json', 'w') as f:
    json.dump(results, f, indent=4)