import argparse
import random
import os
from tqdm import tqdm
from utils.json_utils import load_jsonlines
from utils.gsm8k_prompt import get_answer, zero_shot_prompt_instructwhiteBox, cnt_decompose_steps
from utils.general_utils import seed_everything
from utils.openai_utils import check_cost
from utils.gsm8k_prompt_singleturn import SingleTurn_decompose_instuction, SingleTurn_cot_instuction, SingleTurn_decompose_prompt
from models.credentials import chatgpt_api_key_list, chatgpt_0125_api_key_list, gpt_4o_api_key_list, gpt_4o_mini_api_key_list
import concurrent.futures
import json


# set up time format
import time
from openai import OpenAI, AzureOpenAI

current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())

def parse_opt():
    parser = argparse.ArgumentParser(description='whiteBox3 decompose question, chatgpt inference on math problems')
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="Random seed.",
    )
    parser.add_argument(
        "--port",
        type=int,
        default=8000,
        help="vllm whiteBox port",
    )
    parser.add_argument(
        "--few_shot",
        type=int,
        default=0,
        help="few shot num for LLaMA",
    )
    parser.add_argument(
        "--few_gpt_shot",
        type=int,
        default=0,
        help="few shot num for GPT",
    )
    parser.add_argument("--whitebox", type=str, default='meta-llama/Meta-Llama-3-8B-Instruct', help="Use for vllm client. If use lora, set to lora id, otherwise set to base model name")
    parser.add_argument("--intermediate-prompt", type=str, default='decompose', choices=['decompose', 'cot'], help="choose the intermediate prompt kind")
    parser.add_argument("--add_answer", action="store_true", default=False, help="whether add answer in the decomposed intermediate prompt")
    parser.add_argument("--gpt-engine", type=str, default='gpt4o_mini', choices=['gpt4o', 'gpt4o_mini', 'gpt3.5_0125', 'gpt3.5_1106'], help="choose the gpt engine version to use")
    args = parser.parse_args()
    return args

args = parse_opt()
seed_everything(args.seed)

whiteBox_api_key = ""
whiteBox_api_base = f"http://localhost:{args.port}/v1"

test_set = load_jsonlines("data/gsm8k/main/test.jsonl")
os.makedirs("logs/CacheTmp2", exist_ok=True)
whitebox = 'Llama3' if args.whitebox == 'meta-llama/Meta-Llama-3-8B-Instruct' else 'Llama3.1' if args.whitebox == 'meta-llama/Meta-Llama-3.1-8B-Instruct' else args.whitebox
log_file = f"logs/CacheTmp2/log_Llamashot{args.few_shot}_GPTshot{args.few_gpt_shot}_whitebox{args.whitebox}_{current_time}.log"
json_file = log_file.replace('.log', '.json')
with open(log_file, "w") as f:
    f.write("")

