import json
import os
import re
import random
import copy
import argparse
from nltk.stem import WordNetLemmatizer


wordnet_lemmatizer = WordNetLemmatizer()

sql_template = {
    "from": {
        "table_units": [],
        "conds": []
    },
    "select": [],
    "where": [],
    "groupBy": [],
    "having": [],
    "orderBy": [],
    "limit": None,
    "intersect": None,
    "union": None,
    "except": None
}

_agg_ops = ['', 'MAX', 'MIN', 'COUNT', 'SUM', 'AVG']
_cond_ops = ['=', '>', '<', 'OP']

AGG_OPS = ('none', 'max', 'min', 'count', 'sum', 'avg')
WHERE_OPS = ('not', 'between', '=', '>', '<', '>=', '<=', '!=', 'in', 'like', 'is', 'exists')
WHERE_OPS_IDS = {'=': 2, ">": 3, "<": 4}

def load_json(path=[], dataset="wikisql"):

    def convert_db_key(data):
        if 'database_id' in data:
            data['db_id'] = data['database_id']
            data.pop('database_id')
        return data

    if dataset == "wikisql":
        table_path = ["./data/original_datasets/wikisql/train.tables.json",
                      "./data/original_datasets/wikisql/dev.tables.json",
                      "./data/original_datasets/wikisql/test.tables.json"]
        tables = {}
        for p in table_path:
            for data in json.load(open(p, "r", encoding="utf-8")):
                data = convert_db_key(data)
                if data["db_id"] not in tables:
                    tables[data["db_id"]] = data

    interaction_dict = {}
    data_dict = {}
    for p in path:
        for i, data in enumerate(json.load(open(p, "r", encoding="utf-8"))):
            data = convert_db_key(data)
            if dataset in ['sparc', 'cosql'] and 'table' not in p:
                if data['db_id'] not in interaction_dict:
                    interaction_dict[data['db_id']] = 0
                data['interaction_id'] = data['db_id'] + '_' + str(interaction_dict[data['db_id']])
                interaction_dict[data['db_id']] += 1

            if dataset == "wikisql":
                key = " ||| ".join([x[-1] for x in tables[data["db_id"]]["column_names"] if x[-1] != "*"])
            else:
                key = data["db_id"]
            if key not in data_dict:
                data_dict[key] = [data]
            else:
                data_dict[key].append(data)
    return data_dict

def split_dataset(schema_dist_dict, table_dict, data_dict, split=[4, 1, 2], few_shot=False):
    task_splits = []
    for k, v in schema_dist_dict.items():
        tables = []

        examples = []
        for table_name in v:
            if table_name in table_dict:
                tables.extend(table_dict[table_name])
            if table_name in data_dict:
                examples.extend([x for x in data_dict[table_name]])

        random.shuffle(examples)

        if few_shot:
            examples = examples[:500]

        train_num = len(examples) * split[0] // sum(split)
        dev_num = len(examples) * split[1] // sum(split)

        # print(train_num, dev_num)

        train = [x for x in examples[:train_num]]
        dev = [x for x in examples[train_num: train_num+dev_num]]
        test = [x for x in examples[train_num+dev_num:]]

        if few_shot:
            new_table_ids = set()
            new_tables = []
            for x in train + dev + test:
                new_table_ids.add(x["db_id"])
            for _, _tables in table_dict.items():
                for x in _tables:
                    if x["db_id"] in new_table_ids:
                        new_tables.append(x)

            tables = [x for x in new_tables]

        task_splits.append([tables, train, dev, test])
    return task_splits

def evaluate_zs(train, dev, test):
    table_set = set()
    for x in train:
        table_set.add(x["db_id"])

    dev_zs_num = 0
    for x in dev:
        if x["db_id"] not in table_set:
            dev_zs_num += 1

    test_zs_num = 0
    for x in test:
        if x["db_id"] not in table_set:
            test_zs_num += 1

    print("dev zs: {}, test zs: {}".format(dev_zs_num / len(dev), test_zs_num / len(test)))

pattern_1 = re.compile(r'[\"](.*?)[\"]', re.S)
pattern_2 = re.compile(r"['](.*?)[']", re.S)

