import os
import json
from collections import defaultdict
import argparse
import pandas as pd
from utils.utils import SQLEvaluator
from sqlglot import parse_one, exp
from tqdm import tqdm


def get_gold_columns_only(sql: str, dataset: str, data_dir: str):
    data_dir = data_dir
    dataset = dataset
    table_path = f"{data_dir}/tables.json"
    try:
        with open(table_path, "r") as f:
            tables = json.load(f)
    except Exception as e:
        raise ValueError(f"Error in loading tables.json: {e}")
    column_names = [column[1] for db in tables if db['db_id'] == dataset for column in db['column_names_original']]
    try:
        # Parse the SQL query
        parsed = parse_one(sql, read='sqlite', error_level='ignore')#parse_one(sql)
        
        # Extract all column references
        columns = parsed.find_all(exp.Column,)
        # Get unique column names, ignoring table names
        column_list = list(set(col.name for col in columns))
        column_list = [column for column in column_list if not column.endswith('id') and column in column_names]
        return column_list
    except Exception as e:
        #print(e)
        return []
    
def parse_comma_separated_list(value):
    """Parse a comma-separated string into a list of strings."""
    if isinstance(value, list):
        return value
    return [item.strip() for item in value.split(',')]

def value_checker(val_dict, pred_sql):
    if not val_dict:
        return True
    if isinstance(val_dict, dict):
        for key, value in val_dict.items():
            if str(value) not in pred_sql:
                return False
    elif isinstance(val_dict, list):
        for value in val_dict:
            if str(value) not in pred_sql:
                return False
    else:
        print(f"Warning: Unexpected val_dict type: {type(val_dict)}")
        return True
    return True

def check_column_match(pred_sql: str, gold_sql:str, dataset: str, data_dir: str):
    
    for i in range(10):
        pred_sql = pred_sql.replace("\n", " ")
        pred_sql = pred_sql.replace("  ", " ")
    
    # Initialize variables with default values
    pred_columns = []
    gold_columns = []
    is_correct = False
    
    try:
        pred_columns = get_gold_columns_only(pred_sql, dataset, data_dir)
        gold_columns = get_gold_columns_only(gold_sql, dataset, data_dir)
        if len(set(gold_columns) - set(pred_columns)) == 0:
            is_correct = True
        else:
            is_correct = False
    except Exception as e:
        print(f"Warning: Error in column matching: {e}")
        is_correct = False
    
    result = {
        "is_correct": is_correct,
        "pred_columns": pred_columns,
        "gold_columns": gold_columns
    }
    return result

parser = argparse.ArgumentParser(description="Evaluate SQL correction results with error type analysis.")
parser.add_argument("--model_name", type=str, required=True, help="Name of the model used for SQL generation.")
parser.add_argument("--dataset_name_list", type=parse_comma_separated_list, required=True, help="Comma-separated list of dataset names.")
parser.add_argument("--corrector_name", type=str, required=True, help="Name of the SQL corrector.")
parser.add_argument("--correction_output_dir", type=str, default="./output", help="Directory for correction output JSON files.")
parser.add_argument("--granular_evaluation", action="store_true", help="Enable granular evaluation for sub-types.")
parser.add_argument("--data_dir", type=str, default="./databases", help="Directory for database files.")
args = parser.parse_args()


is_subset = ""
# is_subset = "_subset"



_evaluator = {}
ablation_dict = {
    "schema_linking" : [0,0],
    "join_group_by" : [0,0],
    "value_parsing" : [0,0],
    "other_local" : [0,0],
    "global" : [0,0],
}

