import os
import json
import argparse
from tqdm import tqdm
from llm_api import get_llm_response
from code_exec import get_exec_client, extract_code, exec_code
from prompt_ablation_format import PROMPT_FORMAT_SINGLE, PROMPT_DF_RCT_FORMAT , PROMPT_NO_DF_RCT_FORMAT
import pandas as pd

def parse_option():
    parser = argparse.ArgumentParser("command line arguments for generation.")
    
    # parser.add_argument('--model', type=str, default="gpt-4o-2024-05-13", help='model name')
    parser.add_argument('--model', type=str, default="gpt-4o", help='model name')
    parser.add_argument('--dataset', type=str, default="ablation_200", help='dataset name')
    parser.add_argument('--code_exec_url', type=str, default="http://localhost:8081/execute", help='code execution docker url')
    parser.add_argument('--conv_id', type=str, default="EVAL", help='code execution conversation id')
    parser.add_argument('--max_turn_num', type=int, default=5, help='max turn number of conversation')
    parser.add_argument('--row', type=int, default=5, help='df extracted rows')
    opt = parser.parse_args()
    
    opt = parser.parse_args()

    return opt


def gen_file_content(input_file):
    # 读取Excel文件
    excel_file = pd.ExcelFile(input_file)

    # 获取所有sheet的名称
    sheet_names = excel_file.sheet_names

    # 创建一个空字典，用于存储每个sheet的数据
    excel_data = {}

    # 遍历每个sheet并将其内容转换为字符串
    for sheet_name in sheet_names:
        df = excel_file.parse(sheet_name)  # 读取当前sheet的数据
        len = opt.row if df.shape[0] > opt.row else df.shape[0]
        excel_data[sheet_name] = df.head(len).to_string()  # 将DataFrame转换为字符串并存储到字典中

    # 将每个sheet的字符串内容存储到一个变量中
    final_str = ""
    for sheet_name, sheet_str in excel_data.items():
        final_str += f"Sheet Name: {sheet_name}\n"
        final_str += sheet_str + "\n"
        final_str += "-" * 50 + "\n"
    
    return final_str

def gen_solution(opt):
    client = get_exec_client(opt.code_exec_url, opt.conv_id)
    # with open(f'{opt.dataset}/multi_df-RCT-gpt-4o-2rows.jsonl') as fp:
    #     dataset = [json.loads(i) for i in fp.readlines()]
    with open(f'{opt.dataset}/dataset.json', 'r') as fp:
        dataset = json.load(fp)

    folder_path = f'{opt.dataset}/outputs/multi_{opt.model}'
    
    directory = os.path.abspath(folder_path)
    print(directory)
    if not os.path.exists(directory):
        os.makedirs(directory)
        os.chmod(directory, 0o777)  # 设置权限为777
        
    cnt = 0
    for data in tqdm(dataset):
        # if cnt > 9:
        #     break
        # if cnt < 9:
        #     cnt += 1
        #     continue
        # try:
        suffix = 'xlsx'
        file_name = f"1_{data['spreadsheet_path'].lstrip('spreadsheet/')}_input.{suffix}"
        input_path = f"/mnt/data/{data['spreadsheet_path']}/{file_name}"
        output_path = f"/mnt/data/outputs/multi_{opt.model}/{file_name.rstrip(f'_input.{suffix}')}_output.{suffix}"
        find_input_path = f"{opt.dataset}/{data['spreadsheet_path']}/{file_name}"
        file_content = gen_file_content(find_input_path)
        prompt = PROMPT_FORMAT_SINGLE.format_map({
            'instruction': data['instruction'],
            'spreadsheet_path': input_path,
            'spreadsheet_content' : file_content,
            'instruction_type': data['instruction_type'],
            'answer_position': data['answer_position'],
            'max_turn_num' : opt.max_turn_num,
            'output_path': output_path
        })
        messages = [prompt]
        for turn in tqdm(range(opt.max_turn_num)):
            response = get_llm_response(messages, opt.model)
            # response = get_llama3_response(messages=messages)
            messages.append(response)
            try:
                exec_result = exec_code(client, extract_code(response))
            except Exception as e:
                exec_result = 'Error occur when running code.'
            messages.append(exec_result)
            if os.path.exists(output_path.replace('/mnt/data', opt.dataset)):
                break
        conv_result = {
            'id': data['id'],
            'instruction_type': data['instruction_type'],
            'conversation': messages,
            'solution': extract_code(response)
        }
        # except Exception as e:
        #     print(e)
            # conv_result = {
            #     'id': data['id'],
            #     'instruction_type': data['instruction_type'],
            #     'conversation': "",
            #     'solution': ""
            # }
            # with open(f'multi_{opt.model}.jsonl', 'a+') as f:
            #     f.write(json.dumps(data, ensure_ascii=False) + '\n')
        with open(f'{opt.dataset}/outputs/conv_multi_{opt.model}.jsonl', 'a+') as fp:
            fp.write(json.dumps(conv_result, ensure_ascii=False) + '\n')
        cnt += 1


def run_solution(opt):
    client = get_exec_client(opt.code_exec_url, opt.conv_id)

    with open(f'{opt.dataset}/outputs/conv_multi_{opt.model}.jsonl', 'r') as fp:
        conv_records = [json.loads(line) for line in fp.readlines()]
    for conv in tqdm(conv_records):
        suffix = 'xlsx'
        for idx in range(1, 4):
            input_file = f"{idx}_{conv['id']}_input.{suffix}"
            output_file = f"{idx}_{conv['id']}_output.{suffix}"
            solution = conv['solution'].replace(f"1_{conv['id']}_input.{suffix}", input_file)
            solution = solution.replace(f"1_{conv['id']}_output.{suffix}", output_file)
            exec_result = exec_code(client, solution)


if __name__ == '__main__':
    opt = parse_option()
    print(opt)
    gen_solution(opt)
    run_solution(opt)
