import os
import json
from tqdm import tqdm
import openai
import time
import numpy as np
import argparse

from utils import *

def get_table_input(data_files, num_of_rows, data_dir='../../evaluation_data/data/'):
    output = []
    for file in data_files:
        df = pd.read_csv(os.path.join(data_dir, file))
        ## shuffle the data
        df = df.sample(frac=1).reset_index(drop=True)
        output.append(f'{file}:\n{df.head(num_of_rows).to_markdown()}')
    return ('\n\n'.join(output))


def get_pot_input_qrdata(entry, num_of_rows = 10):
    instruction = '''You are a data analyst and good at quantitative reasoning. You are required to respond to a quantitative question using the provided data. The data description, first 10 rows of the data, and the question can be found below. Please write python code to analyze the whole data and answer the question. Please encase the Python code within triple backticks. You can use any python library you imported. The returned value of the code is supposed to be the answer. The format of the code should be
```python
def solution():
    # import libraries if needed

    # load data
    
    # write code to get the answer
    
    # return answer
```
'''
    question = entry['question']
    data_description = entry['data_description']

    while num_of_rows > 0:
        prompt = f'''
Data Description:
{data_description}

First {num_of_rows} rows of the data:
{get_table_input(entry['data_files'], num_of_rows)}

Question:
{question}

Response:
```python
'''
        if len(prompt.split()) < 3000:
            break
        num_of_rows -= 1

    return instruction + prompt


def get_pot_input_theoremqa(entry):
    instruction = '''You are a data analyst and good at quantitative reasoning. You are required to respond to a quantitative question below. Please write python code to answer the question. Please encase the Python code within triple backticks. You can use any python library you imported. The returned value of the code is supposed to be the answer. The format of the code should be
```python
def solution():
    # import libraries if needed

    # write code to get the answer
    
    # return answer
```
'''
    problem_text = entry["Question"]

    prompt = f'''Question:
{problem_text}

Response:
'''
    return instruction + prompt


def get_pot_input_scibench(entry):
    instruction = '''You are a data analyst and good at quantitative reasoning. You are required to respond to a quantitative question below. Please write python code to answer the question. Please encase the Python code within triple backticks. You can use any python library you imported. The returned value of the code is supposed to be the answer. The format of the code should be
```python
def solution():
    # import libraries if needed

    # write code to get the answer
    
    # return answer
```
'''
    unit_prob = entry["unit"]
    if remove_not(entry["unit"]):
        unit_prob = remove_not(entry["unit"])
    problem_text = entry["problem_text"] + " The unit of the answer is " + unit_prob + "."

    prompt = f'''Question:
{problem_text}

Response:
'''
    return instruction + prompt


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', type=str, required=True)
    parser.add_argument('--domain', type=str, default='')
    parser.add_argument('--max_tokens', type=int, default=2048)
    args = parser.parse_args()
    
    path = f'outputs/'
    if args.domain == 'causality':
        data = json.load(open(f'../../evaluation_data/qrdata_causal.json'))
        input_prompts = [get_pot_input_qrdata(d) for d in data]
    elif args.domain == 'physics':
        data = json.load(open(f'../../evaluation_data/theoremqa_phy.json'))
        input_prompts = [get_pot_input_theoremqa(d) for d in data]
    elif args.domain == 'chemistry':
        data = json.load(open(f'../../evaluation_data/scibench_chem.json'))
        input_prompts = [get_pot_input_scibench(d) for d in data]

    output_folder = os.path.join(path, 'tmp_response')
    all_responses = run_inference(input_prompts, output_folder, args)
    
    for idx in range(len(data)):
        data[idx]['output'] = all_responses[idx]
                
    json.dump(data, open(os.path.join(path, f'{args.model_name}_{args.domain}_pot_0shot.json'), 'w'), indent = 4)

    remove_tmp_files(output_folder)