import cv2
from tqdm import tqdm
import json
from utils.functions import response_to_json
import numpy as np
from scipy.spatial import distance
import itertools
import copy
import json

def error_slices_predict(prompt_agent, debug_agent, tags, combination_size, task="cls"):
    descriptions = {}
    for cls, cls_tags in tqdm(tags.items()):
        rules = prompt_agent.predict_bg_errors_rules()
        question = prompt_agent.predict_bg_errors_question(cls, cls_tags, combination_size)
        response = debug_agent.ask_text(rules, question)
        cls_description = response_to_json(response)
        descriptions[cls] = cls_description["predictions"]
        if task == "cls":
            for target_cls, _ in tqdm(tags.items()):
                if cls == target_cls:
                    continue
                rules = prompt_agent.predict_cls_errors_rules()
                question = prompt_agent.predict_cls_errors_question(cls, target_cls, cls_tags, combination_size)
                response = debug_agent.ask_text(rules, question)
                cls_description = response_to_json(response)
                descriptions[cls] += cls_description["predictions"]
        elif task == "pose":
            rules = prompt_agent.predict_pose_errors_rules()
            question = prompt_agent.predict_pose_errors_question(cls, cls_tags, combination_size)
            response = debug_agent.ask_text(rules, question)
            cls_description = response_to_json(response)
            descriptions[cls] += cls_description["predictions"]
        elif task == "det":
            rules = prompt_agent.predict_det_errors_rules()
            question = prompt_agent.predict_det_errors_question(cls, cls_tags, combination_size)
            response = debug_agent.ask_text(rules, question)
            cls_description = response_to_json(response)
            descriptions[cls] += cls_description["predictions"]

        with open('predict_errors_temp.json', 'w') as f:
            json.dump(descriptions, f, indent=4)
    return descriptions



def get_closest_tags(tag, tag_features):
    """
    get the top-n tags with the closest feature distances from the given tag
    
    :param tag: the tag to be mutated
    :param tag_features: {attribute: {tag: feature}}
    :return: the top-n closest tags
    """
    tag_feature = tag_features[tag]
    all_tags = list(tag_features.keys())
    
    # calculate the feature distances from other tags
    dist = {}
    for other_tag in all_tags:
        if other_tag != tag:
            other_feature = tag_features[other_tag]
            dist[other_tag] = distance.euclidean(tag_feature, other_feature)
    
    # get the top-n tags
    closest_tags = sorted(dist, key=dist.get)
    return closest_tags

def error_slices_variate(stat_results, tag_features):
    extended_slices = {}
    for cls, stat_results_cls in stat_results.items():
        extended_slices[cls] = []
        slices = [[key, value] for key, value in stat_results_cls.items()]
        sorted_slices = sorted(slices,key = lambda x: x[1]["accuracy"])

        for key, slice in sorted_slices:
            extend_slice = {}
            for attr_name, ori_tag in slice["name"].items():
                attr_type = attr_name.split(", ")[0]
                attr = attr_name[len(attr_type)+2:]
                if attr_type not in extend_slice:
                    extend_slice[attr_type] = {}
                extend_slice[attr_type][attr] = ori_tag
            
            for attr_name, ori_tag in slice["name"].items():
                attr_type = attr_name.split(", ")[0]
                attr = attr_name[len(attr_type)+2:]
                tag_set = tag_features[cls][attr_type][attr]
                if len(tag_set) > 4 and slice["count"] > 15:
                    closest_tag = get_closest_tags(ori_tag, tag_set)[0]
                    new_slice = copy.deepcopy(extend_slice)
                    new_slice[attr_type][attr] = closest_tag
                    has_add = False
                    for exists_slice in extended_slices[cls]:
                        if exists_slice == new_slice:
                            has_add = True
                    if not has_add:
                        extended_slices[cls].append(new_slice)
            if len(extended_slices[cls]) > 20:
                break
    return extended_slices