import argparse
import os
import json
import concurrent.futures
import subprocess
import uuid
from tqdm import tqdm
from unidiff import PatchSet

from datasets import load_dataset, load_from_disk
from localize.RepoCoderLocalize import AFL as RCL
from localize.util.utils import load_existing_instance_ids, load_jsonl, setup_logger
from localize.util.preprocess_data import (
    check_contains_valid_loc,
)

from localize.ts_structure import (
    get_before_after_ts_structure_from_scratch,
    CodeStructure
)

# SET THIS IF YOU WANT TO USE THE PREPROCESSED FILES
PROJECT_FILE_LOC = os.environ.get("PROJECT_FILE_LOC", None)

def solve(args, task, existing_instance_ids):

    instance_id = task["instance_id"]
    log_file = os.path.join(
        args.output_folder, "task_solve_logs", f"{instance_id}.log"
    )
    os.makedirs(os.path.dirname(log_file), exist_ok=True)

    logger = setup_logger(log_file)
    logger.info(f"Processing task {instance_id}")

    if instance_id in existing_instance_ids:
        logger.info(f"Skipping existing instance_id: {instance_id}")
        return 

    os.makedirs(PROJECT_FILE_LOC, exist_ok=True)

    before_project_file = os.path.join(PROJECT_FILE_LOC, instance_id + "-before.json")
    after_project_file = os.path.join(PROJECT_FILE_LOC, instance_id + "-after.json")
    if os.path.exists(before_project_file):
        structure = CodeStructure.load(before_project_file)
    else:
        # we need to get the project structure directly
        if "language" in task:
            task_language = task["language"]
        else:
            task_language = "python"
        structure, after_structure = get_before_after_ts_structure_from_scratch(
            task["repo"], task["base_commit"], task['patch'], instance_id, "playground", task_language
        )
        structure.save(before_project_file, repo_name=task["repo"], commit_id=task["base_commit"])
        after_structure.save(after_project_file, repo_name=task["repo"], commit_id=task["base_commit"])

    # localize tool 
    fl = RCL(
        instance_id,
        structure,
        task["problem_statement"],
        args.model,
        args.max_length,
        args.max_tokens,
        args.backend,
        logger,
        base_url=args.base_url,
    )

    logger.info(f"================ localize files of {instance_id} ================")

    found_files, additional_artifact_loc_file, file_traj = fl.localize_file()

    logger.info(f"================ localize functions of {instance_id} ================")

    if args.with_global:
        related_locs_pred, func_raw_output, func_traj = fl.localize_func(files=found_files, with_global=True)
    else:
        related_locs_pred, func_raw_output, func_traj = fl.localize_func(files=found_files, with_global=False)

    logger.info(f"================ localization completed for {instance_id} ================")

    # Save localization results
    with open(args.output_file, "a") as f:
        f.write(
            json.dumps(
                {
                    "instance_id": instance_id,
                    "found_files": list(found_files),
                    "found_related_locs": related_locs_pred,
                    "problem_statement": task["problem_statement"],
                    "additional_artifact_loc_file": additional_artifact_loc_file,
                    "file_traj": file_traj,
                    "func_raw_output": func_raw_output,
                    "func_traj": func_traj,
                }
            )
            + "\n"
        )

def solve_software_tasks(args):
    if args.data_file:
        instances = load_jsonl(args.data_file)
    else:
        if os.path.exists(args.dataset):
            instances = load_from_disk(args.dataset)["test"]
        else:
            instances = load_dataset(args.dataset, split="test")
    
    existing_instance_ids = (
        load_existing_instance_ids(args.output_file) if args.skip_existing else set()
    )

    if args.num_threads == 1:
        for instance in tqdm(instances, colour="MAGENTA"):
            try:
                solve(args, instance, existing_instance_ids)
            except Exception as e:
                print(f"Error processing instance {instance['instance_id']}: {e}")
                continue
    else:
        with concurrent.futures.ProcessPoolExecutor(
                max_workers=args.num_threads
        ) as executor:
            futures = [
                executor.submit(
                    solve,
                    args,
                    instance,
                    existing_instance_ids
                )
                for instance in instances
            ]
            for future in tqdm(
                    concurrent.futures.as_completed(futures),
                    total=len(instances),
                    colour="MAGENTA",
            ):
                future.result()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Localize bugs in software repositories")

    parser.add_argument("--output_folder", type=str, required=True)
    parser.add_argument("--output_file", type=str, default="localization_outputs.jsonl")
    parser.add_argument("--max_length", type=int, default=32768)
    parser.add_argument("--max_tokens", type=int, default=4096)
    parser.add_argument("--with_global", action="store_true")
    parser.add_argument(
        "--num_threads",
        type=int,
        default=1,
        help="Number of threads to use for creating API requests",
    )
    parser.add_argument("--target_id", type=str)
    parser.add_argument(
        "--skip_existing",
        action="store_true",
        help="Skip localization of instance id's which already contain a localization in the output file.",
    )
    parser.add_argument(
        "--mock", action="store_true", help="Mock run to compute prompt tokens."
    )
    parser.add_argument(
        "--base_url",
        type=str,
        default=None,
    )
    parser.add_argument(
        "--model",
        type=str,
        default="gpt-4o-2024-05-13",
    )
    parser.add_argument(
        "--backend",
        type=str,
        default="openai",
        choices=["openai", "deepseek", "anthropic", "claude"],
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default="princeton-nlp/SWE-bench_Lite",
        help="Current supported dataset for evaluation",
    )
    parser.add_argument(
        "--data_file",
        type=str,
        default=None,
        help="Repo-level data file",
    )
    args = parser.parse_args()
    args.output_file = os.path.join(args.output_folder, args.output_file)

    solve_software_tasks(args)

