import os
import json
import pandas as pd
from sqlglot import parse_one, exp

def get_gold_columns_only(sql: str, dataset: str, data_dir: str):
    data_dir = data_dir
    dataset = dataset
    table_path = f"{data_dir}/tables.json"
    try:
        with open(table_path, "r") as f:
            tables = json.load(f)
    except Exception as e:
        raise ValueError(f"Error in loading tables.json: {e}")
    column_names = [column[1] for db in tables if db['db_id'] == dataset for column in db['column_names_original']]
    try:
        # Parse the SQL query
        parsed = parse_one(sql, read='sqlite', error_level='ignore')#parse_one(sql)
        
        # Extract all column references
        columns = parsed.find_all(exp.Column,)
        # Get unique column names, ignoring table names
        column_list = list(set(col.name for col in columns))
        column_list = [column for column in column_list if not column.endswith('id') and column in column_names]
        return column_list
    except Exception as e:
        #print(e)
        return []
def check_column_match(pred_sql: str, gold_sql:str, dataset: str, data_dir: str):
    
    for i in range(10):
        pred_sql = pred_sql.replace("\n", " ")
        pred_sql = pred_sql.replace("  ", " ")
    
    # Initialize variables with default values
    pred_columns = []
    gold_columns = []
    is_correct = False
    
    try:
        pred_columns = get_gold_columns_only(pred_sql, dataset, data_dir)
        gold_columns = get_gold_columns_only(gold_sql, dataset, data_dir)
        if len(set(gold_columns) - set(pred_columns)) == 0:
            is_correct = True
        else:
            is_correct = False
    except Exception as e:
        print(f"Warning: Error in column matching: {e}")
        is_correct = False
    
    result = {
        "is_correct": is_correct,
        "pred_columns": pred_columns,
        "gold_columns": gold_columns
    }
    return result

def parse_comma_separated_list(value):
    """Parse a comma-separated string into a list of strings."""
    if isinstance(value, list):
        return value
    return [item.strip() for item in value.split(',')]


from tqdm import tqdm
import argparse

# Argument parsing
parser = argparse.ArgumentParser(description="Evaluate SQL correction results for recall of non-answerable questions.")
parser.add_argument("--model_name", type=str, required=True, help="Name of the model used for SQL generation.")
parser.add_argument("--dataset_name_list", type=parse_comma_separated_list, default=["mimic_iv"], help="Comma-separated list of dataset names.")
parser.add_argument("--corrector_name", type=str, required=True, help="Name of the SQL corrector.")
parser.add_argument("--correction_output_dir", type=str, default="/nfs_edlab/wschay/sql-error-detection/sql-error-detection/output", help="Directory for correction output JSON files.")
parser.add_argument("--granular_evaluation", action="store_true", help="Enable granular evaluation for sub-types.")
parser.add_argument("--data_dir", type=str, default="databases", help="Directory for database files.")
args = parser.parse_args()


is_subset = ""
# is_subset = "_subset"


# Initialize recall_metrics for this corrector
recall_metrics = {
    "vague-question": [0, 0],  # [correct_abstentions, total_instances]
    "vague-word": [0, 0],
    "ambiguous-reference": [0, 0],
    "small-talk": [0, 0],
    "out-of-scope": [0, 0],
    "missing-column": [0, 0],
    "infeasible-faq": [0, 0]
}

