
""" Functions for generating test quality labels """

import util_repo
import util_tokens
import pandas as pd

repo_folder = "repos"  # for local clones

def instance_label(df_instance, patch_semantic_fcn, test_semantic_fcn):
    """
    Determine if instance should be excluded based on test patch heuristic.
    Assigns True (exclude) or False (retain) to the 'auto_test_label' field.
    """

    print(df_instance["instance_id"])

    try:

        instance_repo = util_repo.InstanceRepo(repo_folder, df_instance["repo"])

        # get solution patch tokens
        instance_repo.checkout(df_instance["base_commit"])
        modified_files_patch = instance_repo.apply_patchstr(df_instance["patch"])
        patch_strings, patch_numbers, patch_names = process_modified_files(instance_repo, modified_files_patch, patch_semantic_fcn)

        # get test patch tokens
        instance_repo.checkout(df_instance["base_commit"])  # reset
        modified_files_test = instance_repo.apply_patchstr(df_instance["test_patch"])
        test_strings, test_numbers, test_names = process_modified_files(instance_repo, modified_files_test, test_semantic_fcn)
        # Simple heuristic to filter out idiomatic names, including magic methods
        idioms = ["self", "cls"] + [n for n in dir(int) if n[0:2] == "__"]  # e.g. int(), str(), list() methods
        test_names = test_names - set(idioms)

        # determine overlapping tokens between solution and test patches
        overlapping_strings = patch_strings.intersection(test_strings)
        overlapping_numbers = patch_numbers.intersection(test_numbers)
        overlapping_names = patch_names.intersection(test_names)

        # check if overlapping tokens are also present in the issue description
        # if not, add to lists of 'unknown' tokens
        unknown_strings = [s for s in overlapping_strings if s not in df_instance["problem_statement"]]
        unknown_numbers = [n for n in overlapping_numbers if n not in df_instance["problem_statement"]]
        unknown_names = [n for n in overlapping_names if n not in df_instance["problem_statement"]]

        # determine whether to exclude the issue
        # exclude (true) if any unknown strings, numbers or names
        if max(len(unknown_strings), len(unknown_numbers), len(unknown_names)) > 0:
            exclude_issue = True
        else:
            exclude_issue = False

        # update instance row
        df_instance["unknown_strings"] = unknown_strings
        df_instance["unknown_numbers"] = unknown_numbers
        df_instance["unknown_names"] = unknown_names
        df_instance["auto_test_label"] = exclude_issue
        df_instance["curation_error"] = False
        df_instance["curation_error_msg"] = ""

        return df_instance

    except Exception as e:
        # report error and retain instance (negative result)
        # this aligns with treatment of missing values in raw swe-v annotations
        df_instance["auto_test_label"] = False
        df_instance["curation_error"] = True
        df_instance["curation_error_msg"] = str(e)
        print(str(e))
        # show traceback if needed
        # import traceback
        # traceback.print_exc()
        print("ERROR: Retain instance (negative result)")
        return df_instance


def process_modified_files(instance_repo, modified_files, semantic_fcn=None):
    """
    Applies the tokenization and/or semantic functions to all modified files
    Determines modified line numbers and provides them to the functions
    """

    modified_strings = set(); modified_numbers = set(); modified_names = set()

    for source_file in modified_files:
        # only do if this is a .py file
        if not source_file.endswith(".py"):
            continue

        modified_lines = instance_repo.get_modified_lines(source_file)

        # use semantic function for names if provided
        if semantic_fcn:
            names = semantic_fcn(source_file, modified_lines)
            strs, nums, _ = util_tokens.tokenize_file(source_file, modified_lines, include_names=False)
        else:
            # otherwise get everything from the tokenizer
            strs, nums, names = util_tokens.tokenize_file(source_file, modified_lines, include_names=True)

        modified_strings.update(strs)
        modified_numbers.update(nums)
        modified_names.update(names)

    return modified_strings, modified_numbers, modified_names


def generate_labels(df_benchmark, df_sample, identifier_mode, limit=None):
    """
    Generate all test quality labels for a set of sample instances.
    Returns a processed dataframe without saving to files.
    """

    if limit:
        print("Limiting to "+str(limit)+" samples for testing")
        df_sample = df_sample.head(limit)

    print("Reducing benchmark to sample instances")
    df_benchmark = pd.merge(df_benchmark, df_sample, how='inner', on='instance_id')

    # determine semantic functions for identifiers
    match identifier_mode:
        case "semantic":
            import util_semantic
            patch_semantic_fcn = util_semantic.declared_identifiers
            test_semantic_fcn = util_semantic.used_identifiers
        case "tokens":
            patch_semantic_fcn = None  # defaults to tokenizer
            test_semantic_fcn = None  # defaults to tokenizer
        case _:
            raise ValueError("Invalid identifier_mode. Choose 'semantic' or 'tokens'.")

    print("\nGenerating labels for all instances ...\n")
    labeller = lambda row: instance_label(row, patch_semantic_fcn, test_semantic_fcn)
    df_labels = df_benchmark.apply(labeller, axis=1)

    df_labels = df_labels[["repo", "instance_id", "problem_statement", "patch", "test_patch",
                    "unknown_strings", "unknown_numbers", "unknown_names",
                    "auto_test_label","curation_error", "curation_error_msg"]]

    return df_labels


def save_labels_to_json(df_labels, experiment_folder):
    """
    Save the processed labels dataframe to JSON files
    in the specified experiment folder.
    """

    labels_json = "labels.json"  # for generated labels
    errors_json = "errors.jsonl"  # for curation errors

    print("\nWriting labels to file\n")
    df_labels.to_json(experiment_folder+"/"+labels_json, orient='records')

    # list errors for convenience
    df_errors = df_labels.loc[df_labels["curation_error"]][["instance_id", "curation_error_msg"]]
    df_errors.to_json(experiment_folder+"/"+errors_json, orient='records', lines=True)
