import argparse
import concurrent.futures
import json
import os, time
import random
from multiprocessing import Lock, Manager
from unidiff import PatchSet

from datasets import load_dataset, load_from_disk
from tqdm import tqdm

# from .AFL import AFL
from localize.location import CodeLocation, CodeLocationGroup
from localize.ts_structure import (
    get_before_after_ts_structure_from_scratch,
    CodeStructure
)

from localize.util.utils import load_existing_instance_ids, load_jsonl, setup_logger
from localize.util.load_data_to_swe import load_crawled_data_to_swe

MAX_RETRIES = 5
# SET THIS IF YOU WANT TO USE THE PREPROCESSED FILES
PROJECT_FILE_LOC = os.environ.get("PROJECT_FILE_LOC", None)

from unidiff import PatchSet
import logging
from typing import List

# Assume logger is already configured
# logging.basicConfig(level=logging.INFO)
# logger = logging.getLogger(__name__)
from unidiff import PatchSet
import logging
from typing import List

# Assume logger is already configured
# logging.basicConfig(level=logging.INFO)
# logger = logging.getLogger(__name__)

# def create_clean_patch_hunk_level(patch_content: str) -> str:
#     """
#     Parses a patch and cleans it at the HUNK level.

#     - If a hunk contains ANY substantive changes, it is kept IN ITS ENTIRETY.
#     - If a hunk ONLY contains meaningless changes (blank lines, comments), it is removed.

#     Args:
#         patch_content: The original patch content as a string.

#     Returns:
#         A new patch string with purely cosmetic hunks removed.
#     """
#     if not patch_content:
#         return ""

#     try:
#         # Parse the patch from a string; specifying encoding is recommended.
#         patch_set = PatchSet(patch_content)
#         final_patch_parts = []

#         # Iterate over each patched file in the patch set.
#         for patched_file in patch_set:
#             is_py_file = patched_file.path.endswith('.py')
#             kept_hunks_for_file = []

#             # Iterate over each hunk to decide whether to keep or discard it.
#             for hunk in patched_file:
#                 # For non-Python files, we discard all hunks.
#                 if not is_py_file:
#                     # kept_hunks_for_file.append(str(hunk))
#                     continue

#                 # --- Hunk-level decision logic ---
#                 # Check if this hunk has at least one meaningful change.
#                 has_substantive_change = False
#                 for line in hunk:
#                     if line.is_added or line.is_removed:
#                         content = line.value.strip()
#                         # A substantive change is not empty and not a comment.
#                         if content and not content.startswith('#'):
#                             has_substantive_change = True
#                             break  # Found a good line, no need to check further.
                
#                 # If a substantive change was found, keep the entire original hunk.
#                 if has_substantive_change:
#                     kept_hunks_for_file.append(str(hunk))

#             # If the file still has any hunks left after filtering...
#             if kept_hunks_for_file:
#                 # ...then add the file header and the kept hunks to our final patch.
#                 file_header = f"--- {patched_file.source_file}\n+++ {patched_file.target_file}"
#                 final_patch_parts.append(file_header)
#                 final_patch_parts.extend(kept_hunks_for_file)

#         # Join all the parts to form the final, clean patch string.
#         return "\n".join(final_patch_parts) + "\n" if final_patch_parts else ""

#     except Exception as e:
#         # It's better to use a logger in a real application.
#         print(f"Failed to create clean patch: {e}")
#         return ""

def localize_instance_with_try(
    bug, args, swe_bench_data, existing_instance_ids, write_lock=None
):
    try:
        localize_instance(bug, args, swe_bench_data, existing_instance_ids, write_lock=write_lock)
    except Exception as e:
        print(e)
        if write_lock is not None:
            write_lock.acquire()
        with open(args.output_file, "a") as f:
            bug["gt_localization"] = {
                "valid": False,
                "error": str(e),
            }
            f.write(
                json.dumps(bug) + "\n"
            )
        if write_lock is not None:
            write_lock.release()

