import docker
import json
import re
from pathlib import Path
from typing import Optional, Dict, List, Tuple
import difflib
from collections import defaultdict
from unidiff import PatchSet

from swebench.harness.constants import (
    KEY_INSTANCE_ID,
    KEY_MODEL,
    KEY_PREDICTION,
    RUN_EVALUATION_LOG_DIR,
    LOG_REPORT,
)
from swebench.harness.docker_utils import list_images
from swebench.harness.test_spec.test_spec import make_test_spec
from swebench.harness.utils import get_modified_files


def safe_get_modified_files(patch: str) -> set:
    """
    Safely get modified files from a patch, handling parse errors.
    
    Args:
        patch (str): The patch string
        
    Returns:
        set: Set of modified file paths, empty set if parsing fails
    """
    try:
        return set(get_modified_files(patch))
    except Exception as e:
        # If unidiff fails, fall back to manual parsing
        return _manual_get_modified_files(patch)


def _manual_get_modified_files(patch: str) -> set:
    """
    Manually extract modified files from patch when unidiff fails.
    
    Args:
        patch (str): The patch string
        
    Returns:
        set: Set of modified file paths
    """
    if not patch:
        return set()
    
    modified_files = set()
    
    for line in patch.split('\n'):
        line = line.strip()
        
        # Check for diff --git format
        if line.startswith('diff --git'):
            # Extract file path from "diff --git a/path/to/file b/path/to/file"
            parts = line.split()
            if len(parts) >= 4:
                # Remove 'a/' prefix from the file path
                file_path = parts[2]
                if file_path.startswith('a/'):
                    file_path = file_path[2:]
                modified_files.add(file_path)
        
        # Check for --- a/ format
        elif line.startswith('--- a/'):
            file_path = line[6:]  # Remove '--- a/' prefix
            modified_files.add(file_path)
        
        # Check for +++ b/ format  
        elif line.startswith('+++ b/'):
            file_path = line[6:]  # Remove '+++ b/' prefix
            modified_files.add(file_path)
        
        # Check for --- format without a/ prefix
        elif line.startswith('--- ') and not line.startswith('--- a/'):
            file_path = line[4:]  # Remove '--- ' prefix
            # Skip /dev/null and timestamp info
            if file_path and not file_path.startswith('/dev/null') and '\t' not in file_path:
                modified_files.add(file_path)
        
        # Check for +++ format without b/ prefix
        elif line.startswith('+++ ') and not line.startswith('+++ b/'):
            file_path = line[4:]  # Remove '+++ ' prefix
            # Skip /dev/null and timestamp info
            if file_path and not file_path.startswith('/dev/null') and '\t' not in file_path:
                modified_files.add(file_path)
    
    return modified_files


