import argparse
import concurrent.futures
import json
import os

from datasets import load_dataset, load_from_disk
from tqdm import tqdm
from localize.util.load_data_to_swe import load_crawled_data_to_swe

from localize.RepoCoderLocalize import AFL as RCL
from localize.util.utils import (
    load_existing_instance_ids,
    load_json,
    setup_logger,
)
from localize.ts_structure import (
    get_before_after_ts_structure_from_scratch,
    CodeStructure
)

# SET THIS IF YOU WANT TO USE THE PREPROCESSED FILES
PROJECT_FILE_LOC = os.environ.get("PROJECT_FILE_LOC", None)

def localize_instance_with_try(
    bug, args, swe_bench_data, existing_instance_ids, pred_files
):
    try:
        localize_instance(bug, args, swe_bench_data, existing_instance_ids, pred_files)
    except Exception as e:
        print(e)

def localize_instance(
        bug, args, swe_bench_data, existing_instance_ids, pred_files
):
    instance_id = bug["instance_id"]
    log_file = os.path.join(
        args.output_folder, "func_localization_construct_logs", f"{instance_id}.log"
    )
    os.makedirs(os.path.dirname(log_file), exist_ok=True)

    if args.target_id is not None:
        if args.target_id != bug["instance_id"]:
            return

    logger = setup_logger(log_file)
    logger.info(f"Processing bug {instance_id}")

    if bug["instance_id"] in existing_instance_ids:
        logger.info(f"Skipping existing instance_id: {bug['instance_id']}")
        return

    before_project_file = os.path.join(PROJECT_FILE_LOC, bug["instance_id"] + "-before.json")
    after_project_file = os.path.join(PROJECT_FILE_LOC, bug["instance_id"] + "-after.json")
    is_loaded = False
    if os.path.exists(before_project_file):
        try:
            structure = CodeStructure.load(before_project_file)
            is_loaded = True
        except Exception as e:
            logger.error(f"Error loading before structure for {bug['instance_id']}: {e}")
    if not is_loaded:
        # we need to get the project structure directly
        structure, after_structure = get_before_after_ts_structure_from_scratch(
            bug["repo"], bug["base_commit"], bug['patch'], bug["instance_id"], "playground", bug["language"]
        )
        structure.save(before_project_file)
        after_structure.save(after_project_file)

    logger.info(f"================ localize {instance_id} ================")

    bench_data = [x for x in swe_bench_data if x["instance_id"] == instance_id][0]
    problem_statement = bench_data["problem_statement"]

    # filter_none_python(structure)  # some basic filtering steps

    # # filter out test files (unless its pytest)
    # if not d["instance_id"].startswith("pytest"):
    #     filter_out_test_files(structure)

    # localization
    fl = RCL(
        bug["instance_id"],
        structure,
        problem_statement,
        args.model,
        args.max_length,
        args.max_tokens,
        args.backend,
        logger
    )

    # print(pred_files, found_related_locs)
    # Builddict
    # topn_func, func_raw_output, func_traj = fl.localize_with_p(file=pred_files, max_retry=args.max_retry)
    topn_func, func_raw_output, func_traj = fl.localize_func_with_gt(gt_localization=bug["gt_localization"], files=pred_files, max_retry=args.max_retry, max_functions=args.top_n)

    with open(args.output_file, "a") as f:
        f.write(
            json.dumps(
                {
                    "instance_id": bug["instance_id"],
                    "repo": bug["repo"],
                    "language": bug["language"],
                    "base_commit": bug["base_commit"],
                    "problem_statement": bug["problem_statement"],
                    "patch": bug["patch"],
                    "found_files": pred_files,
                    "found_related_locs": topn_func,
                    "dialogue": func_traj,
                    "gt_localization": bug["gt_localization"],
                }
            )
            + "\n"
        )