for dataset in args.dataset_name_list:

    correction_path = os.path.join(args.correction_output_dir, f"{args.corrector_name}_{dataset}_{args.model_name}{is_subset}.json")    
    eval_correction_path = os.path.join(args.correction_output_dir+'_eval', f"{args.corrector_name}_{dataset}_{args.model_name}{is_subset}.json")
    

    def safe_json_load(file_path):
        """Safely load JSON file with error handling"""
        try:
            with open(file_path, 'r') as f:
                return json.load(f)
        except json.JSONDecodeError as e:
            print(f"Warning: JSON decode error in {file_path}: {e}")
            print("Attempting to fix the JSON file...")
            
            # Try to read and fix the JSON file
            with open(file_path, 'r') as f:
                content = f.read()
            
            # Common JSON fixes
            # Remove trailing commas before closing brackets/braces
            import re
            content = re.sub(r',(\s*[}\]])', r'\1', content)
            
            try:
                return json.loads(content)
            except json.JSONDecodeError as e2:
                print(f"Error: Could not fix JSON file {file_path}: {e2}")
                print("Skipping this dataset...")
                return None
        except Exception as e:
            print(f"Error: Could not read file {file_path}: {e}")
            return None

    if os.path.exists(eval_correction_path):
        corrections = safe_json_load(eval_correction_path)
    else:
        corrections = safe_json_load(correction_path)
    
    if corrections is None:
        print(f"Skipping dataset {dataset} due to JSON loading error")
        continue

    if dataset not in _evaluator:
        _evaluator[dataset] = SQLEvaluator(data_dir=args.data_dir, dataset=dataset)


    base_data_path = f"data_final/{dataset}_test.json"
    
    if not os.path.exists(correction_path):
        print(f"Warning: Correction file not found: {correction_path}")
        continue
        
    with open(base_data_path, 'r') as f:
        base_data = json.load(f)
    base_data = {sample['id']: sample for sample in base_data}
    data_dir = args.data_dir
            
    with open(correction_path, 'r') as f:
        correction_list = json.load(f)

    correction_data_path = f"correction-data/table_5/{dataset}/merged_results.json"
    
    if not os.path.exists(correction_data_path):
        print(f"Warning: Ablation data file not found: {correction_data_path}")
        continue
        
    with open(correction_data_path, "r") as f:
        correction_data = json.load(f)

    schema_linking_dict = defaultdict(list)
    value_parsing_dict = defaultdict(list)
    join_group_by_dict = defaultdict(list)
    other_local_dict = defaultdict(list)
    global_dict = defaultdict(list)

    for instance in tqdm(correction_list, desc=f"Processing {dataset}"):
        generator_name_instance_id = instance['id']
        
        # Handle different ID formats
        if '_' in generator_name_instance_id:
            parts = generator_name_instance_id.split("_")
            if len(parts) >= 2:
                generator_name = parts[0]
                instance_id = parts[1]
            else:
                generator_name = "unknown"
                instance_id = generator_name_instance_id
        else:
            generator_name = "unknown"
            instance_id = generator_name_instance_id

        init_pred_sql = instance['init_pred_sql']
        init_pred_sql_exec_result = instance['init_pred_sql_exec_result']
        gold_sql = instance['gold_sql']
        gold_answer = instance['gold_answer']
        sample_type = instance['sample_type']
        final_sql = instance['final_sql']
        final_sql_exec_result = instance['final_sql_exec_result']
        
        # correction_list is optional and not used in the analysis
        correction_list = instance.get('correction_list', [])
        
        final_exec_acc = gold_answer == final_sql_exec_result
        init_exec_acc = gold_answer == init_pred_sql_exec_result

        if gold_sql == 'null':
            continue
        if init_exec_acc == True:
            continue

        final_predicted_question_type = None
        
        
        is_value_parsing_error = False
        is_schema_linking_error = False
        is_join_group_by_error = False
        is_other_local_error = False
        is_global_error = False
        assigned = False


        # Get base instance data
        if instance_id in base_data:
            base_instance = base_data[instance_id]
            val_dict = base_instance.get('val_dict', {})
        else:
            # Skip instances not found in base data silently
            # print(f"Warning: Instance {instance_id} not found in base data")
            val_dict = {}
        try:
            error_type = correction_data[generator_name][instance_id]['error_type']
        except KeyError:
            # Try with different generator name patterns
            found = False
            for gen_name in correction_data.keys():
                if instance_id in correction_data[gen_name]:
                    error_type = correction_data[gen_name][instance_id]['error_type']
                    generator_name = gen_name
                    found = True
                    break
            
            if not found:
                # Skip instances without error type data silently
                # print(f"Warning: No error type found for {generator_name}_{instance_id}")
                continue
        if not value_checker(val_dict, init_pred_sql):
            value_parsing_dict[error_type].append(instance)
            if not error_type.lower().startswith("global"):
                is_value_parsing_error = True
            assigned = True
        if not assigned and (("join" in gold_sql.lower() and not "join" in init_pred_sql.lower()) or ("group by" in gold_sql.lower() and not "group by" in init_pred_sql.lower())):
            join_group_by_dict[error_type].append(instance)
            is_join_group_by_error = True
            assigned = True
        if not assigned and not check_column_match(init_pred_sql, gold_sql, dataset, data_dir)["is_correct"]:
            schema_linking_dict[error_type].append(instance)
            
            is_schema_linking_error = True
            assigned = True
        if not assigned and error_type.lower().startswith("global"):
            try:
                global_type = error_type.lower().split("global:")[1].strip()
            except:
                global_type = error_type.lower().split("global")[1].strip()
                if not global_type:
                    global_type = "global"

            global_dict[global_type].append(instance)
            is_global_error = True
            if "value" in global_type:
                is_value_parsing_error = True
                is_global_error = False
            if "reference" in global_type:
                is_schema_linking_error = True
                is_global_error = False
            #ablation_dict["global"][0] += 1
            assigned = True
        if not assigned and error_type.lower().startswith("local"):
            try:
                other_type = error_type.lower().split("local: ")[1].strip()
            except:
                other_type = error_type.lower().split("local")[1].strip()
                if not other_type:
                    other_type = "local"
            if "value" in other_type:
                is_value_parsing_error = True
            if "reference" in other_type:
                is_schema_linking_error = True
            other_local_dict[other_type].append(instance)
            is_other_local_error = True
        
            assigned = True

        # Count total instances for each error type
        if is_schema_linking_error:
            ablation_dict["schema_linking"][0] += 1
        if is_value_parsing_error:
            ablation_dict["value_parsing"][0] += 1
        if is_join_group_by_error:
            ablation_dict["join_group_by"][0] += 1
        if is_other_local_error:
            ablation_dict["other_local"][0] += 1
        if is_global_error:
            ablation_dict["global"][0] += 1
            
        # Count successful corrections
        # correction_successful = False
        
        # Check if correction was successful based on different criteria
        # if final_sql in ['ambiguous', 'unanswerable', 'vague-question', 'vague-word', 'ambiguous-reference', 'infeasible-faq', 'missing-column', 'small-talk', 'out-of-scope']:
        #     correction_successful = False
        # if final_exec_acc and not init_exec_acc:
        #     correction_successful = True
        # elif final_exec_acc and init_exec_acc:
        #     correction_successful = True  # Maintained accuracy

        if final_exec_acc:
            if is_schema_linking_error:
                ablation_dict["schema_linking"][1] += 1
            if is_value_parsing_error:
                ablation_dict["value_parsing"][1] += 1
            if is_join_group_by_error:
                ablation_dict["join_group_by"][1] += 1
            if is_other_local_error:
                ablation_dict["other_local"][1] += 1
            if is_global_error:
                ablation_dict["global"][1] += 1

