import cv2
from tqdm import tqdm
import json
from utils.functions import get_image, response_to_json, update_form


def get_tags(prompt_agent, debug_agent, attrs):
    ''' function get_tags
    extract a list of possible tags of visual attributes, e.g., gender: male, female
    arguments:
        prompt_agent (Prompts) - prompt rule and template maker for GPT interaction
        debug_agent (Debug_with_GPT4V) - agent for GPT interaction with prompts and images
        attrs (dict) - dict storing the attributes with their descriptions
    returns:
        tags (dict) - dict storing all tags of the attributes
    '''
    tags = {}
    cls_to_tagging = list(attrs.keys())
    for i in range(3):
        error_cls = []
        for cls in tqdm(cls_to_tagging):
            attr = attrs[cls]
            # extraction
            rules = prompt_agent.extract_tags_rules()
            question = prompt_agent.extract_tags_question(attr, cls)
            response = debug_agent.ask_text(rules, question)
            cls_tags = response_to_json(response)
            # checking
            rules = prompt_agent.check_tags_rules()
            question = prompt_agent.check_tags_question(cls_tags, cls)
            response = debug_agent.ask_text(rules, question)
            cls_tags = response_to_json(response)
            if not check_name_unchange(attr, cls_tags):
                error_cls.append(cls)
            for attr_type, attr_dict in cls_tags.items():
                for attr, tag_list in attr_dict.items():
                    if attr[:3] == "is ":
                        cls_tags[attr_type][attr] = ["yes","no"]
            tags[cls] = cls_tags
            with open('tags_temp.json', 'w') as f:
                json.dump(tags, f, indent=4)
        cls_to_tagging = error_cls
    return tags

def check_name_unchange(attrs, tags):
    for attr_type in attrs:
        if attr_type not in tags:
            print("missing", attr_type, list(tags.keys()))
            return False
        else:
            for attr in attrs[attr_type]:
                if attr not in tags[attr_type]:
                    print("missing", attr, list(tags[attr_type].keys()))
                    return False
    return True

if __name__ == '__main__':
    print(get_tags.__doc__)