def localize_instance(
    bug, args, swe_bench_data, existing_instance_ids, write_lock=None
):
    instance_id = bug["instance_id"]
    log_file = os.path.join(
        args.output_folder, "localization_logs", f"{instance_id}.log"
    )
    if args.target_id is not None:
        if args.target_id != bug["instance_id"]:
            return

    # already judged in the entry of localize
    # if instance_id in existing_instance_ids:
    #     print(f"Skipping existing instance_id: {bug['instance_id']}")
    #     return

    logger = setup_logger(log_file)
    logger.info(f"Processing bug {instance_id}")

    
    bug["patch"] = bug.get("patch", "")
    if bug["patch"].strip() == "":
        logger.info(f"Skipping instance with empty patch: {bug['instance_id']}")
        return

    if PROJECT_FILE_LOC:
        os.makedirs(PROJECT_FILE_LOC, exist_ok=True)
    # TODO: try
    before_project_file = os.path.join(PROJECT_FILE_LOC, bug["instance_id"] + "-before.json")
    after_project_file = os.path.join(PROJECT_FILE_LOC, bug["instance_id"] + "-after.json")
    if os.path.exists(before_project_file) and os.path.exists(after_project_file):
        structure = CodeStructure.load(before_project_file)
        after_structure = CodeStructure.load(after_project_file)
    else:
        # we need to get the project structure directly
        structure, after_structure = get_before_after_ts_structure_from_scratch(
            bug["repo"], bug["base_commit"], bug['patch'], bug["instance_id"], args.playground_folder, bug["language"]
        )
    if structure == None or after_structure == None:
        logger.info(f"Structure of original code repo or modified code repo can not be parsed correctly.")
        return
    structure.save(before_project_file)
    after_structure.save(after_project_file)
    

    logger.info(f"================ localize {instance_id} ================")

    bench_data = [x for x in swe_bench_data if x["instance_id"] == instance_id][0]
    problem_statement = bench_data["problem_statement"]

    added_files, modified_files, removed_files = [], [], []
    localizations = set()
    new_components = set()
    new_components_parents = set()

    valid_localization = True
    # ground truth localization by analyzing the edit code lines in the patch
    for file in PatchSet(bug["patch"]):
        file_path = file.path.replace("\\", "/")
        
        if file.is_removed_file:
            removed_files.append(file_path)
            if not structure.get_file_node(file_path):
                logger.warning(f"Removed File {file_path} not found in structure, skipping localization.")
                valid_localization = False
                break
        elif file.is_added_file:
            added_files.append(file.path)
            if not after_structure.get_file_node(file_path):
                logger.warning(f"Added File {file_path} not found in after structure, skipping localization.")
                valid_localization = False
                break
        else:
            modified_files.append(file.path)
            if not structure.get_file_node(file_path) or not after_structure.get_file_node(file_path):
                logger.warning(f"Modified File {file_path} not found in structure or after structure, skipping localization.")
                valid_localization = False
                break
            for hunk in file:
                for line in hunk:
                    if line.line_type == "+":
                        line_number = line.target_line_no
                        location = after_structure.get_location(file_path, line_number)
                        if location:
                            # identify if there is newly added namespace in the after_structure
                            if not structure.find(location):
                                new_components.add(location)
                                parent_location = location.get_parent()
                                if not structure.find(parent_location):
                                    # if class itself are newly added, then parent is the global space
                                    # TODO: if it is counted for edit namespace for these new components
                                    parent_location = parent_location.get_parent()
                                new_components_parents.add(parent_location)
                            else:
                                # if there is indeed this location in the before structre, then it is a modification
                                localizations.add(location)
                    elif line.line_type == "-":
                        line_number = line.source_line_no
                        location = structure.get_location(file_path, line_number)
                        if location:
                            localizations.add(location)
    
    localizations = CodeLocationGroup(list(localizations))
    new_components = list(new_components)
    new_components_parents = list(new_components_parents)

    gt_location_num = len(localizations)

    if write_lock is not None:
        write_lock.acquire()
    with open(args.output_file, "a") as f:
        if not valid_localization:
            logger.warning(f"Invalid localization for {instance_id}, skipping.")
        else:
            # compute the location num of related_locs_gt
            if gt_location_num > args.top_n:
                logger.warning(f"GT related locs num {gt_location_num} is larger than top_n {args.top_n}, skip this instance.")
            elif len(added_files + modified_files + removed_files) > args.top_n_files:
                logger.warning(
                    f"GT file num {len(added_files + modified_files + removed_files)} is larger than top_n_files {args.top_n_files}, skip this instance."
                )
            else:
                bug["gt_localization"] = {
                    "valid": True,
                    "instance_id": instance_id,
                    "added_files": added_files,
                    "modified_files": modified_files,
                    "removed_files": removed_files,
                    "edit_namespaces": [str(item) for item in localizations.locations],
                    "edit_namespaces_for_new_components": [str(item) for item in new_components_parents],
                    "new_namespaces": [str(item) for item in new_components],
                    "related_locs": localizations.to_file_location_string_map()
                }
                f.write(
                    json.dumps(bug) + "\n"
                )
    if write_lock is not None:
        write_lock.release()

