import cv2
from tqdm import tqdm
from utils.functions import  response_to_json, get_image
import pickle

def get_labels(prompt_agent, debug_agent, data, tags, worker):
    labels = {cls:{} for cls in tags}
    for i in range(5):# process failure
        labels, failed_data, num_error = label_process(prompt_agent, debug_agent, data, tags, labels, worker)
        data = failed_data
        if num_error == 0:
            break
        with open('labels_temp_%d.pkl'%worker, 'wb') as file:
            pickle.dump(labels, file)
    return labels, tags

def quick_fix(response):
    for attr_type in response:
        for attr, tag in response[attr_type].items():
            if attr[:3] == "is " and tag == 'not visible':
                response[attr_type][attr] = "no"
    return response

def check_new_tag(tags, response):
    has_new_tag = False
    for attr_type in response:
        for attr, tag in response[attr_type].items():
            tag_set =  tags[attr_type][attr]
            if type(tag) != str:
                for tag_in_list in tag:
                    if tag_in_list not in tag_set:
                        print("\n New tag:", attr_type, attr, tag_in_list, tag_set)
                        has_new_tag = True
                continue
            elif tag not in tag_set:
                print("\n New tag:", attr_type, attr, tag, tag_set)
                has_new_tag = True
    return has_new_tag

def label_process(prompt_agent, debug_agent, data, tags, labels, worker):
    failed_data = {}
    num_error = 0
    for cls, tag_dict in tqdm(list(tags.items())):
        cls_label = {}
        failed_data[cls] = []
        rules = prompt_agent.label_data_stage3_rules()
        for image in tqdm(data[cls]):
            question = prompt_agent.label_data_stage3_question(tag_dict, cls)
            if type(image) != str:
                image_path, box = image[0], image[1]
                image = tuple([image_path, tuple(box)])
                input_image = get_image(image_path, box=box, padding_ratio=0.3, show_box=True)
            else:
                image_path = image
                input_image = get_image(image_path)
            cv2.imwrite('temp_%d.png'%worker, input_image)
            try:
                response = debug_agent.ask_image('temp_%d.png'%worker, rules, question)
                response_json = response_to_json(response)
                response_json = quick_fix(response_json)
                if response_json == {}:
                    num_error += 1
                    failed_data[cls].append(image)
                    print(f'\nERROR in {cls} {image_path}:\n{response}\n')
                elif check_new_tag(tag_dict, response_json):
                    raise ValueError
                else:
                    cls_label[image] = response_json
            except KeyboardInterrupt:
                print("exit")
                exit()
            except:
                num_error += 1
                failed_data[cls].append(image)
                print(f'\nERROR in {cls} {image_path}:\n{response}\n')   
            with open('labels_temp_%d.pkl'%worker, 'wb') as file:
                pickle.dump(labels, file)   
        labels[cls].update(cls_label)
    return labels, failed_data, num_error