from utils import *
from prompts import *
from collections import defaultdict
import math
import numpy as np
import openai
import re
import argparse



class Agent():
    
    def __init__(self, OPENAI_API_KEY):
        openai.api_key = OPENAI_API_KEY
        self.model_name = "gpt-3.5-turbo-0301"

    def predict_answer(self, user_message, temperature=0.0):
        response = openai.ChatCompletion.create(
            model="gpt-3.5-turbo-0301",
            messages=[
                {'role': 'system', 'content': "You are a helpful assistant who can solve logical puzzles."},
                {"role": "user", "content": user_message}],
            temperature=temperature,
            max_tokens=1000)
        return response['choices'][0]['message']['content'].lower().replace('.', '')



def ucb_cot(agent, state, target, participating_blocks, step_action_score, step_counter, step_action_counter, UCB_CONSTANT, step_action_ucb, grid_reward):
    
    history = step_action_score2text_llm_following(step_action_ucb)

    if history:
        prompt=prompt_with_history_v2(state, target, history)
    else:
        prompt=prompt_without_history_v2(state, target, history)

    check = True
    check_counter = 5
    local_step_action_score = {}

    while check and check_counter:
        try:
            op = agent.predict_answer(prompt, args.model_temperature)
            op_list = op.split('\n')
            idx_sequence = []
            for st in op_list:
                if 'step sequence' in st:
                    continue
                elif 'step' in st:
                    idx, val = st.split(': ')
                    idx, val = idx.strip().lower(), val.strip().lower()
                    step_action_score[idx][val] += 0.0
                    local_step_action_score[idx] = val
                    idx_sequence.append(idx)
                    step_counter[idx] += 1
                    step_action_counter[idx][val] +=1
            check = False
        except:
            check_counter -= 1
            import time
            time.sleep(3)


    step_sequence = list(local_step_action_score.values())

    for idx, step in enumerate(step_sequence):
        new_state, valid_action = add_action_to_json_state(state_text2json(state.lower().replace('.', ''), participating_blocks), step)
        step_counter[idx_sequence[idx]] += 1
        step_action_counter[idx_sequence[idx]][step] += 1
        
        if not valid_action:
            for stepi in local_step_action_score:
                step_action_score[stepi][local_step_action_score[stepi]] += 0.0
                step_action_ucb[stepi][local_step_action_score[stepi]] = get_ucb_score(step_action_score[stepi][local_step_action_score[stepi]], UCB_CONSTANT, step_counter[stepi], step_action_counter[stepi][local_step_action_score[stepi]])
            return step_action_score, step_action_ucb, step_counter, step_action_counter
        else:
            if new_state == state_text2json(target.lower().replace('.', ''), participating_blocks):
                for jdx, jval in enumerate(step_sequence[:idx+1]):
                    step_action_score[idx_sequence[jdx]][jval] += float(grid_reward)
                    step_action_ucb[idx_sequence[jdx]][jval] = get_ucb_score(step_action_score[idx_sequence[jdx]][jval], UCB_CONSTANT, step_counter[idx_sequence[jdx]], step_action_counter[idx_sequence[jdx]][jval])
                return step_action_score, step_action_ucb, step_counter, step_action_counter
        state = state_json2text(new_state)

    if state == state_text2json(target.lower().replace('.', ''), participating_blocks):
        for stepi in local_step_action_score:
            step_action_score[stepi][local_step_action_score[stepi]] += float(grid_reward)
            step_action_ucb[stepi][local_step_action_score[stepi]] = get_ucb_score(step_action_score[stepi][local_step_action_score[stepi]], UCB_CONSTANT, step_counter[stepi], step_action_counter[stepi][local_step_action_score[stepi]])
        return step_action_score, step_action_ucb, step_counter, step_action_counter
    else:
        for stepi in local_step_action_score:
            step_action_score[stepi][local_step_action_score[stepi]] += 0.0
            step_action_ucb[stepi][local_step_action_score[stepi]] = get_ucb_score(step_action_score[stepi][local_step_action_score[stepi]], UCB_CONSTANT, step_counter[stepi], step_action_counter[stepi][local_step_action_score[stepi]])
        return step_action_score, step_action_ucb, step_counter, step_action_counter




def main(args):
    UCB_CONSTANT = args.exploration_constant
    grid_reward = args.reward

    comp_answer_steps = args.no_of_answer_steps.split(',')

    for NO_OF_STEPS_IN_ANSWER in comp_answer_steps:
        print('#'*50)
        print('No of steps in ans: ', NO_OF_STEPS_IN_ANSWER)
        print('#'*50)

        bw_data = get_blocksworld_data(int(NO_OF_STEPS_IN_ANSWER))
        
        for _ in range(args.no_of_trials):
            
            preds = []
            avg_actions = []
            
            for item_idx, item in tqdm(enumerate(bw_data)):
                step_action_score = defaultdict(lambda: defaultdict(float))
                step_action_ucb = defaultdict(lambda: defaultdict(float))
                step_counter = defaultdict(int)
                step_action_counter = defaultdict(lambda: defaultdict(int))

                # Initialize Agent
                agent = Agent(args.OPENAI_API_KEY)
                
                init_block_config = state_text2json(item['real_problem'], item['participating_blocks'])
                final_block_config = copy.deepcopy(init_block_config)
                gt_action_sequence = real_solution2text(item['real_solution'])
                
                for action in gt_action_sequence:
                    final_block_config, valid_action = add_action_to_json_state(final_block_config, action)
                    assert valid_action is True, 'Cannot reach final block config'
    
    
                # Learn: run iterations
                for pq in range(args.no_of_passes):
                    step_action_score, step_action_ucb, step_counter, step_action_counter = ucb_cot(
                        agent, state_json2text(init_block_config), 
                        state_json2text(final_block_config), 
                        item['participating_blocks'], 
                        step_action_score,
                        step_counter,
                        step_action_counter,
                        UCB_CONSTANT,
                        step_action_ucb,
                        grid_reward
                    )    
    
                final_steps = []
                
                for stepi in step_action_score:
                    best_action, best_score = sorted(step_action_score[stepi].items(), key=lambda x: x[1], reverse=True)[0]
                    if best_score != 0.0:
                        final_steps.append(best_action)
                    else:
                        break

                # If none of the generated solutions are correct then step_action_score's best_score will always be zero
                # Hence a non-empty final_steps indicates proposed solution is correct
                if final_steps:
                    preds.append(1)
            
            print('No of questions: ', str(len(bw_data)))
            print('No of correct answers: ', sum(preds))
        print()

    
    


if __name__ == '__main__':
    
    parser = argparse.ArgumentParser()
    parser.add_argument('-no_of_passes', default=10, type=int)
    parser.add_argument('-no_of_trials', default=1, type=int)
    parser.add_argument('-reward', default=1, type=int)
    parser.add_argument('-exploration_constant', default=10, type=int)
    parser.add_argument('-model_temperature', default=0.0, type=float)
    parser.add_argument('-OPENAI_API_KEY', default="")
    parser.add_argument('-no_of_answer_steps', default='2,4,6')
    args = parser.parse_args()
    
    main(args)