import json
import os
import re
import requests
import traceback

from argparse import ArgumentTypeError
from concurrent.futures import ThreadPoolExecutor, as_completed
from datasets import Dataset, load_dataset, load_from_disk, DownloadMode, DownloadConfig
from dotenv import load_dotenv
from pathlib import Path
from tqdm import tqdm
from typing import cast
from swebench.harness.constants import (
    SWEbenchInstance,
    KEY_INSTANCE_ID,
    KEY_MODEL,
    KEY_PREDICTION,
)
from unidiff import PatchSet

load_dotenv()


class EvaluationError(Exception):
    def __init__(self, instance_id, message, logger):
        super().__init__(message)
        self.instance_id = instance_id
        self.log_file = logger.log_file
        self.logger = logger

    def __str__(self):
        log_msg = traceback.format_exc()
        self.logger.info(log_msg)
        return (
            f"{self.instance_id}: {super().__str__()}\n"
            f"Check ({self.log_file}) for more information."
        )


def get_predictions_from_file(predictions_path: str, dataset_name: str, split: str):
    if predictions_path == "gold":
        if os.environ.get("TE_QUIET", "").lower() not in ("1", "true", "yes"):
            print("Using gold predictions - ignoring predictions_path")
        dataset = load_swebench_dataset(dataset_name, split)
        return [
            {
                KEY_INSTANCE_ID: datum[KEY_INSTANCE_ID],
                KEY_PREDICTION: datum["patch"],
                KEY_MODEL: "gold",
            }
            for datum in dataset
        ]
    if predictions_path.endswith(".json"):
        # Try UTF-8 with BOM, then UTF-8, then ignore errors as last resort
        try:
            with open(predictions_path, "r", encoding="utf-8-sig") as f:
                predictions = json.load(f)
        except UnicodeDecodeError:
            try:
                with open(predictions_path, "r", encoding="utf-8") as f:
                    predictions = json.load(f)
            except UnicodeDecodeError:
                with open(predictions_path, "r", encoding="utf-8", errors="ignore") as f:
                    predictions = json.load(f)
            if isinstance(predictions, dict):
                predictions = list(
                    predictions.values()
                )  # compatible with SWE-agent predictions
            if not isinstance(predictions, list):
                raise ValueError(
                    "Predictions must be a list[prediction] or a dictionary[instance_id: prediction]"
                )
    elif predictions_path.endswith(".jsonl"):
        # Try UTF-8 with BOM, then UTF-8, then ignore errors as last resort
        try:
            with open(predictions_path, "r", encoding="utf-8-sig") as f:
                predictions = [json.loads(line) for line in f]
        except UnicodeDecodeError:
            try:
                with open(predictions_path, "r", encoding="utf-8") as f:
                    predictions = [json.loads(line) for line in f]
            except UnicodeDecodeError:
                with open(predictions_path, "r", encoding="utf-8", errors="ignore") as f:
                    predictions = [json.loads(line) for line in f]
    else:
        raise ValueError("Predictions path must be .json or .jsonl")

    # Normalize instance id key (support model files that use 'task_id' or variants)
    for pred in predictions:
        if isinstance(pred, dict) and KEY_INSTANCE_ID not in pred:
            alt_keys = [
                "task_id",
                "instanceId",
                "id",
            ]
            for ak in alt_keys:
                if ak in pred:
                    pred[KEY_INSTANCE_ID] = pred[ak]
                    break
    # Validate that each prediction has an instance_id
    for pred in predictions:
        if not isinstance(pred, dict):
            raise ValueError(f"Each prediction must be a dictionary, got {type(pred)}")
        if KEY_INSTANCE_ID not in pred:
            raise ValueError(f"Each prediction must contain '{KEY_INSTANCE_ID}'")

    return predictions


