import json
import re
from collections import defaultdict
import math
import numpy as np
import openai
import re
import argparse
from gsm8k_utils import *
from gsm8k_prompts import *
from tqdm import tqdm

class Agent():

    def __init__(self, OPENAI_API_KEY):
        openai.api_key = OPENAI_API_KEY
        self.model_name = "gpt-3.5-turbo-0301"

    def predict_answer(self, user_message, max_tokens, temperature=0.0):
        response = openai.ChatCompletion.create(
            model=self.model_name,
            messages=[
                {'role': 'system', 'content': "You are a helpful assistant who can solve arithmetic questions."},
                {"role": "user", "content": user_message}],
            max_tokens=max_tokens,
            temperature=temperature)
        return response['choices'][0]['message']['content'].lower()

    def validate_correctness(self, user_message):
        response = openai.ChatCompletion.create(
            model="gpt-3.5-turbo-0301",
            messages=[
                {'role': 'system', 'content': "You are a helpful assistant who can solve arithmetic questions. Check if the solution for the given problem is correct or not."},
                {"role": "user", "content": user_message}])
        return response['choices'][0]['message']['content'].lower().strip()
        


def get_gsm8k_dataset(dataset_name):

    assert dataset_name == 'gsm8k', 'GSM8K dataset only'

    # read few shot exmaples
    with open('./few_shot_3.jsonl') as f:
        few_shot_samples = [json.loads(line) for line in f]

    # read GSM8K test dataset
    with open('./gsm_test_all.jsonl') as f:
        data = [json.loads(line) for line in f]

    few_shot_samples_list = []
    for i in few_shot_samples:
        q = i['question']
        a = re.sub('<<.*?>>', '', i['answer'])
        a_lines = a.split('\n')
        steps = a_lines[:-1]
        final_ans = a_lines[-1].split('####')[1].strip()
        few_shot_samples_list.append((q, steps, final_ans))

    data_listt = []
    for i in data:
        q = i['question']
        a = re.sub('<<.*?>>', '', i['answer'])
        a_lines = a.split('\n')
        steps = a_lines[:-1]
        final_ans = a_lines[-1].split('####')[1].strip()
        data_listt.append([q, steps, final_ans])

    return data_listt

def ucb_cot(agent, sample, state_action_score, state_ucb_score, depth, R, \
      state_action, state_counter, state_action_counter, C, tmp):
    
    question = sample[0]
    state_config = question + '\n\nSolution for Problem 4:\n'

    check = True
    check_counter = 3
    max_tokens = 1000
    local_step_action_score = {}
    opz = None
    op_steps = None
    op = None
    temperature_r = tmp
    op_steps_zzz = None

    history = step_action_score2text5(state_ucb_score)
                
    if not history:
        prompt=prompt_without_history(state_config)
    else:
        prompt=prompt_with_history(state_config, history)

    pred_ans = None
    while check and check_counter:
        try:
            op_steps_zzz = agent.predict_answer(prompt, max_tokens, temperature_r)
            op_steps = op_steps_zzz.strip().split('\n')
            idx_sequence = []
            
            for st in op_steps:
                if len(st) > 5:
                    step_idx, astep = st.split(':', 1)
                    step_idx, astep = step_idx.strip().lower(), astep.strip().lower()
                    state_action_score[step_idx][astep] += 0.0
                    local_step_action_score[step_idx] = astep
                    idx_sequence.append(step_idx)
                    if 'final' in step_idx:
                        pred_ans = re.sub("[^0-9]", "", astep.strip())
                        state_action_score['final answer'][pred_ans] += 0.0
                        local_step_action_score['final answer'] = pred_ans
                        idx_sequence.append('final answer')
                        check = False
        except Exception as e:
            max_tokens = 1000
            check_counter -= 1
            import time
            time.sleep(3)

    step_sequence = list(local_step_action_score.values())
    sss = False

    if "yes" in agent.validate_correctness(prompt_without_history(question + '\n\nSolution for Problem 4:\n'+op_steps_zzz+'\n\nIs the solution and final answer provided for Problem 4 correct? If correct, then return "yes" else return "no"\n')):
        reward = R
        sss = True
    else:
        reward=0.0
        
    trajectory = []
    for state_config, action in local_step_action_score.items():
        if action not in state_action[state_config]:
            state_action[state_config].append(action)
        state_counter[state_config] += 1
        state_action_counter[state_config][action] += 1
        trajectory.append((state_config, action))  

    for ss, aa in trajectory:
        state_action_score[ss][aa] += reward
        state_ucb_score[ss][aa] = get_ucb_score(state_action_score[ss][aa], C, state_counter[ss], state_action_counter[ss][aa])

    return state_action, state_counter, state_action_counter, state_action_score, state_ucb_score, sss
    
    

def main(args):

    # get GSM8K dataset
    data_listt  = get_gsm8k_dataset('gsm8k')
    
    depth = args.depth
    R = args.reward
    C = args.exploration_constant
    tmp = args.model_temperature
    
    for jik in range(args.no_of_trials):
        preds = []
        print('No of samples: ', len(data_listt))
        print('R, C, tmp: ', R, C, tmp)
        
        for xyz, d in tqdm(enumerate(data_listt)):
            state_ucb_score = defaultdict(lambda: defaultdict(float))
            state_action_score = defaultdict(lambda: defaultdict(float))
            state_action = defaultdict(list)
            state_action_counter = defaultdict(lambda: defaultdict(int))
            state_counter = defaultdict(int)
            final_answer = []
    
            for _ in range(args.no_of_passes):
                
                # Initialize Agent
                agent = Agent(args.OPENAI_API_KEY)
                
                state_action, state_counter, state_action_counter, state_action_score, state_ucb_score, found_ans = ucb_cot(agent, d, state_action_score, state_ucb_score, depth, R, state_action, state_counter, state_action_counter, C, tmp)
                
            if 'final answer' in state_action_score:
                pred = re.sub("[^0-9]", "", sorted(state_action_score['final answer'].items(), key=lambda x: x[1], reverse=True)[0][0].strip())
                if pred == str(d[2]):
                    preds.append(1)

            # Get the best solution:
            for st in state_action_score:
                act = sorted(state_action_score[st].items(), key=lambda x: x[1], reverse=True)[0][0].strip()
                final_answer.append((st, act))
            print(final_answer)    

        print('Total number of correct answers:\t', sum(preds))
    

if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument('-no_of_passes', default=10, type=int)
    parser.add_argument('-no_of_trials', default=1, type=int)
    parser.add_argument('-depth', default=10, type=int)
    parser.add_argument('-reward', default=1, type=int)
    parser.add_argument('-exploration_constant', default=10, type=int)
    parser.add_argument('-model_temperature', default=0.0, type=float)
    parser.add_argument('-OPENAI_API_KEY', default="")
    args = parser.parse_args()
    main(args)

