import argparse
import os
import json
import concurrent.futures
from tqdm import tqdm
from unidiff import PatchSet
import random

from localize.RepoCoderLocalize import AFL as RCL
from localize.util.model import make_model

from localize.util.utils import load_existing_instance_ids, load_jsonl, setup_logger
from localize.location import CodeLocationGroup
from localize.ts_structure import (
    get_completion_ts_structure_from_scratch,
    CodeStructure
)

from localize.util.preprocess_data import line_wrap_content

from utils_edit import patch_to_search_replace_unidiff
from utils_parsing import CompletionTaskConstructor
from prompt import *

# SET THIS IF YOU WANT TO USE THE PREPROCESSED FILES
PROJECT_FILE_LOC = os.environ.get("PROJECT_FILE_LOC", None)

def construct_topn_file_context_for_completion(
    gold_file,
    file_to_locs,
    structure: CodeStructure,
    context_window: int,
    loc_interval: bool = True,
    fine_grain_loc_only: bool = False,
    add_space: bool = False,
    sticky_scroll: bool = False,
    no_line_number: bool = True,
):
    """Concatenate provided locations to form a context.

    loc: {"file_name_1": ["loc_str_1"], ...}
    """
    file_loc_intervals = dict()
    topn_content = ""

    # add gold file content
    content = '\n'.join(structure.get_file_node(gold_file)['text_lines'])
    topn_content += f"### {gold_file}\n{content}\n\n\n"

    for pred_file, locs in file_to_locs.items():
        if pred_file == gold_file:
            continue
        content = '\n'.join(structure.get_file_node(pred_file)['text_lines'])

        line_locs, context_intervals = structure.transfer_location_strings_into_intervals(
            file_to_locs,
            pred_file,
            context_window,
            loc_interval,
            fine_grain_loc_only,
        )

        if len(line_locs) > 0:
            # Note that if no location is predicted, we exclude this file.
            file_loc_content = line_wrap_content(
                content,
                context_intervals,
                add_space=add_space,
                no_line_number=no_line_number,
                sticky_scroll=sticky_scroll,
            )
            topn_content += f"### {pred_file}\n{file_loc_content}\n\n\n"
            file_loc_intervals[pred_file] = context_intervals

    return topn_content, file_loc_intervals