def calculate_comprehensive_metrics(predictions: dict, full_dataset: list, run_id: str = None) -> dict:
    """
    Calculate comprehensive metrics for patches including the 9 requested indicators.

    Args:
        predictions (dict): Predictions dict generated by the model
        full_dataset (list): List of all instances
        run_id (str): Run ID for accessing report files

    Returns:
        dict: Dictionary containing comprehensive metrics
    """
    # Create a mapping from instance_id to instance data
    instance_data = {instance[KEY_INSTANCE_ID]: instance for instance in full_dataset}
    
    # Initialize counters
    total_patches = 0
    empty_patches = 0
    unresolved_patches = 0
    resolved_patches = 0
    total_modified_lines = 0
    total_modified_files = 0
    total_cyclomatic_complexity = 0
    total_code_smells = 0
    total_error_handling = 0
    total_gold_files = 0  # Total files modified in all gold patches
    total_localized_files = 0  # Total files correctly localized in generated patches
    
    # Lists to store per-instance metrics
    modified_lines_list = []
    modified_files_list = []
    cyclomatic_complexity_list = []
    code_smells_list = []
    error_handling_list = []
    resolved_list = []
    localization_success_list = []
    instance_ids = []  # Store instance IDs to track correspondence with metrics
    
    # Lists to store 0/1 indicators for each instance
    empty_patch_binary_list = []  # 1 if empty patch, 0 otherwise
    unresolved_patch_binary_list = []  # 1 if unresolved patch, 0 otherwise
    apply_list = []  # 1 if patch successfully applied, 0 otherwise
    
    # Lists to store empty and unresolved instance IDs
    empty_patch_list = []
    unresolved_patch_list = []
    
    for instance_id, prediction in predictions.items():
        if instance_id not in instance_data:
            continue
            
        instance = instance_data[instance_id]
        model_patch = prediction.get(KEY_PREDICTION, "")
        
        total_patches += 1
        
        # Initialize binary indicators for this instance
        is_empty_patch = 0
        is_unresolved = 0
        is_resolved = 0
        is_applied = 0  # Initialize patch application status
        
        # 1. Empty Patch Rate
        if not model_patch or model_patch.strip() == "":
            empty_patches += 1
            is_empty_patch = 1
            empty_patch_list.append(instance_id)  # Add to empty patch list
            # For empty patches, add zero values to detailed metrics
            modified_lines_list.append(0)
            modified_files_list.append(0)
            cyclomatic_complexity_list.append(0)
            code_smells_list.append(0)
            error_handling_list.append(0)
            localization_success_list.append(0.0)
            resolved_list.append(0)
            instance_ids.append(instance_id)  # Track instance ID
            # Add binary indicators
            empty_patch_binary_list.append(1)
            unresolved_patch_binary_list.append(0)  # Empty patches are not counted as unresolved
            apply_list.append(0)  # Empty patches are not applied
            continue
        
        # Check if patch was resolved (from report files)
        if run_id:
            report_file = (
                RUN_EVALUATION_LOG_DIR
                / run_id
                / prediction[KEY_MODEL].replace("/", "__")
                / instance_id
                / LOG_REPORT
            )
            
            if report_file.exists():
                try:
                    report = json.loads(report_file.read_text())
                    if instance_id in report:
                        # 3. Resolved
                        if report[instance_id].get("resolved", False):
                            resolved_patches += 1
                            is_resolved = 1
                        # Check if patch was successfully applied
                        if report[instance_id].get("patch_successfully_applied", False):
                            is_applied = 1
                except:
                    pass
        
        # Determine if this patch is unresolved (non-empty and not resolved)
        if not is_resolved:
            unresolved_patches += 1
            is_unresolved = 1
            unresolved_patch_list.append(instance_id)  # Add to unresolved patch list
        
        # 4. Modified Lines
        modified_lines = _count_modified_lines(model_patch)
        
        # 5. Modified Files
        modified_files = _count_modified_files(model_patch)
        
        # Add metrics to detailed lists for all patches (no discarding based on file/line count)
        modified_lines_list.append(modified_lines)
        modified_files_list.append(modified_files)
        
        # 6. Cyclomatic Complexity
        cyclomatic_complexity = _calculate_cyclomatic_complexity(model_patch)
        cyclomatic_complexity_list.append(cyclomatic_complexity)
        
        # 7. Localization Success Rate
        gold_patch = instance.get("patch", "")
        gold_files = safe_get_modified_files(gold_patch) if gold_patch else set()
        model_files = safe_get_modified_files(model_patch)
        
        # Store individual localization success rate for this instance
        instance_localization_rate = len(model_files & gold_files) / len(gold_files) if gold_files else 0.0
        localization_success_list.append(instance_localization_rate)
        
        # 8. Code Smells
        code_smells = _detect_code_smells_count(model_patch)
        code_smells_list.append(code_smells)
        
        # 9. Error Handling
        error_handling = _count_error_handling(model_patch)
        error_handling_list.append(error_handling)
        
        # Add resolved status for all patches
        resolved_list.append(is_resolved)
        
        # Track instance ID for correspondence with metrics
        instance_ids.append(instance_id)
        
        # Add binary indicators for non-empty patches
        empty_patch_binary_list.append(0)  # Not empty
        unresolved_patch_binary_list.append(1 if is_unresolved else 0)
        apply_list.append(1 if is_applied else 0)  # 1 if patch successfully applied, 0 otherwise
        
        # Only accumulate totals for unresolved patches (for average calculations)
        if is_unresolved:
            total_modified_lines += modified_lines
            total_modified_files += modified_files
            total_cyclomatic_complexity += cyclomatic_complexity
            total_code_smells += code_smells
            total_error_handling += error_handling
            
            # Count total gold files and correctly localized files
            total_gold_files += len(gold_files)
            localized_files = len(model_files & gold_files)  # Intersection
            total_localized_files += localized_files
    
    # Calculate rates and averages
    non_empty_patches = total_patches - empty_patches
    # Count of instances that were not discarded (includes empty, unresolved, and resolved patches)
    included_patches = len(modified_lines_list)  # This represents all patches that were processed
    
    return {
        "empty_patch_rate": empty_patches / total_patches if total_patches > 0 else 0,
        "unresolved_rate": unresolved_patches / total_patches if total_patches > 0 else 0,
        "resolved_rate": resolved_patches / total_patches if total_patches > 0 else 0,
        "avg_modified_lines": total_modified_lines / unresolved_patches if unresolved_patches > 0 else 0,
        "avg_modified_files": total_modified_files / unresolved_patches if unresolved_patches > 0 else 0,
        "avg_cyclomatic_complexity": total_cyclomatic_complexity / unresolved_patches if unresolved_patches > 0 else 0,
        "localization_success_rate": total_localized_files / total_gold_files if total_gold_files > 0 else 0,
        "avg_code_smells": total_code_smells / unresolved_patches if unresolved_patches > 0 else 0,
        "avg_error_handling": total_error_handling / unresolved_patches if unresolved_patches > 0 else 0,
        "total_patches": total_patches,
        "empty_patches": empty_patches,
        "non_empty_patches": non_empty_patches,
        "included_patches": included_patches,  # All patches that were processed (includes empty patches)
        "unresolved_patches": unresolved_patches,
        "resolved_patches": resolved_patches,
        "total_gold_files": total_gold_files,  # Total files in all gold patches
        "total_localized_files": total_localized_files,  # Total correctly localized files
        "empty_patch_list": empty_patch_list,  # List of instance IDs with empty patches
        "unresolved_patch_list": unresolved_patch_list,  # List of instance IDs with unresolved patches
        "detailed_metrics": {
            "instance_ids": instance_ids,  # List of instance IDs corresponding to each metric value
            "localization_success_list": localization_success_list,
            "modified_lines_list": modified_lines_list,
            "modified_files_list": modified_files_list,
            "cyclomatic_complexity_list": cyclomatic_complexity_list,
            "code_smells_list": code_smells_list,
            "error_handling_list": error_handling_list,
            "resolved_list": resolved_list,
            "empty_patch_binary_list": empty_patch_binary_list,  # 0/1 list for empty patches
            "unresolved_patch_binary_list": unresolved_patch_binary_list,  # 0/1 list for unresolved patches
            "apply_list": apply_list,  # 0/1 list for patch application status
        }
    }