# Helper function to format percentage
def format_metric(num, den, no_frac=False):
    if den > 0:
        if no_frac:
            return f"{round(num / den * 100, 1)}"
        else:
            return f"{round(num / den * 100, 1)} ({num}/{den})"
    return "NaN"

def save_table5_results_to_excel(model_name, corrector_name, ablation_dict, excel_file_path="table5_results.xlsx"):
    """Save Table 5 ablation results to Excel file with model name"""
    
    # Map corrector names to display names
    corrector_mapping = {
        'single_turn_two_stage': 'Two-Stage',
        'single_turn_correction': 'Single-Turn',
        'verifier_correction': 'Single-Turn-Veri',
        'multi_turn_correction': 'Multi-Turn-SelfRef',
        'single_turn_correction_exp': 'Single-Turn-Cls',
        'verifier_correction_exp': 'Single-Turn-Veri-Cls',
        'multi_turn_correction_exp': 'Multi-Turn-SelfRef-Cls'
    }
    
    display_corrector_name = corrector_mapping.get(corrector_name, corrector_name)
    
    # Check if file exists to load existing data
    try:
        existing_df = pd.read_excel(excel_file_path)
    except FileNotFoundError:
        existing_df = pd.DataFrame()
    
    # Create new row with results
    new_row = {
        'Model': model_name,
        'Corrector': display_corrector_name,
        'Schema_Linking (%)': format_metric(ablation_dict['schema_linking'][1], ablation_dict['schema_linking'][0], no_frac=True),
        'Schema_Linking_Fraction': f"{ablation_dict['schema_linking'][1]}/{ablation_dict['schema_linking'][0]}",
        'Value_Parsing (%)': format_metric(ablation_dict['value_parsing'][1], ablation_dict['value_parsing'][0], no_frac=True),
        'Value_Parsing_Fraction': f"{ablation_dict['value_parsing'][1]}/{ablation_dict['value_parsing'][0]}",
        'Join_Group_By (%)': format_metric(ablation_dict['join_group_by'][1], ablation_dict['join_group_by'][0], no_frac=True),
        'Join_Group_By_Fraction': f"{ablation_dict['join_group_by'][1]}/{ablation_dict['join_group_by'][0]}",
        'Other_Local (%)': format_metric(ablation_dict['other_local'][1], ablation_dict['other_local'][0], no_frac=True),
        'Other_Local_Fraction': f"{ablation_dict['other_local'][1]}/{ablation_dict['other_local'][0]}",
        'Global (%)': format_metric(ablation_dict['global'][1], ablation_dict['global'][0], no_frac=True),
        'Global_Fraction': f"{ablation_dict['global'][1]}/{ablation_dict['global'][0]}"
    }
    
    # Add new row to existing data
    new_df = pd.concat([existing_df, pd.DataFrame([new_row])], ignore_index=True)
    
    # Save to Excel file
    new_df.to_excel(excel_file_path, index=False)

