import json



import argparse
from prompt import PromptManager
from extract_json_code import extract_json_code, extract_diff_code

import time
from copy import deepcopy


import os
import tqdm

from vllm import LLM, SamplingParams
from transformers import AutoTokenizer














class Retry:
    """处理异常的上下文管理器 + 迭代器"""

    def __init__(self, max_tries=5):
        self.max_tries = max_tries

    def __iter__(self):
        for i in range(self.max_tries):
            yield self
            if self.allright or i == self.max_tries - 1:
                return
            time.sleep(3)

    def __enter__(self):
        self.allright = False

    def __exit__(self, exc_type, exc_val, exc_tb):
        if exc_type is None:
            self.allright = True
        else:
            print(exc_val)
        return True


def format_sft(instruction, output, history=None):
    data = {"instruction": instruction, "output": output}
    if history:
        data["history"] = history
    return data

def add_line_numbers(file_info):
    lines = file_info["content"].splitlines()
    numbered_lines = [f"{idx + 1} {line}" for idx, line in enumerate(lines)]
    file_info["content"] = "\n".join(numbered_lines)
    return file_info


def pretty_json(data) -> str:
    """返回美化后的 JSON 字符串，保持 Unicode 字符不被转义"""
    return json.dumps(data, indent=2, ensure_ascii=False)


def generate_issue(code, prompts, message=None):
    prompt_text = (
        prompts.get_prompt("寻找Issue")
        + "\nThe following is the specific code information. The first digit of each line represents the line number of the current code, followed by a space and the specific code content:\n```json\n"
        + pretty_json(code)
        + "\n```\n"
    )
    if message:
        prompt_text += (
            "\nThe following is the specific error message information:\n```\n"
            + message
            + "\n```\n"
        )
    prompt_text += prompts.get_prompt("寻找Issue")
    response = respond(prompt_text)
    return extract_json_code(response)


def generate_issue_explanation(issue, code, prompts, message=None):
    prompt_text = (
        prompts.get_prompt("解释Issue")
        + "\nThe following is the specific Issue information:\n```json\n"
        + pretty_json(issue)
        + "```\nThe following is the specific Code information. The first digit of each line represents the line number of the current code, followed by a space and the specific code content:\n```json\n"
        + pretty_json(code)
        + "```"
    )
    if message:
        prompt_text += (
            "\nThe following is the specific error message information:\n```\n"
            + message
            + "```"
        )
    response = respond(prompt_text)
    return response


def generate_code_location(issue, code, prompts, message=None):
    prompt_text = (
        prompts.get_prompt("定位Issue")
        + "\nThe following is the specific Issue information:\n```json\n"
        + pretty_json(issue)
        + "```\nThe following is the specific Code information. The first digit of each line represents the line number of the current code, followed by a space and the specific code content:\n```json\n"
        + pretty_json(code)
        + "```"
    )
    if message:
        prompt_text += (
            "\nThe following is the specific error message information:\n```\n"
            + message
            + "\n```\n"
        )
    prompt_text += prompts.get_prompt("定位Issue")
    for retry in Retry(3):
        with retry:
            response = respond(prompt_text)
            res = extract_json_code(response)
            res = json.loads(res)
    if retry.allright:
        pass
    else:
        res = [{"location": res}]
    return res


def generate_code_patch(issue, code, prompts, localization=None):
    prompt_text = (
        prompts.get_prompt("解决Issue")
        + "\nThe following is the specific Issue information:\n```json\n"
        + pretty_json(issue)
        + "```\nThe following is the specific Code information. The first digit of each line represents the line number of the current code, followed by a space and the specific code content:\n```json\n"
        + pretty_json(code)
        + "```\n"
    )
    if localization:
        prompt_text += (
            "\nThe following is the specific code localization which may cause the issue information:\n```json\n"
            + pretty_json(localization)
            + "```\n"
        )
    prompt_text += prompts.get_prompt("解决Issue")
    for retry in Retry(3):
        with retry:
            response = respond(prompt_text)
            res = extract_diff_code(response)
    return res


def add_json(json_data):
    return "```json" + json.dumps(json_data, indent=4, ensure_ascii=False) + "```"


def convert_sets_to_lists(data):
    if isinstance(data, set):
        return list(data)
    elif isinstance(data, dict):
        return {key: convert_sets_to_lists(value) for key, value in data.items()}
    elif isinstance(data, list):
        return [convert_sets_to_lists(element) for element in data]
    return data