def construct_code_completion_diaglogue(
    args,
    bug
):
    bug["instance_id"] = bug["instance_id"] + "-completion"
    instance_id = bug["instance_id"]

    log_file = os.path.join(
        args.output_folder, "completion_construct_logs", f"{instance_id}.log"
    )
    os.makedirs(os.path.join(args.output_folder, "completion_construct_logs"), exist_ok=True)

    logger = setup_logger(log_file)

    completion_task_constructor = CompletionTaskConstructor(
        patch=bug['patch'],
        language=bug['language']
    )

    try:
        completion_tasks = completion_task_constructor.create_function_level_completion()
        if not completion_tasks:
            completion_tasks = completion_task_constructor.create_normal_completion()
            if not completion_tasks:
                logger.info(f"No valid completion task can extracted on {bug['instance_id']} correctly.")
                return
    except Exception as e:
        logger.info(f"Completion task can not be extracted on {bug['instance_id']} (Error).")
        return
    
    for completion_task in completion_tasks:
        logger.info(f"Processing completion task for {instance_id}")
        logger.info(f"PR info: {bug['repo']}, {bug['base_commit']}, {bug['instance_id']}; Patch:\n{bug['patch']}[PATCH END]\n")
        logger.info(f"Task patch: \n{completion_task['task_patch']}[PATCH END]\nCompletion patch\n{completion_task['completion_patch']}[PATCH END]\n")
        logger.info(f"File name: {completion_task['file_name']}, {completion_task['signature']}, {completion_task['docstring']}")

    # choose completion task
    choose_flag = False
    chosen_completion_task = None
    for completion_task in completion_tasks:
        gold_file = completion_task['file_name']

        # get repo structure for the completion task
        before_project_file = os.path.join(PROJECT_FILE_LOC, instance_id + "-before.json")
        after_project_file = os.path.join(PROJECT_FILE_LOC, instance_id + "-after.json")
        # if os.path.exists(before_project_file) and os.path.exists(after_project_file):
        #     structure = CodeStructure.load(before_project_file)
        #     after_structure = CodeStructure.load(after_project_file)
        # else:
        structure, after_structure = get_completion_ts_structure_from_scratch(
            instance_id, bug["repo"], bug["base_commit"], bug['patch'] + bug['test_patch'], completion_task, "playground", bug["language"]
        )
        if structure == None or after_structure == None:
            logger.info(f"Structure of original code repo or modified code repo can not be parsed correctly.")
            continue

        # ground truth localization (edit code lines in the completion)
        localizations = set()
        for file in PatchSet(completion_task['completion_patch']):
            for hunk in file:
                for line in hunk:
                    if line.line_type == "+":
                        line_number = line.target_line_no
                        location = after_structure.get_location(gold_file, line_number)
                        localizations.add(location)
        
        localizations = CodeLocationGroup(list(localizations))

        if len(localizations) != 1:
            logger.info(f"Ground truth localization of completion task {gold_file} is not unique ({len(localizations)}), skip this completion task.")
            continue

        gt_localization = {
            "valid": True,
            "instance_id": instance_id,
            "added_files": [],
            "modified_files": [gold_file],
            "removed_files": [],
            "edit_namespaces": [str(item) for item in localizations.locations],
            "edit_namespaces_for_new_components": [],
            "new_namespaces": [],
            "related_locs": localizations.to_file_location_string_map()
        }
        bug["gt_localization"] = gt_localization

        if "anonymous" in "".join(gt_localization["related_locs"][gold_file]):
            logger.info(f"Completion task {gold_file} contains anonymous locations, skip this completion task.")
            print(f"Completion task {gold_file} contains anonymous locations, skip this completion task.")
            continue

        bug['problem_statement'] = f'''Please help me to complete the code of the function at the position of '[TODO]', in the {gold_file}
{completion_task['signature']}
{completion_task['docstring']}
'''.strip()
        choose_flag = True
        chosen_completion_task = completion_task
        # save the before and after structure for this valid task
        structure.save(before_project_file)
        after_structure.save(after_project_file)

        break  # choose the first valid completion task

    # end of choice for completion task
    if not choose_flag:
        logger.info(f"No valid completion task can be extracted on {bug['instance_id']} correctly.")
        print(f"No valid completion task can be extracted on {bug['instance_id']} correctly.")
        with open(args.output_file, "a") as f:
            f.write(
                json.dumps(
                    {
                        "instance_id": bug["instance_id"],
                        "valid": False,
                    }
                )
                + "\n"
            )
        return

    # CHANGED: Added probabilistic logic for localization
    if random.random() < args.augment_localization:
        # With probability (augment_localization), perform fault localization
        logger.info(f"Performing localization for {bug['instance_id']}.")
        fl = RCL(
            bug["instance_id"],
            structure,
            bug['problem_statement'],
            args.model,
            args.max_length,
            args.max_tokens,
            args.backend,
            logger
        )

        found_files, additional_artifact_loc_file, file_traj = fl.localize_file_with_gt(
            bug["gt_localization"],
            mock=args.mock
        )

        topn_func, func_raw_output, func_traj = fl.localize_func_with_gt(gt_localization=bug["gt_localization"], files=found_files, max_retry=args.max_retry)

        related_locs_pred = topn_func
    else:
        # With probability augment_localization, skip localization and use empty context
        logger.info(f"Skipping localization for {bug['instance_id']} due to augment_localization probability.")
        found_files = []
        related_locs_pred = {}

    for key, value in related_locs_pred.items():
        if "anonymous" in "".join(value):
            no_anonymous = False
            logger.info(f"Completion task {gold_file} contains anonymous locations, convert to only contain gold file.")
            found_files = []
            related_locs_pred = {}
            break

    file_context, file_loc_intervals = construct_topn_file_context_for_completion(
        gold_file,
        related_locs_pred,
        structure,
        context_window=10,
    )

    prompt_query = code_edit_prompt.format(
        task_description = bug['problem_statement'],
        top_n_file_content = file_context,
    )

    # file modified
    file_list = f"* {gold_file}"

    # construct search-replaces
    search_replaces = patch_to_search_replace_unidiff(completion_task["completion_patch"])
    search_replaces_content = ''
    for sr in search_replaces:
        single_sr_response = single_search_replace_prompt.format(reason='[MASK]', search_replace=sr)
        search_replaces_content += (single_sr_response + '\n\n')

    prompt_response = code_edit_response_prompt.format(
        plan='[MASK]',
        files=file_list,
        search_replaces=search_replaces_content
    )

    logger.info(">>>>>Query\n" + prompt_query)
    logger.info(">>>>>Response\n" + prompt_response)

    model = make_model(
        model=args.model,
        backend=args.backend,
        logger=logger,
        max_tokens=args.max_tokens,
        temperature=args.temperature,
        batch_size=1,
    )

    message = [
        {"role": "user", "content": prompt_query},
        {"role": "assistant", "content": prompt_response},
        {"role": "user", "content": "Please fill the [MASK] position with correspond reasoning process, and give me the content of filled assistant answer directly."}
    ]
    res = model.codegen(message, num_samples=1)[0]
    logger.info(">>>>>Reasoned Response\n" + res["response"])

    print(f"Construct completion candidate successfully for {bug['instance_id']}")
    with open(args.output_file, "a") as f:
        f.write(
            json.dumps(
                {
                    "instance_id": bug["instance_id"],
                    "valid": True,
                    "repo": bug["repo"],
                    "language": bug["language"],
                    "base_commit": bug["base_commit"],
                    "problem_statement": bug["problem_statement"],
                    "found_files": found_files,
                    "found_related_locs": related_locs_pred,
                    "gt_localization": bug["gt_localization"],
                    "dialogue": [
                        {"role": "user", "content": prompt_query},
                        {"role": "assistant", "content": res["response"]},
                    ],
                    "patch": bug['patch'],
                    "test_patch": bug['test_patch'],
                    "completion_task": chosen_completion_task,
                }
            )
            + "\n"
        )

