import math
from collections import Counter


def step_action_score2text_r0(step_action_score):

    if len(step_action_score) == 0:
        return None
    
    res = ""
    for stepi in step_action_score:
        
        root_multiply = False
        root_divide = False
        root_add = False
        root_subtract = False

        # check if there are actions that were never taken - will have high UCB score
        for acti in step_action_score[stepi]:
            if '*' in acti:
                root_multiply = True
            elif '/' in acti:
                root_divide = True
            elif '+' in acti:
                root_add = True
            elif '-' in acti:
                root_subtract = True

        if all([root_multiply, root_divide, root_add, root_subtract]):
            
            res += 'Possible options for '+stepi+' and expected rewards:\n'
            val_floats = [float(zi) for zi in step_action_score[stepi].values()]
            theta = max(val_floats)
            for acti in step_action_score[stepi]:
                scorei = "HIGH" if step_action_score[stepi][acti] >= theta else "LOW"
                res += acti +' has '+ scorei + ' reward\n'
            
        else:
            res += 'Possible options for '+stepi+' and expected rewards:\n'
            val_floats = [float(zi) for zi in step_action_score[stepi].values()]
            theta = max(val_floats)
            for acti in step_action_score[stepi]:
                scorei = "HIGH" if step_action_score[stepi][acti] > theta else "LOW"
                res += acti +' has '+ scorei + ' reward\n'
            
            
            if root_multiply:
                res += 'Multiplication(*) based operation over other numbers in the Problem 4 has HIGH reward\n'
            if root_divide:
                res += 'Division(/) based operation over other numbers in the Problem 4 has HIGH reward\n'
            if root_add:
                res += 'Addition(+) based operation over other numbers in the Problem 4 has HIGH reward\n'
            if root_subtract:
                res += 'Subtraction(-) operation over other numbers in the Problem 4 has HIGH reward\n'

        if not root_multiply:
            res += 'Multiplication(*) based operation has HIGH reward\n'
        if not root_divide:
            res += 'Division(/) based operation has HIGH reward\n'
        if not root_add:
            res += 'Addition(+) based operation has HIGH reward\n'
        if not root_subtract:
            res += 'Subtraction(-) operation has HIGH reward\n'

        res += '\n'    
    return res

def step_action_score2text(step_action_score):

    if len(step_action_score) == 0:
        return None
    
    res = ""
    for stepi in step_action_score:
            
        res += 'Possible options for '+stepi+' and expected rewards:\n'
        val_floats = [float(zi) for zi in step_action_score[stepi].values()]
        theta = max(val_floats)
        
        if theta != 0.0:
            for acti in step_action_score[stepi]:
                scorei = "HIGH" if step_action_score[stepi][acti] >= theta else "LOW"
                res += acti +' has '+ scorei + ' reward\n'
            res += '\n'
        else:
            for acti in step_action_score[stepi]:
                scorei = "LOW"
                res += acti +' has '+ scorei + ' reward\n'
            res += '\n'    
    return res

def get_ucb_score(reward, C, N, n):
    return reward + C*(math.sqrt(math.log(N)/n))

def step_action_score2text1(step_action_score):

    if len(step_action_score) == 0:
        return None
    
    res = ""
    for stepi in step_action_score:
        val_floats = [float(zi) for zi in step_action_score[stepi].values()]
        theta = max(val_floats)
        if theta != 0.0:
            res += 'Possible options for '+stepi+' and expected rewards:\n'
            for acti in step_action_score[stepi]:
                if step_action_score[stepi][acti] >= theta:
                    scorei = "HIGH" 
                    res += acti +' has '+ scorei + ' reward\n'
                res += '\n'
    return res


def step_action_score2text5(step_action_score):

    if len(step_action_score) == 0:
        return None
    
    res = ""
    for stepi in step_action_score:
        
        root_multiply = False
        root_divide = False
        root_add = False
        root_subtract = False

        # check if there are actions that were never taken - will have high UCB score
        for acti in step_action_score[stepi]:
            if '*' in acti:
                root_multiply = True
            elif '/' in acti:
                root_divide = True
            elif '+' in acti:
                root_add = True
            elif '-' in acti:
                root_subtract = True

        if all([root_multiply, root_divide, root_add, root_subtract]):
            
            res += 'Possible options for '+stepi+' and expected rewards:\n'
            val_floats = [float(zi) for zi in step_action_score[stepi].values()]
            theta = max(val_floats)
            for acti in step_action_score[stepi]:
                scorei = "HIGH" if step_action_score[stepi][acti] >= theta else "LOW"
                res += acti +' has '+ scorei + ' reward\n'
            res += '\n'
        else:
            res += 'Possible options for '+stepi+' and expected rewards:\n'
            val_floats = [float(zi) for zi in step_action_score[stepi].values()]
            theta = max(val_floats)
            for acti in step_action_score[stepi]:
                scorei = "HIGH" if step_action_score[stepi][acti] > theta else "LOW"
                res += acti +' has '+ scorei + ' reward\n'
            res += '\n'
            
        res += '\n'
    
    return res


def step_action_score2text_ucl(step_action_score):
    if len(step_action_score) == 0:
        return None
    res = ""
    for stepi in step_action_score:
        res += 'Possible options for '+stepi+' and expected rewards:\n'
        for acti in step_action_score[stepi]:
            scorei = "HIGH" if step_action_score[stepi][acti] > 0.0 else "LOW"
            res += acti +' has '+ scorei + ' reward\n'
        res += '\n'
    
    return res

def get_action_ops2token_ids(action_ops):

    import tiktoken
    enc = tiktoken.encoding_for_model(xmodel)
    action_ops2token_ids = {}
    
    for op in action_ops:
        temp = []
        temp += enc.encode(op)
        temp += enc.encode(' '+op)
        action_ops2token_ids[op] = list(set(temp))

    return action_ops2token_ids



def past_actions_review(state_config, state_action_score, \
                        state_action, state_action_counter, \
                        state_counter, R, C, K, B):

    N = state_counter[str(state_config)]

    all_actions_comp = []
    for s in state_action_counter:
        if s == str(state_config):
            for a in state_action_counter[s]:
                for t in range(state_action_counter[s][a]):
                    all_actions_comp.append(a)
    all_actions_comp = Counter(" ".join(all_actions_comp).split())
                
    token_lvl_exploitation = defaultdict(float)
    token_lvl_exploration = defaultdict(float)
    
    for act in state_action[str(state_config)]:
        for tok in act.split():
            
            n = float(all_actions_comp[tok])
            token_lvl_exploration[tok] = C*math.sqrt((math.log(N)+0.00005)/n)
            token_lvl_exploitation[tok] += state_action_score[str(state_config)][act]

    token_lvl_ucb = {}
    for tok in token_lvl_exploration:
        token_lvl_ucb[tok] = token_lvl_exploitation[tok] + token_lvl_exploration[tok]

    token2bias = {}
                            
    for tok in token_lvl_ucb:
        # print(state_counter)
        # print(state_action_counter)
        if False:
            token2bias[tok] = token_lvl_ucb[tok]
        else:
            token2bias[tok] = B * math.log(token_lvl_ucb[tok] / K)
        # token2bias[tok] = token_lvl_ucb[tok]
        
    # action_ops2token_ids = get_action_ops2token_ids(action_ops)

    

    # print(token2bias)
    return token2bias
