import json
from sqlglot import parse_one,exp
from copy import deepcopy

FILENAME = r'file_name' # examples: ./examples/output_qwen3-235b'

INPUT_FILE = FILENAME + '.json'
OUTPUT_GOLD = FILENAME + '-gold.txt'
OUTPUT_PRED = FILENAME + '-pred.txt'

field = 'infer'

with open(INPUT_FILE, encoding='utf-8') as f:
    data = json.load(f)


def get_dict_value(r: dict, field: str):
    if '.' in field:
        field = field.split('.')
        for k in field:
            try:
                k = int(k)
            except:
                pass
            r = r[k]
        return r
    else:
        return r[field]


def set_dict_value(r: dict, field: str, value):
    if '.' in field:
        field = field.split('.')
        for k in field[:-1]:
            try:
                k = int(k)
            except:
                pass
            if k not in r:
                r[k] = {}
            r = r[k]
        k = field[-1]
        try:
            k = int(k)
        except:
            pass
        r[k] = value
    else:
        r[field] = value

def tc_new2old(db_name,new_tbl,new_col=None):
    db_name = db_name.lower().strip()
    new_tbl = new_tbl.lower().strip() if new_tbl else None
    new_col = new_col.lower().strip() if new_col else None
    with open('../../Spider-Ent/table_name_mappings.json', 'r', encoding='utf-8') as f:
        table_name_mappings = json.load(f)
    table_name_mappings = {
        db_name: {value.lower().strip(): key.lower().strip() for key, value in mappings.items()}
        for db_name, mappings in table_name_mappings.items()
    }
    with open('../../Spider-Ent/column_name_mappings.json', 'r', encoding='utf-8') as f:
        column_name_mappings = json.load(f)
    column_name_mappings = {
        db_name: {
            table: {value.lower().strip(): key.lower().strip() for key, value in col_mappings.items()}
            for table, col_mappings in tables.items()
        }
        for db_name, tables in column_name_mappings.items()
    }
    old_tbl = table_name_mappings.get(db_name, {}).get(new_tbl, None)
    old_col = None
    if new_col:
        if old_tbl:
            old_col = column_name_mappings.get(db_name, {}).get(old_tbl, {}).get(new_col, None)
        else:
            for k,v in column_name_mappings.items():
                db_cols = {}
                for t,cols in v.items():
                    db_cols.update(cols)
                column_name_mappings[k] = db_cols
            old_col = column_name_mappings.get(db_name, {}).get(new_col, None)
    return old_tbl, old_col

# convert tables and columns in infer SQL to the orignal ones in Spider database
def convert_sql(db_name,new_sql):
    old_sql = new_sql
    try:
        new_sql_parsed = parse_one(new_sql, dialect='sqlite')
    except:
        return old_sql
    
    for t in new_sql_parsed.find_all(exp.Table):
        old_tbl, _ = tc_new2old(db_name, t.name)
        if old_tbl:
            t.set('this', exp.to_identifier(old_tbl))
    for c in new_sql_parsed.find_all(exp.Column):
        _, old_col = tc_new2old(db_name, None, c.name)
        if old_col:
            c.set('this', exp.to_identifier(old_col))
    old_sql = new_sql_parsed.sql(dialect='sqlite')
    return old_sql

with open('../../Spider-Ent/Spider-Ent.json', encoding='utf-8') as f:
    dev_data = json.load(f)

with open(OUTPUT_GOLD, 'w', encoding='utf-8') as fg, \
        open(OUTPUT_PRED, 'w', encoding='utf-8') as fp:
    for r, dev_r in zip(data, dev_data):
        if r is None:
            r = deepcopy(dev_r)
            r['output'] = r['query']
            set_dict_value(r, field, 'SELECT')
        assert r['question'] == dev_r['question']
        assert r['output'] == dev_r['query']
        assert r['db'] == dev_r['db_id']
        gold = r['output']
        if get_dict_value(r, field):
            pred = get_dict_value(r, field).replace('\n', ' ').replace('\t', ' ').strip()
            pred = convert_sql(r['db'], pred)
        else:
            pred = 'SELECT;'
        db = r['db']
        fg.write(gold + '\t' + db + '\n')
        fp.write(pred + '\t' + db + '\n')