import argparse
import os
import json
import concurrent.futures
from tqdm import tqdm
from unidiff import PatchSet

from localize.util.model import make_model

from localize.util.utils import load_existing_instance_ids, load_jsonl, setup_logger

from localize.ts_structure import CodeStructure
from localize.RepoCoderLocalize import DialogueManager

from localize.util.preprocess_data import line_wrap_content

from utils_edit import patch_to_search_replace_unidiff

from localize.ts_structure import (
    get_before_after_ts_structure_from_scratch,
    CodeStructure
)

from prompt import *

# SET THIS IF YOU WANT TO USE THE PREPROCESSED FILES
PROJECT_FILE_LOC = os.environ.get("PROJECT_FILE_LOC", None)

def construct_topn_file_context(
    file_to_locs,
    structure: CodeStructure,
    context_window: int,
    loc_interval: bool = True,
    fine_grain_loc_only: bool = False,
    add_space: bool = False,
    sticky_scroll: bool = False,
    no_line_number: bool = True,
):
    """Concatenate provided locations to form a context.

    loc: {"file_name_1": ["loc_str_1"], ...}
    """
    file_loc_intervals = dict()
    topn_content = ""

    for pred_file, locs in file_to_locs.items():
        content = '\n'.join(structure.get_file_node(pred_file)['text_lines'])
        line_locs, context_intervals = structure.transfer_location_strings_into_intervals(
            file_to_locs,
            pred_file,
            context_window,
            loc_interval,
            fine_grain_loc_only,
        )

        if len(line_locs) > 0:
            # Note that if no location is predicted, we exclude this file.
            file_loc_content = line_wrap_content(
                content,
                context_intervals,
                add_space=add_space,
                no_line_number=no_line_number,
                sticky_scroll=sticky_scroll,
            )
            topn_content += f"### {pred_file}\n{file_loc_content}\n\n\n"
            file_loc_intervals[pred_file] = context_intervals

    return topn_content, file_loc_intervals

def construct_code_edit_diaglogue_with_try(
    args,
    bug,
    prediction, 
    existing_instance_ids
):
    try:
        construct_code_edit_diaglogue(
            args,
            bug,
            prediction, 
            existing_instance_ids
        )
    except Exception as e:
        print(e)

def construct_code_edit_diaglogue(
    args,
    bug,
    prediction, 
    existing_instance_ids
):
    instance_id = bug["instance_id"]
    log_file = os.path.join(
        args.output_folder, "task_construct_logs", f"{instance_id}.log"
    )
    os.makedirs(os.path.join(args.output_folder, "task_construct_logs"), exist_ok=True)

    logger = setup_logger(log_file)

    if bug["instance_id"] in existing_instance_ids and args.skip_existing:
        logger.info(f"Skipping existing instance_id: {bug['instance_id']}")
        print(f"Skipping existing instance_id: {bug['instance_id']}")
        return

    before_project_file = os.path.join(PROJECT_FILE_LOC, bug["instance_id"] + "-before.json")
    after_project_file = os.path.join(PROJECT_FILE_LOC, bug["instance_id"] + "-after.json")
    is_loaded = False
    if os.path.exists(before_project_file):
        try:
            structure = CodeStructure.load(before_project_file)
            is_loaded = True
        except Exception as e:
            logger.error(f"Error loading before structure for {bug['instance_id']}: {e}")
    if not is_loaded:
        # we need to get the project structure directly
        structure, after_structure = get_before_after_ts_structure_from_scratch(
            bug["repo"], bug["base_commit"], bug['patch'], bug["instance_id"], "playground", bug["language"]
        )
        structure.save(before_project_file)
        after_structure.save(after_project_file)

    related_locs_gt = bug['gt_localization']["related_locs"]
    related_locs_pred = prediction["found_related_locs"]

    file_context, file_loc_intervals = construct_topn_file_context(
        related_locs_pred,
        structure,
        context_window=10,
    )

    prompt_query = code_edit_prompt.format(
        task_description = bug['problem_statement'],
        top_n_file_content = file_context,
    )

    # file add and delete
    file_lines = []
    for file in bug["gt_localization"]["removed_files"]:
        file_lines.append(f"- {file}")
    for file in bug["gt_localization"]["modified_files"]:
        file_lines.append(f"* {file}")
    for file in bug["gt_localization"]["added_files"]:
        file_lines.append(f"+ {file}")

    file_list = '\n'.join(file_lines)

    # construct search-replaces
    patch = bug['patch']
    search_replaces = patch_to_search_replace_unidiff(patch)
    search_replaces_content = ''
    for sr in search_replaces:
        single_sr_response = single_search_replace_prompt.format(reason='[MASK]', search_replace=sr)
        search_replaces_content += (single_sr_response + '\n\n')

    prompt_response = code_edit_response_prompt.format(
        plan='[MASK]',
        files=file_list,
        search_replaces=search_replaces_content
    )

    logger.info(">>>>>Query\n" + prompt_query)
    logger.info(">>>>>Response\n" + prompt_response)

    model = make_model(
        model=args.model,
        backend=args.backend,
        logger=logger,
        max_tokens=4096,
        temperature=args.temperature,
        batch_size=1,
    )

    message = [
        {"role": "user", "content": prompt_query},
        {"role": "assistant", "content": prompt_response},
        {"role": "user", "content": "Please fill the [MASK] position with correspond reasoning process, and give me the content of filled assistant answer directly."}
    ]
    res = model.codegen(message, num_samples=1)[0]
    logger.info(">>>>>Reasoned Response\n" + res["response"])

    constructed_dialogue = DialogueManager(logger)
    constructed_dialogue.add_user_message(prompt_query)
    constructed_dialogue.add_assistant_message(res['response'])

    constructed_dialogue.present_dialogue()

    with open(args.output_file, "a") as f:
        f.write(
            json.dumps(
                {
                    "instance_id": bug["instance_id"],
                    "repo": bug["repo"],
                    "language": bug["language"],
                    "base_commit": bug["base_commit"],
                    "problem_statement": bug["problem_statement"],
                    "patch": bug["patch"],
                    "dialogue": constructed_dialogue.trajectory,
                    "found_files": list(related_locs_pred.keys()),
                    "found_related_locs": related_locs_pred,
                }
            )
            + "\n"
        )

