import copy
import random
from tqdm import tqdm
import math


clrs = ['red', 'blue', 'orange', 'yellow', 'white', 'magenta', 'black', 'cyan', 'green', 'violet', 'gold', 'silver']


def get_blocksworld_data(NO_OF_STEPS_IN_ANSWER):

    DATA_PATH = "/export/home/data/lam_mcts/task1_reasoning.txt"
    clrs = ['red', 'blue', 'orange', 'yellow', 'white', 'magenta', 'black', 'cyan', 'green', 'violet', 'gold', 'silver']

    with open(DATA_PATH) as f:
        d = f.readlines()
    d = [i[:-1] for i in d]
    
    result = []
    
    i = 0
    while i < len(d):
        
        
        if d[i] == '===================================SUCCESS===================================' or \
        d[i] == '===================================FAILURE===================================':
            
            instruction = []
            i += 1
            while d[i] != '[STATEMENT]':
                instruction.append(d[i])
                i += 1
            
            problem = []
            i += 1
            while d[i] != '[PLAN]':
                problem.append(d[i])
                i += 1
                
            solution = []
            i += 1
            while d[i] != '[PLAN END]':
                solution.append(d[i])
                i += 1
                
            real_problem = []
            while d[i] != '[STATEMENT]': 
                i+=1
            i += 1
            while d[i] != '[PLAN]':
                real_problem.append(d[i])
                i += 1
            
            while d[i] != '-------- Ground truth plan ---------': 
                i+=1
            i += 1
            real_solution = []
            while d[i] != '=============================================================================':
                real_solution.append(d[i])
                i += 1
                
            result.append({'instruction': instruction,
                          'problem': problem,
                          'solution': solution,
                          'real_problem': real_problem,
                          'real_solution': real_solution})
                
        else:
            i += 1
            continue
            
    # extract 2-steps
    result_2  = []
    c = []
    
    for i in result:
        if len(i['real_solution'])-1 == NO_OF_STEPS_IN_ANSWER:
            result_2.append(i)
          
    
    # remove 'My plan is as follows:', from problem and real_problem
    for i in range(len(result_2)):
        result_2[i]['problem'] = result_2[i]['problem'][:-2]
        result_2[i]['real_problem'] = result_2[i]['real_problem'][:-2]
        
    # extract goal
    
    for i in range(len(result_2)):
        result_2[i]['goal'] = result_2[i]['real_problem'][-2]
        result_2[i]['goal_pre'] = result_2[i]['problem'][-2]
        result_2[i]['real_problem'] = result_2[i]['real_problem'][:-2]
        result_2[i]['problem'] = result_2[i]['problem'][:-2]
    
    result = result_2

    for i in range(len(result)):
        result[i]['instruction'] = "\n".join(result[i]['instruction'])
        result[i]['problem'] = "\n".join(result[i]['problem'])
        result[i]['solution'] = "\n".join(result[i]['solution'])
        result[i]['real_problem'] = "\n".join(result[i]['real_problem'])
        result[i]['real_solution'] = "\n".join(result[i]['real_solution'])


    for i in range(len(result)):
        temp = []
        for j in clrs:
            if j in result[i]['real_problem'].lower():
                temp.append(j)
        result[i]['participating_blocks'] = temp

    return result


