import os
os.environ["CUDA_VISIBLE_DEVICES"]="1,2,6,7"
import re
import json
import argparse
import pandas as pd
from openpyxl import load_workbook
from code_exec import get_exec_client, extract_code, exec_code
from code_exec_docker.jupyter_kernel_cli import ClientJupyterKernel
from tqdm import tqdm
from vllm import LLM, SamplingParams
import gc

LLAMA3_FORMAT = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are a helpful AI assistant<|eot_id|><|start_header_id|>user<|end_header_id|>

{prompt}<|eot_id|>"""

DEEPSEEK_FORMAT = """
You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.
### Instruction:
{prompt}
### Response:\n"""

def parse_option():
    parser = argparse.ArgumentParser("command line arguments for generate.")
    
    parser.add_argument('--model', type=str, default="llama3", help='model name')
    parser.add_argument('--dataset', type=str, default="dataset_cell_561", help='dataset name')
    parser.add_argument('--code_exec_url', type=str, default="http://localhost:8081/execute", help='code execution docker url')
    parser.add_argument('--conv_id', type=str, default="EVAL", help='code execution conversation id')
    parser.add_argument('--use_exec_feedback', type = str, default = False,
                    help = 'input model')
    opt = parser.parse_args()

    return opt


def gen_prompt(input_file_path, output_path, data, file_content):
    prompt = f"""
    You are a spreadsheet expert who can manipulate spreadsheets through Python code.

You need to solve the given spreadsheet manipulation question, which contains six types of information:
- instruction: The question about spreadsheet manipulation.
- spreadsheet_path: The path of the spreadsheet file you need to manipulate.
- Excel content : The content of input excel. Due to limited input tokens, only the first few rows are provided for each sheet.
- instruction_type: There are two values (Cell-Level Manipulation, Sheet-Level Manipulation) used to indicate whether the answer to this question applies only to specific cells or to the entire worksheet.
- answer_position: The position need to be modified or filled. For Cell-Level Manipulation questions, this field is filled with the cell position; for Sheet-Level Manipulation, it is filled with the worksheet's name. You only need to modify or fill in values within the cell range or sheet specified by answer_position.
- output_path: You need to generate the modified spreadsheet file in this new path.

Below is the spreadsheet manipulation question you need to solve:
### instruction
{data['instruction']}

### spreadsheet_path
{input_file_path}

### Excel content
{file_content}

### instruction_type
{data['instruction_type']}

### answer_position
{data['answer_position']}

### output_path
{output_path}

The solution of the question can be generate through a multi-turn interaction and you can do two types of actions.
1. Spreadsheet information acquisition: You can generate Python code to obtain the information in the spreadsheet file. In the next turn, the execution result of you Python code will provide to you.
2. Question solution generation: You can generate Python code for the final solution of the question. If error occur when executing code, the error traceback will provide to you for code refinement.
    """
    return LLAMA3_FORMAT.format_map({'prompt': prompt})


def gen_file_content(input_file):
    # 读取Excel文件
    excel_file = pd.ExcelFile(input_file)

    # 获取所有sheet的名称
    sheet_names = excel_file.sheet_names

    # 创建一个空字典，用于存储每个sheet的数据
    excel_data = {}

    # 遍历每个sheet并将其内容转换为字符串
    for sheet_name in sheet_names:
        df = excel_file.parse(sheet_name)  # 读取当前sheet的数据
        len = 10 if df.shape[0] > 10 else df.shape[0]
        excel_data[sheet_name] = df.head(len).to_string()  # 将DataFrame转换为字符串并存储到字典中

    # 将每个sheet的字符串内容存储到一个变量中
    final_str = ""
    for sheet_name, sheet_str in excel_data.items():
        final_str += f"Sheet Name: {sheet_name}\n"
        final_str += sheet_str + "\n"
        final_str += "-" * 50 + "\n"
    
    return final_str

def get_llama3_role(cnt , reply):
    role = "user" if cnt % 2 == 0 else "assistant"
    return f'<|start_header_id|>{role}<|end_header_id|>{reply}<|eot_id|>'

def get_deepseek_role(cnt , reply):
    if cnt % 2 == 0:
        return f'{reply}\n### Response:\n' 
    else:
        return f'{reply}\n<|EOT|>\n### Instruction:\n'
    
