import json
from tqdm import tqdm
import os
import sys
import re
import time
from collections import defaultdict
from typing import Dict, List, Union, Any, Tuple,Set

def read_jsonl(path):
    res = []
    with open(path, "r") as f:
        for line in tqdm(f.readlines()):
            res.append(json.loads(line))
        return res

def write_jsonl(data_to_write, path, mode):
    with open(path, mode, encoding="utf-8") as f:
        for x in tqdm(data_to_write):
            line = json.dumps(x, ensure_ascii=False)
            f.write(line + '\n')

def write_res(path, batch_res):
    with open(path,"a+", encoding='utf-8') as f:
        for r in batch_res:
      #      print(r[0])
            try:
                f.write(json.dumps(r, ensure_ascii=False)+"\n")
            except:
                try:
                    f.write(json.dumps(r, ensure_ascii=True)+"\n")
                except:
                    raise ValueError
                
def check_brackets_balance(text: str) -> bool:
        """Check bracket balance"""
        stack = []
        pairs = {'(': ')', '[': ']', '{': '}'}
        temp = text.replace('/-','(').replace('-/',')')
        for char in temp:
            if char in pairs:
                stack.append(char)
            elif char in pairs.values():
                if not stack or pairs[stack.pop()] != char:
                    return False
   #     print(len(stack))
        return len(stack) == 0

def extract_imports(code: str) -> Set[str]:
    """Extract import statements from code"""
    import_pattern = r'^\s*import\s+([A-Za-z0-9_\.]+)'
    imports = set()
    
    for line in code.split('\n'):
        match = re.match(import_pattern, line)
        if match:
            if not 'markers' in match.group(0).strip():
                imports.add(match.group(0).strip())
            
    return imports

def group_by_imports(statements) -> Dict[frozenset, List[str]]:
    """Group by import libraries"""
    groups = {}
    for statement in statements:
        s = statement['statement']
        imports = extract_imports(s)
        key = frozenset(imports)
        if key not in groups:
            groups[key] = []
        groups[key].append(statement)
    return groups

def combine_lean4_statements_simple(statements: List[str]) -> str:
    """Combine multiple Lean 4 theorem statements into a single file"""
    # Extract all import statements
    all_imports = set()
    for stmt in statements:
        all_imports.update(extract_imports(stmt))
    
    # Remove import statements from statements
    cleaned_statements = []
    for stmt in statements:
        cleaned_stmt = stmt
        for import_stmt in all_imports:
            cleaned_stmt = re.sub(re.escape(import_stmt) + r'\s*\n', '', cleaned_stmt)
        cleaned_statements.append(cleaned_stmt.strip())
    
    # Build merged code
    current_date = time.strftime("%Y-%m-%d")
    
    # Header and imports
    result = [
        f"-- Lean 4 syntax verification file",
        f"-- Auto-generated on {current_date}",
        "",
        "-- Shared import area"
    ]
    
    # Add all imports
    if all_imports:
        result.extend(sorted(all_imports))
    else:
        result.append("-- No import statements")
    
    result.append("")
    
    # Add each statement, isolated by namespace
    for i, stmt in enumerate(cleaned_statements, 1):
        result.extend([
            f"-- Statement {i}",
            f"namespace Statement{i}",
            stmt,
            f"  #eval \"Statement {i} syntax check passed\"",
            f"end Statement{i}",
            ""
        ])
    
    # Add verification summary
    result.append("-- Verification summary")
    result.append("#eval \"All statements syntax check completed\"")
    
    return "\n".join(result)

def analyze_lean4_results(result_json: Dict[str, Any]) -> List[bool]:
    """
    Analyze Lean 4 execution results, return syntax check results for each code block.
    
    Args:
        result_json: JSON object of Lean 4 execution results
    
    Returns:
        List[bool]: Syntax check results for each code block, True means passed, False means failed
    """
    # Get error information
    errors = result_json.get('info', {}).get('errors', [])
   # errors = result_json.get('errors', [])
   # print(errors)
    # Get line number ranges for code blocks
    block_ranges = get_block_ranges(result_json)
    
    # Initialize result list (default all code blocks pass)
    results = [True] * len(block_ranges)
    try:
        sys_errors = result_json['info']['system_errors']
    except:
        return [None] * len(block_ranges)



    if sys_errors:
        return [None] * len(block_ranges)
    # Process error information
    for error in errors:
        # Get error position
        error_line = error.get('pos', {}).get('line', 0)
        
        # Determine which code block the error is in
        for i, (start, end) in enumerate(block_ranges):
            if start <= error_line <= end:
                results[i] = False
                break
#     warnings = result_json.get('info', {}).get('warnings', [])
#     # Process warning information
#     for warning in warnings:
#    #     print(warning['data'])
#         # Get error position
#         warning_line = warning.get('pos', {}).get('line', 0)
        