def state_text2json(text, participating_blocks):
    if isinstance(text, str):
        text = text.replace(' and ', ',')
        text = text.split(',')
    state = {'hand_empty': None}
    for blk in participating_blocks:
        state[blk] = {'is_on_top_of': None, 'is_the_bottom_of': None}
    for entry in text:
        if len(entry) < 5 or 'condition' in entry or 'initial' in entry:
            continue
        if 'hand' in entry and 'empty' in entry:
            state['hand_empty'] = True
        elif 'hand' in entry:
            entry_list = entry.split()
            state['hand_empty'] = entry_list[entry_list.index('block')-1]
        elif 'clear' in entry:
            entry_list = entry.split()
            state[entry_list[entry_list.index('block')-1]]['is_the_bottom_of'] = 'clear'
        elif 'table' in entry:
            entry_list = entry.split()
            state[entry_list[entry_list.index('block')-1]]['is_on_top_of'] = 'table'
        elif 'top' in entry:
            # print('*'*50)
            # print(entry)
            entry_list = entry.split()
            indices = [i for i, x in enumerate(entry_list) if x == "block"]
            top_block = entry_list[indices[0]-1]
            bottom_block = entry_list[indices[1]-1]
            state[top_block]['is_on_top_of'] = bottom_block
            state[bottom_block]['is_the_bottom_of'] = top_block
        else:
            print(entry)
            import sys
            sys.exit()
    return state



def state_json2text(state):
    res = []
    if state['hand_empty'] == True:
        res.append('the hand is empty')
    elif state['hand_empty'] in clrs:
        res.append(f"the {state['hand_empty']} block is in the hand")
    for col_blk in state:
        if col_blk == 'hand_empty':
            continue
        if state[col_blk]['is_on_top_of'] == 'table':
            res.append(f"the {col_blk} block is on the table")
        elif state[col_blk]['is_on_top_of'] in clrs:
            res.append(f"the {col_blk} block is on top of the {clrs[clrs.index(state[col_blk]['is_on_top_of'])]} block")
        if state[col_blk]['is_the_bottom_of'] == 'clear':
            res.append(f"the {col_blk} block is clear")
        elif state[col_blk]['is_the_bottom_of'] in clrs:
            res.append(f"the {clrs[clrs.index(state[col_blk]['is_the_bottom_of'])]} block is on top of the {col_blk} block")
    res = list(set(res))    
    return ", ".join(res)[:]+'.'


def add_action_to_json_state(state_prime, action_prime):
    try:
        state = copy.deepcopy(state_prime)
        action = copy.deepcopy(action_prime)
        
        entry_list = action.split()
        valid_action = True
    
        if 'unstack' in action and 'table' in action:
            action = action.replace('unstack', 'pick')
        elif 'stack' in action and 'table' in action:
            action = action.replace('stack', 'put')
        
        if 'pick' in action:
            # to perform pick, hand has to be empty, blk should be on table and should be clear
            clr_blk = entry_list[entry_list.index('block')-1]
            if state['hand_empty'] != True or state[clr_blk]['is_on_top_of'] != 'table'\
                or state[clr_blk]['is_the_bottom_of'] != 'clear':
                valid_action = False
                return state, valid_action
            state['hand_empty'] = clr_blk
            state[clr_blk]['is_the_bottom_of'] = None
            state[clr_blk]['is_on_top_of'] = None
            
        elif 'put' in action:
            #
            clr_blk = entry_list[entry_list.index('block')-1]
            if state['hand_empty'] == True or state[clr_blk]['is_on_top_of'] is not None\
                or state[clr_blk]['is_the_bottom_of'] is not None:
                valid_action = False
                return state, valid_action
            state['hand_empty'] = True
            state[clr_blk]['is_the_bottom_of'] = 'clear'
            state[clr_blk]['is_on_top_of'] = 'table'
        
        elif 'unstack' in action:
            #
            entry_list = action.split()
            indices = [i for i, x in enumerate(entry_list) if x == "block"]
            blk1 = entry_list[indices[0]-1]
            blk2 = entry_list[indices[1]-1]
            
            if state[blk1]['is_on_top_of'] != blk2 or \
                state[blk2]['is_the_bottom_of'] != blk1:
                valid_action = False
                return state, valid_action
            state['hand_empty'] = blk1
            state[blk1]['is_the_bottom_of'] = None
            state[blk1]['is_on_top_of'] = None
            state[blk2]['is_the_bottom_of'] = 'clear'
            
        elif 'stack' in action:
            #
            entry_list = action.split()
            indices = [i for i, x in enumerate(entry_list) if x == "block"]
            blk1 = entry_list[indices[0]-1]
            blk2 = entry_list[indices[1]-1]
            
            if state['hand_empty'] != blk1 or \
                state[blk2]['is_the_bottom_of'] != 'clear':
                valid_action = False
                return state, valid_action
            state['hand_empty'] = True
            state[blk1]['is_on_top_of'] = blk2
            state[blk1]['is_the_bottom_of'] = 'clear'
            state[blk2]['is_the_bottom_of'] = blk1
    except:
        state_prime, False
        
    return state, valid_action
        

