import json
import re
from collections import defaultdict
import math
import numpy as np
import openai
import re
import argparse
from gsm8k_utils import *
from gsm8k_prompts import *
from tqdm import tqdm

class Agent():

    def __init__(self, OPENAI_API_KEY):
        openai.api_key = OPENAI_API_KEY
        self.model_name = "gpt-3.5-turbo-0301"

    def predict_answer(self, user_message, max_tokens, temperature, lb):
        response = openai.ChatCompletion.create(
            model=self.model_name,
            messages=[
                {'role': 'system', 'content': "You are a helpful assistant who can solve arithmetic questions."},
                {"role": "user", "content": user_message}],
            max_tokens=max_tokens,
            temperature=temperature, top_p=0.00000000000001,
        n=1,
        logit_bias=lb)
        return response

    def validate_correctness(self, user_message):
        response = openai.ChatCompletion.create(
            model="gpt-3.5-turbo-0301",
            messages=[
                {'role': 'system', 'content': "You are a helpful assistant who can solve arithmetic questions. Check if the solution for the given problem is correct or not."},
                {"role": "user", "content": user_message}])
        return response['choices'][0]['message']['content'].lower().strip()


def get_gsm8k_dataset(dataset_name):

    assert dataset_name == 'gsm8k', 'GSM8K dataset only'

    # read few shot exmaples
    with open('./few_shot_3.jsonl') as f:
        few_shot_samples = [json.loads(line) for line in f]

    # read GSM8K test dataset
    with open('./gsm_test_all.jsonl') as f:
        data = [json.loads(line) for line in f]

    few_shot_samples_list = []
    for i in few_shot_samples:
        q = i['question']
        a = re.sub('<<.*?>>', '', i['answer'])
        a_lines = a.split('\n')
        steps = a_lines[:-1]
        final_ans = a_lines[-1].split('####')[1].strip()
        few_shot_samples_list.append((q, steps, final_ans))

    data_listt = []
    for i in data:
        q = i['question']
        a = re.sub('<<.*?>>', '', i['answer'])
        a_lines = a.split('\n')
        steps = a_lines[:-1]
        final_ans = a_lines[-1].split('####')[1].strip()
        data_listt.append([q, steps, final_ans])

    return data_listt



def ucl_cot(agent, sample, \
        state_action_score, state_action, \
        state_action_counter, state_counter,\
        R, C, K, B, depth, tmp):

    
    trajectory = []
    action_history = []
            
    state_config = sample[0]
    ans_list = sample[1]
    dc = 0
    ft = True
    

    while dc < depth:
        check = True
        valid_state = True
        check_counter = 0
        temp_poss = [0.0, 0.5, 1.0]
        max_tokens = 100
        
        while check and check_counter < 3:
            try:

                token2bias = past_actions_review(state_config, state_action_score, \
                        state_action, state_action_counter, \
                        state_counter, R, C, K, B)

                
                action_ops2token_ids = get_action_ops2token_ids(list(token2bias.keys()))
                
                action_ops2token_ids_list = []
                for i in list(action_ops2token_ids.values()):
                    action_ops2token_ids_list+=i

                token_id2bias = {}
                for tok in token2bias:
                    for j in action_ops2token_ids[tok]:
                        token_id2bias[j] = token2bias[tok]
                for ttok in action_ops2token_ids:
                    if ttok not in token2bias:
                        for j in action_ops2token_ids[ttok]:
                            token_id2bias[j] = 1

                prompt=prompt_without_history_ucl(state_config, ft)
                
                action = None
                org_action = None
                action = agent.predict_answer(prompt, lb=token_id2bias, temperature=tmp, max_tokens=max_tokens)
                action = action['choices'][0]['message']['content'].strip().replace('Solution for Problem 4:', '').replace('Solution for Problem 4:\n', '').replace('solution for problem 4:', '')
                if ft:
                    action = 'step 1: '+action
                action = action.strip().replace('\n step', '\nstep').replace('\n final', '\final')

                org_action = action.split('\nstep')[0].split('\nfinal')[0].strip()

                if 'step' in action.lower().split()[0]:
                    raction = action.split(':', 1)[1].split('\nstep')[0]
                    if raction in action_history:
                        check_counter += 1
                        continue
                    else:
                        check = False
                elif 'answer' in action.lower():
                    if '=' in action.lower():
                        raction = re.sub("[^0-9]", "", action.split('=', 1)[1].strip())
                    else:
                        raction = re.sub("[^0-9]", "", action.split(':', 1)[1].strip())
                    if raction in action_history:
                        check_counter += 1
                        continue
                    
                    else:
                        check = False
                else:
                    check_counter += 1

                ft = False
                    
            except Exception as e:
                max_tokens = 300
                check_counter += 1
                import time
                time.sleep(10)
        if check == True and check_counter >= 3:
            return state_action_score, state_action, state_action_counter, state_counter

        is_final = False
        if 'step' in action.lower().split()[0]:
            action = action.split(':', 1)[1].split('\nstep')[0]
        elif 'answer' in action.lower():
            action = action.replace('\n', '')
            if '=' in action.lower():
                action = re.sub("[^0-9]", "", action.split('=', 1)[1].strip()[:10])
            elif ':' in action.lower():
                action = re.sub("[^0-9]", "", action.split(':', 1)[1].strip()[:10])
            else:
                action = re.sub("[^0-9]", "", action.strip()[:10])
            is_final = True
        else:
            dc += 1
            action_history.append(action)
            continue
            
        if action not in state_action[str(state_config)]:
            state_action[str(state_config)].append(action)
        state_counter[str(state_config)] += 1
        state_action_counter[str(state_config)][action] += 1
        
        trajectory.append((str(state_config), action))
        action_history.append(action)
        

        if "yes" in agent.validate_correctness(prompt_without_history_ucl(state_config + org_action+'\n\nIs the solution and final answer provided for Problem 4 correct? If correct, then return "yes" else return "no"\n', False)):
            for ss, aa in trajectory:
                state_action_score[ss][aa] += R
            return state_action_score, state_action, state_action_counter, state_counter

        else:
            reward = 0.0
            for ss, aa in trajectory:
                state_action_score[ss][aa] += reward

            if is_final:
                return state_action_score, state_action, state_action_counter, state_counter
            elif org_action is not None:
                if 'step 1' in org_action.lower():
                    state_config = state_config + f"""\n\nSolution for Problem 4:\n""" + org_action + '\n'
                else:
                    state_config = state_config + org_action + '\n'
                

        dc += 1
    return state_action_score, state_action, state_action_counter, state_counter