def localize(args):
    # if args.dataset == "princeton-nlp/SWE-bench_Verified":
    #     swe_bench_data = load_from_disk("./datasets/SWE-bench_Verified_test")
    # else:
    #     swe_bench_data = load_from_disk("./datasets/SWE-bench_Lite_test")
    if args.data_file:
        swe_bench_data = load_crawled_data_to_swe(args.data_file)
    if args.gt_file:
        swe_bench_data = [json.loads(item) for item in open(args.gt_file).readlines() if item.strip()]
    else:
        if "sampled" in args.dataset:
            swe_bench_data = load_from_disk(f"./datasets/{args.dataset}")
        else:
            swe_bench_data = load_dataset(args.dataset, split="test")
    existing_instance_ids = (
        load_existing_instance_ids(args.output_file) if args.skip_existing else set()
    )

    file_loc_results = [json.loads(item) for item in open(args.loc_file).readlines() if item.strip()]
    file_loc_results = {item["instance_id"]: item for item in file_loc_results}
    
    swe_bench_data = [
        x for x in swe_bench_data if x["instance_id"] not in existing_instance_ids and x["instance_id"] in file_loc_results
    ]

    if args.num_threads == 1:
        for bug in swe_bench_data:
            localize_instance(
                bug, args, swe_bench_data, existing_instance_ids, file_loc_results[bug["instance_id"]]["found_files"]
            )
    else:
        with concurrent.futures.ProcessPoolExecutor(
                max_workers=args.num_threads
        ) as executor:
            futures = [
                executor.submit(
                    localize_instance_with_try,
                    bug,
                    args,
                    swe_bench_data,
                    existing_instance_ids,
                    file_loc_results[bug["instance_id"]]["found_files"]
                )
                for bug in swe_bench_data
            ]
            for future in tqdm(
                    concurrent.futures.as_completed(futures),
                    total=len(swe_bench_data),
                    colour="MAGENTA",
            ):
                future.result()

def main():
    parser = argparse.ArgumentParser()

    parser.add_argument("--output_folder", type=str, required=True)
    parser.add_argument("--output_file", type=str, default="loc_outputs_func_augment.jsonl")
    parser.add_argument("--loc_file", type=str, default="loc_outputs.jsonl")
    parser.add_argument("--max_retry", type=int, default=10)
    parser.add_argument("--max_length", type=int, default=32768)
    parser.add_argument("--max_tokens", type=int, default=4096)
    parser.add_argument("--temperature", type=float, default=0.0)
    parser.add_argument("--top_n", type=int, default=10)
    parser.add_argument("--add_space", action="store_true")
    parser.add_argument("--no_line_number", action="store_true")
    parser.add_argument("--sticky_scroll", action="store_true")
    parser.add_argument("--context_window", type=int, default=10)
    parser.add_argument("--num_samples", type=int, default=1)
    parser.add_argument(
        "--dataset",
        type=str,
        default="princeton-nlp/SWE-bench_Lite",
        help="Current supported dataset for evaluation",
    )
    parser.add_argument(
        "--data_file",
        type=str,
        default=None,
        help="Repo-level data file",
    )
    parser.add_argument(
        "--gt_file",
        type=str,
        default=None,
        help="Ground truth localization file",
    )
    parser.add_argument(
        "--num_threads",
        type=int,
        default=1,
        help="Number of threads to use for creating API requests",
    )
    parser.add_argument("--target_id", type=str)
    parser.add_argument(
        "--skip_existing",
        action="store_true",
        help="Skip localization of instance id's which already contain a localization in the output file.",
    )
    parser.add_argument(
        "--mock", action="store_true", help="Mock run to compute prompt tokens."
    )
    parser.add_argument(
        "--model",
        type=str,
        default="gpt-4o-2024-08-06",
    )
    parser.add_argument(
        "--backend", type=str, default="openai", choices=["openai", "deepseek", "anthropic", "claude"]
    )

    args = parser.parse_args()

    import os

    args.output_file = os.path.join(args.output_folder, args.output_file)
    # args.loc_file = os.path.join(args.output_folder, args.loc_file)

    assert (
            not os.path.exists(args.output_file) or args.skip_existing
    ), "Output file already exists and not set to skip existing localizations"

    assert (not "deepseek" in args.model) or (
            args.backend == "deepseek"
    ), "Must specify `--backend deepseek` if using a DeepSeek model"

    os.makedirs(os.path.join(args.output_folder, "localization_logs"), exist_ok=True)
    os.makedirs(args.output_folder, exist_ok=True)

    # write the arguments
    with open(f"{args.output_folder}/args.json", "w") as f:
        json.dump(vars(args), f, indent=4)

    localize(args)

if __name__ == "__main__":
    main()