def add_action_to_text_state(action, state):
    pass


def add_action_to_init_block_state(state_prime, action_prime):

    state = copy.deepcopy(state_prime)
    action = copy.deepcopy(action_prime)
    
    valid_action = True
    if 'unstack' in action and 'table' in action:
        action = action.replace('unstack', 'pick')
    elif 'stack' in action and 'table' in action:
        action = action.replace('stack', 'put')
        
    if 'top' in action:
        try:
            entry_list = action.split()
            indices = [i for i, x in enumerate(entry_list) if x == "block"]
            blk1 = entry_list[indices[0]-1]
            blk2 = entry_list[indices[1]-1]

            # the condition is already met then don't execute
            if state[blk2]['is_the_bottom_of'] == blk1 and state[blk1]['is_on_top_of'] == blk2:
                return state, True
            
            if state[blk2]['is_the_bottom_of'] != 'clear':
                valid_action = False
                return state, valid_action
            
            ref1 = state[blk1]['is_on_top_of']
            if ref1 != 'table':
                state[ref1]['is_the_bottom_of'] = 'clear'
            state[blk1]['is_on_top_of'] = blk2
            
            state[blk2]['is_the_bottom_of'] = blk1
        except:
            return state, False
        
    return state, valid_action
    
    

def action_tuple2text(tuple_action):
    if tuple_action[0] == 'pick':
        return f"pick up the {tuple_action[1]} block"
    elif tuple_action[0] == 'put':
        return f"put down the {tuple_action[1]} block"
    elif tuple_action[0] == 'unstack':
        return f"unstack {tuple_action[1]} block from on top of the {tuple_action[2]} block"
    elif tuple_action[0] == 'stack':
        return f"stack the {tuple_action[1]} block on top of the {tuple_action[2]} block"
        
    

def action_text2tuple(action):
    entry_list = action.split()
    if 'pick' in action:
        clr_blk = entry_list[entry_list.index('block')-1]
        return ('pick', clr_blk)
    elif 'put' in action:
        clr_blk = entry_list[entry_list.index('block')-1]
        return ('put', clr_blk)
    elif 'unstack' in action:
        indices = [i for i, x in enumerate(entry_list) if x == "block"]
        blk1 = entry_list[indices[0]-1]
        blk2 = entry_list[indices[1]-1]
        return ('unstack', blk1, blk2)
    elif 'stack' in action:
        indices = [i for i, x in enumerate(entry_list) if x == "block"]
        blk1 = entry_list[indices[0]-1]
        blk2 = entry_list[indices[1]-1]
        return ('stack', blk1, blk2)



def real_solution2text(rs):
    rs = rs.replace('(', '').replace(')', '').split('\n')
    res = []
    for st in rs:
        if len(st) < 5:
            continue
        if 'unstack' in st:
            _, b1, b2 = st.split()
            res.append(f"unstack {b1} block from on top of {b2} block")
        elif 'stack' in st:
            _, b1, b2 = st.split()
            res.append(f"stack {b1} block on top of {b2} block")
        elif 'pick' in st:
            _, b1 = st.split()
            res.append(f"pick {b1} block from on top of the table")
        elif 'put' in st:
            _, b1 = st.split()
            res.append(f"put {b1} block on top of the table")
    return res


