import os
import re
import random
import math
import time

import numpy as np

from utils import get_conf, check_proxy, get_response


API_KEY = "***"

model_card = "text-davinci-003"
os.environ['no_proxy'] = '*'
proxies = {
    "http": "socks5h://localhost:11284",
    "https": "socks5h://localhost:11284"
}
proxy_info = check_proxy(proxies)

max_tokens = 256
temperature = 0.7


def post_process_value(generate_answer, location=-1):
    generate_answer = generate_answer.replace(',', '')                                    
    generate_answer = ''.join(char for char in generate_answer if not char.isalpha())    
    generate_answer = ''.join(char for char in generate_answer if char not in ['(', ')'])   
    generate_answer = generate_answer.strip()                                              
    if type(generate_answer) == str and len(generate_answer) >= 1 and generate_answer[-1] == '.': 
        generate_answer = generate_answer[:-1]
    generate_answer = generate_answer.strip()
    if ' ' in generate_answer:                                                        
        generate_answer = generate_answer.split(' ')[location]
    if type(generate_answer) == str and len(generate_answer) >= 1:                     
        pass
    else:
        generate_answer = 0
    if generate_answer in ['-', '=', '+']:                                                 
        generate_answer = 0
    if type(generate_answer) == str and '%' in generate_answer:                        
        generate_answer = float(generate_answer.rstrip('%')) / 100
    if type(generate_answer) == str and ':' in generate_answer:                          
        generate_answer = generate_answer.replace(':', '.')
    if type(generate_answer) == str and len(generate_answer) >= 1 and generate_answer[-1] in ['.', '/']: 
        generate_answer = generate_answer[:-1]
    if type(generate_answer) == str:
        generate_answer = eval(generate_answer)
    return generate_answer


def get_arabic_number(question, reasoning):
    prompt = f'Q: {question}\nA: {reasoning}\nTherefore, the answer (expressed in Arabic numerals and without units) is:'
    value = get_response(
        API_KEY,
        prompt=prompt,
        model=model_card,
        max_tokens=max_tokens,
        temperature=temperature
    )[0]
    time.sleep(0.1)
    value = post_process_value(value)
    return value


def generate_answer(original_question, instructions, problems_have_been_solved, num_demonstrations):
    if len(problems_have_been_solved)==0:
        if instructions!='None':
            reasoning_prompt = f"The instructions are as follows: {instructions}\nLet's consider these instructions and ignore the irrelevant conditions to solve the problem.\nQ: {original_question}\nA: Let's think step by step."
        else:
            reasoning_prompt = f"Q: {original_question}\nA: Let's think step by step."
    else:
        similarity = [np.mean(sample.get('confusion_score')) for sample in problems_have_been_solved]
        if len(similarity)<=num_demonstrations:
            demonstrate_idx = list(range(len(similarity)))
        else:
            demonstrate_idx = list(np.argsort(similarity)[-num_demonstrations:])
        demonstrations = [problems_have_been_solved[idx] for idx in range(len(problems_have_been_solved)) if idx in demonstrate_idx]
        demonstrations_question = [demonstration.get('problem') for demonstration in demonstrations]
        demonstrations_answer = [demonstration.get('reasoning_path') for demonstration in demonstrations]
        demonstrations = '\n'.join([f'Q: {demonstrations_question[idx]}\nA: {demonstrations_answer[idx]}' for idx in range(len(demonstrations))])
        if instructions!='None':
            reasoning_prompt = f"The instructions are as follows: {instructions}\nLet's consider these instructions and ignore the irrelevant conditions to solve the problem.\n{demonstrations}\nQ: {original_question}\nA: Let's think step by step."
        else:
            reasoning_prompt = f"{demonstrations}\nQ: {original_question}\nA: Let's think step by step."
    reasoning_path = get_response(
        API_KEY,
        prompt=reasoning_prompt,
        model=model_card,
        max_tokens=max_tokens,
        temperature=temperature
    )[0]
    time.sleep(0.1)
    numerical_answer = get_arabic_number(original_question, reasoning_path)
    return reasoning_path, numerical_answer
