import os
import sys
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.dirname(current_dir))
from utils.openai_proxy import completions_with_backoff

import pandas as pd
from tqdm import tqdm
import json

import concurrent.futures

def llm_gen_once(inp, temperature, max_tokens, sample_n=1):
    assert sample_n == 1, "deepseek api only support sample_n == 1"
    output = completions_with_backoff(inp, temperature=temperature, max_tokens=max_tokens, sample_n=sample_n)
    return output['choices'][0]['message']['content']


def testgen_prompt_template(test_prompt):
    prefix_prompt = "Generate only assertion statements based on the following description. Do not generate any other code:\n\n"
    output_prompt = prefix_prompt + test_prompt
    return output_prompt

def codegen_prompt_template(code_prompt):
    prefix_prompt = "Only complete the code:\n\n"
    output_prompt = prefix_prompt + code_prompt
    return output_prompt

PROMPT_TEMPLATES={"prompt_codegen": codegen_prompt_template, "prompt_testgen": testgen_prompt_template}

def load_parquet(file_path):
    load_path = file_path
    read_pd = pd.read_parquet(load_path)
    read_pd = read_pd.to_json(orient='records')
    read_pd = json.loads(read_pd)
    return read_pd


def write_jsonl(path, content):
    with open(path, 'w') as f:
        for line in content:
            f.write(json.dumps(line) + "\n")

def generate_one_parquet(inp_parquet_path, output_jsonl_path, prompt_column, sample_n=10, temperature=1.25, max_tokens=512):
    prompt_template = PROMPT_TEMPLATES[prompt_column]
    inp_json_list = load_parquet(inp_parquet_path)
    assert prompt_column in inp_json_list[0], f"Column {prompt_column} not found in parquet file"

    inp_json_list = inp_json_list * sample_n
    import copy

    inp_json_list = [copy.deepcopy(item) for item in inp_json_list]


    def get_completion(inp_dict):
        inp = inp_dict['input']
        idx = inp_dict['column_idx']
        output_str = ""
        output_str = llm_gen_once(prompt_template(inp[prompt_column]), temperature, max_tokens)
        return {"column_idx": idx, "output": output_str}

    max_concurrent_tasks = 1000
    with concurrent.futures.ThreadPoolExecutor(max_workers=max_concurrent_tasks) as executor:
        futures = [executor.submit(get_completion, {'input':inp, 'column_idx':column_idx}) for column_idx,inp in enumerate(inp_json_list)]

        for _, future in enumerate(tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Generating completions")):
            try:
                each_sample_out = future.result()
                idx = each_sample_out['column_idx']
                each_sample_out = each_sample_out['output']
                inp_json_list[idx]['output'] = each_sample_out
            except Exception as e:
                print(f"[ERROR] An error occurred: {e}")
    write_jsonl(output_jsonl_path, inp_json_list)
    return inp_json_list

if __name__ == "__main__":
    base_path = ""
    save_path = base_path + "data/deepseekcoder_saved/prompt_codegen2sol"
    inp_path = base_path + "data/selfoss_parquet"
    all_parquet_paths = os.listdir(inp_path)
    all_parquet_paths = [e for e in all_parquet_paths if e.endswith("parquet")]
    inp_out_paths = [(os.path.join(inp_path,e), os.path.join(save_path,e)) for e in all_parquet_paths]

    for idx, inp_out_path in enumerate(inp_out_paths):
        print("Generating idx: ", idx)
        inp_path = inp_out_path[0]
        out_path = inp_out_path[1] + ".jsonl"
        generate_one_parquet(inp_path, out_path, 'prompt_codegen', sample_n = 10, temperature=1.5, max_tokens=512)







