import itertools
import os
import numpy as np
import tqdm
from collections import Counter
import concurrent.futures
import copy


ATTRIBUTE_TYPES = ["main object", "background", "global"]

def preprocess_data(tags, labels, cls):
    # refined_attributes, refined_tags, refined_labels  = list(tags[cls].keys()), tags[cls], labels[cls]
    # return refined_attributes, refined_tags, refined_labels
    refined_attributes = []
    refined_tags = {}
    refined_labels = {}
    for attr_type in tags[cls]:
        for attr in tags[cls][attr_type]:
            refined_attributes.append(attr_type + ", " + attr)
            refined_tags[attr_type + ", " + attr] = tags[cls][attr_type][attr]

    refined_labels = {}
    for path, data_label in labels[cls].items():
        good_label = True
        new_data_label = {}
        for attr_type in tags[cls]:
            for attr in tags[cls][attr_type]:
                try:
                    new_data_label[attr_type + ", " + attr] =  data_label[attr_type][attr]
                except:
                    good_label = False
                    break
        if not good_label: continue
        if good_label:
            refined_labels[path] = new_data_label
    return refined_attributes, refined_tags, refined_labels

def fast_enumeration(tags, labels, maximum_depth=1, valueable_size=10):
    stat_dicts = {}
    for cls in tags:
        print("start",cls)
        refined_attributes, refined_tags, refined_labels = preprocess_data(tags, labels, cls)
        stat_dicts[cls] = static_analysis_cls_tree_mat_prune(refined_attributes, refined_tags, refined_labels, maximum_depth, valueable_size)
    return stat_dicts


def baseline_enumeration(tags, labels, maximum_depth=1, valueable_size=10):
    stat_dicts = {}
    for cls in tags:
        print("start",cls)
        refined_attributes, refined_tags, refined_labels = preprocess_data(tags, labels, cls)
        stat_dicts[cls] = static_analysis_cls_naive(refined_attributes, refined_tags, refined_labels, maximum_depth, valueable_size)
    return stat_dicts

def static_analysis_cls_naive(attributes, tags, labels, maximum_depth=1, valueable_size=10):
    # 10000 data, 4 combination, 167708s
    slices = {}
    for combination_size in range(maximum_depth):
        valueable_data_count = 0
        combination_size += 1
        print("combination: ", combination_size)
        attribute_combinations = list(itertools.combinations(attributes, combination_size))
        for attribute_combination in tqdm.tqdm(attribute_combinations):
            tags_list = [tags[attribute] for attribute in attribute_combination]
            tag_combinations = list(itertools.product(*tags_list))
            for tag_combination in tag_combinations:
                count = 0
                key_tuple = []
                idx = 0
                for attribute in attributes:
                    if attribute in attribute_combination:
                        key_tuple.append(tag_combination[idx])
                        idx += 1
                    else:
                        key_tuple.append("")
                
                for path, data_label in labels.items():
                    match_slice = True
                    for attribute, tag in zip(attribute_combination, tag_combination):
                        if data_label[attribute] != tag:
                            match_slice = False
                            break
                    if match_slice:
                        count += 1
                if count > 0:
                    key_tuple = tuple(key_tuple)
                    slices[key_tuple] = {}
                    slices[key_tuple]["name"] = {zip(attribute_combination, tag_combination)}
                    slices[key_tuple]["count"] = count
                if count > valueable_size:
                    valueable_data_count += 1
        print("layer: ", combination_size, "valuable slices:",valueable_data_count)  
    return slices, valueable_data_count