def tokenize_query(query):
    query = query.replace("(", " ( ")
    query = query.replace(")", " ) ")
    query = query.replace(".", " . ")
    query = query.replace(",", " , ")
    query = query.replace(",", " ; ")

    query_no_val = copy.deepcopy(query)
    values_1 = ["\"" + v + "\"" for v in re.findall(pattern_1, query_no_val)]
    for val in values_1:
        query_no_val = query_no_val.replace(val, "value")

    values_2 = ["'" + v + "'" for v in re.findall(pattern_2, query_no_val)]
    for val in values_2:
        query_no_val = query_no_val.replace(val, "value")

    query = query.replace("\"", " \" ")
    query = query.replace("'", " ' ")
    tokens = [token for token in query.split(" ") if token != ""]
    tokens_no_val = [token.lower() for token in query_no_val.split(" ") if token != ""]

    return tokens, tokens_no_val

def convert_format_for_one_interaction(interaction):
    datas = []
    db_id = interaction["db_id"]
    interaction_id = interaction['interaction_id']
    context = ""
    context_toks = []
    for step, inter in enumerate(interaction["interaction"]):
        query_toks, query_toks_no_val = tokenize_query(inter["query"])
        new_data = {
            "interaction_id": interaction_id,
            "db_id": db_id,
            "query": inter["query"],
            "query_toks": query_toks,
            "query_toks_no_value": query_toks_no_val,
            "question": inter["utterance"],
            "question_toks": inter["utterance_toks"],
            "context": copy.deepcopy(context),
            "context_toks": copy.deepcopy(context_toks),
            "sql": inter["sql"]
        }
        datas.append(new_data)
        if step > 0:
            context += " | "
            context_toks += ["|"]
        context += inter["utterance"]
        context_toks += inter["utterance_toks"]
    return datas

def convert_conversation_format_for_one_file(in_path, out_path):
    interactions = json.load(open(in_path, "r"))
    new_data_list = []
    for idx, interaction in enumerate(interactions):
        new_data = convert_format_for_one_interaction(interaction)
        new_data_list.extend(new_data)
    json.dump(new_data_list, open(out_path, "w"), indent=4)

def convert_conversation_format(dataset):
    task_num = len(os.listdir(f'./data/task_splits/{dataset}/'))
    for i in range(task_num):
        for mode in ['train', 'dev', 'test']:
            convert_conversation_format_for_one_file(f'./data/task_splits/{dataset}/task_{i}/{mode}.json',
                                                     f'./data/task_splits/{dataset}/task_{i}/{mode}.json')


def generate_SQL(data, tables):
    sql = data["sql"]
    table = tables[data["table_id"]]
    col_names = [x.replace(" ", "_") if x != "" else "NONE" for x in table["header"]]

    sel_name = col_names[sql["sel"]]
    agg_name = _agg_ops[sql["agg"]]

    query = "SELECT"
    query_toks = ["SELECT"]
    query_toks_no_value = ["SELECT"]

    if agg_name == "":
        query += " " + sel_name
        query_toks += [sel_name]
        query_toks_no_value += [sel_name]
    else:
        query += " " + agg_name + "(" + sel_name + ")"
        query_toks.extend([agg_name, "(", sel_name, ")"])
        query_toks_no_value.extend([agg_name, "(", sel_name, ")"])

    query += " FROM " + data["table_id"]
    query_toks.extend(["FROM", data["table_id"]])
    query_toks_no_value.extend(["FROM", data["table_id"]])

    if len(sql["conds"]) > 0:
        query += " WHERE "
        query_toks += ["WHERE"]
        query_toks_no_value += ["WHERE"]
    for i, (col, op, val) in enumerate(sql["conds"]):
        col_name = col_names[col]
        op_name = _cond_ops[op]
        val_name = "\"" + str(val) + "\""
        if i > 0:
            query += " AND "
            query_toks.append("AND")
            query_toks_no_value.append("AND")
        query += " ".join([col_name, op_name, val_name])
        query_toks.extend([col_name, op_name, val_name])
        query_toks_no_value.extend([col_name, op_name, "VALUE"])

    return query, query_toks, query_toks_no_value