def run_threadpool(func, payloads, max_workers):
    if max_workers <= 0:
        return run_sequential(func, payloads)
    succeeded, failed = [], []
    # Compute total TE tests (if payloads carry TestSpec with new_te_tests_count)
    try:
        te_total = sum(getattr(p[0], "new_te_tests_count", 0) or 0 for p in payloads)
    except Exception:
        te_total = 0
    with tqdm(total=len(payloads), smoothing=0) as pbar:
        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            # Create a future for running each instance
            futures = {executor.submit(func, *payload): payload for payload in payloads}
            # Wait for each future to complete
            for future in as_completed(futures):
                try:
                    # Check if instance ran successfully
                    future.result()
                    succeeded.append(futures[future])
                except Exception as e:
                    print(f"{type(e)}: {e}")
                    traceback.print_exc()
                    failed.append(futures[future])
                # Update progress bar
                pbar.update(1)
                if te_total > 0:
                    pbar.set_description(
                        f"{len(succeeded)} ran successfully, {len(failed)} failed | TE new tests: {te_total}"
                    )
                else:
                    pbar.set_description(
                        f"{len(succeeded)} ran successfully, {len(failed)} failed"
                    )
    return succeeded, failed


def run_sequential(func, args_list):
    """
    Run a function with a list of arguments sequentially
    """
    succeeded, failed = [], []
    pbar = tqdm(total=len(args_list), smoothing=0)
    for args in args_list:
        try:
            func(*args)
            succeeded.append(args)
        except Exception:
            traceback.print_exc()
            failed.append(args)
        pbar.update(1)
        pbar.set_description(f"{len(succeeded)} ran successfully, {len(failed)} failed")
    pbar.close()
    return succeeded, failed


def load_swebench_dataset(
    name="SWE-bench/SWE-bench", split="test", instance_ids=None
) -> list[SWEbenchInstance]:
    """
    Load SWE-bench dataset from Hugging Face Datasets or local .json/.jsonl file
    """
    # check that all instance IDs are in the dataset
    if instance_ids:
        instance_ids = set(instance_ids)
    # Load from local .json/.jsonl file
    if name.endswith(".json"):
        dataset = json.loads(Path(name).read_text(encoding="utf-8"))
    elif name.endswith(".jsonl"):
        try:
            text = Path(name).read_text(encoding="utf-8-sig")
        except UnicodeDecodeError:
            try:
                text = Path(name).read_text(encoding="utf-8")
            except UnicodeDecodeError:
                text = Path(name).read_text(encoding="utf-8", errors="ignore")
        dataset = [json.loads(line) for line in text.splitlines()]
    else:
        # Load from Hugging Face Datasets
        # Offline mode is now opt-in via TE_FORCE_OFFLINE=1 (or existing HF_*_OFFLINE envs)
        force_offline = os.environ.get("TE_FORCE_OFFLINE") in ("1", "true", "yes", "on")
        if force_offline:
            os.environ.setdefault("HF_DATASETS_OFFLINE", "1")
            os.environ.setdefault("HF_HUB_OFFLINE", "1")
            os.environ.setdefault("HUGGINGFACE_HUB_OFFLINE", "1")
            # Programmatically tell huggingface_hub to go offline (prevents README HEADs)
            try:
                from huggingface_hub import set_offline_mode as _set_offline
                _set_offline(True)
            except Exception:
                pass
        # Reduce warning verbosity from datasets/hub
        try:
            import datasets as _ds
            _ds.utils.logging.set_verbosity_error()
        except Exception:
            pass
        if name.lower() in {"swe-bench", "swebench", "swe_bench"}:
            name = "SWE-bench/SWE-bench"
        elif name.lower() in {
            "swe-bench-lite",
            "swebench-lite",
            "swe_bench_lite",
            "swe-bench_lite",
            "lite",
        }:
            name = "SWE-bench/SWE-bench_Lite"
        # 1) If a local dataset directory is provided, use it
        if (Path(name) / split / "dataset_info.json").exists():
            dataset = cast(Dataset, load_from_disk(Path(name) / split))
        else:
            # 2) Try to resolve from the local HF datasets cache to avoid any hub touches
            dataset = None  # type: ignore[assignment]
            try:
                # Determine HF datasets cache root
                hf_home = os.environ.get("HF_HOME") or os.path.join(Path.home(), ".cache", "huggingface")
                cache_root = Path(os.environ.get("HF_DATASETS_CACHE") or Path(hf_home) / "datasets")
                # Known repo name mappings
                wanted = []
                if name.lower().endswith("swe-bench"):
                    wanted = ["SWE-bench___swe-bench"]
                elif name.lower().endswith("swe-bench_lite"):
                    wanted = ["SWE-bench___swe-bench_lite", "SWE-bench___swe-bench_Lite"]
                else:
                    wanted = []
                found = []
                if cache_root.exists():
                    for child in cache_root.iterdir():
                        try:
                            if any(child.name.startswith(w) for w in wanted):
                                # descend versions/default
                                for sub in child.rglob("dataset_info.json"):
                                    if sub.parent.name == split:
                                        found.append(sub.parent)
                        except Exception:
                            continue
                if len(found) > 0:
                    # Pick the newest by modified time
                    found.sort(key=lambda p: p.stat().st_mtime, reverse=True)
                    dataset = cast(Dataset, load_from_disk(found[0]))
            except Exception:
                dataset = None  # type: ignore[assignment]
            # 3) Final fallback: load_dataset
            if dataset is None:
                # If offline forced, do local-only; else allow download
                local_only = os.environ.get("HF_DATASETS_OFFLINE") in ("1", "true", "yes", "on") or force_offline
                dl_cfg = DownloadConfig(local_files_only=local_only)
                dataset = cast(
                    Dataset,
                    load_dataset(
                        name,
                        split=split,
                        download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS,
                        download_config=dl_cfg,
                    ),
                )
    dataset_ids = {instance[KEY_INSTANCE_ID] for instance in dataset}
    if instance_ids:
        if instance_ids - dataset_ids:
            raise ValueError(
                (
                    "Some instance IDs not found in dataset!"
                    f"\nMissing IDs:\n{' '.join(instance_ids - dataset_ids)}"
                )
            )
        dataset = [
            instance
            for instance in dataset
            if instance[KEY_INSTANCE_ID] in instance_ids
        ]
    return [cast(SWEbenchInstance, instance) for instance in dataset]


