#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
This script checks the syntax of all Python files generated by composition.py
using flake8 with only syntax-related errors (E and F categories).
It first tries to automatically fix formatting issues using autopep8.
If autopep8 can fix the code, it will be updated in place and treated as valid code.
Only code that cannot be fixed by autopep8 will be considered invalid.
All reports, logs, and plots will be saved in ../../logs.
"""

import os
import subprocess
from pathlib import Path
from typing import List, Dict, Tuple, Counter
from tqdm import tqdm
import json
from datetime import datetime
import logging
import sys
import shutil
import difflib
import re
import matplotlib.pyplot as plt
from collections import Counter

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

# Define constants for paths
LOGS_DIR = "../../logs"
COMPOSITION_DIR = "../../DfyAutoSpec_wtc_chain300"
ERROR_DIR = "../../logs/bugs"

# Ensure logs directory exists
os.makedirs(LOGS_DIR, exist_ok=True)

# Setup logging configuration
def setup_logging() -> logging.Logger:
    """
    Setup logging configuration with timestamp in filename.
    
    Returns:
        Logger instance
    """
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    log_file = os.path.join(LOGS_DIR, f"syntax_check.log")
    
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(log_file, encoding='utf-8'),
            logging.StreamHandler(sys.stdout)
        ]
    )
    return logging.getLogger(__name__)

logger = setup_logging()

def generate_diff_log(original_content: str, fixed_content: str, log_path: str) -> None:
    """
    Generate a diff log file showing the changes made to the code.
    
    Args:
        original_content: Original code content
        fixed_content: Fixed code content
        log_path: Path to save the diff log
    """
    # Generate unified diff
    diff = difflib.unified_diff(
        original_content.splitlines(keepends=True),
        fixed_content.splitlines(keepends=True),
        fromfile='original',
        tofile='fixed',
        n=3  # Show 3 lines of context
    )
    
    # Write diff to log file
    with open(log_path, 'w', encoding='utf-8') as f:
        f.write(f"Code fix time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        f.write("Changes made by autopep8:\n")
        f.write("="*80 + "\n")
        f.writelines(diff)
        f.write("\n" + "="*80 + "\n")

def fix_with_autopep8(file_path: str) -> Tuple[bool, str]:
    """
    Try to fix formatting issues in a file using autopep8.
    If the file is fixed, generate a fix.log to record the changes.
    Only generate fix.log for files that had syntax errors and were successfully fixed.
    
    Args:
        file_path: Path to the Python file to fix
        
    Returns:
        Tuple of (success, message)
        - success: Whether the file was successfully fixed
        - message: Details about the fix or error message
    """
    try:
        # Read original content
        with open(file_path, 'r', encoding='utf-8') as f:
            original_content = f.read()
        
        # Run autopep8
        result = subprocess.run(
            ['autopep8', '--in-place', '--aggressive', '--aggressive', file_path],
            capture_output=True,
            text=True
        )
        
        if result.returncode != 0:
            return False, f"Error running autopep8: {result.stderr}"
        
        # Read fixed content
        with open(file_path, 'r', encoding='utf-8') as f:
            fixed_content = f.read()
        
        # Check if any changes were made
        if original_content != fixed_content:
            # Verify the fix
            verify_result = subprocess.run(
                ['flake8', '--select=E,F', file_path],
                capture_output=True,
                text=True
            )
            
            if verify_result.returncode == 0:
                # Generate fix.log only if the file was successfully fixed
                fix_log_path = os.path.join(os.path.dirname(file_path), "fix.log")
                with open(fix_log_path, 'w', encoding='utf-8') as f:
                    f.write(f"Fix time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
                    f.write(f"File: {os.path.basename(file_path)}\n")
                    f.write("\nChanges made:\n")
                    f.write("="*80 + "\n")
                    for line in difflib.unified_diff(
                        original_content.splitlines(),
                        fixed_content.splitlines(),
                        fromfile='original',
                        tofile='fixed',
                        lineterm=''
                    ):
                        f.write(line + '\n')
                return True, f"Successfully fixed. Changes recorded in fix.log"
            else:
                # If still has errors, restore original content
                with open(file_path, 'w', encoding='utf-8') as f:
                    f.write(original_content)
                return False, f"Fix did not resolve all issues: {verify_result.stdout}"
        
        return False, "No changes needed or no changes made"
        
    except Exception as e:
        return False, f"Error during fix: {str(e)}"

def run_flake8(file_path: str) -> Tuple[bool, str]:
    """
    Run flake8 on a Python file, but only check for errors that affect code execution.
    Ignore style warnings like line length (E501) and undefined names (F821) that don't affect runtime.
    
    Args:
        file_path: Path to the Python file to check
        
    Returns:
        Tuple of (is_valid, error_msg)
        - is_valid: True if no critical errors found
        - error_msg: Error message if invalid
    """
    try:
        # Run flake8 with specific error codes that affect code execution
        # E901: SyntaxError or IndentationError
        # E902: IOError
        # F406: 'from module import *' only allowed at module level
        # F407: an undefined __future__ feature name was imported
        # F632: use ==/!= to compare str, bytes, and int literals
        # F704: a yield or yield from statement outside a function
        # F706: a return statement outside a function
        # F707: an except: block as not the last exception handler
        # F821: undefined name (only if it's a NameError)
        # F822: undefined name in __all__
        # F823: local variable referenced before assignment
        # F831: duplicate argument name in function definition
        # F901: raise NotImplemented should be raise NotImplementedError
        # F902: invalid first argument given to super()
        result = subprocess.run([
            'flake8',
            '--select=E901,E902,F406,F407,F632,F704,F706,F707,F821,F822,F823,F831,F901,F902',
            '--ignore=E501,F401,F402,F403,F404,F405,F601,F602,F621,F622,F631,F633,F634,F701,F702,F703,F705,F811,F812,F841,F842,F903,F904,F905,F906,F907,F908,F909',
            file_path
        ], capture_output=True, text=True)
        
        if result.returncode == 0:
            return True, ""
        else:
            return False, result.stdout.strip()
            
    except Exception as e:
        return False, f"Error running flake8: {str(e)}"

def move_invalid_code(file_path: str, error_msg: str) -> None:
    """
    Move the entire directory containing invalid code to error directory,
    and save error information to err.log in the moved directory.
    Only generate err.log for files that have syntax errors and cannot be fixed.
    
    Args:
        file_path: Path to the invalid Python file
        error_msg: Error message from flake8 check
    """
    try:
        # Create error directory if it doesn't exist
        os.makedirs(ERROR_DIR, exist_ok=True)
        
        # Get the directory containing the invalid file
        comp_dir = os.path.dirname(file_path)
        dir_name = os.path.basename(comp_dir)
        
        # Create error subdirectory with the same name
        error_subdir = os.path.join(ERROR_DIR, dir_name)
        
        # If error subdirectory already exists, add a timestamp to make it unique
        if os.path.exists(error_subdir):
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            error_subdir = f"{error_subdir}_{timestamp}"
        
        # Move the entire directory to error directory
        shutil.move(comp_dir, error_subdir)
        
        # Generate err.log only for files with syntax errors
        err_log_path = os.path.join(error_subdir, "err.log")
        with open(err_log_path, 'w', encoding='utf-8') as f:
            f.write(f"Error check time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
            f.write(f"File: {os.path.basename(file_path)}\n")
            f.write(f"Error message:\n{error_msg}\n")
        
        logger.info(f"Moved directory to: {error_subdir}")
        logger.info(f"Error information saved to: {err_log_path}")
    except Exception as e:
        logger.error(f"Error moving directory: {str(e)}")

def check_and_fix_file(file_path: str) -> Tuple[bool, str, bool]:
    """
    Check a file and try to fix formatting issues if needed.
    Only check for errors that affect code execution.
    Style issues like line length and undefined names (that don't cause runtime errors) are ignored.
    
    Args:
        file_path: Path to the Python file to check and fix
        
    Returns:
        Tuple of (is_valid, message, was_fixed)
        - is_valid: True if file is valid (either originally valid or successfully fixed)
        - message: Status message or error message if invalid
        - was_fixed: True if autopep8 made changes to fix the file
    """
    print(f"\n{'='*80}\nProcessing file: {file_path}\n{'='*80}")
    
    # First check if the file has any critical issues
    is_valid, error_msg = run_flake8(file_path)
    
    if is_valid:
        return True, "File is valid (no critical errors)", False
        
    # If there are critical issues, try to fix them with autopep8
    print(f"Found critical issues in {file_path}, attempting to fix...")
    was_fixed, fix_message = fix_with_autopep8(file_path)
    
    if was_fixed:
        # File was successfully fixed and verified
        return True, fix_message, True
    
    # If autopep8 couldn't fix the critical issues, move to error directory
    move_invalid_code(file_path, error_msg)
    return False, error_msg, False

def parse_error_types(error_msg: str) -> List[str]:
    """
    Parse error messages to extract error types.
    
    Args:
        error_msg: Error message from flake8
        
    Returns:
        List of error types (e.g., ['E201', 'F401'])
    """
    # Extract error codes using regex
    error_pattern = r'([EF]\d{3})'
    return re.findall(error_pattern, error_msg)

def analyze_error_types(error_messages: List[str]) -> Dict[str, int]:
    """
    Analyze and count error types from error messages.
    
    Args:
        error_messages: List of error messages
        
    Returns:
        Dictionary mapping error types to their counts
    """
    error_types = []
    for msg in error_messages:
        error_types.extend(parse_error_types(msg))
    return dict(Counter(error_types))

def plot_error_distribution(error_counts: Dict[str, int], output_file: str) -> str:
    """
    Generate a bar plot of error type distribution.
    
    Args:
        error_counts: Dictionary mapping error types to their counts
        output_file: Full path to save the plot
        
    Returns:
        Path to the generated plot
    """
    if not error_counts:
        return ""
    
    # Sort error types by count
    sorted_errors = sorted(error_counts.items(), key=lambda x: x[1], reverse=True)
    error_types, counts = zip(*sorted_errors)
    
    # Create figure and axis
    plt.figure(figsize=(12, 6))
    bars = plt.bar(error_types, counts)
    
    # Customize the plot
    plt.title('Distribution of Syntax Error Types')
    plt.xlabel('Error Type')
    plt.ylabel('Count')
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    
    # Add count labels on top of bars
    for bar in bars:
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., height,
                f'{int(height)}',
                ha='center', va='bottom')
    
    # Save the plot
    plt.savefig(output_file, dpi=300, bbox_inches='tight')
    plt.close()
    
    return output_file

def check_all_compositions() -> Dict:
    """
    Check syntax of all Python files in the composition directory.
    Files that can be fixed by autopep8 will be updated in place and treated as valid.
    Only files that cannot be fixed will be moved to error directory.
    
    Returns:
        Dictionary containing check results and statistics
    """
    # Convert to absolute path
    base_path = Path(os.path.abspath(COMPOSITION_DIR))
    
    # Find all Python files
    py_files = list(base_path.rglob("*.py"))
    
    if not py_files:
        print(f"No Python files found in {COMPOSITION_DIR}")
        return {}
    
    print(f"Found {len(py_files)} Python files to check")
    
    # Initialize results
    results = {
        "timestamp": datetime.now().isoformat(),
        "total_files": len(py_files),
        "valid_files": 0,
        "invalid_files": 0,
        "auto_fixed_files": 0,
        "invalid_files_list": [],
        "auto_fixed_files_list": [],
        "error_types": {}
    }
    
    # Collect all error messages for analysis
    all_error_messages = []
    
    # Check each file
    for file_path in tqdm(py_files, desc="Checking files"):
        is_valid, message, was_fixed = check_and_fix_file(str(file_path))
        
        # Get relative path for reporting
        rel_path = file_path.relative_to(base_path)
        rel_path_str = str(rel_path)
        
        if is_valid:
            results["valid_files"] += 1
            if was_fixed:
                results["auto_fixed_files"] += 1
                results["auto_fixed_files_list"].append({
                    "file": rel_path_str,
                    "message": message
                })
        else:
            results["invalid_files"] += 1
            results["invalid_files_list"].append({
                "file": rel_path_str,
                "error": message
            })
            all_error_messages.append(message)
    
    # Analyze error types
    if all_error_messages:
        results["error_types"] = analyze_error_types(all_error_messages)
    
    return results

def print_summary(results: Dict):
    """
    Print a summary of the check results.
    
    Args:
        results: Dictionary containing check results
    """
    print("\n=== Syntax Check Summary ===")
    print(f"Total files checked: {results['total_files']}")
    print(f"Valid files: {results['valid_files']} (including {results['auto_fixed_files']} auto-fixed files)")
    print(f"Invalid files: {results['invalid_files']}")
    
    if results['auto_fixed_files'] > 0:
        print("\nAuto-fixed files (check fix.log for changes):")
        for file_info in results['auto_fixed_files_list']:
            print(f"- {file_info['file']}")
            print(f"  {file_info['message']}")
    
    if results['invalid_files'] > 0:
        print("\nInvalid files (could not be auto-fixed):")
        for file_info in results['invalid_files_list']:
            print(f"\n{file_info['file']}:")
            print(f"Error: {file_info['error']}")
        
        # Print error type statistics
        if results['error_types']:
            print("\nError Type Distribution:")
            for error_type, count in sorted(results['error_types'].items(), key=lambda x: x[1], reverse=True):
                print(f"{error_type}: {count} occurrences")

def save_report(results: Dict) -> None:
    """
    Save the check results to a JSON file and generate error distribution plot.
    All files will be saved in the logs directory with timestamp.
    
    Args:
        results: Dictionary containing check results
    """
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    # Save JSON report
    json_file = os.path.join(LOGS_DIR, f"syntax_check_report.json")
    with open(json_file, 'w', encoding='utf-8') as f:
        json.dump(results, f, indent=2, ensure_ascii=False)
    
    logger.info(f"\nReport saved to {json_file}")
    
    # Generate and save error distribution plot
    if results.get('error_types'):
        plot_file = os.path.join(LOGS_DIR, f"error_distribution.png")
        plot_path = plot_error_distribution(results['error_types'], plot_file)
        if plot_path:
            logger.info(f"Error distribution plot saved to {plot_path}")

def main():
    """
    Main function to run the syntax check.
    Only checks for errors that affect code execution.
    Style issues like line length and undefined names (that don't cause runtime errors) are ignored.
    """
    print("Starting syntax check of generated Python files...")
    print("Note: Only checking for errors that affect code execution.")
    print("Style issues like line length and undefined names (that don't cause runtime errors) are ignored.")
    print("Files that can be fixed by autopep8 will be updated in place and treated as valid.")
    print("A fix.log will be generated for each fixed file to record the changes.")
    print(f"All reports and logs will be saved in: {LOGS_DIR}")
    
    # Check if required tools are installed
    try:
        subprocess.run(['flake8', '--version'], capture_output=True, check=True)
        subprocess.run(['autopep8', '--version'], capture_output=True, check=True)
        import matplotlib
    except (subprocess.CalledProcessError, FileNotFoundError) as e:
        print("Error: Required tools are not installed. Please install them using:")
        print("pip install flake8 autopep8 matplotlib")
        return
    
    # Run the check
    results = check_all_compositions()
    
    if not results:
        return
    
    # Save and print results
    save_report(results)
    print_summary(results)
    
    # Print additional information about moved files
    if results['invalid_files'] > 0:
        print(f"\nOnly files that could not be auto-fixed have been moved to {ERROR_DIR}")
        print("Each directory contains both Python files and metadata files")
        print(f"Error distribution plot has been generated in {LOGS_DIR}")
    if results['auto_fixed_files'] > 0:
        print("\nFixed files have been updated in place with fix.log recording the changes")

if __name__ == "__main__":
    main() 