def convert_format_for_one_data(data, tables):
    table_id = data["table_id"]
    question = data["question"]
    sql = data["sql"]

    query, query_toks, query_toks_no_value = generate_SQL(data, tables)

    new_data = {
        "db_id": table_id,
        "query": query,
        "query_toks": [x for x in query_toks],
        "query_toks_no_value": [x for x in query_toks_no_value],
        "question": question,
        "question_toks": [wordnet_lemmatizer.lemmatize(x.lower()) for x in question.split(' ') if x != ""],
        "sql": None
    }

    new_sql = copy.deepcopy(sql_template)
    new_sql["from"]["table_units"] = [["table_unit", 0]]
    new_sql["select"] = [False, [[sql["agg"], [0, [0, sql["sel"], False], None]]]]

    for i, (col, op, val) in enumerate(sql["conds"]):
        if i > 0:
            new_sql["where"].append("and")
        one_cond = [False, WHERE_OPS_IDS[_cond_ops[op]], [0, [0, col, False], None], val, None]
        new_sql["where"].append(one_cond)

    new_data["sql"] = copy.deepcopy(new_sql)
    return new_data

def convert_format_for_one_table(table):
    db_id = table["id"]
    db = {
        "column_names": [[-1, "*"]],
        "column_names_original": [[-1, "*"]],
        "column_types": [],
        "db_id": db_id,
        "foreign_keys": [],
        "primary_keys": [1],
        "table_names": [table["caption"] if "caption" in table else "NONE"],
        "table_names_original": []
    }
    for col in table["header"]:
        if col == "":
            col = "NONE"
        db["column_names"].append([0, col])
        db["column_names_original"].append([0, col])
    for _type in table["types"]:
        if _type == "real":
            db["column_types"].append("number")
        else:
            db["column_types"].append(_type)
    db["db_id"] = db_id
    return db

def convert_wikisql_format():
    convert_wikisql_format_for_one_file("./data/original_datasets/wikisql/train.jsonl",
                                        "./data/original_datasets/wikisql/train.tables.jsonl",
                                        "./data/original_datasets/wikisql/train.json")
    convert_wikisql_format_for_one_file("./data/original_datasets/wikisql/dev.jsonl",
                                        "./data/original_datasets/wikisql/dev.tables.jsonl",
                                        "./data/original_datasets/wikisql/dev.json")
    convert_wikisql_format_for_one_file("./data/original_datasets/wikisql/test.jsonl",
                                        "./data/original_datasets/wikisql/test.tables.jsonl",
                                        "./data/original_datasets/wikisql/test.json")
    convert_wikisql_table_format("./data/original_datasets/wikisql/train.tables.jsonl",
                                 "./data/original_datasets/wikisql/train.tables.json",)
    convert_wikisql_table_format("./data/original_datasets/wikisql/dev.tables.jsonl",
                                 "./data/original_datasets/wikisql/dev.tables.json", )
    convert_wikisql_table_format("./data/original_datasets/wikisql/test.tables.jsonl",
                                 "./data/original_datasets/wikisql/test.tables.json", )

def convert_wikisql_format_for_one_file(in_path, table_path, out_path):
    tables = {}
    with open(table_path, "r") as fin:
        for line in fin:
            table = json.loads(line)
            tables[table["id"]] = table
    new_data_list = []
    with open(in_path, "r") as fin:
        for line in fin:
            data = json.loads(line)
            new_data = convert_format_for_one_data(data, tables)
            new_data_list.append(new_data)
    json.dump(new_data_list, open(out_path, "w"), indent=4)

def convert_wikisql_table_format(in_path, out_path):
    new_table_list = []
    with open(in_path, "r") as fin:
        for line in fin:
            table = json.loads(line)
            new_table = convert_format_for_one_table(table)
            new_table_list.append(new_table)
    json.dump(new_table_list, open(out_path, "w"), indent=4)