for dataset in args.dataset_name_list:

    correction_path = os.path.join(args.correction_output_dir, f"{args.corrector_name}_{dataset}_{args.model_name}.json")
    
    # Also check for alternative file naming patterns
    if not os.path.exists(correction_path):
        # Try with _subset suffix
        correction_path_subset = os.path.join(args.correction_output_dir, f"{args.corrector_name}_{dataset}_{args.model_name}_subset.json")
        if os.path.exists(correction_path_subset):
            correction_path = correction_path_subset
    
    # # Try with tool_agent prefix
    # if not os.path.exists(correction_path):
    #     correction_path_tool = os.path.join(args.correction_output_dir, f"{corrector_name}_tool_agent_{dataset}_{args.model_name}_subset.json")
    #     if os.path.exists(correction_path_tool):
    #         correction_path = correction_path_tool
    
    def safe_json_load(file_path):
        """Safely load JSON file with error handling"""
        try:
            with open(file_path, 'r') as f:
                return json.load(f)
        except json.JSONDecodeError as e:
            print(f"Warning: JSON decode error in {file_path}: {e}")
            print("Attempting to fix the JSON file...")
            
            # Try to read and fix the JSON file
            with open(file_path, 'r') as f:
                content = f.read()
            
            # Common JSON fixes
            # Remove trailing commas before closing brackets/braces
            import re
            content = re.sub(r',(\s*[}\]])', r'\1', content)
            
            try:
                return json.loads(content)
            except json.JSONDecodeError as e2:
                print(f"Error: Could not fix JSON file {file_path}: {e2}")
                print("Skipping this dataset...")
                return None
        except Exception as e:
            print(f"Error: Could not read file {file_path}: {e}")
            return None

    if not os.path.exists(correction_path):
        print(f"Warning: Correction file not found: {correction_path}")
        continue
    
    correction_list = safe_json_load(correction_path)
    if correction_list is None:
        print(f"Skipping dataset {dataset} due to JSON loading error")
        continue

    # Load all data files (answerable, ambiguous, unanswerable)
    base_data = {}
    
    # Load answerable questions
    answerable_path = f"data_final/{dataset}_test.json"
    if os.path.exists(answerable_path):
        with open(answerable_path, 'r') as f:
            answerable_data = json.load(f)
        for sample in answerable_data:
            base_data[sample['id']] = sample
    
    # Load ambiguous questions
    ambiguous_path = f"data_final/{dataset}_ambig_test.json"
    if os.path.exists(ambiguous_path):
        with open(ambiguous_path, 'r') as f:
            ambiguous_data = json.load(f)
        for sample in ambiguous_data:
            base_data[sample['id']] = sample
    
    # Load unanswerable questions
    unanswerable_path = f"data_final/{dataset}_unans_test.json"
    if os.path.exists(unanswerable_path):
        with open(unanswerable_path, 'r') as f:
            unanswerable_data = json.load(f)
        for sample in unanswerable_data:
            base_data[sample['id']] = sample

    # Recall metrics are already initialized at abstainer level
    
    # Debug: count question types in base data (commented out for clean output)
    # question_type_counts = {}
    # for sample_id, sample in base_data.items():
    #     note = sample.get('note', 'unknown')
    #     question_type_counts[note] = question_type_counts.get(note, 0) + 1
    # print(f"Question types in {dataset}: {question_type_counts}")
    
    for instance in tqdm(correction_list, desc=f"Processing {args.corrector_name} on {dataset}"):
        # Extract instance information
        instance_id = instance['id']
        
        # Handle different ID formats - extract the actual ID
        if '_' in instance_id:
            parts = instance_id.split("_")
            instance_id_clean = parts[1] if len(parts) >= 2 else parts[0]  # Use the second part for correction data
        else:
            instance_id_clean = instance_id
        
        # Get question type from base data
        question_type = None
        if instance_id_clean in base_data:
            base_instance = base_data[instance_id_clean]
            question_type = base_instance.get('note', None)
        
        # Skip if we can't determine question type or if it's answerable (easy/medium/hard)
        if not question_type or question_type in ['easy', 'medium', 'hard']:
                continue
        
        # Count total instances for each question type
        if question_type in recall_metrics:
            recall_metrics[question_type][1] += 1
        
        # Check if the model correctly identified the question as non-answerable
        final_sql = instance.get('final_sql', '')
        
        # Count correct abstentions based on final_sql output
        model_abstained = False
        
        if question_type in ['vague-question', 'vague-word', 'ambiguous-reference']:
            # For ambiguous questions, check if model output is 'ambiguous' or specific ambiguous types
            if final_sql in ['ambiguous', 'vague-question', 'vague-word', 'ambiguous-reference']:
                model_abstained = True
        elif question_type in ['small-talk', 'out-of-scope', 'missing-column', 'infeasible-faq']:
            # For unanswerable questions, check if model output is 'unanswerable' or specific unanswerable types
            if final_sql in ['unanswerable', 'small-talk', 'out-of-scope', 'missing-column', 'infeasible-faq']:
                model_abstained = True
        
        # Count correct abstentions for recall calculation
        if model_abstained and question_type in recall_metrics:
            recall_metrics[question_type][0] += 1

# Helper function to format percentage
def format_metric(num, den, no_frac=False):
    if den > 0:
        if no_frac:
            return f"{round(num / den * 100, 1)}"
        else:
            return f"{round(num / den * 100, 1)} ({num}/{den})"
    return "NaN"