def localize(args):
    if args.data_file:
        swe_bench_data = [json.loads(line) for line in open(os.path.join(args.output_folder, args.data_file))]
        random.seed(42)
        random.shuffle(swe_bench_data)
    else:
        if "sampled" in args.dataset:
            swe_bench_data = load_from_disk(f"./datasets/{args.dataset}")
        else:
            swe_bench_data = load_dataset(args.dataset, split="test")
    
    existing_instance_ids = (
        load_existing_instance_ids(args.output_file) if args.skip_existing else set()
    )

    swe_bench_data = [
        x for x in swe_bench_data if x["instance_id"] not in existing_instance_ids
    ]

    if args.num_threads == 1:
        for bug in tqdm(swe_bench_data, colour="MAGENTA"):
            # TODO: try
            localize_instance_with_try(
                bug, args, swe_bench_data, existing_instance_ids
            )
    else:
        with Manager() as manager:
            write_lock = manager.Lock()
            print(args.num_threads, "threads will be used for localization.")
            with concurrent.futures.ProcessPoolExecutor(
                max_workers=args.num_threads
            ) as executor:
                futures = [
                    executor.submit(
                        localize_instance_with_try,
                        bug,
                        args,
                        swe_bench_data,
                        existing_instance_ids,
                        write_lock,
                    )
                    for bug in swe_bench_data
                ]
                for future in tqdm(
                    concurrent.futures.as_completed(futures),
                    total=len(swe_bench_data),
                    colour="MAGENTA",
                ):
                    future.result()
    # # temp for debugging: convert jsonl file to json file 
    # if args.output_file.endswith(".jsonl"):
    #     jsonl_data = load_jsonl(args.output_file)
    #     with open(args.output_file.replace(".jsonl", ".json"), "w") as f:
    #         json.dump(jsonl_data, f, ensure_ascii=False, indent=4)
    

def merge(args):
    """Merge predicted locations."""
    start_file_locs = load_jsonl(args.start_file)

    def merge_locs(sample_found_locs: list[dict]):
        merged_found_locs = {}
        for locs in sample_found_locs:
            for fn, file_found_locs in locs.items():
                if isinstance(file_found_locs, str) and file_found_locs.strip():
                    merged_found_locs.setdefault(fn, [""])[0] += "\n" + file_found_locs
                elif "\n".join(file_found_locs).strip():
                    merged_found_locs.setdefault(fn, [""])[0] += "\n" + "\n".join(
                        file_found_locs
                    )
        return merged_found_locs

    # Dump each location sample.
    for st_id in range(args.num_samples):
        en_id = st_id
        merged_locs = []
        for locs in start_file_locs:
            merged_found_locs = []
            if "found_edit_locs" in locs and len(locs["found_edit_locs"]):
                merged_found_locs = merge_locs(
                    locs["found_edit_locs"][st_id : st_id + 1]
                )
            merged_locs.append({**locs, "found_edit_locs": merged_found_locs})
        with open(
            f"{args.output_folder}/loc_merged_{st_id}-{en_id}_outputs.jsonl", "w"
        ) as f:
            for data in merged_locs:
                f.write(json.dumps(data) + "\n")

def main():
    parser = argparse.ArgumentParser()

    parser.add_argument("--output_folder", type=str, required=True)
    parser.add_argument("--playground_folder", type=str, required=True)
    parser.add_argument("--output_file", type=str, default="loc_outputs.jsonl")
    parser.add_argument("--top_n", type=int, default=10)
    parser.add_argument("--top_n_files", type=int, default=5)
    parser.add_argument("--temperature", type=float, default=0.0)
    parser.add_argument("--num_samples", type=int, default=1)
    parser.add_argument("--merge", action="store_true")
    parser.add_argument("--add_space", action="store_true")
    parser.add_argument("--no_line_number", action="store_true")
    parser.add_argument("--sticky_scroll", action="store_true")
    parser.add_argument("--context_window", type=int, default=10)
    parser.add_argument("--keep_old_order", action="store_true")
    parser.add_argument(
        "--num_threads",
        type=int,
        default=1,
        help="Number of threads to use for creating API requests",
    )
    parser.add_argument("--target_id", type=str)
    parser.add_argument(
        "--skip_existing",
        action="store_true",
        help="Skip localization of instance id's which already contain a localization in the output file.",
    )
    parser.add_argument(
        "--mock", action="store_true", help="Mock run to compute prompt tokens."
    )
    parser.add_argument(
        "--model",
        type=str,
        default="gpt-4o-2024-05-13",
    )
    parser.add_argument(
        "--backend",
        type=str,
        default="openai",
        choices=["openai", "deepseek", "anthropic", "claude"],
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default="princeton-nlp/SWE-bench_Lite",
        help="Current supported dataset for evaluation",
    )
    parser.add_argument(
        "--data_file",
        type=str,
        default=None,
        help="Repo-level data file",
    )

    args = parser.parse_args()
    args.output_file = os.path.join(args.output_folder, args.output_file)

    os.makedirs(os.path.join(args.output_folder, "localization_logs"), exist_ok=True)
    os.makedirs(args.output_folder, exist_ok=True)

    # write the arguments
    with open(f"{args.output_folder}/args.json", "w") as f:
        json.dump(vars(args), f, indent=4)

    if args.merge:
        merge(args)
    else:
        localize(args)

if __name__ == "__main__":
    main()
