from itertools import chain

import matplotlib
import torch
import os
import re
from PIL import Image, ImageOps
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, AutoTokenizer, AutoModelForCausalLM, AutoModelForZeroShotObjectDetection
# from core.all_rules import rules

def load_image(image_path):
    image = Image.open(image_path)
    image = ImageOps.exif_transpose(image)
    image = image.convert("RGB")

    return image

def get_llm_response(responses):
    r = []
    for response in responses:
        response = response.split('<|assistant|>', -1)[-1]
        response = response.replace('\n', '')
        r.append(response)
    return r

def get_vlm_response(prompt, response):
    r = []
    for p, f in zip(prompt, response):
        r.append(f[len(p):])
    return r

def group_images(images, n):
    return [images[i:i + n] for i in range(0, len(images), n)]

def flatten_images(images):
    return list(chain(*images))

def cat_activate_vector(id, rules, cls_shot, activate_vector_path):
    activate_vectors = []
    for r in rules:
        if r.__name__ == 'n_shot_top5_distr':
            for c_s in cls_shot:
                path = os.path.join(activate_vector_path, f'{id}_{r.__name__}_{c_s}.pth')
                activate_vector = torch.load(path)
                activate_vectors.append(activate_vector)

        else:
            path = os.path.join(activate_vector_path, f'{id}_{r.__name__}.pth')
            activate_vector = torch.load(path)
            activate_vectors.append(activate_vector)

    activate_vectors = torch.cat(activate_vectors, dim=-1)
    torch.save(activate_vectors, os.path.join(activate_vector_path, f'{id}_activate.pth'))
    return

def count_rules(rules, cls_shot):
    count = 0
    for r in rules:
        if r.__name__ == 'n_shot_top5_distr':
            for _ in cls_shot:
                count += 3
        else:
            count += 1

    return count


def load_activate_vectors(activate_vectors):
    # The input must be type of list : [vector1, vector2, ...]
    ref = activate_vectors[0]
    activate_vectors_ = [ref]
    for i in range(1, len(activate_vectors)):
        activate_vector = activate_vectors[i]
        activate_vector[activate_vector == -1] = ref[activate_vector == -1]
        activate_vectors_.append(activate_vector)
    return activate_vectors_


def load_vlm(vlm_path):
    min_pixels = 256 * 28 * 28
    max_pixels = 512 * 28 * 28

    vlm = Qwen2VLForConditionalGeneration.from_pretrained(
        vlm_path,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        )
    vlm_processor = AutoProcessor.from_pretrained(vlm_path, min_pixels=min_pixels, max_pixels=max_pixels)

    return vlm, vlm_processor

def load_florence2(vlm_path):
    model = AutoModelForCausalLM.from_pretrained(vlm_path, trust_remote_code=True, torch_dtype='auto').eval().cuda()
    processor = AutoProcessor.from_pretrained(vlm_path, trust_remote_code=True)

    return model, processor

def load_llm(llm_path):
    llm = AutoModelForCausalLM.from_pretrained(llm_path, device_map='auto', torch_dtype=torch.bfloat16)
    llm_processor = AutoTokenizer.from_pretrained(llm_path)

    return llm.eval(), llm_processor

def load_dino(dino_path):
    model = AutoModelForZeroShotObjectDetection.from_pretrained(dino_path).to('cuda')
    processor = AutoProcessor.from_pretrained(dino_path)
    return model, processor

def run_example(image, task_prompt, model, processor, text_input=None):
    """
    :param image: PLT
    :param task_prompt: Task: '<CAPTION>' '<OD>' '<CAPTION_TO_PHRASE_GROUNDING>'
    :param model: Florence2 model
    :param processor: Florence2 processor
    :param text_input: Prompt
    :return: Result
    """
    if text_input is None:
        prompt = task_prompt
    else:
        prompt = task_prompt + text_input
    inputs = processor(text=prompt, images=image, return_tensors="pt").to('cuda', torch.float16)
    generated_ids = model.generate(
      input_ids=inputs["input_ids"].cuda(),
      pixel_values=inputs["pixel_values"].cuda(),
      max_new_tokens=1024,
      early_stopping=False,
      do_sample=False,
      num_beams=3,
    )
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
    parsed_answer = processor.post_process_generation(
        generated_text,
        task=task_prompt,
        image_size=(image.width, image.height)
    )

    return parsed_answer