def _count_modified_lines(patch: str) -> int:
    """Count the number of modified lines in a patch."""
    if not patch:
        return 0
    
    added_lines = 0
    deleted_lines = 0
    
    for line in patch.split('\n'):
        if line.startswith('+') and not line.startswith('+++'):
            added_lines += 1
        elif line.startswith('-') and not line.startswith('---'):
            deleted_lines += 1
    
    return added_lines + deleted_lines


def _count_modified_files(patch: str) -> int:
    """Count the number of modified files in a patch."""
    if not patch:
        return 0
    
    # Use a set to store unique file paths to avoid duplicates
    modified_files = set()
    
    for line in patch.split('\n'):
        line = line.strip()
        
        # Check for diff --git format
        if line.startswith('diff --git'):
            # Extract file path from "diff --git a/path/to/file b/path/to/file"
            parts = line.split()
            if len(parts) >= 4:
                # Remove 'a/' prefix from the file path
                file_path = parts[2]
                if file_path.startswith('a/'):
                    file_path = file_path[2:]
                modified_files.add(file_path)
        
        # Check for --- a/ format
        elif line.startswith('--- a/'):
            file_path = line[6:]  # Remove '--- a/' prefix
            modified_files.add(file_path)
        
        # Check for +++ b/ format  
        elif line.startswith('+++ b/'):
            file_path = line[6:]  # Remove '+++ b/' prefix
            modified_files.add(file_path)
        
        # Check for --- format without a/ prefix
        elif line.startswith('--- ') and not line.startswith('--- a/'):
            file_path = line[4:]  # Remove '--- ' prefix
            # Skip /dev/null and timestamp info
            if file_path and not file_path.startswith('/dev/null') and '\t' not in file_path:
                modified_files.add(file_path)
        
        # Check for +++ format without b/ prefix
        elif line.startswith('+++ ') and not line.startswith('+++ b/'):
            file_path = line[4:]  # Remove '+++ ' prefix
            # Skip /dev/null and timestamp info
            if file_path and not file_path.startswith('/dev/null') and '\t' not in file_path:
                modified_files.add(file_path)
    
    return len(modified_files)


def _calculate_cyclomatic_complexity(patch: str) -> float:
    """Calculate cyclomatic complexity introduced by the patch."""
    if not patch:
        return 0.0
    
    added_lines = []
    for line in patch.split('\n'):
        if line.startswith('+') and not line.startswith('+++'):
            added_lines.append(line[1:])  # Remove the '+' prefix
    
    if not added_lines:
        return 0.0
    
    # Join lines to form code blocks
    code_text = '\n'.join(added_lines)
    
    # Simple heuristic-based complexity calculation
    # Count complexity-increasing constructs
    complexity_indicators = [
        'if ', 'elif ', 'else:', 'for ', 'while ', 'try:', 'except:', 
        'and ', 'or ', 'with ', 'match ', 'case '
    ]
    
    complexity = 0
    for line in added_lines:
        line_lower = line.lower().strip()
        for indicator in complexity_indicators:
            if indicator in line_lower:
                complexity += 1
                break  # Only count once per line
    
    return complexity


