#!/usr/bin/env python3
"""
Classify bug types using an LLM based on buggy code, reference code, and test failure traces.
"""

import json
import argparse
import os
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
from datasets import load_dataset
from tqdm import tqdm
from src.utils.api import get_llm_output

# Thread-safe file writing
file_lock = threading.Lock()

# Bug classification categories (6 categories)
BUG_CATEGORIES = """
1. LOGIC_ERROR - The algorithm or reasoning is fundamentally wrong.
   What it is: Wrong formula, wrong condition logic, wrong data transformation, infinite loops, wrong control flow.
   Examples: using max() instead of min(), wrong regex pattern, inverted boolean condition, zip(word, word) instead of zip(word, word[1:]), while loop that never terminates.
   NOT: Simple typos or wrong variable names (those are WRONG_VALUE).

2. WRONG_VALUE - Correct algorithm but a specific identifier, literal, or constant is wrong.
   What it is: Typos in variable/function names, wrong string literals, wrong numbers, off-by-one errors.
   Examples: 'csv_file' instead of 'csv_file_path', 'Yes' instead of 'True', range(1,7) instead of range(1,8), 'Count' instead of 'Counter', 'j' instead of 'i'.
   Use for any simple substitution error where one name/value is swapped for another.

3. MISSING_EDGE_CASE - Code works for normal inputs but fails on edge cases.
   What it is: Missing validation, missing null checks, missing empty input handling.
   Examples: no check for empty list before accessing [0], doesn't handle None, no try/except for expected errors.
   NOT: Wrong logic that fails on all inputs (that's LOGIC_ERROR).

4. API_MISUSE - Misunderstanding how to use a library/framework API, or using wrong method/type.
   What it is: Wrong method for the task, wrong argument order, calling method that doesn't exist on a type, wrong dict key access.
   Examples: model(X) instead of model.predict(X), Series.reshape() (Series has no reshape), calling .encode() on bytes, accessing df['nonexistent_column'].
   NOT: Simple typos in method names (those are WRONG_VALUE).

5. MISSING_IMPORT - A module or function needs to be imported but isn't.
   What it is: Code would work if the import statement existed.
   Examples: using os.path without 'import os', using Counter without 'from collections import Counter'.
   NOT: Typos like 'Count' instead of 'Counter' (that's WRONG_VALUE, not a missing import).

6. SYNTAX_ERROR - Code cannot be parsed by Python at all.
   What it is: ONLY actual parse errors - missing brackets, missing colons, invalid syntax constructs.
   Examples: unclosed parenthesis, missing colon after if/for/def, invalid slice syntax like [: 5].
   NOT: NameErrors from typos (those are WRONG_VALUE). NOT: IndentationErrors (exclude these).
"""

CLASSIFICATION_PROMPT = """You are a bug classification expert. Analyze the buggy code, reference code, and test failure to classify the bug type.

## Bug Categories (choose EXACTLY ONE):
{categories}

## Reference Code (correct implementation):
```python
{reference_code}
```

## Buggy Code (contains the bug):
```python
{buggy_code}
```

## Test Failure Traceback:
```
{traceback}
```

## Task:
1. Compare the buggy code to the reference code
2. Identify the PRIMARY difference that causes the test failure
3. Classify into EXACTLY ONE category

## Classification Rules:
- LOGIC_ERROR: algorithm/formula is wrong, or control flow issues like infinite loops
- WRONG_VALUE: simple substitution errors (wrong variable name, wrong string, wrong number)
- MISSING_EDGE_CASE: code works normally but fails on empty/None/boundary inputs
- API_MISUSE: wrong library method, wrong arguments, or using method that doesn't exist on a type
- MISSING_IMPORT: code would work if an import was added
- SYNTAX_ERROR: only actual parse failures (missing brackets, colons) - NOT NameErrors from typos
- If one identifier is swapped for another ('csv_file' vs 'csv_file_path'), that's WRONG_VALUE
- Exclude indentation errors from classification

Respond in JSON format:
{{
    "bug_category": "<ONE of: LOGIC_ERROR, WRONG_VALUE, MISSING_EDGE_CASE, API_MISUSE, MISSING_IMPORT, SYNTAX_ERROR>",
    "bug_description": "<1 sentence: what is wrong>",
    "root_cause": "<the specific code difference causing the failure>"
}}
"""