def gen_solution(opt):
    client = get_exec_client(opt.code_exec_url, opt.conv_id)
    with open(f'{opt.dataset}/1_561_formula_dataset.json', 'r') as fp:
        dataset = json.load(fp)
    
    # load model
    llm = LLM(model='/workspace/mzy/MODELS/LLM-Research/Meta-Llama-3-70B-Instruct', max_model_len=8192, tensor_parallel_size=4)

    folder_path = f'{opt.dataset}/outputs/multi_{opt.model}'
    
    directory = os.path.abspath(folder_path)
    print(directory)
    
    if not os.path.exists(directory):
        os.makedirs(directory)
    os.chmod(directory, 0o777)  # 设置权限为777
    

    for i in tqdm(range(len(dataset))):
        dataset[i]['prompts'] = ""
        dataset[i]['messages'] = []
        dataset[i]['flag'] = 0
        dataset[i]['cnt'] = 0
        
        suffix = 'xlsx'
        file_name = f"1_{dataset[i]['spreadsheet_path'].lstrip('spreadsheet/')}_input.{suffix}"

        input_path = f"/mnt/data/{dataset[i]['spreadsheet_path']}/{file_name}"
        output_path = f"/mnt/data/outputs/multi_{opt.model}/{file_name.rstrip(f'_input.{suffix}')}_output.{suffix}"
        
        dataset[i]['output_path'] = output_path
        
        find_input_path = f"{opt.dataset}/{dataset[i]['spreadsheet_path']}/{file_name}"
        file_content = gen_file_content(find_input_path)
        prompt = gen_prompt(input_file_path=input_path , output_path=output_path , data=dataset[i] , file_content=file_content)
        dataset[i]['prompts'] += prompt
        dataset[i]['messages'].append(prompt)
        dataset[i]['cnt'] += 1
        
    cnt = 0
    while cnt < 5:  #设定轮数
        cnt += 1
        new_prompts = []
        for i in tqdm(range(len(dataset))):
            new_prompts.append(dataset[i]['prompts'])
            if i < 5:
                print(dataset[i]['prompts'])
        sampling_params = SamplingParams(max_tokens=512, n=1, stop=['<|EOT|>'])
        responses = llm.generate(new_prompts, sampling_params=sampling_params)
            
        for i in tqdm(range(len(dataset))):
            if dataset[i]['flag'] == 1:
                continue
            reply = responses[i].outputs[0].text.lstrip(' ')
            dataset[i]['prompts'] += get_llama3_role(dataset[i]['cnt'] , reply)
            # dataset[i]['prompts'] += get_deepseek_role(dataset[i]['cnt'] , reply)
            dataset[i]['messages'].append(reply)
            dataset[i]['cnt'] += 1
            
            try:
                exec_result = exec_code(client, extract_code(reply))
            except Exception as e:
                exec_result = 'Error occur when running code.'
            dataset[i]['prompts'] += get_llama3_role(dataset[i]['cnt'] , exec_result)
            # dataset[i]['prompts'] += get_deepseek_role(dataset[i]['cnt'] , exec_result)
            dataset[i]['messages'].append(exec_result)
            dataset[i]['cnt'] += 1
            
            dataset[i]['solution'] = extract_code(reply)
            if os.path.exists(dataset[i]['output_path'].replace('/mnt/data', opt.dataset)):
                dataset[i]['flag'] = 1
            
    for i in tqdm(range(len(dataset))):
        conv_result = {
            'id': dataset[i]['id'],
            'instruction_type': dataset[i]['instruction_type'],
            'conversation': dataset[i]['messages'],
            'solution': dataset[i]['solution']
        }
        with open(f'{opt.dataset}/outputs/conv_multi_{opt.model}.jsonl', 'a+') as fp:
            fp.write(json.dumps(conv_result, ensure_ascii=False) + '\n')
            fp.flush()
    
def run_solution(opt):
    client = get_exec_client(opt.code_exec_url, opt.conv_id)

    with open(f'{opt.dataset}/outputs/conv_multi_{opt.model}.jsonl', 'r') as fp:
        conv_records = [json.loads(line) for line in fp.readlines()]
    for conv in tqdm(conv_records):
        suffix = 'xlsx'
        for idx in range(2, 4):
            input_file = f"{idx}_{conv['id']}_input.{suffix}"
            output_file = f"{idx}_{conv['id']}_output.{suffix}"
            solution = conv['solution'].replace(f"1_{conv['id']}_input.{suffix}", input_file)
            solution = solution.replace(f"1_{conv['id']}_output.{suffix}", output_file)
            exec_result = exec_code(client, solution)
            
if __name__ == "__main__":

    opt = parse_option()
    print(opt)
    gen_solution(opt)
    run_solution(opt)