def construct_code_edit_data(args):
    instances = [json.loads(l) for l in open(args.gt_file)]
    prediction_file = [json.loads(l) for l in open(args.pred_file)]
    prediction_map = {item["instance_id"] : item for item in prediction_file}

    existing_instance_ids = (
        load_existing_instance_ids(args.output_file) if args.skip_existing else set()
    )

    instances = [
        x for x in instances if x["instance_id"] not in existing_instance_ids and x["instance_id"] in prediction_map
    ]

    if args.num_threads == 1:
        for instance in instances:
            prediction = prediction_map[instance["instance_id"]]
            construct_code_edit_diaglogue_with_try(args, instance, prediction, existing_instance_ids)
    else:
        with concurrent.futures.ProcessPoolExecutor(
                max_workers=args.num_threads
        ) as executor:
            futures = [
                executor.submit(
                    construct_code_edit_diaglogue_with_try,
                    args,
                    instance,
                    prediction_map[instance["instance_id"]],
                    existing_instance_ids
                )
                for instance in instances if instance["instance_id"] in prediction_map
            ]
            for future in tqdm(
                    concurrent.futures.as_completed(futures),
                    total=len(instances),
                    colour="MAGENTA",
            ):
                future.result()

if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument("--output_folder", type=str, required=True)
    parser.add_argument("--output_file", type=str, default="task_outputs.jsonl")
    parser.add_argument("--top_n", type=int, default=10)
    parser.add_argument("--temperature", type=float, default=0.0)
    parser.add_argument("--num_samples", type=int, default=1)
    parser.add_argument("--merge", action="store_true")
    parser.add_argument("--add_space", action="store_true")
    parser.add_argument("--no_line_number", action="store_true")
    parser.add_argument("--sticky_scroll", action="store_true")
    parser.add_argument("--context_window", type=int, default=10)
    parser.add_argument("--keep_old_order", action="store_true")
    parser.add_argument(
        "--num_threads",
        type=int,
        default=1,
        help="Number of threads to use for creating API requests",
    )
    parser.add_argument("--target_id", type=str)
    parser.add_argument(
        "--skip_existing",
        action="store_true",
        help="Skip localization of instance id's which already contain a localization in the output file.",
    )
    parser.add_argument(
        "--mock", action="store_true", help="Mock run to compute prompt tokens."
    )
    parser.add_argument(
        "--model",
        type=str,
        default="gpt-4o-2024-05-13",
    )
    parser.add_argument(
        "--backend",
        type=str,
        default="openai",
        choices=["openai", "deepseek", "anthropic", "claude"],
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default="princeton-nlp/SWE-bench_Lite",
        help="Current supported dataset for evaluation",
    )
    parser.add_argument(
        "--gt_file",
        type=str,
        default=None,
        help="Ground truth localization file",
    )
    parser.add_argument(
        "--pred_file",
        type=str,
        default=None,
        help="Repo-level func localization result file",
    )
    args = parser.parse_args()
    args.output_file = os.path.join(args.output_folder, args.output_file)

    os.makedirs(args.output_folder, exist_ok=True)

    construct_code_edit_data(args)