### MARK - Patch Correction
PATCH_PATTERN = re.compile(
    r"(?:diff[\w\_\.\ \/\-]+\n)?\-\-\-\s+a\/(?:.*?)\n\+\+\+\s+b\/(?:.*?)(?=diff\ |\-\-\-\ a\/|\Z)",
    re.DOTALL,
)
PATCH_FILE_PATTERN = re.compile(r"\-\-\-\s+a\/(?:.+)\n\+\+\+\s+b\/(?:.+)")
PATCH_HUNK_PATTERN = re.compile(
    r"\@\@\s+\-(\d+),(\d+)\s+\+(\d+),(\d+)\s+\@\@(.+?)(?=diff\ |\-\-\-\ a\/|\@\@\ \-|\Z)",
    re.DOTALL,
)


def get_first_idx(charlist):
    """Get index of first occurrence of "-" or "+" in charlist"""
    first_min = charlist.index("-") if "-" in charlist else len(charlist)
    first_plus = charlist.index("+") if "+" in charlist else len(charlist)
    return min(first_min, first_plus)


def get_last_idx(charlist):
    """Get index of last occurrence of "-" or "+" in charlist"""
    char_idx = get_first_idx(charlist[::-1])
    last_idx = len(charlist) - char_idx
    return last_idx + 1


def strip_content(hunk):
    """Remove trailing non +/- lines and trailing whitespace per line per hunk"""
    first_chars = list(map(lambda x: None if not len(x) else x[0], hunk.split("\n")))
    first_idx = get_first_idx(first_chars)
    last_idx = get_last_idx(first_chars)
    new_lines = list(map(lambda x: x.rstrip(), hunk.split("\n")[first_idx:last_idx]))
    # should leave one space for empty context lines
    new_lines = [line if line.strip() else " " for line in new_lines]
    new_hunk = "\n" + "\n".join(new_lines) + "\n"
    return new_hunk, first_idx - 1