def main(args):

    # get GSM8K dataset
    data_listt  = get_gsm8k_dataset('gsm8k')
    
    depth = args.depth
    R = args.reward
    C = args.exploration_constant
    K = args.K
    B = args.B
    tmp = args.model_temperature
    
    for jik in range(args.no_of_trials):
        preds = []
        print('No of samples: ', len(data_listt))
        print('R, C, tmp: ', R, C, tmp)
        
        for xyz, d in tqdm(enumerate(data_listt)):
            state_ucb_score = defaultdict(lambda: defaultdict(float))
            state_action_score = defaultdict(lambda: defaultdict(float))
            state_action = defaultdict(list)
            state_action_counter = defaultdict(lambda: defaultdict(int))
            state_counter = defaultdict(int)
            final_answer = []
    
            for _ in range(args.no_of_passes):
                
                # Initialize Agent
                agent = Agent(args.OPENAI_API_KEY)
                
                state_action_score, state_action, state_action_counter, state_counter = ucl_cot(agent, d, \
        state_action_score, state_action, \
        state_action_counter, state_counter,\
        R, C, K, B, depth, tmp)
                
            # Perform: best the best result
            final_steps = []
            for stepi in state_action_score:
                best_action, best_score = sorted(state_action_score[stepi].items(), key=lambda x: x[1], reverse=True)[0]
                if best_score != 0.0:
                    final_steps.append(best_action)
                    if best_action == d[2]:
                        preds.append(1)
                        break
                else:
                    pass

            # Get the best solution:
            for st in state_action_score:
                act = sorted(state_action_score[st].items(), key=lambda x: x[1], reverse=True)[0][0].strip()
                final_answer.append((st, act))
            print(final_answer)    

        print('Total number of correct answers:\t', sum(preds))


if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('-no_of_passes', default=10, type=int)
    parser.add_argument('-no_of_trials', default=1, type=int)
    parser.add_argument('-depth', default=10, type=int)
    parser.add_argument('-reward', default=1, type=int)
    parser.add_argument('-K', default=5, type=int)
    parser.add_argument('-B', default=2, type=int)
    parser.add_argument('-exploration_constant', default=10, type=int)
    parser.add_argument('-model_temperature', default=0.0, type=float)
    parser.add_argument('-OPENAI_API_KEY', default="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")
    args = parser.parse_args()
    main(args)
