import argparse
import random
import os
from tqdm import tqdm
from utils.json_utils import load_jsonlines
from utils.gsm8k_prompt import get_answer, zero_shot_prompt_instructwhiteBox, get_decompose, get_full_decompose
from utils.general_utils import seed_everything
from utils.openai_utils import check_cost
from utils.gsm8k_prompt_singleturn import SingleTurn_decompose_instuction, SingleTurn_decompose_prompt
from models.credentials import gpt_4o_mini_api_key_list, chatgpt_0125_api_key_list
import concurrent.futures
# set up time format
import time
from openai import OpenAI, AzureOpenAI
import json

current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())

def parse_opt():
    parser = argparse.ArgumentParser(description='whiteBox3 decompose question, chatgpt inference on math problems')
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="Random seed.",
    )
    parser.add_argument(
        "--port",
        type=int,
        default=8000,
        help="vllm whiteBox port",
    )
    parser.add_argument(
        "--few_shot",
        type=int,
        default=0,
        help="few shot num for LLaMA",
    )
    parser.add_argument(
        "--few_gpt_shot",
        type=int,
        default=0,
        help="few shot num for GPT",
    )
    parser.add_argument("--dataset", type=str, choices=['train', 'test'], default='train', help="choose which dataset to generate")
    parser.add_argument("--output_dir", type=str, default='data/gsm8k/gpt3.5', help="choose where to store the generated data")
    parser.add_argument("--add_answer", action="store_true", default=False, help="whether add answer in the decomposed intermediate prompt")    
    
    args = parser.parse_args()
    return args

args = parse_opt()
seed_everything(args.seed)

whiteBox_api_key = ""
whiteBox_api_base = f"http://localhost:{args.port}/v1"

train_set = load_jsonlines("data/gsm8k/socratic/train_socratic.jsonl")
test_set = load_jsonlines("data/gsm8k/socratic/test_socratic.jsonl")

save_path = os.path.join(args.output_dir, f'{args.dataset}')
os.makedirs(save_path, exist_ok=True)
json_file = f"{save_path}/train.json"