def static_analysis_cls_tree_mat_prune(attributes, tags, labels, maximum_depth=1, valueable_size=10):
    # 10000 data, 3 combination, 17s
    # 4 combination, 329s
    slices = {}
    empty_data_mat = np.zeros(len(labels))
    
    for layer in range(maximum_depth):
        valueable_data_count = 0
        attribute_combinations = list(itertools.combinations(attributes, layer+1))
        for attribute_combination in tqdm.tqdm(attribute_combinations):
            tags_list = [tags[attribute] for attribute in attribute_combination]
            tag_combinations = list(itertools.product(*tags_list))
            for tag_combination in tag_combinations:
                if layer != 0:
                    parent_combination1 = tag_combination[:-1]
                    parent_combination2 = tag_combination[1:]
                    parent_tuple1 = []
                    parent_tuple2 = []
                    idx = 0
                    for attribute in attributes:
                        if attribute in attribute_combination[:-1]:
                            parent_tuple1.append(parent_combination1[idx])
                            idx += 1
                        else:
                            parent_tuple1.append("")
                    idx = 0
                    for attribute in attributes:
                        if attribute in attribute_combination[1:]:
                            parent_tuple2.append(parent_combination2[idx])
                            idx += 1
                        else:
                            parent_tuple2.append("")
                    parent_tuple1, parent_tuple2 = tuple(parent_tuple1), tuple(parent_tuple2)
                    if parent_tuple1 not in slices or parent_tuple2 not in slices:
                        continue
                    data_mat1 = slices[parent_tuple1]["data"]
                    data_mat2 = slices[parent_tuple2]["data"]
                    new_data_mat = data_mat1 * data_mat2
                else:
                    new_data_mat = copy.deepcopy(empty_data_mat)
                    for i, (path, data_label) in enumerate(labels.items()):
                        match_slice = True
                        for attribute, tag in zip(attribute_combination, tag_combination):
                            if data_label[attribute] != tag:
                                match_slice = False
                                break
                        if match_slice:
                            new_data_mat[i] = 1
                
                key_tuple = []
                idx = 0
                for attribute in attributes:
                    if attribute in attribute_combination:
                        key_tuple.append(tag_combination[idx])
                        idx += 1
                    else:
                        key_tuple.append("")
                key_tuple = tuple(key_tuple)
                data_count = np.sum(new_data_mat)

                if data_count > valueable_size:
                    slices[key_tuple] = {}
                    slices[key_tuple]["data"] = new_data_mat
                    slices[key_tuple]["name"] = {attr:tag for attr, tag in zip(attribute_combination, tag_combination)}
                    slices[key_tuple]["count"] = data_count
                    valueable_data_count += 1
        print("layer: ", layer, "valuable slices:", valueable_data_count)    
    return slices

def post_process(tags, labels, targets_dict, stat_results, acc_thresholds, count_thresholds):
    filter_stat_results = {}
    for cls in stat_results:
        _, tags_cls, labels_cls = preprocess_data(tags, labels, cls)
        print(list(targets_dict[cls].keys())[0], list(labels_cls.keys())[0])
        correctness_list = np.array([targets_dict[cls][path] for path, data_label in labels_cls.items() if path in targets_dict[cls]])
        mean_value_cls = np.mean(correctness_list)
        print(mean_value_cls)
        datas = list(labels_cls.keys())
        acc_threshold = acc_thresholds if type(acc_thresholds) != dict else acc_thresholds[cls]
        count_threshold = count_thresholds if type(count_thresholds) != dict else count_thresholds[cls]
        filter_stat_results[cls] = {}
        for key, slices in stat_results[cls].items():
            if slices["count"] > count_threshold:
                accuracy = slices["data"] @ correctness_list / slices["count"]
                slices["accuracy"] = accuracy
                slices["visuals"] = []
                # if len(filter_stat_results[cls]) < len(stat_results[cls]) * 0.05:
                if accuracy - mean_value_cls < acc_threshold:
                    filter_stat_results[cls][key] = slices
        filter_stat_results[cls] = shrink_combinations(tags_cls, filter_stat_results[cls])
        print(cls, " match slices:", len(filter_stat_results[cls]))
        for key, slices in filter_stat_results[cls].items():
            for i, data in enumerate(slices["data"]):
                if data:
                    filter_stat_results[cls][key]["visuals"].append([datas[i], targets_dict[cls][datas[i]]])
        
    return filter_stat_results

def combination_to_key(tags_cls, combination):
    key = []
    for attr in tags_cls:
        key.append(combination.get(attr, ""))
    return tuple(key)

def shrink_combinations(tags_cls, stat_results):
    def get_parent_combinations(attribute_combination):
        attributes = list(attribute_combination.keys())
        parent_combinations = []
        # Get all possible parent combinations by reducing the number of attributes
        for r in range(len(attributes) - 1, 0, -1):
            parent_combos = itertools.combinations(attributes, r)
            for combo in parent_combos:
                parent_combination = {attr: attribute_combination[attr] for attr in combo}
                parent_combinations.append(parent_combination)
        return parent_combinations

    # Convert the dictionary keys into tuples of (attribute, tag) for easier comparison
    pruned_stat_results = stat_results.copy()
    # Iterate through each combination in the dictionary
    for key, value in tqdm.tqdm(stat_results.items()):
        attribute_combination = value["name"]
        accuracy = value["accuracy"]
        
        # Find all possible parent combinations
        parent_combinations = get_parent_combinations(attribute_combination)

        # Check the accuracy of all parents
        lower_than_all_parents = True
        for parent_combination in parent_combinations:
            p_key = combination_to_key(tags_cls, parent_combination)
            parent_dict = stat_results.get(p_key, "")
            if parent_dict:
                if accuracy > parent_dict["accuracy"]:
                    lower_than_all_parents = False
                    break

        if not lower_than_all_parents:
            pruned_stat_results.pop(key)
    
    return pruned_stat_results