import os
import re
import subprocess
import json
from datasets import load_dataset
from git import Repo, GitCommandError
from pathlib import Path
def clone_at_commit(repo_id: str, base_commit: str, parent_dir: str, instance_id: str) -> str:
    repo_url = f"https://github.com/{repo_id}.git"
    repo_name = repo_id.split('/')[-1]
    target_folder = f"{repo_name}_{instance_id}"
    repo_path = os.path.join(parent_dir, target_folder)
    if os.path.exists(repo_path):
        return repo_path
    try:
        print(f"Cloning {repo_url} into {repo_path} ...")
        subprocess.run(["git", "clone", repo_url, repo_path], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        print(f"Checking out commit {base_commit} ...")
        subprocess.run(["git", "checkout", base_commit], cwd=repo_path, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        print(f"Repository cloned at {repo_path}, checked out commit {base_commit}")
        return repo_path
    except subprocess.CalledProcessError as e:
        print(f"Git command failed: {e}")
        return None
def extract_patch_info(patch_str):
    file_sections = []
    file_patches = re.split(r'(^--- a/.+?$)', patch_str, flags=re.MULTILINE)[1:]
    for i in range(0, len(file_patches), 2):
        if i + 1 >= len(file_patches):
            break
        file_header = file_patches[i].strip()
        file_content = file_patches[i + 1]
        file_match = re.match(r'^--- a/(.+?)$', file_header)
        if not file_match:
            continue
        file_path = file_match.group(1)
        changes = []
        changes_matches = re.finditer(r'@@ -(\d+)(?:,(\d+))? \+\d+(?:,\d+)? @@', file_content)
        for match in changes_matches:
            start_line = int(match.group(1))
            line_count = int(match.group(2)) if match.group(2) else 1
            changes.append({
                'start_line': start_line,
                'line_count': line_count,
                'end_line': start_line + line_count - 1
            })
        if changes:
            file_sections.append({
                'file_path': file_path,
                'changes': changes
            })
    return file_sections
def get_file_content_from_local(repo_path, file_path):
    abs_path = os.path.join(repo_path, file_path)
    try:
        with open(abs_path, "r", encoding="utf-8") as f:
            return f.read().splitlines()
    except Exception as e:
        print(f"读取文件失败: {e}")
        return None
def apply_patch(repo_path, patch_str):
    patch_file = os.path.join(repo_path, "temp.patch")
    try:
        with open(patch_file, "w", encoding="utf-8") as f:
            f.write(patch_str)
        subprocess.run(
            ["git", "apply", "temp.patch"],
            cwd=repo_path,
            check=True,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE
        )
        print("✅ Patch applied successfully.")
    except subprocess.CalledProcessError as e:
        pass
    finally:
        if os.path.exists(patch_file):
            os.remove(patch_file)
def dedup_any(seq):
    seen = set()
    result = []
    for x in seq:
        marker = json.dumps(x, sort_keys=True, ensure_ascii=False, default=str)
        if marker not in seen:
            seen.add(marker)
            result.append(x)
    return result
def parse_fail_to_pass_entries(fail_list):
    entries = []
    fail_list = eval(fail_list)
    for item in fail_list:
        if '[' in item:
            item = item.split('[')[0]
        parts = item.split("::")
        if len(parts) == 3:
            file_path, class_name, func_name = parts
        elif len(parts) == 2:
            file_path, func_name = parts
            class_name = ""
        else:
            print(f"⚠️ Unexpected FAIL_TO_PASS format: {item}")
            continue
        entries.append({
            "entry_file": file_path,
            "entry_class": class_name,
            "entry_function": func_name
        })
    return dedup_any(entries)
import re
def process_patch(dataset_path, repos_dir, output_file, limit=5):
    print("Loading dataset...")
    try:
        ds = load_dataset(dataset_path)
    except Exception as e:
        print(f"Error loading dataset: {e}")
        return
    os.makedirs(repos_dir, exist_ok=True)
    count = 0
    with open(output_file, 'w', encoding='utf-8') as f_out:
        for i, example in enumerate(ds['test']):
            if i >= limit:
                break
            repo_language = example['repo_language']
            if repo_language != 'python':
                continue
            count += 1
            print(count)
            repo_full_name = example['repo']
            instance_id = example['instance_id']
            base_commit = example['base_commit']
            patch_str = example['patch']
            test_patch_str = example.get('test_patch', '')
            issue = example.get('problem_statement', '')
            fail_to_pass = example.get('fail_to_pass', [])
            file_sections = extract_patch_info(patch_str)
            if not file_sections:
                print("No valid file sections found in patch")
                continue
            repo_path = clone_at_commit(repo_full_name, base_commit, repos_dir, instance_id=instance_id)
            if not repo_path:
                continue
            if test_patch_str:
                apply_patch(repo_path, test_patch_str)
            entry_infos = parse_fail_to_pass_entries(fail_to_pass)
            all_code_sections = []
            for file_section in file_sections:
                file_path = file_section['file_path']
                changes = file_section['changes']
                file_content = get_file_content_from_local(repo_path, file_path)
                if not file_content:
                    continue
                for change in changes:
                    start = max(0, change['start_line'] - 1)
                    end = min(len(file_content), start + change['line_count'])
                    code_snippet = "\n".join(file_content[start:end])
                    func_match = re.search(r'^\s*def\s+([a-zA-Z_]\w*)\s*\(', code_snippet, flags=re.MULTILINE)
                    func_name = func_match.group(1) if func_match else None
                    if not func_name:
                        continue
                    all_code_sections.append({
                        'file_path': file_path,
                        'start_line': change['start_line'],
                        'end_line': change['end_line'],
                        'func_name': func_name,  
                    })
            if not all_code_sections:
                continue
            all_code_sections = dedup_any(all_code_sections)
            result = {
                'repo_id': repo_path.split("/")[-1],
                'repo_path': repo_path,
                'issue': issue,
                'entry_points': entry_infos,
                'entry_file': entry_infos[0]['entry_file'],
                'entry_class': entry_infos[0]['entry_class'],
                'entry_function': entry_infos[0]['entry_function'],
                'code_sections': all_code_sections
            }
            if len(all_code_sections) > 10:
                continue
            print("Len of entry points:", len(entry_infos))
            print("Len of gts:", len(all_code_sections))
            f_out.write(json.dumps(result, ensure_ascii=False) + '\n')
            print(f"✅ Saved data for {instance_id} ({len(entry_infos)} entries, {len(all_code_sections)} code sections)")
if __name__ == "__main__":
    dataset_path = "ScaleAI/SWE-bench_Pro"
    repos_dir = "./repos_pro"
    output_file = "./issues_data_swe_pro.jsonl"
    process_patch(dataset_path, repos_dir, output_file, limit=800)
    dataset_path = "princeton-nlp/SWE-bench_Verified"
    repos_dir = "./repos_verified"
    output_file = "./issues_data_swe_verified.jsonl"
    process_patch(dataset_path, repos_dir, output_file, limit=800)
    dataset_path = "SWE-bench/SWE-smith"
    repos_dir = "./repos_smith"
    output_file = "./issues_data_swe_smith.jsonl"
    process_patch(dataset_path, repos_dir, output_file, limit=800)