if __name__ == "__main__":

    arg_parser = argparse.ArgumentParser()
    arg_parser.add_argument('--dataset', type=str, default="spider")
    args = arg_parser.parse_args()

    random.seed(2023)
    dataset = args.dataset

    if dataset == "spider":
        data_path = ["./data/original_datasets/spider/train_others.json",
                     "./data/original_datasets/spider/train_spider.json",
                     "./data/original_datasets/spider/dev.json"]
        table_path = ["./data/original_datasets/spider/tables.json"]
        schema_dist_path = "./data/schema_distribution/spider_schema.json"
        out_path = "./data/task_splits/spider"
        few_shot = False

    elif dataset == "sparc":
        data_path = ["./data/original_datasets/sparc/train.json",
                     "./data/original_datasets/sparc/dev.json"]
        table_path = ["./data/original_datasets/sparc/tables.json"]
        schema_dist_path = "./data/schema_distribution/sparc_schema.json"
        out_path = "./data/task_splits/sparc"
        few_shot = False

    elif dataset == "cosql":
        data_path = ["./data/original_datasets/cosql_dataset/sql_state_tracking/cosql_train.json",
                     "./data/original_datasets/cosql_dataset/sql_state_tracking/cosql_dev.json"]
        table_path = ["./data/original_datasets/cosql_dataset/tables.json"]
        schema_dist_path = "./data/schema_distribution/cosql_schema.json"
        out_path = "./data/task_splits/cosql"
        few_shot = False

    elif dataset == "wikisql":
        convert_wikisql_format()
        data_path = ["./data/original_datasets/wikisql/train.json",
                     "./data/original_datasets/wikisql/dev.json",
                     "./data/original_datasets/wikisql/test.json"]
        table_path = ["./data/original_datasets/wikisql/train.tables.json",
                      "./data/original_datasets/wikisql/dev.tables.json",
                      "./data/original_datasets/wikisql/test.tables.json"]
        schema_dist_path = "./data/schema_distribution/wikisql_schema.json"
        out_path = "./data/task_splits/wikisql"
        few_shot = True

    else:
        raise NotImplementedError("Not implement this dataset. ")

    schema_dist_dict = json.load(open(schema_dist_path, "r", encoding="utf-8"))

    data_dict = load_json(data_path, dataset)
    table_dict = load_json(table_path, dataset)

    task_splits = split_dataset(schema_dist_dict, table_dict, data_dict, split=[4, 1, 2], few_shot=few_shot)

    os.makedirs(out_path, exist_ok=True)

    n_table = []
    n_train = []
    n_dev = []
    n_test = []

    for i, (table, train, dev, test) in enumerate(task_splits):
        print("Task {}: Table {}, Train {}, Dev {}, Test {}.".format(i, len(table), len(train), len(dev), len(test)))
        n_table.append(len(table))
        n_train.append(len(train))
        n_dev.append(len(dev))
        n_test.append(len(test))
        # evaluate_zs(train, dev, test)
        # print()
        task_file_path = os.path.join(out_path, "task_{}".format(str(i)))
        if not os.path.exists(task_file_path):
            os.makedirs(task_file_path)
        # json.dump(table, open(os.path.join(task_file_path, "tables.json"), "w", encoding="utf-8"), indent=4, ensure_ascii=False)
        # json.dump(train, open(os.path.join(task_file_path, "train.json"), "w", encoding="utf-8"), indent=4, ensure_ascii=False)
        # json.dump(dev, open(os.path.join(task_file_path, "dev.json"), "w", encoding="utf-8"), indent=4, ensure_ascii=False)
        # json.dump(test, open(os.path.join(task_file_path, "test.json"), "w", encoding="utf-8"), indent=4, ensure_ascii=False)

    print(''.join([f'({i + 1}, {x}) ' for i, x in enumerate(n_table)]))
    print()
    print(''.join([f'({i + 1}, {x}) ' for i, x in enumerate(n_train)]))
    print()
    print(''.join([f'({i + 1}, {x}) ' for i, x in enumerate(n_dev)]))
    print()
    print(''.join([f'({i + 1}, {x}) ' for i, x in enumerate(n_test)]))
    print()

    if dataset in ['sparc', 'cosql']:
        convert_conversation_format(dataset)