def get_hunk_stats(pre_start, pre_len, post_start, post_len, hunk, total_delta):
    """Recalculate hunk start/end position and diff delta"""
    stats = {"context": 0, "added": 0, "subtracted": 0}
    hunk = hunk.split("\n", 1)[-1].strip("\n")
    for line in hunk.split("\n"):
        if line.startswith("-"):
            stats["subtracted"] += 1
        elif line.startswith("+"):
            stats["added"] += 1
        else:
            stats["context"] += 1
    context = stats["context"]
    added = stats["added"]
    subtracted = stats["subtracted"]
    pre_len = context + subtracted
    post_start = pre_start + total_delta
    post_len = context + added
    total_delta = total_delta + (post_len - pre_len)
    return pre_start, pre_len, post_start, post_len, total_delta


def extract_minimal_patch(model_patch):
    """
    Wrapper function that takes hunk and
    * Removes trailing non +/- lines and trailing whitespace per line per hunk
    * Recalculates hunk start/end position and diff delta
    * Returns new patch
    """
    model_patch = model_patch.lstrip("\n")
    new_patch = ""
    for patch in PATCH_PATTERN.findall(model_patch):
        total_delta = 0
        patch_header = PATCH_FILE_PATTERN.findall(patch)[0]
        if patch_header:
            new_patch += patch_header + "\n"
        for hunk in PATCH_HUNK_PATTERN.findall(patch):
            pre_start, pre_len, post_start, post_len, content = hunk
            pre_start, pre_len, post_start, post_len, content = list(
                map(lambda x: int(x) if x.isnumeric() else x, hunk)
            )
            content, adjust_pre_start = strip_content(content)
            pre_start += adjust_pre_start
            pre_start, pre_len, post_start, post_len, total_delta = get_hunk_stats(
                pre_start, pre_len, post_start, post_len, content, total_delta
            )
            new_patch += (
                f"@@ -{pre_start},{pre_len} +{post_start},{post_len} @@{content}"
            )
    return new_patch


def has_attribute_or_import_error(log_before):
    """
    Check to see if Attribute/Import-prefix is in log text

    Args:
        log_before (str): Validation log text before patch application
    """
    log_before = log_before.lower()

    if any([x in log_before for x in ["attribute", "import"]]):

        def get_lines_with_word(text, target_word):
            # Function to extract line(s) that contains target_word
            text, target_word = text.lower(), target_word.lower()
            lines, hits = text.split("\n")[::-1], []
            for line in lines:
                if target_word in line:
                    hits.append(line)
            return hits

        # Get line with Attribute/Import error
        lines_1 = get_lines_with_word(log_before, "attribute")
        lines_2 = get_lines_with_word(log_before, "import")
        lines_1 = " ".join(lines_1)
        lines_2 = " ".join(lines_2)

        if any([(x in lines_1 or x in lines_2) for x in ["error", "fail"]]):
            return True
    return False


def str2bool(v):
    """
    Minor helper function to convert string to boolean
    """
    if isinstance(v, bool):
        return v
    if v.lower() in ("yes", "true", "t", "y", "1"):
        return True
    elif v.lower() in ("no", "false", "f", "n", "0"):
        return False
    else:
        raise ArgumentTypeError("Boolean value expected.")


def optional_str(value: str) -> str | None:
    """
    Convert special string values to None, otherwise return the string as-is.
    """
    if value.lower() in ("none", "null", ""):
        return None
    return value


def get_repo_file(repo, commit, filepath):
    url = f"https://raw.githubusercontent.com/{repo}/{commit}/{filepath}"
    try:
        response = requests.get(url)
        if response.status_code == 200:
            return response.text
        return None
    except:
        return None


def get_modified_files(patch: str) -> list[str]:
    """
    Get the list of modified files in a patch
    """
    source_files = []
    for file in PatchSet(patch):
        if file.source_file != "/dev/null":
            source_files.append(file.source_file)
    source_files = [x[2:] for x in source_files if x.startswith("a/")]
    return source_files


def ansi_escape(text: str) -> str:
    """
    Remove ANSI escape sequences from text
    """
    return re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])").sub("", text)
