from utils import *
from prompts import *
from collections import defaultdict
import math
import numpy as np
import openai
import re
import argparse



class Agent():
    
    def __init__(self, OPENAI_API_KEY):
        openai.api_key = OPENAI_API_KEY
        self.model_name = "gpt-3.5-turbo-0301"

    def predict_answer(self, user_message, lb, temperature=0.0):
        response = openai.ChatCompletion.create(
            model=self.model_name,
            messages=[
                {'role': 'system', 'content': "You are a helpful assistant who can solve logical puzzles."},
                {"role": "user", "content": user_message}],
            max_tokens=150,
        temperature=temperature,
        n=1,
        logit_bias=lb)
        return response




def ucl_cot(agent, state_config, state_action_score, state_action, \
       state_action_counter, state_counter, \
      state, target, final_block_config, R, C, K, B, depth, action_ops):

    trajectory = []
    action_ops2token_ids = get_action_ops2token_ids(action_ops, agent.model_name)
    action_ops2token_ids_list = []
          
    for i in list(action_ops2token_ids.values()):
        action_ops2token_ids_list+=i
        
    track_loop = []
    
    for _ in range(depth):

        if str(state_config) in track_loop:
            break
        track_loop.append(str(state_config))
    
        prompt=prompt_without_history_v5(state_json2text(state_config), target, None)
    
        check = True
        valid_state = True
        check_counter = 0
        temp_poss = [0.0, 0.5, 1.0]
        token2bias=None
        action = None
        
        while check and check_counter < 5:
            try:
                token2bias = past_actions_review(state_config, state_action_score, \
                                            state_action, state_action_counter, \
                                            state_counter, action_ops, R, C, K, B, action_ops2token_ids )
                
                token_id2bias = {}
                for tok in token2bias:
                    for j in action_ops2token_ids[tok]:
                        token_id2bias[j] = token2bias[tok]
                for ttok in action_ops2token_ids:
                    if ttok not in token2bias:
                        for j in action_ops2token_ids[ttok]:
                            token_id2bias[j] = 10
                
                action = None
                action = agent.predict_answer(prompt, lb=token_id2bias, temperature=0.0)
                action = action['choices'][0]['message']['content']
                action = action.lower().split('1:')[1].split('step')[0].replace('\n', '').strip()
                temp_state, valid_state = add_action_to_json_state(state_config, action)
                if valid_state:
                    check = False
                else:
                    if action:
                        if action not in state_action[str(state_config)]:
                            state_action[str(state_config)].append(action)
                        state_counter[str(state_config)] += 1
                        state_action_counter[str(state_config)][action] += 1
                        state_action_score[str(state_config)][action] = 0.0
                    
                    check_counter+=1
                    
                # TODO
                # if an action leads to invalid state then the respective toks should have -100 bias
            except Exception as e:
                check_counter += 1
                import time
                time.sleep(3)

        if check and check_counter >= 5:
            return state_action_score, state_action, state_action_counter, state_counter
    
        # add action to state
        
        if action not in state_action[str(state_config)]:
            state_action[str(state_config)].append(action)
        state_counter[str(state_config)] += 1
        state_action_counter[str(state_config)][action] += 1
        trajectory.append((str(state_config), action))
               
        if temp_state == final_block_config:
            reward = R
            for ss, aa in trajectory:
                state_action_score[ss][aa] += reward
            break
                
        else:
            reward = 0.0
            for ss, aa in trajectory:
                state_action_score[ss][aa] += reward
            state_config = copy.deepcopy(temp_state)
            
    return state_action_score, state_action, state_action_counter, state_counter



def main(args):
    UCB_CONSTANT = args.exploration_constant
    grid_reward = args.reward
    K = args.K
    B = args.B
    depth = args.depth

    comp_answer_steps = args.no_of_answer_steps.split(',')

    for NO_OF_STEPS_IN_ANSWER in comp_answer_steps:
        print('#'*50)
        print('No of steps in ans: ', NO_OF_STEPS_IN_ANSWER)
        print('#'*50)

        bw_data = get_blocksworld_data(int(NO_OF_STEPS_IN_ANSWER))
        
        for _ in range(args.no_of_trials):
            
            preds = []
            avg_actions = []
            
            for item_idx, item in tqdm(enumerate(bw_data)):
                state_action_score = defaultdict(lambda: defaultdict(float))
                state_action = defaultdict(list)
                state_action_counter = defaultdict(lambda: defaultdict(int))
                state_counter = defaultdict(int)

                # Initialize Agent
                agent = Agent(args.OPENAI_API_KEY)

                action_operators = ['unstack', 'stack', 'pick', 'put']
                action_operands = [] #item['participating_blocks']
                action_ops = action_operators + action_operands
                
                init_block_config = state_text2json(item['real_problem'], item['participating_blocks'])
                final_block_config = copy.deepcopy(init_block_config)
                gt_action_sequence = real_solution2text(item['real_solution'])
                
                for action in gt_action_sequence:
                    final_block_config, valid_action = add_action_to_json_state(final_block_config, action)
                    assert valid_action is True, 'Cannot reach final block config'
    
    
                # Learn: run iterations
                for pq in range(args.no_of_passes):
                    state_action_score, state_action, state_action_counter, state_counter = ucl_cot(
                        agent, copy.deepcopy(init_block_config), \
                   state_action_score, state_action, \
                   state_action_counter, state_counter, \
                   state_json2text(init_block_config), \
                   state_json2text(final_block_config), \
                   final_block_config, grid_reward, UCB_CONSTANT, K, B, depth, action_ops
                    )    
    
                final_steps = []
                
                for stepi in state_action_score:
                    best_action, best_score = sorted(state_action_score[stepi].items(), key=lambda x: x[1], reverse=True)[0]
                    if best_score != 0.0:
                        final_steps.append(best_action)
                    else:
                        break

                # If none of the generated solutions are correct then step_action_score's best_score will always be zero
                # Hence a non-empty final_steps indicates proposed solution is correct
                if final_steps:
                    preds.append(1)
            
            print('No of questions: ', str(len(bw_data)))
            print('No of correct answers: ', sum(preds))
        print()

    
    


if __name__ == '__main__':
    
    parser = argparse.ArgumentParser()
    parser.add_argument('-no_of_passes', default=10, type=int)
    parser.add_argument('-no_of_trials', default=1, type=int)
    parser.add_argument('-K', default=5, type=int)
    parser.add_argument('-B', default=2, type=int)
    parser.add_argument('-depth', default=10, type=int)
    parser.add_argument('-reward', default=1, type=int)
    parser.add_argument('-exploration_constant', default=10, type=int)
    parser.add_argument('-model_temperature', default=0.0, type=float)
    parser.add_argument('-OPENAI_API_KEY', default="")
    parser.add_argument('-no_of_answer_steps', default='2,4,6')
    args = parser.parse_args()
    
    main(args)