def construct_code_completion_data(args):
    instances = [json.loads(item) for item in open(args.data_file).readlines() if item.strip()]
    existing_instance_ids = (
        load_existing_instance_ids(args.output_file) if args.skip_existing else set()
    )
    if args.skip_existing:
        instances = [
            x for x in instances if x["instance_id"] + '-completion' not in existing_instance_ids
        ]
    if args.num_threads == 1:
        for instance in instances:
            # try:
            construct_code_completion_diaglogue(args, instance)
            # except Exception as e:
            #     print(f"Error processing instance {instance['instance_id']}: {e}")
            #     continue
    else:
        with concurrent.futures.ProcessPoolExecutor(
                max_workers=args.num_threads
        ) as executor:
            futures = [
                executor.submit(
                    construct_code_completion_diaglogue,
                    args,
                    instance
                )
                for instance in instances
            ]
            for future in tqdm(
                    concurrent.futures.as_completed(futures, timeout=600),
                    total=len(instances),
                    colour="MAGENTA",
            ):
                future.result()

if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument("--output_folder", type=str, required=True)
    parser.add_argument("--output_file", type=str, default="completion_outputs.jsonl")
    parser.add_argument("--top_n", type=int, default=10)
    parser.add_argument("--max_retry", type=int, default=10)
    parser.add_argument("--max_length", type=int, default=32768)
    parser.add_argument("--max_tokens", type=int, default=4096)
    parser.add_argument("--temperature", type=float, default=0.0)
    parser.add_argument("--num_samples", type=int, default=1)
    parser.add_argument("--merge", action="store_true")
    parser.add_argument("--add_space", action="store_true")
    parser.add_argument("--no_line_number", action="store_true")
    parser.add_argument("--sticky_scroll", action="store_true")
    parser.add_argument("--context_window", type=int, default=10)
    parser.add_argument("--keep_old_order", action="store_true")
    parser.add_argument(
        "--num_threads",
        type=int,
        default=1,
        help="Number of threads to use for creating API requests",
    )
    parser.add_argument("--target_id", type=str)
    parser.add_argument(
        "--skip_existing",
        action="store_true",
        help="Skip localization of instance id's which already contain a localization in the output file.",
    )
    parser.add_argument(
        "--mock", action="store_true", help="Mock run to compute prompt tokens."
    )
    parser.add_argument(
        "--model",
        type=str,
        default="gpt-4o-2024-05-13",
    )
    parser.add_argument(
        "--backend",
        type=str,
        default="openai",
        choices=["openai", "deepseek", "anthropic", "claude"],
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default="princeton-nlp/SWE-bench_Lite",
        help="Current supported dataset for evaluation",
    )
    parser.add_argument(
        "--data_file",
        type=str,
        default=None,
        help="processed repo-level data file",
    )
    parser.add_argument(
        "--pred_file",
        type=str,
        default=None,
        help="Repo-level func localization result file",
    )
    parser.add_argument(
        "--augment_localization",
        type=float,
        default=0.1,
        help="Probability (0.0 to 1.0) to skip fault localization, using only the gold file context."
    )
    args = parser.parse_args()
    args.output_file = os.path.join(args.output_folder, args.output_file)

    os.makedirs(args.output_folder, exist_ok=True)

    if PROJECT_FILE_LOC:
        os.makedirs(PROJECT_FILE_LOC, exist_ok=True)

    construct_code_completion_data(args)