def classify_bug(buggy_code: str, reference_code: str, traceback: str, model: str = "gpt-4o-mini") -> dict:
    """Classify a single bug using LLM."""
    prompt = CLASSIFICATION_PROMPT.format(
        categories=BUG_CATEGORIES,
        reference_code=reference_code[:2000],  # Truncate if too long
        buggy_code=buggy_code[:2000],
        traceback=traceback[:1500]
    )
    
    try:
        response = get_llm_output(
            message=prompt,
            model=model,
            max_new_tokens=500,
            temperature=0,
            json_object=True
        )
        # Handle both dict and string responses
        if isinstance(response, dict):
            return response
        return json.loads(response)
    except Exception as e:
        return {"bug_category": "ERROR", "bug_description": str(e), "root_cause": "classification failed"}


def process_sample(row, ds_path, model, output_file, skip_pass):
    """Process a single sample - designed for parallel execution."""
    uid = row.get('uid', '')
    mi = row.get('mutation_info', '')
    
    # Handle datasets with or without mutation_info
    if mi:
        data = json.loads(mi)
        status = data[0]
        
        # Skip if passes all tests
        if skip_pass and status == "pass":
            return None
        
        # Get first test failure traceback
        test_results = data[1] if len(data) > 1 else {}
        if status == "timeout":
            traceback = "TIMEOUT: Code execution timed out (likely infinite loop)"
        elif test_results:
            first_test = list(test_results.keys())[0]
            traceback = test_results[first_test]
        else:
            traceback = "No traceback available"
    else:
        # No mutation_info - classify based on code diff only
        status = "unknown"
        traceback = "No traceback available - classify based on code comparison only"
    
    buggy_code = row.get('buggy_solution', '').replace('```python\n', '').replace('\n```', '').strip()
    reference_code = row.get('reference_solution', '').replace('```python\n', '').replace('\n```', '').strip()
    
    # Classify
    classification = classify_bug(
        buggy_code=buggy_code,
        reference_code=reference_code,
        traceback=traceback,
        model=model
    )
    
    result = {
        "uid": uid,
        "dataset": ds_path,
        "status": status,
        **classification
    }
    
    # Thread-safe file write
    with file_lock:
        with open(output_file, 'a') as f:
            f.write(json.dumps(result) + '\n')
    
    return result


def main():
    parser = argparse.ArgumentParser(description="Classify bugs using LLM")
    parser.add_argument("--dataset", type=str, default="human", 
                        choices=["human", "qwen", "oss", "og", "new", "all"],
                        help="Dataset to classify")
    parser.add_argument("--model", type=str, default="gpt-4o-mini",
                        help="LLM model to use for classification")
    parser.add_argument("--limit", type=int, default=None,
                        help="Limit number of samples to classify")
    parser.add_argument("--output", type=str, default="bug_classifications.jsonl",
                        help="Output file path")
    parser.add_argument("--skip-pass", action="store_true",
                        help="Skip samples where buggy code passes all tests")
    parser.add_argument("--workers", type=int, default=4,
                        help="Number of parallel workers")
    parser.add_argument("--split", type=str, default="test",
                        help="Dataset split to use (e.g., test, test_large)")
    args = parser.parse_args()
    
    DATASETS = {
        # "adv": "anonymous/bugbench_adv",  # Excluded
        "human": "anonymous/bugbench_human",
        "qwen": "anonymous/bugbench_qwen",
        "oss": "anonymous/bugbench_oss",
        "og": "anonymous/bugbench_og",
        "new": "anonymous/bugbench_new",
    }
    
    if args.dataset == "all":
        datasets_to_process = list(DATASETS.items())
    else:
        datasets_to_process = [(args.dataset, DATASETS[args.dataset])]
    
    results = []
    
    for ds_name, ds_path in datasets_to_process:
        print(f"\nProcessing {ds_name} with {args.workers} workers...")
        ds = load_dataset(ds_path, split=args.split)
        
        samples = list(ds)
        if args.limit:
            samples = samples[:args.limit]
        
        # Parallel processing
        with ThreadPoolExecutor(max_workers=args.workers) as executor:
            futures = {
                executor.submit(process_sample, row, ds_path, args.model, args.output, args.skip_pass): row
                for row in samples
            }
            
            for future in tqdm(as_completed(futures), total=len(futures), desc=ds_name):
                result = future.result()
                if result:
                    results.append(result)
    
    # Print summary
    print("\n" + "=" * 60)
    print("CLASSIFICATION SUMMARY")
    print("=" * 60)
    
    from collections import Counter
    for ds_name in set(r['dataset'] for r in results):
        ds_results = [r for r in results if r['dataset'] == ds_name]
        categories = Counter(r['bug_category'] for r in ds_results)
        
        print(f"\n{ds_name.upper()} ({len(ds_results)} samples):")
        for cat, count in categories.most_common():
            pct = count / len(ds_results) * 100
            print(f"  {cat:<25} {count:>4} ({pct:>5.1f}%)")


if __name__ == "__main__":
    main()
