import re
import json
import argparse
import logging
from lean_verifier import LeanVerifier

# Logging configuration
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    filename="logs/fix_unused.log",
    filemode="w",
)

def parse_warnings(response):
    """Get warnings from LeanREPL response"""
    warnings = []
    
    try:
        if response is not None and 'response' in response and 'messages' in response['response']:
            for msg in response['response']['messages']:
                if msg['severity'] == 'warning' and 'unused variable' in msg['data']:
                    # parse unused variable warning
                    match = re.search(r'unused variable `([^`]+)`', msg['data'])
                    if match:
                        var_name = match.group(1)
                        warnings.append({
                            'line': msg['pos']['line'],
                            'column': msg['pos']['column'],
                            'end_column': msg['endPos']['column'],
                            'var_name': var_name
                        })
    except Exception as e:
        logging.error(f"Error parsing warnings: {e} \n response: {response}")
        warnings = []
    
    return warnings

def find_column_index(line, target_col):
    """Find the column index of the target_col in the line, ignoring visual width"""
    current_col = 0
    for idx, char in enumerate(line):
        if current_col >= target_col:
            return idx
        current_col += 1
    return len(line)


def fix_unused_variables(statement, warnings):
    """Fix unused variables in the statement, replace them with underscores"""
    lines = statement.split('\n')
    
    # sort by line number first, then by column position in descending order
    # to avoid modifying the column position of subsequent warnings
    for warning in sorted(warnings, key=lambda w: (w['line'], w['column']), reverse=True):
        line_num = warning['line'] - 1  # convert to 0-indexed
        col_num = warning['column'] - 1
        end_col_num = warning['end_column']
        var_name = warning['var_name']
        
        if line_num < len(lines):
            line = lines[line_num]
            actual_col = find_column_index(line, col_num)
            actual_end_col = find_column_index(line, end_col_num)
                        
            # Find the complete variable name starting from the position
            target_part = line[actual_col:actual_end_col]
            new_target_part = target_part.replace(var_name, '_', 1)
            new_line = line[:actual_col] + new_target_part + line[actual_end_col:]
            if new_target_part == target_part:
                logging.warning(f"Failed to replace {var_name} with '_' at line {line_num+1}, column {col_num}, content={target_part}")
            lines[line_num] = new_line
    
    return '\n'.join(lines)

def fix_unused_data(data):
    total_statements = []
    idxs = []
    for item in data:
        imports_idx = (0, item["imports"].count("\n"))
        if item.get("use_full", False) == False:
            statement = (
                item["imports"] + "\n" + item["formal_theorem"].replace("sorry", "by" + "\n" + item["formal_proof"])
            )
            statement_idx = (imports_idx[1] + 1, imports_idx[1] + 1 + item["formal_theorem"].count("\n"))
        else:
            statement = (
                item["imports"] + "\n" + item["full_formal_theorem"].replace("sorry", "by" + "\n" + item["formal_proof"])
            )
            statement_idx = (imports_idx[1] + 1, imports_idx[1] + 1 + item["full_formal_theorem"].count("\n"))
        proof_idx = (statement_idx[1] + 1, statement_idx[1] + 1 + item["formal_proof"].count("\n"))
        idxs.append((imports_idx, statement_idx, proof_idx))
        total_statements.append(statement)
    
    with LeanVerifier() as verifier:
        # verify the initial statement
        results = verifier.verify_batch(total_statements, timeout=60, use_tqdm=True)

    total_results = []
    for item, idx, statement, response in zip(data, idxs, total_statements, results):
            
        if response and len(response) > 0:
            warnings = parse_warnings(response)
            
            if warnings:
                # fix unused variables
                fixed_statement = fix_unused_variables(statement, warnings)
                # split the fixed statement into imports, statement, and proof
                imports_idx, statement_idx, proof_idx = idx
                fixed_steps = fixed_statement.split("\n")
                imports = "\n".join(fixed_steps[imports_idx[0]:imports_idx[1]+1])
                statement = "\n".join(fixed_steps[statement_idx[0]:statement_idx[1]+1])
                proof = "\n".join(fixed_steps[proof_idx[0]:proof_idx[1]+1])
                
                new_fixed_statement = imports + "\n" + statement + "\n" + proof
                if fixed_statement != new_fixed_statement:
                    logging.warning(f"fixed_statement: {fixed_statement} != {new_fixed_statement}")
                    logging.warning(f"Original statement:\n{statement}\n" + "="*50 + "\n")
                    logging.warning(f"Verification result:\n{response}\n" + "="*50 + "\n")
                    logging.warning(f"\nFound {len(warnings)} unused variable warnings:")
                    for warning in warnings:
                        logging.warning(f"  Line {warning['line']}, column {warning['column']}: {warning['var_name']}")                    
                    logging.warning(f"\nFixed statement:\n{fixed_statement}\n" + "="*50 + "\n")

                if item.get("use_full", False) == False:
                    result = {
                        "name": item["name"],
                        "imports": imports,
                        "formal_theorem": statement.rsplit("by", 1)[0] + "sorry",
                        "full_formal_theorem": item["full_formal_theorem"],
                        "formal_proof": proof,
                        "use_full": False
                    }
                else:
                    result = {
                        "name": item["name"],
                        "imports": imports,
                        "formal_theorem": item["formal_theorem"],
                        "full_formal_theorem": statement,
                        "formal_proof": proof,
                        "use_full": True
                    }
                total_results.append(result)
                
                # verify the fixed statement
                # logging.info("\nVerified fixed statement:")
                # fixed_response = verifier.verify_batch([fixed_statement], timeout=60)
                # logging.info(f"{fixed_response}\n" + "="*50 + "\n")
            else:
                total_results.append(item)

    return total_results


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_file", type=str, required=True, help="Input file path")
    parser.add_argument("--output_file", type=str, required=True, help="Output file path")
    parser.add_argument("--num_theorems", type=int, default=10, help="Number of theorems to convert")
    parser.add_argument("--batch_size", type=int, default=10000, help="Number of batch processed")
    args = parser.parse_args()
        
    with open(args.input_file, "r") as f:
        data = json.load(f)
    args.num_theorems = len(data) if args.num_theorems == -1 else args.num_theorems
    data = data[0:args.num_theorems]
    
    total_results = []
    for i in range(0, len(data), args.batch_size):
        print(f"Fixing batch {i//args.batch_size + 1} of {len(data) // args.batch_size + 1}")
        total_results.extend(fix_unused_data(data[i:i+args.batch_size]))
    
    print(f"Total results: {len(total_results)}")
    
    with open(args.output_file, "w") as f:
        json.dump(total_results, f, indent=2, ensure_ascii=False)


if __name__ == "__main__":
    main()