def _calculate_localization_success(model_patch: str, gold_patch: str) -> float:
    """Calculate localization success rate at file level."""
    if not model_patch or not gold_patch:
        return 0.0
    
    try:
        model_files = safe_get_modified_files(model_patch)
        gold_files = safe_get_modified_files(gold_patch)
        
        if not gold_files:
            return 0.0
        
        # Calculate overlap
        overlap = len(model_files & gold_files)
        return overlap / len(gold_files)
    except:
        return 0.0


def _detect_code_smells_count(patch: str) -> int:
    """Detect and count code smells in a patch."""
    if not patch:
        return 0
    
    added_lines = []
    for line in patch.split('\n'):
        if line.startswith('+') and not line.startswith('+++'):
            added_lines.append(line[1:])
    
    if not added_lines:
        return 0
    
    smell_patterns = [
        r'[a-zA-Z_]\w*\s*=\s*[a-zA-Z_]\w*\s*=',  # Multiple assignments
        r'\beval\s*\(',  # eval usage
        r'\bexec\s*\(',  # exec usage
        r'#\s*TODO|#\s*FIXME|#\s*HACK',  # Development comments
        r'\b[a-z]+\d+[a-z]*\b',  # Variable names with numbers
        r'[a-zA-Z_]\w*\s*=\s*[a-zA-Z_]\w*\s*\+\s*1',  # i = i + 1 instead of i += 1
        r'len\([^)]+\)\s*==\s*0',  # len(x) == 0 instead of not x
        r'["\'][^"\']*["\'].*==.*["\'][^"\']*["\']',  # String comparison
    ]
    
    smell_count = 0
    for line in added_lines:
        for pattern in smell_patterns:
            if re.search(pattern, line):
                smell_count += 1
                break  # Only count once per line
    
    return smell_count


def _count_error_handling(patch: str) -> int:
    """Count error handling constructs in a patch."""
    if not patch:
        return 0
    
    added_lines = []
    for line in patch.split('\n'):
        if line.startswith('+') and not line.startswith('+++'):
            added_lines.append(line[1:])
    
    if not added_lines:
        return 0
    
    error_handling_patterns = [
        r'\btry\s*:',
        r'\bexcept\s+\w+',
        r'\bexcept\s*:',
        r'\braise\s+\w+',
        r'\bassert\s+',
        r'\bif\s+.*\bis\s+None',
        r'\bif\s+.*\bis\s+not\s+None',
        r'\bif\s+not\s+\w+',
        r'\.get\s*\(',  # Using .get() for safe dict access
    ]
    
    error_handling_count = 0
    for line in added_lines:
        for pattern in error_handling_patterns:
            if re.search(pattern, line):
                error_handling_count += 1
                break  # Only count once per line
    
    return error_handling_count


