import cv2
from tqdm import tqdm
from utils.functions import  response_to_json, get_image
import json

def refine_tags(prompt_agent, debug_agent, data, tags, worker, maximum_update_data):
    # init
    new_tags = {}
    for cls in tags:
        new_tags[cls] = {}
        for attr_type in tags[cls]:
            new_tags[cls][attr_type] = {}
            for attr in tags[cls][attr_type]:
                new_tags[cls][attr_type][attr] = []

    # collect new tags
    for cls, tag_dict in tqdm(tags.items()):
        count = 0
        rules = prompt_agent.label_data_stage1_rules()
        for image in tqdm(data[cls]):
            count += 1
            if count > maximum_update_data: break
            question = prompt_agent.label_data_stage1_question(tag_dict, cls)
            if type(image) == list:
                image_path, box = image[0], image[1]
                input_image = get_image(image_path, box=box)
            else:
                image_path = image
                input_image = get_image(image_path)
            cv2.imwrite('temp_%d.png'%worker, input_image)
            try:
                response = debug_agent.ask_image('temp_%d.png'%worker, rules, question)
                response_json = response_to_json(response)
                new_tags = collect_new_tags(tags, new_tags, response_json, cls)
            except KeyboardInterrupt:
                print("exit")
                exit()
            except:
                print(f'\nERROR in {cls} {image_path}:\n{response}\n')
                continue
        with open('tags_refined_temp_%d.json'%worker, 'w') as f:
            json.dump(new_tags, f, indent=4)
    return new_tags

def collect_new_tags(old_tag, new_tag, response, cls):
    for attr_type in old_tag[cls]:
        for attr in old_tag[cls][attr_type]:
            old_tag_list = old_tag[cls][attr_type][attr]
            new_tag_list = new_tag[cls][attr_type][attr]
            response_tag = response[attr_type][attr]
            
            if ", " in response_tag:
                response_tag = response_tag.split(", ")
            if type(response_tag) == list:
                for tag_in_list in response_tag:
                    if tag_in_list not in old_tag_list and tag_in_list not in new_tag_list:
                        new_tag[cls][attr_type][attr].append(tag_in_list)
            elif response_tag not in old_tag_list and response_tag not in new_tag_list:
                new_tag[cls][attr_type][attr].append(response_tag)
    return new_tag

def merge_tags(prompt_agent, debug_agent, tags, new_tags):
    merge_tag = {}
    for cls in tqdm(tags):
        merge_tag[cls] = {}
        for attr_type in tags[cls]:
            merge_tag[cls][attr_type] = {}
            for attr in tags[cls][attr_type]:
                merge_tag[cls][attr_type][attr] = {}
                old_tag_list = tags[cls][attr_type][attr]
                new_tag_list = new_tags[cls][attr_type][attr]
                merge_tag[cls][attr_type][attr]["old tags"] = old_tag_list
                merge_tag[cls][attr_type][attr]["new tags"] = new_tag_list
    results = {}
    cls_for_merge = list(merge_tag.keys())
    for i in range(3):
        error_cls = []
        for cls in tqdm(cls_for_merge):
            tags_cls = merge_tag[cls]
            rules = prompt_agent.label_data_stage2_rules()
            question = prompt_agent.label_data_stage2_question(tags_cls, cls)
            response = debug_agent.ask_text(rules, question)
            response_json = response_to_json(response)
            try:
                results[cls] = response_json
            except:
                error_cls.append(cls)
                print(f'\nERROR in {cls}:\n{response}\n')
                continue
        cls_for_merge = error_cls
    return results