def save_table6_results_to_excel(model_name, corrector_name, recall_metrics, excel_file_path="table6_results.xlsx"):
    """Save Table 6 recall results to Excel file with model name"""
    
    # Map corrector names to display names
    corrector_mapping = {
        'single_turn_two_stage': 'Two-Stage',
        'single_turn_correction': 'Single-Turn',
        'verifier_correction': 'Single-Turn-Veri',
        'multi_turn_correction': 'Multi-Turn-SelfRef',
        'single_turn_correction_exp': 'Single-Turn-Cls',
        'verifier_correction_exp': 'Single-Turn-Veri-Cls',
        'multi_turn_correction_exp': 'Multi-Turn-SelfRef-Cls'
    }
    
    display_corrector_name = corrector_mapping.get(corrector_name, corrector_name)
    
    # Check if file exists to load existing data
    try:
        existing_df = pd.read_excel(excel_file_path)
    except FileNotFoundError:
        existing_df = pd.DataFrame()
    
    # Type mapping for abbreviations
    type_mapping = {
        "vague-question": "VQ",
        "vague-word": "VW", 
        "ambiguous-reference": "AR",
        "small-talk": "ST",
        "out-of-scope": "OS",
        "missing-column": "MC",
        "infeasible-faq": "IF"
    }
    
    # Create new row with results
    new_row = {
        'Model': model_name,
        'Corrector': display_corrector_name
    }
    
    # Add each question type's recall percentage and fraction
    for question_type, (correct, total) in recall_metrics.items():
        if total > 0:
            abbrev = type_mapping.get(question_type, question_type.upper())
            new_row[f'{abbrev} (%)'] = format_metric(correct, total, no_frac=True)
            new_row[f'{abbrev}_Fraction'] = f"{correct}/{total}"
    
    # Add new row to existing data
    new_df = pd.concat([existing_df, pd.DataFrame([new_row])], ignore_index=True)
    
    # Save to Excel file
    new_df.to_excel(excel_file_path, index=False)

# Save results to Excel
save_table6_results_to_excel(args.model_name, args.corrector_name, recall_metrics)

print(f"\nRecall Results for {args.corrector_name}:")
print("Question Type Recall (correctly identifying non-answerable questions):")

# Map question types to abbreviations as mentioned in the caption
type_mapping = {
    "vague-question": "VQ",
    "vague-word": "VW", 
    "ambiguous-reference": "AR",
    "small-talk": "ST",
    "out-of-scope": "OS",
    "missing-column": "MC",
    "infeasible-faq": "IF"
}

for question_type, (correct, total) in recall_metrics.items():
    if total > 0:
        abbrev = type_mapping.get(question_type, question_type.upper())
        print(f"{abbrev} ({question_type}): {format_metric(correct, total)}")

# Calculate overall recall for ambiguous and unanswerable categories
ambiguous_correct = recall_metrics["vague-question"][0] + recall_metrics["vague-word"][0] + recall_metrics["ambiguous-reference"][0]
ambiguous_total = recall_metrics["vague-question"][1] + recall_metrics["vague-word"][1] + recall_metrics["ambiguous-reference"][1]

unanswerable_correct = recall_metrics["small-talk"][0] + recall_metrics["out-of-scope"][0] + recall_metrics["missing-column"][0] + recall_metrics["infeasible-faq"][0]
unanswerable_total = recall_metrics["small-talk"][1] + recall_metrics["out-of-scope"][1] + recall_metrics["missing-column"][1] + recall_metrics["infeasible-faq"][1]

# print(f"\nOverall Recall:")
# if ambiguous_total > 0:
#     print(f"Ambiguous Questions: {format_metric(ambiguous_correct, ambiguous_total)}")
# if unanswerable_total > 0:
#     print(f"Unanswerable Questions: {format_metric(unanswerable_correct, unanswerable_total)}")

# total_correct = ambiguous_correct + unanswerable_correct
# total_instances = ambiguous_total + unanswerable_total
# if total_instances > 0:
#     print(f"All Non-answerable Questions: {format_metric(total_correct, total_instances)}")

print("=" * 20)

print_list = []
for question_type, (correct, total) in recall_metrics.items():
    if total > 0:
        abbrev = type_mapping.get(question_type, question_type.upper())
        print_list.append(format_metric(correct, total, no_frac=True))
print(" & ".join(print_list))

print("=" * 100)
print(f"Results saved to table6_results.xlsx with model: {args.model_name}, corrector: {args.corrector_name}")
