import sys
import json
from copy import deepcopy
from sqlglot import parse_one,exp


FILENAME = r'file_name' # examples: ./examples/output_qwen3-235b'


INPUT_FILE = FILENAME + '.json'
OUTPUT_GOLD = FILENAME + '-gold.sql'
OUTPUT_PRED = FILENAME + '-pred.json'
OUTPUT_DEV = FILENAME + '-dev.json'

field = 'infer'

with open(INPUT_FILE, encoding='utf-8') as f:
    data = json.load(f)


def get_dict_value(r: dict, field: str):
    if '.' in field:
        field = field.split('.')
        for k in field:
            try:
                k = int(k)
            except:
                pass
            r = r[k]
        return r
    else:
        return r[field]


def set_dict_value(r: dict, field: str, value):
    if '.' in field:
        field = field.split('.')
        for k in field[:-1]:
            try:
                k = int(k)
            except:
                pass
            if k not in r:
                r[k] = {}
            r = r[k]
        k = field[-1]
        try:
            k = int(k)
        except:
            pass
        r[k] = value
    else:
        r[field] = value

def tc_new2old(db_name,new_tbl,new_col=None):
    db_name = db_name.lower().strip()
    new_tbl = new_tbl.lower().strip() if new_tbl else None
    new_col = new_col.lower().strip() if new_col else None
    with open('../../BIRD-Ent/table_name_mappings.json', 'r', encoding='utf-8') as f:
        table_name_mappings = json.load(f)
    # Reverse the table_name_mappings for all databases
    table_name_mappings = {
        db_name: {value.lower().strip(): key.lower().strip() for key, value in mappings.items()}
        for db_name, mappings in table_name_mappings.items()
    }
    with open('../../BIRD-Ent/column_name_mappings.json', 'r', encoding='utf-8') as f:
        column_name_mappings = json.load(f)
    column_name_mappings = {
        db_name: {
            table: {value.lower().strip(): key.lower().strip() for key, value in col_mappings.items()}
            for table, col_mappings in tables.items()
        }
        for db_name, tables in column_name_mappings.items()
    }
    old_tbl = table_name_mappings.get(db_name, {}).get(new_tbl, None)
    old_col = None
    if new_col:
        if old_tbl:
            old_col = column_name_mappings.get(db_name, {}).get(old_tbl, {}).get(new_col, None)
        else:
            for k,v in column_name_mappings.items():
                db_cols = {}
                for t,cols in v.items():
                    db_cols.update(cols)
                column_name_mappings[k] = db_cols
            old_col = column_name_mappings.get(db_name, {}).get(new_col, None)
    return old_tbl, old_col

# convert tables and columns in infer SQL to the orignal ones in BIRD database
def convert_sql(db_name,new_sql):
    old_sql = new_sql
    try:
        new_sql_parsed = parse_one(new_sql, dialect='sqlite')
    except:
        return old_sql
    
    for t in new_sql_parsed.find_all(exp.Table):
        old_tbl, _ = tc_new2old(db_name, t.name)
        if old_tbl:
            t.set('this', exp.to_identifier(old_tbl))
    for c in new_sql_parsed.find_all(exp.Column):
        _, old_col = tc_new2old(db_name, None, c.name)
        if old_col:
            c.set('this', exp.to_identifier(old_col))
    old_sql = new_sql_parsed.sql(dialect='sqlite')
    return old_sql


with open('../../BIRD-Ent/BIRD-Ent.json', encoding='utf-8') as f:
    dev_data = json.load(f)

out_data = []
out_dev_data = []
assert len(data) == len(dev_data)
for i, (r, dev_r) in enumerate(zip(data, dev_data)):
    if r is None:
        r = deepcopy(dev_r)
        r['db'] = r['db_id']
        r['output'] = r['SQL']
        r['question'] += ' Hint: ' + r['evidence']
        set_dict_value(r, field, 'SELECT')
    out_data.append(r)
    out_dev_data.append(dev_r)
    # assert i == dev_r['question_id']
    assert r['db'] == dev_r['db_id']
    try:
        assert r['output'] == dev_r['SQL']
    except:
        print(f"Warning: Mismatch in SQL for question_id {dev_r['question_id']}.")
        print(f"output SQL: {r['output']}")
        print(f"Gold SQL: {dev_r['SQL']}")
        sys.exit(1)
    # assert r['question'] == dev_r['question'] + ' Hint: ' + dev_r['evidence']

with open(OUTPUT_GOLD, 'w', encoding='utf-8') as fg, \
        open(OUTPUT_PRED, 'w', encoding='utf-8') as fp:
    res = {}
    assert len(out_data) == len(out_dev_data)
    for i, (r, dev_r) in enumerate(zip(out_data, out_dev_data)):
        # assert i == dev_r['question_id']
        assert r['db'] == dev_r['db_id']
        assert r['output'] == dev_r['SQL']
        # assert r['question'] == dev_r['question'] + ' Hint: ' + dev_r['evidence']
        gold = r['output']
        if get_dict_value(r, field):
            pred = get_dict_value(r, field).replace('\n', ' ').replace('\t', ' ').strip()
            pred = convert_sql(r['db'], pred)
        else:
            pred = 'SELECT'
        db = r['db']
        res[str(i)] = pred + '\t----- bird -----\t' + db
        fg.write(gold + '\t' + db + '\n')
    json.dump(res, fp, ensure_ascii=False, indent=4)

with open(OUTPUT_DEV, 'w', encoding='utf-8') as f:
    json.dump(out_dev_data, f, ensure_ascii=False, indent=4)