import cv2
import numpy as np
from tqdm import tqdm
import json
from utils.functions import get_image, response_to_json


def get_attrs(prompt_agent, debug_agent, data, pairs_per_class=1, task="classification"):
    ''' function get_attrs
    extract a list of attributes from data samples, e.g., brightness, object color, etc.
    arguments:
        prompt_agent (Prompts) - prompt rule and template maker for GPT interaction
        debug_agent (Debug_with_GPT4V) - agent for GPT interaction with prompts and images
        data (dict) - information of all images in the dataset
        pairs_per_class (int) - number of data pairs to be compared for each category
    returns:
        attrs (dict) - dict storing the attributes with their descriptions
    '''
    # extraction
    loader = tqdm(data.keys())
    save_attrs = {}
    for cls in loader:
        rules = prompt_agent.init_attrs_rules()
        question = prompt_agent.init_attrs_question(cls)
        response = debug_agent.ask_text(rules, question)
        init_attrs = response_to_json(response)
        if task == "classification":
            for cls2 in data.keys():
                if cls == cls2:
                    continue
                rules = prompt_agent.extract_clf_attrs_rules()
                question = prompt_agent.extract_clf_attrs_question(cls, cls2)
                response = debug_agent.ask_text(rules, question)
                error_attrs = response_to_json(response)
        elif task == "pose":
            rules = prompt_agent.extract_pose_attrs_rules()
            question = prompt_agent.extract_pose_attrs_question(cls)
            response = debug_agent.ask_text(rules, question)
            error_attrs = response_to_json(response)
        elif task == "detection":
            rules = prompt_agent.extract_det_attrs_rules()
            question = prompt_agent.extract_det_attrs_question(cls)
            response = debug_agent.ask_text(rules, question)
            error_attrs = response_to_json(response)
        else:
            print("only support classification/pose/detection")
            raise NotImplementedError

        attrs = merge_attribute_forms(init_attrs, error_attrs)
        rules = prompt_agent.check_attrs_rules()
        question = prompt_agent.check_attrs_question(attrs, cls)
        response = debug_agent.ask_text(rules, question)
        attrs = response_to_json(response)

        pairs_idxs = tqdm(range(pairs_per_class))
        for _ in pairs_idxs:
            idx1 = np.random.randint(len(data[cls]))
            idx2 = idx1
            while idx2 == idx1:
                idx2 = np.random.randint(len(data[cls]))
            if type(data[cls][idx1]) == list:
                image1 = get_image(data[cls][idx1][0], box=data[cls][idx1][1], patch_path='./assets/no1.jpg')
                image2 = get_image(data[cls][idx2][0], box=data[cls][idx2][1], patch_path='./assets/no2.jpg')
            else:
                image1 = get_image(data[cls][idx1], './assets/no1.jpg')
                image2 = get_image(data[cls][idx2], './assets/no1.jpg')
            cv2.imwrite('temp1.png',image1)
            cv2.imwrite('temp2.png',image2)
            rules = prompt_agent.extract_dataset_attrs_rules()
            question = prompt_agent.extract_dataset_attrs_question(attrs, cls)
            response = debug_agent.ask_multiimage(['temp1.png', 'temp2.png'], rules, question)
            response_json = response_to_json(response)
            attrs = merge_attribute_forms(attrs, response_json)
            pairs_idxs.set_postfix(ordered_dict={'main object': len(attrs['main object']), 'background': len(attrs['background']), 'global': len(attrs['global'])})

        rules = prompt_agent.check_attrs_rules()
        question = prompt_agent.check_attrs_question(attrs, cls)
        response = debug_agent.ask_text(rules, question)
        attrs = response_to_json(response)
        save_attrs[cls] = attrs
    return save_attrs

def merge_attribute_forms(form1, form2):
    """
    Merges two attribute forms into one. For each category ("main object", "background", "global"),
    combines the lists of attributes, removing duplicates.

    :param form1: The first attribute form as a dictionary.
    :param form2: The second attribute form as a dictionary.
    :return: A merged attribute form as a dictionary.
    """
    merged_form = {}

    # Define categories to merge
    categories = ["main object", "background", "global"]

    for category in categories:
        # Get attributes from both forms
        attributes1 = set(form1.get(category, []))
        attributes2 = set(form2.get(category, []))
        
        # Merge attributes and remove duplicates
        merged_attributes = list(attributes1.union(attributes2))
        
        # Add to the merged form
        merged_form[category] = merged_attributes

    return merged_form