class generate_pipeline():
    def __init__(self):
        self.api_idx = 0
        if args.gpt_engine == 'gpt4o':
            self.api_key_list = gpt_4o_api_key_list
            self.model_type = 'gpt-4o'
        elif args.gpt_engine == 'gpt4o_mini':
            self.api_key_list = gpt_4o_mini_api_key_list
            self.model_type = 'gpt-4o-mini'
        elif args.gpt_engine == 'gpt3.5_0125':
            self.api_key_list = chatgpt_0125_api_key_list
            self.model_type = 'gpt-3.5-turbo'
        elif args.gpt_engine == 'gpt3.5_1106':
            self.api_key_list = chatgpt_api_key_list
            self.model_type = 'gpt-3.5-turbo'
        self.gpt_client = AzureOpenAI(
            azure_endpoint = self.api_key_list[self.api_idx]['azure_endpoint'],
            api_key = self.api_key_list[self.api_idx]['api_key'],
            api_version = self.api_key_list[self.api_idx]['api_version'],
        )
        self.engine = self.api_key_list[self.api_idx]['engine']

        # self.whiteBox_model = self.whiteBox_client.models.list().data[0].id
        self.max_length = 2048
        self.data = test_set
        self.question_responses = [] # question and decomposed subquestions
        self.log_file = log_file
        self.num_returns = 1
        self.cnt = self.corr = self.cost = 0
        # count decompose question steps
        self.cnt_steps = []
        self.cnt_logs = []
        self.system_choice = 1 if args.add_answer else 0

        if args.intermediate_prompt == 'decompose':
            self.whiteBox_client = OpenAI(
                api_key=whiteBox_api_key,
                base_url=whiteBox_api_base,
            )
            model_lists = [data.id for data in self.whiteBox_client.models.list().data]
            assert args.whitebox in model_lists, f"Model {args.whitebox} not in the model list"
            self.whiteBox_model = args.whitebox
            print(f'Choose model: {self.whiteBox_model} from list: {model_lists}')

    def switch_api_key(self):
        self.api_idx = (self.api_idx + 1) % len(self.api_key_list)
        self.gpt_client = AzureOpenAI(
            api_key = self.api_key_list[self.api_idx]['api_key'],
            api_version = self.api_key_list[self.api_idx]['api_version'],
            azure_endpoint = self.api_key_list[self.api_idx]['azure_endpoint'],
        )
        self.engine = self.api_key_list[self.api_idx]['engine']

    def generate(self):
        print('='*50)
        print('Start generating {} data...'.format(len(self.data)))
        whiteBox_p_bar = tqdm(range(len(self.data)), desc='whiteBox processing') 
        gpt_p_bar = tqdm(range(len(self.data)), desc='GPT processing')

        def whiteBox_query(question):
            if args.intermediate_prompt == 'decompose':
                messages = SingleTurn_decompose_instuction(question['question'],args.few_shot, self.system_choice)
            # elif args.intermediate_prompt == 'cot':
            #     messages = SingleTurn_cot_instuction(question['question'],args.few_shot)
            else:
                raise NotImplementedError
            num_trials = 0
            max_trials = 3
            # print('messages:', messages)
            while num_trials < max_trials:
                try:
                    raw_response = self.whiteBox_client.chat.completions.create(
                        model=self.whiteBox_model,
                        messages=messages,
                        max_tokens=self.max_length,
                        n=self.num_returns,
                        temperature=0.0
                    )
                    response = raw_response.choices[0].message.content
                    self.cnt_steps.append(cnt_decompose_steps(response))
                    # print('question:\n', question['question'])
                    start = "Let's break down this problem"
                    if start in response:
                        response = start + response.split(start)[1]
                    if args.intermediate_prompt == 'decompose' and self.system_choice == 1:
                        end = "Let's think step by step"
                        if end in response:
                            response = response.split(end)[0].strip()
                    #     last_question_index = response.rfind('?')
                    #     if last_question_index != -1:
                    #         response = response[:last_question_index+1].strip()
                    # print('response:\n', response)
                    question_response = question.copy()
                    question_response['response'] = response.strip()
                    self.question_responses.append(question_response)
                    self.cnt_logs.append({'question': question['question'], 'response': response, 'steps': cnt_decompose_steps(response)})
                    whiteBox_p_bar.update(1)
                    break
                except Exception as e:
                    num_trials += 1
                    print(e)
                    if num_trials > 3:
                        print(f"Retry exceed the max_retries {num_trials} times.")
                        break
                    time.sleep(10)
            # with open(self.log_file, "a") as f:
            #     f.write(f"Prompt: {messages[-1]['content']}\n")
            #     f.write(f"Response: {response}\n\n")

        def gpt_query(question_response):
            if args.intermediate_prompt == 'decompose':
                # messages = zero_shot_prompt_instructwhiteBox(question_response['question'], question_response['response'], 'decompose')
                messages = SingleTurn_decompose_prompt(question_response['question'], question_response['response'], args.few_gpt_shot)
            elif args.intermediate_prompt == 'cot':
                # messages = zero_shot_prompt_instructwhiteBox(question_response['question'], question_response['response'], 'cot')
                messages = SingleTurn_cot_instuction(question_response['question'],args.few_gpt_shot)
            num_trials = 0
            max_trials = 3
            while num_trials < max_trials:
                try:
                    raw_response = self.gpt_client.chat.completions.create(
                        model=self.engine,
                        messages=messages,
                        max_tokens=self.max_length,
                        n=self.num_returns,
                        temperature=0.0
                    )
                    response = raw_response.choices[0].message.content
                    self.cost += check_cost(raw_response.usage.prompt_tokens, raw_response.usage.completion_tokens, model_type=self.model_type)
                    if self.cost > 5.0:
                        with open(self.log_file, "a") as f:
                            f.write(f"Cost too high: {self.cost}\n")

                    predict_ans = get_answer(response, None, 'So the answer is')
                    true_ans = get_answer(question_response["answer"])
                    self.cnt += 1
                    if predict_ans == true_ans:
                        self.corr += 1
                    else:
                        with open(self.log_file, "a") as f:
                            f.write(f"{messages}\n")
                            f.write(f"Question: {question_response['question']}\n")
                            f.write(f"Response: {response}\n")
                            f.write(f"Predicted Answer: {predict_ans}\n")
                            f.write(f"True Answer: {true_ans}\n")
                            f.write(f"Current Accuracy: {self.corr/self.cnt:.3f}\n")
                            f.write(f"Current Cost: {self.cost:.3f}$\n\n")

                    gpt_p_bar.update(1)
                    break
                except Exception as e:
                    self.switch_api_key()
                    num_trials += 1
                    print(e)
                    if num_trials > 3:
                        print(f"Retry exceed the max_retries {num_trials} times.")
                        break
                    time.sleep(10)
        # for i in range(10):
        #     whiteBox_query(self.data[i])
        
        if args.intermediate_prompt == 'cot':
            with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
                executor.map(gpt_query, self.data)
        elif args.intermediate_prompt == 'decompose':
            with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
                executor.map(whiteBox_query, self.data)
            with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
                executor.map(gpt_query, self.question_responses)
            self.cnt_logs.append({'mean': sum([step['steps'] for step in self.cnt_logs])/len(self.cnt_logs), 'len': len(self.cnt_logs)})
            with open (json_file, 'w') as f:
                json.dump(self.cnt_logs, f, indent=2)
        print(f"Total Accuracy: {self.corr/self.cnt:.3f}")
        print(f"Total Cost: {self.cost:.3f}$")
        with open(self.log_file, "a") as f:
            f.write(f"Total Accuracy: {self.corr/self.cnt:.3f}\n")
            f.write(f"Total Cost: {self.cost:.3f}$")

generator = generate_pipeline()
generator.generate()