# Save results to Excel
save_table5_results_to_excel(args.model_name, args.corrector_name, ablation_dict)

print(f"\nResults for {args.corrector_name}:")
if ablation_dict['schema_linking'][0] > 0:
    print(f"Schema Linking: {format_metric(ablation_dict['schema_linking'][1], ablation_dict['schema_linking'][0])}")
if ablation_dict['value_parsing'][0] > 0:
    print(f"Value Parsing: {format_metric(ablation_dict['value_parsing'][1], ablation_dict['value_parsing'][0])}")
if ablation_dict['join_group_by'][0] > 0:
    print(f"Join Group By: {format_metric(ablation_dict['join_group_by'][1], ablation_dict['join_group_by'][0])}")
if ablation_dict['other_local'][0] > 0:
    print(f"Other Local: {format_metric(ablation_dict['other_local'][1], ablation_dict['other_local'][0])}")
if ablation_dict['global'][0] > 0:
    print(f"Global: {format_metric(ablation_dict['global'][1], ablation_dict['global'][0])}")

print("=" * 20)

print(format_metric(ablation_dict['schema_linking'][1], ablation_dict['schema_linking'][0], no_frac=True), end= " & ")
print(format_metric(ablation_dict['value_parsing'][1], ablation_dict['value_parsing'][0], no_frac=True), end= " & ")
print(format_metric(ablation_dict['join_group_by'][1], ablation_dict['join_group_by'][0], no_frac=True), end= " & ")
print(format_metric(ablation_dict['other_local'][1], ablation_dict['other_local'][0], no_frac=True), end= " & ")
print(format_metric(ablation_dict['global'][1], ablation_dict['global'][0], no_frac=True))

print("=" * 100)
print(f"Results saved to table5_results.xlsx with model: {args.model_name}, corrector: {args.corrector_name}")