import requests
import json
import argparse
from collections import Counter
import time
import re
import openai
from requests.exceptions import Timeout
import concurrent.futures
from tqdm import tqdm


## call gpt4 with retry
def call_standard_gpt4_with_retry(messages, max_retries = 5, delay_between_retries=6, n = 1, temperature = 0.7, top_p = 0.95, max_tokens = 800):
    url = "gpt4_url" ## gpt-4 url

    payload = json.dumps({
        "model": "gpt-4",
        "messages": messages,
            "n": n, 
            "temperature": temperature,
            "top_p": top_p,
            "frequency_penalty": 0,
            "presence_penalty": 0,
            "max_tokens": max_tokens,
            "stream": False,
            "stop": None
        })
    headers = {
        'Content-Type': 'application/json'
    }

    retries = 0
    while retries < max_retries:
        try:
            response = requests.request("POST", url, headers =headers, data=payload, timeout=300)
            if response.status_code == 200:
                return response
            else:
                print(f"http api call failed:, status code: {response.status_code}. re-trying...")
                retries += 1
                time.sleep(delay_between_retries)
        except Timeout:
                print("timeout, re-trying...")
                retries += 1
                time.sleep(delay_between_retries)
        except requests.RequestException as e:
            print(f"http api call failed: {e}. re-trying...")
            retries += 1
            time.sleep(delay_between_retries)
    raise Exception(f"http api call failed, already re-trying {max_retries} times.")



## call davinci003 with retry
def call_standard_davinci003_with_retry(messages, max_retries = 5, delay_between_retries=6, n = 1, temperature = 0.7, top_p = 0.95, max_tokens = 800):
    url = "gpt-3.5-url"
    
    payload = json.dumps({
        "model": "text-davinci-003",
        "messages": messages,
            "n": n, 
            "temperature": temperature,
            "top_p": top_p,
            "frequency_penalty": 0,
            "presence_penalty": 0,
            "max_tokens": max_tokens,
            "stream": False,
            "stop": None
        })
    headers = {
        'Content-Type': 'application/json'
    }

    retries = 0
    while retries < max_retries:
        try:
            response = requests.request("POST", url, headers =headers, data=payload, timeout=300)
            if response.status_code == 200:
                return response
            else:
                print(f"http api call failed:, status code: {response.status_code}. re-trying...")
                retries += 1
                time.sleep(delay_between_retries)
        except Timeout:
                print("http api call timeout, re-trying...")
                retries += 1
                time.sleep(delay_between_retries)
        except requests.RequestException as e:
            print(f"http api call failed: {e}. re-trying...")
            retries += 1
            time.sleep(delay_between_retries)
    raise Exception(f"http api call failed,already re-try {max_retries} times。")

def get_mind_map(input_path, output_path, prompt, max_retries, delay_between_retries, use_cot, n, temperature, top_p, max_tokens, model_name):

    #construct message
    messages = []
    messages.insert(0, {"role": "system", "content": prompt})
    user_message = "claim: [[CLAIM]], evidence: [[EVIDENCE]]"
    
    datas = json.load(open(input_path, 'r'))
    predict_datas = []
    middle_index = 10
    for i, data in enumerate(datas):
        new_data = data            
        claim = data['claim']
        evidence = data['evidence']
        # num_hops = data['num_hops']
        # template_prompt = prompt.replace('[[CLAIM]]', claim).replace('[[EVIDENCE]]', evidence).replace('[[HOP]]', str(num_hops))
        data_prompt = user_message.replace('[[CLAIM]]', claim).replace('[[EVIDENCE]]', evidence)
        messages.append({"role": "user", "content": data_prompt})
        if model_name == 'gpt-4':
            response = call_standard_gpt4_with_retry(messages, max_retries, delay_between_retries, n, temperature, top_p, max_tokens)
        else:
            response = call_standard_davinci003_with_retry(messages, max_retries, delay_between_retries, n, temperature, top_p, max_tokens)


        #pop the user message
        messages.pop()

        json_response = json.loads(response.text)

        #use cot mode
        if use_cot:
            choices = json_response['choices']
            mind_maps = []
            for choice in choices:
                # choice = json.loads(open(choice))
                if choice['finish_reason'] == 'stop':
                    try:
                        mind_map = json.loads(choice['message']['content'])['mind_map']
                    except:
                        print(json_response)
                        mind_map = {}

                    mind_maps.append(mind_map) 
        else:
            mind_maps = []
            # print(json_response)
            if 'choices' in json_response.keys():
                choices = json_response['choices'][0]
                if choices['finish_reason'] == 'stop':
                    message = choices['message']
                    if 'content' in message.keys():
                        content = message['content']
                        try:
                            mind_map = json.loads(content)['mind_map']
                            mind_maps.append(mind_map)
                        except:
                            print(choices)

        new_data['mind_maps'] = mind_maps

        predict_datas.append(new_data)

        if i % 1 == 0:
            print('index:' + str(i))

        if (i + 1)  % 200 == 0:
            output_path_middle = output_path.split('.')[0] + "_" + str(middle_index) + '.json'
            json.dump(predict_datas, open(output_path_middle, 'w'), indent=4)
            print("Predict result saved " + str(output_path_middle) + " !")
            predict_datas = []
            middle_index = middle_index + 1
    
    time.sleep(2)
    
    json.dump(predict_datas, open(output_path, 'w'), indent=4)
    print('Predict result saved !')



def merge_data(data_path, save_path, slice):
    new_datas = []
    base_path = data_path.split('.')[0]
    for i in range(slice):
        datas = json.load(open(base_path +"_"+ str(i) + '.json'))
        new_datas.extend(datas)
    
    json.dump(new_datas, open(save_path, 'w'), indent=4)




if __name__ == '__main__':
    prompt = 'Given a claim and corresponding evidence from user, please summarize the evidence as a mind map. The output must be in a strict JSON format: {"mind_map": "mind_map"}.'
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_path', type=str, default="input_path")
    parser.add_argument('--save_path', type=str, default="./output/result.json")
    parser.add_argument('--use_cot',  action='store_true', default=False)
    parser.add_argument('--n', type = int, default=1)
    parser.add_argument('--temperature', type = float,  default=0.7)
    parser.add_argument('--top_p', type = float,  default=0.95)
    parser.add_argument('--max_tokens', type = int,  default=800)
    parser.add_argument('--max_retries', type = int,  default=10)
    parser.add_argument('--delay_between_retries', type = int,  default=2)
    parser.add_argument('--model_name', type = str,  default='gpt-4')
    args = parser.parse_args()
    get_mind_map(args.input_path, 
                 args.save_path, 
                 prompt,
                 args.max_retries, 
                 args.delay_between_retries, 
                 args.use_cot, 
                 args.n, 
                 args.temperature, 
                 args.top_p, 
                 args.max_tokens, 
                 args.model_name)



   