def step_action_score2text(step_action_score):

    if len(step_action_score) == 0:
        return None
    
    res = ""
    for stepi in step_action_score:
        
        root_stack = False
        root_unstack = False
        root_pick = False
        root_put = False
            
        res += 'Possible options for '+stepi+' and expected rewards:\n'
        val_floats = [float(zi) for zi in step_action_score[stepi].values()]
        theta = max(val_floats)
        if theta != 0.0:
            for acti in step_action_score[stepi]:
                scorei = "HIGH" if step_action_score[stepi][acti] >= theta else "LOW"
                res += acti +' has '+ scorei + ' reward\n'
            res += '\n'
        else:
            for acti in step_action_score[stepi]:
                scorei = "LOW"
                res += acti +' has '+ scorei + ' reward\n'
            res += '\n'
        
    
    return res

def get_ucb_score(reward, C, N, n):
    return reward + C*(math.sqrt(math.log(N)/n))


def step_action_score2text_llm_following(step_action_ucb):

    from collections import defaultdict

    if len(step_action_ucb) == 0:
        return None
    
    res = ""
    for stepi in step_action_ucb:
        
        root_stack = False
        root_unstack = False
        root_pick = False
        root_put = False

        # check if there are actions that were never taken - will have high UCB score
        for acti in step_action_ucb[stepi]:
            
            if 'unstack' in acti:
                root_unstack = True
            elif 'stack' in acti:
                root_stack = True
            elif 'pick' in acti:
                root_pick = True
            elif 'put' in acti:
                root_put = True

        if all([root_stack, root_unstack, root_pick, root_put]):
            
            res += 'Possible options for '+stepi+' and expected rewards:\n'
            val_floats = [float(zi) for zi in step_action_ucb[stepi].values()]
            theta = max(val_floats)
            for acti in step_action_ucb[stepi]:
                scorei = "HIGH" if step_action_ucb[stepi][acti] >= theta else "LOW"
                res += acti +' has '+ scorei + ' reward\n'
                # if scorei == 'HIGH':
                #     high_action_per_step[stepi].append(acti.lower().split(' ')[0])
            res += '\n'
        else:
            res += 'Possible options for '+stepi+' and expected rewards:\n'
            val_floats = [float(zi) for zi in step_action_ucb[stepi].values()]
            theta = max(val_floats)
            for acti in step_action_ucb[stepi]:
                scorei = "HIGH" if step_action_ucb[stepi][acti] > theta else "LOW"
                res += acti +' has '+ scorei + ' reward\n'
                # if scorei == 'HIGH':
                #     high_action_per_step[stepi].append(acti.lower().split(' ')[0])
            
            
            if not root_stack:
                res += 'stack based action has HIGH reward\n'
                # res += 'stack a block on top of another block has HIGH reward\n'
                # high_action_per_step[stepi].append('stack')
            if not root_unstack:
                res += 'unstack based action has HIGH reward\n'
                # res += 'unstack a block from on top of another block has HIGH reward\n'
                # high_action_per_step[stepi].append('unstack')
            if not root_pick:
                res += 'pick based action has HIGH reward\n'
                # res += 'pick up a block has HIGH reward\n'
                # high_action_per_step[stepi].append('pick')
            if not root_put:
                res += 'put based action has HIGH reward\n'
                # res += 'put down a block has HIGH reward\n'
                # high_action_per_step[stepi].append('put')
            res += '\n'

        res += '\n'
    return res

