import math
import numpy as np

class obs_decomp_alg():
    def __init__(self):
        self.history = []
        pass

    '''
    sample entropy start
    '''

    # 对于一个dimension的叶节点的reward diction
    def single_dimension_leaves_statistics(self, leave_reward_list):
        '''
        Count the number of different classes. (Pay attention: For a branch of leaf node!)
        :param leave_reward_list: leaf nodes value list leave_reward_list=[1,1,1,1,1,2,2,2,2,3,3,3,3,3]. leave_reward_list can be the subset.
        :return: reward_dict {1:5,2:6,3:7}
        '''
        reward_unique = list(set(leave_reward_list))
        reward_dict = {}
        for i in range(len(reward_unique)):
            item = reward_unique[i]
            reward_dict[item] = leave_reward_list.count(item)
        return reward_dict

    # 一个dimension的entropy
    def dimension_entropy(self, DSO_table):
        '''
        This function is for calculate the entory of a dimension.
        # :param reward_dict: Dimension value in a list, like reward_list = [1,1,1,1,2,2,2,2,3,3,3,3]
        :param DSO_table: it can be whole and subset DSO_table
        :return: The entropy for a dimension leaf nodes
        '''
        reward_list = list(map(lambda x: x[len(x) - 1], DSO_table))
        reward_dict = self.single_dimension_leaves_statistics(reward_list)
        total_num = np.array(list(reward_dict.values())).sum()
        ent = 0
        for i in range(len(reward_dict)):
            key = list(reward_dict.keys())[i]
            p = reward_dict[key] / total_num
            ent += p * math.log2(p)
        dimension_ent = -ent
        boolean_weight_lambda=len(reward_dict)
        return dimension_ent,boolean_weight_lambda


    def lambda_list_gen(self, history, obs_select, act_select):
        '''
        :param history: history records
        :param obs_select: the selected observation
        :param act_select: the selected action
        :return: a DSO-table
        '''
        DSO_Table = []
        DSO_Table = [i for i in history if i[0] == obs_select and i[2] == act_select]
        return DSO_Table

    '''
    sample entropy end
    '''


    def history_record(self, obs, obs_p, act, re):
        '''
        :param obs: observation
        :param obs_p: observation prime (next obs)
        :param act: action
        :param re: reward
        :return: history list
        '''
        self.history.append([obs, obs_p, act, re])
        return self.history