def make_run_report(
    predictions: dict,
    full_dataset: list,
    run_id: str,
    client: Optional[docker.DockerClient] = None,
) -> Path:
    """
    Make a final evaluation and run report of the instances that have been run.
    Also reports on images and containers that may still running if client is provided.

    Args:
        predictions (dict): Predictions dict generated by the model
        full_dataset (list): List of all instances
        run_id (str): Run ID
        client (docker.DockerClient): Docker client (optional)

    Returns:
        Path to report file
    """
    # instantiate sets to store IDs of different outcomes
    completed_ids = set()
    resolved_ids = set()
    error_ids = set()
    unstopped_containers = set()
    unremoved_images = set()
    unresolved_ids = set()
    incomplete_ids = set()
    # get instances with empty patches
    empty_patch_ids = set()
    applied_patch_ids = set()

    # iterate through dataset and check if the instance has been run
    for instance in full_dataset:
        instance_id = instance[KEY_INSTANCE_ID]
        if instance_id not in predictions:
            # skip instances without predictions
            incomplete_ids.add(instance_id)
            continue
        prediction = predictions[instance_id]
        if prediction.get(KEY_PREDICTION, None) in ["", None]:
            empty_patch_ids.add(instance_id)
            continue
        report_file = (
            RUN_EVALUATION_LOG_DIR
            / run_id
            / prediction[KEY_MODEL].replace("/", "__")
            / prediction[KEY_INSTANCE_ID]
            / LOG_REPORT
        )
        if report_file.exists():
            # If report file exists, then the instance has been run
            completed_ids.add(instance_id)
            report = json.loads(report_file.read_text())
            if report[instance_id]["resolved"]:
                # Record if the instance was resolved
                resolved_ids.add(instance_id)
            else:
                unresolved_ids.add(instance_id)
            if report[instance_id]["patch_successfully_applied"]:
                applied_patch_ids.add(instance_id)
        else:
            # Otherwise, the instance was not run successfully
            error_ids.add(instance_id)

    if client:
        # get remaining images and containers
        images = list_images(client)
        test_specs = list(map(make_test_spec, full_dataset))
        for spec in test_specs:
            image_name = spec.instance_image_key
            if image_name in images:
                unremoved_images.add(image_name)
        containers = client.containers.list(all=True)
        for container in containers:
            if run_id in container.name:
                unstopped_containers.add(container.name)

    # Calculate comprehensive metrics
    comprehensive_metrics = calculate_comprehensive_metrics(predictions, full_dataset, run_id)

    # print final report
    dataset_ids = {i[KEY_INSTANCE_ID] for i in full_dataset}
    print(f"Total instances: {len(full_dataset)}")
    print(f"Instances submitted: {len(set(predictions.keys()) & dataset_ids)}")
    print(f"Instances completed: {len(completed_ids)}")
    print(f"Instances incomplete: {len(incomplete_ids)}")
    print(f"Instances resolved: {len(resolved_ids)}")
    print(f"Instances unresolved: {len(unresolved_ids)}")
    print(f"Instances with empty patches: {len(empty_patch_ids)}")
    print(f"Instances with applied patches: {len(applied_patch_ids)}")
    print(f"Instances with errors: {len(error_ids)}")
    
    # Print comprehensive metrics
    print("\n=== Comprehensive Metrics ===")
    print(f"Empty Patch Rate: {comprehensive_metrics['empty_patch_rate']:.2%}")
    print(f"Unresolved Rate: {comprehensive_metrics['unresolved_rate']:.2%}")
    print(f"Resolved Rate: {comprehensive_metrics['resolved_rate']:.2%}")
    print(f"Average Modified Lines: {comprehensive_metrics['avg_modified_lines']:.2f}")
    print(f"Average Modified Files: {comprehensive_metrics['avg_modified_files']:.2f}")
    print(f"Average Cyclomatic Complexity: {comprehensive_metrics['avg_cyclomatic_complexity']:.2f}")
    print(f"Localization Success Rate: {comprehensive_metrics['localization_success_rate']:.2%}")
    print(f"Average Code Smells: {comprehensive_metrics['avg_code_smells']:.2f}")
    print(f"Average Error Handling: {comprehensive_metrics['avg_error_handling']:.2f}")
    
    if client:
        print(f"Unstopped containers: {len(unstopped_containers)}")
        print(f"Unremoved images: {len(unremoved_images)}")

    # write report to file
    report = {
        "total_instances": len(full_dataset),
        "submitted_instances": len(predictions),
        "completed_instances": len(completed_ids),
        "resolved_instances": len(resolved_ids),
        "unresolved_instances": len(unresolved_ids),
        "empty_patch_instances": len(empty_patch_ids),
        "applied_patch_instances": len(applied_patch_ids),
        "error_instances": len(error_ids),
        # Add comprehensive metrics
        "comprehensive_metrics": comprehensive_metrics,
        "completed_ids": list(sorted(completed_ids)),
        "incomplete_ids": list(sorted(incomplete_ids)),
        "empty_patch_ids": list(sorted(empty_patch_ids)),
        "applied_patch_ids": list(sorted(applied_patch_ids)),
        "submitted_ids": list(sorted(predictions.keys())),
        "resolved_ids": list(sorted(resolved_ids)),
        "unresolved_ids": list(sorted(unresolved_ids)),
        "error_ids": list(sorted(error_ids)),
        # Add empty and unresolved patch lists from comprehensive metrics
        "empty_patch_list": comprehensive_metrics.get('empty_patch_list', []),
        "unresolved_patch_list": comprehensive_metrics.get('unresolved_patch_list', []),
        "schema_version": 2,
    }
    if client:
        report.update(
            {
                "unstopped_instances": len(unstopped_containers),
                "unstopped_containers": list(sorted(unstopped_containers)),
                "unremoved_images": list(sorted(unremoved_images)),
            }
        )
    report_file = Path(
        list(predictions.values())[0][KEY_MODEL].replace("/", "__")
        + f".{run_id}"
        + ".json"
    )
    with open(report_file, "w") as f:
        print(json.dumps(report, indent=4), file=f)
    print(f"Report written to {report_file}")
    return report_file