class generate_pipeline():
    def __init__(self):
        self.api_idx = 0
        self.api_key_list = chatgpt_0125_api_key_list
        self.gpt_client = AzureOpenAI(
            azure_endpoint = self.api_key_list[self.api_idx]['azure_endpoint'],
            api_key = self.api_key_list[self.api_idx]['api_key'],
            api_version = self.api_key_list[self.api_idx]['api_version'],
        )
        self.engine = self.api_key_list[self.api_idx]['engine']

        self.whiteBox_client = OpenAI(
            api_key=whiteBox_api_key,
            base_url=whiteBox_api_base,
        )
        self.whiteBox_model = self.whiteBox_client.models.list().data[0].id
        self.max_length = 512
        self.data = train_set if args.dataset == 'train' else test_set
        # self.data = test_set
        self.question_responses = [] # question and decomposed subquestions
        # self.log_file = log_file
        self.json_file = json_file
        self.num_returns = 10
        self.cnt = self.corr = self.cost = 0
        print(f'Model: {self.whiteBox_model}')
        self.contrsative_pairs = []
        self.system_choice = 1 if args.add_answer else 0

    def switch_api_key(self):
        self.api_idx = (self.api_idx + 1) % len(self.api_key_list)
        self.gpt_client = AzureOpenAI(
            api_key = self.api_key_list[self.api_idx]['api_key'],
            api_version = self.api_key_list[self.api_idx]['api_version'],
            azure_endpoint = self.api_key_list[self.api_idx]['azure_endpoint'],
        )
        self.engine = self.api_key_list[self.api_idx]['engine']

    def generate(self):
        print('='*50)
        print('Start generating {} data...'.format(len(self.data)))
        whiteBox_p_bar = tqdm(range(len(self.data)), desc='whiteBox processing') 
        gpt_p_bar = tqdm(range(len(self.data) * self.num_returns), desc='GPT processing')

        def whiteBox_query(question):
            messages = SingleTurn_decompose_instuction(question['question'],args.few_shot, self.system_choice)
            num_trials = 0
            max_trials = 3
            # print('messages:', messages)
            while num_trials < max_trials:
                try:
                    raw_response = self.whiteBox_client.chat.completions.create(
                        model=self.whiteBox_model,
                        messages=messages,
                        max_tokens=self.max_length,
                        n=self.num_returns,
                        temperature=1.0
                    )
                    # response = raw_response.choices[0].message.content
                    pre_responses = [choice.message.content for choice in raw_response.choices]
                    if self.system_choice == 1:
                        response = []
                        for pre_response in pre_responses:
                            last_question_index = pre_response.rfind('?')
                            if last_question_index != -1:
                                response.append(pre_response[:last_question_index+1].strip())
                    else:
                        response = pre_responses
                    # print('question:\n\n', question['question'])
                    # print('response:\n\n', response)
                    # print('\n')
                    question_response = question.copy()
                    question_response['response'] = response
                    question_response['pre_response'] = pre_responses
                    self.question_responses.append(question_response)
                    whiteBox_p_bar.update(1)
                    break
                except Exception as e:
                    num_trials += 1
                    print(e)
                    if num_trials > 3:
                        print(f"Retry exceed the max_retries {num_trials} times.")
                        break
                    time.sleep(10)
            # with open(self.log_file, "a") as f:
            #     f.write(f"Prompt: {messages[-1]['content']}\n")
            #     f.write(f"Response: {response}\n\n")

        def gpt_query(question_response):
            contrastive_pair = {}
            contrastive_pair['question'] = question_response['question']
            contrastive_pair['answer'] = question_response['answer']
            contrastive_pair['positive'] = []
            contrastive_pair['negative'] = []
            self.cnt += 1

            def _gpt_call(llama_response, full_response=None):
                messages = SingleTurn_decompose_prompt(question_response['question'], llama_response, args.few_gpt_shot)
                num_trials = 0
                max_trials = 3
                while num_trials < max_trials:
                    try:
                        raw_response = self.gpt_client.chat.completions.create(
                            model=self.engine,
                            messages=messages,
                            max_tokens=self.max_length,
                            n=1,
                            temperature=0.0
                        )
                        response = raw_response.choices[0].message.content
                        if response is None or len(response.strip()) == 0:
                            raise RuntimeError("Empty response")
                        self.cost += check_cost(raw_response.usage.prompt_tokens, raw_response.usage.completion_tokens, model_type='gpt-4o-mini')
                        if self.cost > 30.0:
                            print(f"Cost too high: {self.cost}")

                        predict_ans = get_answer(response, None, 'So the answer is')
                        true_ans = get_answer(question_response["answer"])

                        reserve_dict = {'decompose': llama_response, 'response': response}
                        if self.system_choice == 1:
                            reserve_dict['full_decompose'] = full_response

                        if predict_ans == true_ans:
                            contrastive_pair['positive'].append(reserve_dict)
                        else:
                            contrastive_pair['negative'].append(reserve_dict)
                        

                        gpt_p_bar.update(1)
                        return predict_ans == true_ans
                    except Exception as e:
                        self.switch_api_key()
                        num_trials += 1
                        print(e)
                        if num_trials > 3:
                            print(f"Retry exceed the max_retries {num_trials} times.")
                            break
                        time.sleep(10)
            for llama_response, pre_response in zip(question_response['response'], question_response['pre_response']):
                _gpt_call(llama_response, pre_response)
            # we handle the case when there is no positive samples
            if len(contrastive_pair['positive']) == 0:
                question_only = False if args.add_answer else True
                gt_answer = _gpt_call(get_decompose(question_response['answer'], question_only), get_full_decompose(question_response['answer']))
                print(f'Ground truth decomposition have {gt_answer} answer')
            self.contrsative_pairs.append(contrastive_pair)
            if self.cnt % 200 == 0:
                with open(self.json_file, "w") as f:
                    json.dump(self.contrsative_pairs, f, indent=2)
        # for i in range(20):
        #     whiteBox_query(self.data[i])
        with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
            executor.map(whiteBox_query, self.data)
        with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
            executor.map(gpt_query, self.question_responses)
        with open(self.json_file, "w") as f:
            json.dump(self.contrsative_pairs, f, indent=2)
        
        print(f"Total Cost: {self.cost:.3f}$")

generator = generate_pipeline()
generator.generate()