def run_test(data, output_dir, model="gpt4", idd=None):

    output_path = output_dir + idd + f"_test_{model}.json"
    if os.path.exists(output_path):
        return

    prompts = PromptManager()
    buggy_all = data["BuggyCode"]
    message = data["ErrorMessage"]
    buggy = data['FilteredCode']
    for code in buggy:
        code = add_line_numbers(code)
    
    
    
    
    for retry in Retry(3):
        with retry:
            issue_origin_gen = generate_issue(buggy, prompts)
            issue_origin = json.loads(issue_origin_gen)  
    if retry.allright and isinstance(issue_origin, dict):
        pass
    else:
        issue_origin = {"title and description": generate_issue(buggy, prompts)}
    for retry in Retry(3):
        with retry:
            issue_message_gen = generate_issue(buggy, prompts, message)
            issue_message = json.loads(issue_message_gen)  
    if retry.allright and isinstance(issue_message, dict):
        pass
    else:
        issue_message = {
            "title and description": generate_issue(buggy, prompts, message)
        }

    solution_origin = generate_issue_explanation(issue_origin, buggy, prompts)
    solution_message = generate_issue_explanation(
        issue_message, buggy, prompts, message
    )
    solution_ground = generate_issue_explanation(
        data["Issue"], buggy, prompts, message
    )  
    if isinstance(issue_origin, dict):
        issue_origin["explanation"] = solution_origin
    else:
        issue_origin = {
            "title and description": issue_origin,
            "explanation": solution_origin
        }
    if isinstance(issue_message, dict):
        issue_message["explanation"] = solution_message
    else:
        issue_message = {
            "title and description": issue_message,
            "explanation": solution_origin
        }

    issue_ground = deepcopy(data["Issue"])
    issue_ground["explanation"] = solution_ground  
    issue_ground_truth = data["Issue"]
    issue_ground_truth["explanation"] = data["Explain"]  
    
    location_ground_truth = data["BuggyCodeLocation"]  
    location_origin = generate_code_location(
        issue_origin, buggy, prompts
    )  
    location_message = generate_code_location(
        issue_message, buggy, prompts, message
    )  
    location_ground = generate_code_location(
        issue_ground, buggy, prompts
    )  
    location_ground_exp = generate_code_location(issue_ground_truth, buggy, prompts) 
    
    
    
    
    
    
    
    
    
    
    
    
    
    patch_i = generate_code_patch(issue_origin, buggy, prompts)  
    patch_im = generate_code_patch(
        issue_message, buggy, prompts
    )  
    patch_il = generate_code_patch(
        issue_origin, buggy, prompts, location_origin
    )  
    patch_iml = generate_code_patch(
        issue_message, buggy, prompts, location_message
    )  
    patch_ground = generate_code_patch(issue_ground, buggy, prompts)  
    patch_ground_location = generate_code_patch(
        issue_ground, buggy, prompts, location_ground
    )  
    patch_ground_exp = generate_code_patch(issue_ground_truth, buggy, prompts, location_ground_exp) 
    patch_ground_all = generate_code_patch(
        issue_ground_truth, buggy, prompts, location_ground_truth
    )  
    
    results = {
        "model": model,
        "Difficulty": data["Difficulty"],
        "issue_origin": issue_origin,
        "issue_message": issue_message,
        "issue_ground": issue_ground,
        "issue_ground_truth": issue_ground_truth,
        "location_origin": location_origin,
        "location_message": location_message,
        "location_ground": location_ground,
        "location_ground_exp": location_ground_exp,
        "location_ground_truth": location_ground_truth,
        "patch_i": patch_i,
        "patch_im": patch_im,
        "patch_il": patch_il,
        "patch_iml": patch_iml,
        "patch_ground": patch_ground,
        "patch_ground_location": patch_ground_location,
        "patch_ground_exp": patch_ground_exp,
        "patch_ground_all": patch_ground_all,
        "patch_ground_truth": data["Patch"],
        "message": message,
        
        "CodeBase": buggy,
        
        "CommitSHA": data["CommitSHA"],
        
        
    }
    data['Results'] = results
    with open(
        
        
        output_path, "w", encoding="utf-8"
    ) as f:
        json.dump(data, f, indent=2, ensure_ascii=False)
        print("Save to file:", output_path)






def respond(prompt) -> str:
    if isinstance(prompt, str):
        prompt = [
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": str(prompt)}
        ]
    global vllm_model
    global vllm_tokenizer
    global sampling_params
    inputs = vllm_tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True)
    outputs = vllm_model.generate(prompts=[inputs], sampling_params=sampling_params, use_tqdm=False)
    return outputs[0].outputs[0].text


def determine_max_tokens(max_length_k):
    if max_length_k == 2:
        return 500
    elif max_length_k == 4:
        return 1000
    elif max_length_k == 8:
        return 1500
    elif max_length_k >= 16:
        return 2000
    else:
        raise ValueError("Invalid max_length_k. Supported values are 2, 4, 8, 16, 32, 64, 128.")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run inference with vllm")
    parser.add_argument('--model_path', type=str, required=True, help='Path to the model')
    parser.add_argument('--tp', type=int, default=1, help='Tensor parallelism size')
    parser.add_argument('--model', type=str, required=True, help='Name to the model')
    parser.add_argument('--max_length', type=int, choices=[2, 4, 8, 16, 32, 64, 128], default=128, required=True, help='Supported max lengths: 2, 4, 8, 16, 32, 64, 128K')
    args = parser.parse_args()
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    model = args.model
    
    max_tokens = determine_max_tokens(args.max_length)
    global vllm_model
    global vllm_tokenizer
    global sampling_params
    vllm_model = LLM(model=args.model_path, tensor_parallel_size=args.tp,trust_remote_code=True,enforce_eager=True,gpu_memory_utilization=0.98)
    vllm_tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
    sampling_params = SamplingParams(temperature=0.2, top_p=0.95, max_tokens=max_tokens)
    
    
    
    idd = 0
    
    
    
    
    total_lengths = [1500, 3000, 6500, 13000, 30000, 61000, 124000]
    
    for total_length in total_lengths:
        if total_length <= args.max_length * 1000:
            dir = ''
            
            
            output_dir =''
            if not os.path.exists(output_dir):
                os.makedirs(output_dir)
            for file in tqdm.tqdm(os.listdir(dir), desc=f"Run test for {model}"):
                
                if file == '.DS_Store':
                    continue
                file_path = os.path.join(dir, file)
                print(file_path)
                with open(file_path, "r", encoding="utf-8-sig") as f:
                    json_data = f.read()
                json_data = json.loads(json_data)  
                
                run_test(json_data, output_dir, model=model, idd=file_path.split('/')[-1])
                
            print("Now: ", total_length)
    print("Over!", args.model)