import json
import random
import argparse
def load_data(data_path):
    datas =json.load(open(data_path))
    return datas



def build_fk_map_bi_direction(entry):
    fk_pairs = entry.get("foreign_keys", [])
    fk_map = {}
    for a, b in fk_pairs:
        fk_map.setdefault(a, []).append(b)
        fk_map.setdefault(b, []).append(a)
    return fk_map


def render_db_entry_per_table(entry):

    tnames = entry["table_names_original"]          # ['school', 'family', ...]
    cols   = entry["column_names_original"]         # [[tidx, cname], ...]，含 [-1, '*'] 需跳过
    ctypes = entry["column_types"]                  # ['text','number','time',...]
    pk_set = set(entry.get("primary_keys", []))     # 
    fk_map = build_fk_map_bi_direction(entry)       # 

    # 
    table_to_col_idxs = {t: [] for t in tnames}
    for idx, (tidx, cname) in enumerate(cols):
        # if tidx == -1:               # 跳过 [-1, "*"]
        #     continue
        table_to_col_idxs[tnames[tidx]].append(idx)

    # 
    lines = []
    for t in tnames:
        items = []
        for i in table_to_col_idxs[t]:
            tidx, cname = cols[i]
            ctype = ctypes[i]

            # 可选：把 Spider 的 'number'  'integer'
            # if ctype == 'number': ctype = 'integer'

            # piece = f"{t}.{cname} ( {ctype}".lower()
            piece = f"{t}.{cname} ( {ctype}".lower()
            if i in pk_set:
                piece += " | primary key"
            if i in fk_map:
                for tgt in fk_map[i]:
                    ttidx, tcname = cols[tgt]
                    if ttidx >= 0:
                        piece += f" | foreign key -> {tnames[ttidx]}.{tcname}"
            piece += " )"
            items.append(piece)

        line = f"table {t} , columns = [ " + " , ".join(items) + " ]"
        lines.append(line)

    return "\n".join(lines)



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--source', type=str, default='')
    parser.add_argument('--others', type=str, default='')
    parser.add_argument('--table', dest='table', type=str, default='')
    parser.add_argument('--save', type=str,default='')
    parser.add_argument('--type', choices=['train', 'test'], default='train')
    
    args = parser.parse_args()

    # loading training dataset
    datas =load_data(args.source)
    if args.others !='':
        datas_others = load_data(args.others)
        datas = datas + datas_others
    
    print(len(datas), "dataset load sucess")
    
    tables = load_data(args.table)
    print(len(tables), "schema load sucess")

    # link schema to data
    results = []
    for sample in datas:
        db_id = sample["db_id"]
        question = sample["question"]
        query = sample["query"]

        # get schema of db_id
        entry = [t for t in tables if t["db_id"]== db_id][0]
        schema = render_db_entry_per_table(entry)

        sample["schema"] = 'database schema :\n'+ schema
        # results.append({"question": sample['schema'] + ';\n question: ' + question, "query": query})
        # results.append({"question": sample['schema'] + ';\n question: ' + query, "query": question})
        # results.append({"question": sample['schema'] + ';\n QUESTION: ' + question, "query": query})
        # results.append({"question": sample['schema'] + ';\n SQL: ' + query, "query": question})

        # results.append({"question": 'Task: TEXT_TO_SQL \n' + sample['schema'] + ';\n'+  question, "query": query})
        # results.append({"question": 'Task: SQL_TO_TEXT \n' + sample['schema'] + ';\n' + query, "query": question})

        # #normal data
        # message_list = []
        # sys = {'role': 'system', 'content': 'You are a database assistant.'}
        
        # user = {'role': 'user', 'content': 'Task: TEXT_TO_SQL \n' + sample['schema'] + ';\n'+  question}
        # assistant = {'role': 'assistant', 'content': query}
        # message_list.append(sys)
        # message_list.append(user)
        # message_list.append(assistant)
        # results.append({'messages': message_list, 'db_id': db_id})

        # # reverse data
        # message_reverse_list = []
        # user = {'role': 'user', 'content': 'Task: SQL_TO_TEXT \n' + sample['schema'] + ';\n' + query}
        # assistant = {'role': 'assistant', 'content': question}
        # message_reverse_list.append(sys)
        # message_reverse_list.append(user)
        # message_reverse_list.append(assistant)
        # results.append({'messages': message_reverse_list, 'db_id': db_id + '_r'})
        



        # results.append({"question": schema + '; ' + query, "query": question})
        # results.append({"schema": schema, "question": question, "query": query})
        # results.append({"schema": schema, "question": query, "query": question})


        #mix data design
        # instruction_sql_text = "Given the database schema and an SQL query, describe in natural language what the query returns."
        # results.append({"instruction": instruction_sql_text, "input": sample['schema'] + ';\n sql: ' + query, "output": question})

        instruction_text_sql = "Given the database schema and a question in natural language, generate the corresponding SQL query."
        results.append({"instruction": instruction_text_sql, "input": sample['schema'] + ';\n question: ' + question, "output": query})


        if args.type == 'train':
            # results.append({"question": schema + '; ' + query, "query": question})
            random.shuffle(results)

    with open(args.save, "w", encoding="utf-8") as f:
        json.dump(results, f, indent=2, ensure_ascii=False)

    print(len(results), "dataset save sucess")
    print("✅ save in " + args.save)