#         # Determine which code block the error is in
#         for i, (start, end) in enumerate(block_ranges):
#             if start <= warning_line <= end and "declaration uses 'sorry'" in warning['data']:
#                 results[i] = False
#                 break
    return results

def get_block_ranges(result_json: Dict[str, Any]) -> List[Tuple[int, int]]:
    """
    Get line number ranges for code blocks.
    
    Args:
        result_json: JSON object of Lean 4 execution results
    
    Returns:
        List[Tuple[int, int]]: Line number ranges for each code block (start_line, end_line)
    """
    # Get verified code
    # verified_code = result_json['info'].get('verified_code', '')
    verified_code = result_json.get('verified_code', '')
    lines = verified_code.split('\n')
    # print(lines)
    # Find start and end lines for each namespace
    block_ranges = []
    start_line = 0
    block_number = 0
    
    for i, line in enumerate(lines, 1):
        if line.startswith('namespace Statement'):
            # Start new code block
            start_line = i
            block_number = int(line.split('namespace Statement', 1)[1].strip())
        elif line.startswith('end Statement'):
            # End current code block
            end_line = i
            block_ranges.append((start_line, end_line))
    
    return block_ranges

def basic_check(s):
    return check_brackets_balance(s) and bool(re.search(r'by\s+sorry', s))


def top_1_hit(groups):
    """
    Top-1 hit: The proportion where statements with golden_label=True are judged as True, 
    and all statements with golden_label=False in that group are judged as False.
    """
    hit = 0
    for group in groups:
        origin = [item for item in group if item['golden_label'] is True]
        
        perturb = [item for item in group if item['golden_label'] is False]
        # Ensure there's only one original statement
        if len(origin) != 1:
            continue
        if origin[0]['label'] is True and all(x['label'] is False for x in perturb):
            hit += 1
    return hit / len(groups) if groups else 0.0

def precision(groups):
    """
    precision = Number of original statements judged as True / Total number of statements judged as True
    """
    true_pred_origin = 0
    true_pred_total = 0
    for group in groups:
        for item in group:
            if item['label'] is True:
                true_pred_total += 1
                if item['golden_label'] is True:
                    true_pred_origin += 1
    return true_pred_origin / true_pred_total if true_pred_total else 0.0

def recall(groups):
    """
    recall = Number of original statements judged as True / Total number of original statements
    """
    true_pred_origin = 0
    total_origin = 0
    for group in groups:
        for item in group:
            if item['golden_label'] is True:
                total_origin += 1
                if item['label'] is True:
                    true_pred_origin += 1
    return true_pred_origin / total_origin if total_origin else 0.0

def group_by_id(statements):
    """
    Input: statements: List[dict], each dict has 'id' field
    Output: groups: List[List[dict]], each group contains all statements with the same id
    """
    groups_dict = defaultdict(list)
    for item in statements:
        groups_dict[item['id']].append(item)
    # Convert to list of list
    groups = list(groups_dict.values())
    return groups

def acc(groups):
    """
    ACC = (TP + TN) / (TP + TN + FP + FN)
    """
    TP = TN = FP = FN = 0
    for group in groups:
        for item in group:
            pred = item['label']
            gold = item['golden_label']
            if gold is True and pred is True:
                TP += 1
            elif gold is False and pred is False:
                TN += 1
            elif gold is False and pred is True:
                FP += 1
            elif gold is True and pred is False:
                FN += 1
    total = TP + TN + FP + FN
    return (TP + TN) / total if total else 0.0

def tpr(groups):
    """
    TPR = TP / (TP + FN)
    """
    TP = FN = 0
    for group in groups:
        for item in group:
            pred = item['label']
            gold = item['golden_label']
            if gold is True and pred is True:
                TP += 1
            elif gold is True and pred is False:
                FN += 1
    return TP / (TP + FN) if (TP + FN) else 0.0

def fpr(groups):
    """
    FPR = FP / (FP + TN)
    """
    FP = TN = 0
    for group in groups:
        for item in group:
            pred = item['label']
            gold = item['golden_label']
            if gold is False and pred is True:
                FP += 1
            elif gold is False and pred is False:
                TN += 1
    return FP / (FP + TN) if (FP + TN) else 0.0

def tnr(groups):
    """
    TNR = TN / (TN + FP)
    """
    TN = FP = 0
    for group in groups:
        for item in group:
            pred = item['label']
            gold = item['golden_label']
            if gold is False and pred is False:
                TN += 1
            elif gold is False and pred is True:
                FP += 1
    return TN / (TN + FP) if (TN + FP) else 0.0

def fnr(groups):
    """
    FNR = FN / (TP + FN)
    """
    TP = FN = 0
    for group in groups:
        for item in group:
            pred = item['label']
            gold = item['golden_label']
            if gold is True and pred is False:
                FN += 1
            elif gold is True and pred is True:
                TP += 1
    return FN / (TP + FN) if (TP + FN) else 0.0