def run_detect(image, prompt, model, processor):
    inputs = processor(images=image, text=prompt, return_tensors="pt").to('cuda')
    with torch.no_grad():
        outputs = model(**inputs)

    results = processor.post_process_grounded_object_detection(
        outputs,
        inputs.input_ids,
        box_threshold=0.4,
        text_threshold=0.3,
        target_sizes=[image.size[::-1]]
    )

    return results

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patches as patches
matplotlib.use('Agg')


def plot_bbox(image, data, index=''):
    # Create a figure and axes
    fig, ax = plt.subplots()

    # Display the image
    ax.imshow(image)

    # Plot each bounding box
    for bbox, label in zip(data['bboxes'], data['labels']):
        # Unpack the bounding box coordinates
        x1, y1, x2, y2 = bbox
        # Create a Rectangle patch
        rect = patches.Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=1, edgecolor='r', facecolor='none')
        # Add the rectangle to the Axes
        ax.add_patch(rect)
        # Annotate the label
        plt.text(x1, y1, label, color='white', fontsize=8, bbox=dict(facecolor='red', alpha=0.5))

        # Remove the axis ticks and labels
    ax.axis('off')

    # Show the plot
    # plt.show()
    plt.savefig(f'1_{index}.png')

def plot_bbox_d(image, data, index=''):
    # Create a figure and axes
    fig, ax = plt.subplots()

    # Display the image
    ax.imshow(image)

    # Plot each bounding box
    for bbox, label in zip(data['boxes'].cpu().numpy(), data['labels']):
        # Unpack the bounding box coordinates
        x1, y1, x2, y2 = bbox
        # Create a Rectangle patch
        rect = patches.Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=1, edgecolor='r', facecolor='none')
        # Add the rectangle to the Axes
        ax.add_patch(rect)
        # Annotate the label
        plt.text(x1, y1, label, color='white', fontsize=8, bbox=dict(facecolor='red', alpha=0.5))

        # Remove the axis ticks and labels
    ax.axis('off')

    # Show the plot
    # plt.show()
    plt.savefig(f'1_{index}.png')

def extract_answer(res):
    patterns = [
        r'##\s*([ABCD])\s*##',
        r'##(.*?)##',
        r'##\s*([0-3])\s*##'
    ]

    for pattern in patterns:
        match = re.search(pattern, res)
        if match:
            return match.group(1).strip()

    match = re.search(r'([ABCD])\.', res)
    if match:
        return match.group(1).strip()

    match = re.search(r'([A-Za-z]+[0-9]*|[0-9]+)', res)
    if match:
        return match.group(1).strip()

    return None

def convert_to_od_format(data):
    """
    Converts a dictionary with 'bboxes' and 'bboxes_labels' into a dictionary with separate 'bboxes' and 'labels' keys.

    Parameters:
    - data: The input dictionary with 'bboxes', 'bboxes_labels', 'polygons', and 'polygons_labels' keys.

    Returns:
    - A dictionary with 'bboxes' and 'labels' keys formatted for object detection results.
    """
    # Extract bounding boxes and labels
    bboxes = data.get('bboxes', [])
    labels = data.get('bboxes_labels', [])

    # Construct the output format
    od_results = {
        'bboxes': bboxes,
        'labels': labels
    }

    return od_results

if __name__ == '__main__':
    cls_shot = [('diff', 4), ('diff', 5), ('diff', 6), ('diff', 7), ('same', 4), ('same', 5),
                ('same', 6),
                ('same', 7)]
    cls_shot = [('diff', 4), ('diff', 5), ('diff', 6), ('diff', 7)]
    rules = rules['realworld']

    cat_activate_vector('id0', rules, cls_shot, 'activate_vector/real_world/')

    # a = torch.load('/home/15t/fzy/code/raie/activate_vector/cifar10_train/id0_activate.pth')
    # print(a.shape)
    #
    # print(count_rules(rules, cls_shot))