def step_action_score2text_llm_following_v1(step_action_ucb):

    from collections import defaultdict

    if len(step_action_ucb) == 0:
        return None
    
    res = ""
    for stepi in step_action_ucb:

        odd_step = False
        even_step = False
        if '1' in stepi or '3' in stepi or '5' in stepi:
            odd_step = True
        else:
            even_step = True
        
        root_stack = False
        root_unstack = False
        root_pick = False
        root_put = False

        # check if there are actions that were never taken - will have high UCB score
        for acti in step_action_ucb[stepi]:
            
            if 'unstack' in acti:
                root_unstack = True
            elif 'stack' in acti:
                root_stack = True
            elif 'pick' in acti:
                root_pick = True
            elif 'put' in acti:
                root_put = True

        if (odd_step and all([root_unstack, root_pick])) or (even_step and all([root_stack, root_put])):
            
            res += 'Possible options for '+stepi+' and expected rewards:\n'
            val_floats = [float(zi) for zi in step_action_ucb[stepi].values()]
            theta = max(val_floats)
            for acti in step_action_ucb[stepi]:
                scorei = "HIGH" if step_action_ucb[stepi][acti] > theta else "LOW"
                res += acti +' has '+ scorei + ' reward\n'
                # if scorei == 'HIGH':
                #     high_action_per_step[stepi].append(acti.lower().split(' ')[0])
            res += '\n'
        else:
            res += 'Possible options for '+stepi+' and expected rewards:\n'
            val_floats = [float(zi) for zi in step_action_ucb[stepi].values()]
            theta = max(val_floats)
            for acti in step_action_ucb[stepi]:
                scorei = "HIGH" if step_action_ucb[stepi][acti] > theta else "LOW"
                res += acti +' has '+ scorei + ' reward\n'
                # if scorei == 'HIGH':
                #     high_action_per_step[stepi].append(acti.lower().split(' ')[0])
            
            
            if even_step and not root_stack:
                res += 'stack based action has HIGH reward\n'
                # res += 'stack a block on top of another block has HIGH reward\n'
                # high_action_per_step[stepi].append('stack')
            if odd_step and not root_unstack:
                res += 'unstack based action has HIGH reward\n'
                # res += 'unstack a block from on top of another block has HIGH reward\n'
                # high_action_per_step[stepi].append('unstack')
            if odd_step and not root_pick:
                res += 'pick based action has HIGH reward\n'
                # res += 'pick up a block has HIGH reward\n'
                # high_action_per_step[stepi].append('pick')
            if even_step and not root_put:
                res += 'put based action has HIGH reward\n'
                # res += 'put down a block has HIGH reward\n'
                # high_action_per_step[stepi].append('put')
            res += '\n'

        res += '\n'
    return res


def get_action_ops2token_ids(action_ops, xmodel):

    import tiktoken
    enc = tiktoken.encoding_for_model(xmodel)
    action_ops2token_ids = {}
    
    for op in action_ops:
        temp = []
        temp += enc.encode(op)
        temp += enc.encode(' '+op)
        action_ops2token_ids[op] = list(set(temp))

    return action_ops2token_ids

def past_actions_review(state_config, state_action_score, \
                        state_action, state_action_counter, \
                        state_counter, action_ops, R, C, K, B, action_ops2token_ids):

    N = state_counter[str(state_config)]

    all_actions_comp = []
    for s in state_action_counter:
        if s == str(state_config):
            for a in state_action_counter[s]:
                for t in range(state_action_counter[s][a]):
                    all_actions_comp.append(a)
    all_actions_comp = Counter(" ".join(all_actions_comp).split())
                
    token_lvl_exploitation = defaultdict(float)
    token_lvl_exploration = defaultdict(float)
    
    for act in state_action[str(state_config)]:
        for tok in act.split():
            if tok in action_ops:
                n = float(all_actions_comp[tok])
                token_lvl_exploration[tok] = C*math.sqrt((math.log(N)+0.00005)/n)
                token_lvl_exploitation[tok] += state_action_score[str(state_config)][act]

    token_lvl_ucb = {}
    for tok in token_lvl_exploration:
        token_lvl_ucb[tok] = token_lvl_exploitation[tok] + token_lvl_exploration[tok]

    token2bias = {}                   
    for tok in token_lvl_ucb:
        if False:
            token2bias[tok] = token_lvl_ucb[tok]
        else:
            token2bias[tok] = B * math.log(token_lvl_ucb[tok] / K)

    return token2bias

        






    
    
    
    
        
    
        