import matplotlib.pyplot as plt
import logging
import logging_config

class DiversityMetric:

    @classmethod
    def check_feedback_embedding(cls, entry):
        return (
            entry["feedback_embedding"] != None
            and all([x != None for x in entry["feedback_embedding"]])
        )
    
    @classmethod
    def get_category(cls, vec, granularity, dim=30):
        return tuple([round(x * granularity) for x in vec[:dim]])

    @classmethod
    def get_diversity(cls, entries, granularity, dim=30):
        return len(set([
            cls.get_category(e["feedback_embedding"], granularity, dim)
            for e in entries 
            if cls.check_feedback_embedding(e) == True
        ]))


    # @classmethod
    # def plot_diversity_time(cls, entries, color="tab:blue", label=None, need_var: bool=True, linestyle="-"):
    #     logging.info("Plotting diversity time...")
    #     x_list = [] # virtual time
    #     y_list = [] # diversity
    #     var_bottom_list = [] # min_explored_grid_num [10, 10000]
    #     var_upper_list = [] # max_explored_grid_num
        
    #     var_granu_list = [pow(10, 1.5 + i/10) for i in range(0, 11)]

    #     main_granu = pow(10, 2)

    #     base_entry_id = entries[0]["id"]

    #     cur_entries = []
    #     for entry in entries:
    #         entry_id = entry["id"] - base_entry_id
    #         while (
    #             len(x_list) < entry_id
    #             or (len(x_list) == entry_id and cls.check_feedback_embedding(entry) == False)
    #         ):
    #             x = len(x_list)
    #             y = y_list[-1] if len(y_list) > 0 else 0
    #             bottom = var_bottom_list[-1] if (len(var_bottom_list) > 0 and need_var == True) else 0
    #             upper = var_upper_list[-1] if (len(var_upper_list) > 0 and need_var == True) else 0
    #             x_list.append(x)
    #             y_list.append(y)
    #             var_bottom_list.append(bottom)
    #             var_upper_list.append(upper)
            
    #         if len(x_list) > entry_id:
    #             assert len(x_list) == entry_id + 1
    #             continue

    #         assert len(x_list) == entry_id
    #         x_list.append(entry_id)
    #         assert cls.check_feedback_embedding(entry) == True
    #         cur_entries.append(entry)
    #         main_diversity = cls.get_diversity(cur_entries, main_granu)
    #         if need_var == True:
    #             var_diversity_list = [
    #                 cls.get_diversity(cur_entries, granu)
    #                 for granu in var_granu_list
    #             ]
    #             min_diversity = min(var_diversity_list)
    #             max_diverstiy = max(var_diversity_list)
    #         else:
    #             min_diversity = 0
    #             max_diverstiy = 0

    #         y_list.append(main_diversity)
    #         var_bottom_list.append(min_diversity)
    #         var_upper_list.append(max_diverstiy)


    #     assert len(y_list) == len(var_bottom_list) == len(var_upper_list) == len(x_list) == entries[-1]["id"] - base_entry_id + 1

    #     plt.plot(x_list, y_list, color=color, label=label, linestyle=linestyle, linewidth=4)
    #     if need_var == True:
    #         plt.fill_between(x_list, var_bottom_list, var_upper_list, color=color, linewidth=0, edgecolor=None, alpha=0.2)