import os
import sys

os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
from tqdm import tqdm
import torch
import time
import json
import random
import openai
import argparse
import tiktoken
import math
import numpy as np
sys.path.append('..')
from utils.evaluation import evaluate


MAX_PROMPT_LEN = 3000

class ContinualAgentForGPT(object):
    def __init__(self, args):
        self.args = args
        self.tokenizer = tiktoken.encoding_for_model('text-davinci-003')

    def load_continual_tasks(self, task_path, task_perm=None, combine_K=1, few_shot=50):
        random.seed(self.args.seed)

        if any(dataset in task_path for dataset in ['spider', 'cosql', 'sparc']):
            task_0 = [task_perm[:combine_K]]
            other_tasks = [[x] for x in task_perm[combine_K:]]
            random.shuffle(other_tasks)
        elif any(dataset in task_path for dataset in ['combine1', 'combine2']):
            task_0 = [task_perm[:combine_K]]
            other_tasks_1 = [[x] for i, x in enumerate(task_perm[combine_K:]) if i % 2 == 0]
            other_tasks_2 = [[x] for i, x in enumerate(task_perm[combine_K:]) if i % 2 == 1]
            random.shuffle(other_tasks_1)
            random.shuffle(other_tasks_2)
            other_tasks = [other_tasks_1[i // 2] if i % 2 == 0 else other_tasks_2[i // 2] for i in
                           range(len(task_perm[combine_K:]))]
        else:
            raise NotImplementedError("No such dataset !")

        task_split = task_0 + other_tasks
        print("Loading from datasets ...")
        print("Shot Number: ", few_shot)
        print("Task Order: ", task_split)

        self.tasks = []
        whole_test_data = []

        def get_raw_dataset(task_ids):
            train_data = []
            dev_data = []
            test_data = []
            table_data = []
            for task_id in task_ids:

                table_data += json.load(open(task_path.format(task_id, "tables.json")))
                train_data += json.load(open(task_path.format(task_id, "train.json")))
                dev_data += json.load(open(task_path.format(task_id, "dev.json")))
                test_data += json.load(open(task_path.format(task_id, "test.json")))

            # random.shuffle(train_data)
            # random.shuffle(dev_data)
            # random.shuffle(test_data)

            return train_data, dev_data, test_data, table_data

        for i, task_ids in enumerate(task_split):
            train_data, dev_data, test_data, table_data = get_raw_dataset(task_ids)

            if few_shot > 0:
                train_data = [x for x in train_data[:few_shot * len(task_ids)]]

            whole_test_data += [x for x in test_data]

            raw_dataset = {
                'train': train_data,
                'dev': dev_data,
                'test': test_data,
                'table': table_data
            }
            self.tasks.append(raw_dataset)

    def prepare(self, mode):
        n_tokens = 0

        for i in range(len(self.tasks)):
            examples = self.tasks[i][mode]
            dbs = {x['db_id']: x for x in self.tasks[i]['table']}

            _result_path_task_i = os.path.join(self.args.result_path, self.args.dataset, f'task_{i}')
            os.makedirs(_result_path_task_i, exist_ok=True)
            result_path_task_i = os.path.join(self.args.result_path, self.args.dataset, f'task_{i}', mode)
            os.makedirs(result_path_task_i, exist_ok=True)

            prompt_path = os.path.join(result_path_task_i, 'prompt.json')

            print("Preprocessing task {} {} data".format(i, mode))
            prompts = self.make_prompt(task_id=i,
                                       examples=examples,
                                       dbs=dbs)

            n_tokens += sum(len(self.tokenizer.encode(x[-1])) for x in prompts)
            n_tokens += sum(len(self.tokenizer.encode(x['query'])) for x in examples)
            print(n_tokens)

            json.dump(prompts, open(prompt_path, 'w'))
        print(n_tokens)
        return n_tokens

    def experience(self, mode):

        for i in range(len(self.tasks)):
            if i == 0: continue

            result_path_task_i = os.path.join(self.args.result_path, f'{self.args.dataset}', f'task_{i}', mode)

            examples = self.tasks[i][mode]

            prompt_path = os.path.join(result_path_task_i, 'prompt.json')
            prompts = json.load(open(prompt_path))

            pred_sql_path = os.path.join(result_path_task_i, 'pred_sql.txt')
            gold_sql_path = os.path.join(result_path_task_i, 'gold_sql.txt')
            db_id_path = os.path.join(result_path_task_i, 'db_id.txt')

            logit_path = os.path.join(result_path_task_i, 'gpt_logits.jsonl')

            # current task
            print("Experiencing task {} {} data".format(i, mode))
            pred_sqls, gold_sqls, db_ids, pred_logits = self.predict(examples=examples,
                                                                     prompts=prompts,
                                                                     pred_sql_path=pred_sql_path,
                                                                     gold_sql_path=gold_sql_path,
                                                                     db_id_path=db_id_path,
                                                                     logit_path=logit_path)

            # self.save_sql(pred_sqls, pred_sql_path)
            # self.save_sql(gold_sqls, gold_sql_path)
            # self.save_sql(db_ids, db_id_path)
            # self.save_logits(pred_logits, logit_path)

    def evaluate(self, mode):

        for i in range(len(self.tasks[:2])):

            result_path_task_i = os.path.join(self.args.result_path, f'task_{i}', mode)

            pred_sql_path = os.path.join(result_path_task_i, 'pred_sql.txt')
            gold_sql_path = os.path.join(result_path_task_i, 'gold_sql.txt')
            db_id_path = os.path.join(result_path_task_i, 'db_id.txt')

            pred_sqls = self.load_sql(pred_sql_path)
            gold_sqls = self.load_sql(gold_sql_path)
            db_ids = self.load_sql(db_id_path)

            ex_acc, lf_acc, results = evaluate(pred_sqls, gold_sqls, db_ids, "all", dataset=self.args.dataset)

            print(i, ex_acc, lf_acc)

    def make_prompt(self, task_id, examples, dbs):
        prompts = []
        for i, example in enumerate(tqdm(examples)):

            if self.args.prompt_type == "api":
                prompt = self.make_prompt_with_api_doc(example, dbs)
            elif self.args.prompt_type == "in-context":
                prompt = self.make_prompt_in_context(task_id, example, dbs, n_context=self.args.n_context)
            else:
                raise NotImplementedError
            prompts.append((example['guid'], prompt))
        return prompts

    def predict(self, examples, prompts, pred_sql_path, gold_sql_path, db_id_path, logit_path):
        logits = []
        pred_sqls = []
        gold_sqls = []
        db_ids = []

        batches = list(zip(prompts, examples))

        # batches = batches[:5]

        pred_sql_fout = open(pred_sql_path, 'w')
        gold_sql_fout = open(gold_sql_path, 'w')
        db_id_fout = open(db_id_path, 'w')
        logit_fout = open(logit_path, 'w')

        for i in tqdm(range(0, len(batches), 10)):
            prompts_batch, examples_batch = zip(*batches[i : i + 10])

            response = self.get_gpt_response(prompts_batch)
            time.sleep(1)

            sqls_batch = self.extract_sql(response)
            logits_batch = self.extract_logits(response)

            # pred_sqls += sqls_batch
            # gold_sqls += [ex["query"] for ex in examples_batch]
            # db_ids += [ex["db_id"] for ex in examples_batch]
            # logits += logits_batch

            for i in range(len(examples_batch)):
                pred_sql_fout.write(sqls_batch[i] + '\n')
                gold_sql_fout.write(examples_batch[i]["query"] + '\n')
                db_id_fout.write(examples_batch[i]["db_id"] + '\n')
                logit_fout.write(json.dumps(logits_batch[i]) + '\n')

        pred_sql_fout.close()
        gold_sql_fout.close()
        db_id_fout.close()
        logit_fout.close()

        return pred_sqls, gold_sqls, db_ids, logits

    def get_gpt_response(self, prompts):

        while True:
            try:
                response = openai.Completion.create(
                    model="text-davinci-003",
                    prompt=prompts,
                    temperature=0,
                    max_tokens=250,
                    top_p=1.0,
                    frequency_penalty=0.0,
                    presence_penalty=0.0,
                    stop=["#", ";"],
                    logprobs=5
                )
                return response
            except:
                time.sleep(30)

    def extract_sql(self, response):
        sqls = []
        for i in range(len(response.choices)):
            sql = response.choices[i].text
            sql = sql.split("Q:")[0]
            sql = sql.split("DB:")[0]
            sql = sql.replace("\n", " ")
            sql = " ".join([x for x in sql.split(" ") if x not in ["", "\n"]])
            sqls.append(sql)
        return sqls

    def extract_logits(self, response):
        if response == 'none':
            return []
        logits = []
        for i in range(len(response.choices)):
            logits.append([x for x in response.choices[i].logprobs.top_logprobs])
        return logits

    def save_sql(self, sql, sql_path):
        with open(sql_path, 'w', encoding='utf-8') as fout:
            for one_sql in sql:
                fout.write(one_sql + '\n')

    def load_sql(self, sql_path):
        data = []
        with open(sql_path, 'r', encoding='utf-8') as fin:
            for line in fin:
                data.append(line.strip('\n'))
        return data

    def save_logits(self, logits, logits_path):
        with open(logits_path, 'w', encoding='utf-8') as fout:
            for one_logit in logits:
                fout.write(json.dumps(one_logit) + '\n')

    def make_prompt_with_api_doc(self, example, tables):
        table = tables[example["db_id"]]
        col_names = {}
        for table_id, col_name in table["column_names_original"]:
            if table_id < 0:
                continue
            if table_id not in col_names:
                col_names[table_id] = []
            col_names[table_id].append(col_name.lower().replace(" ", "_"))

        db_prompt = "### Postgres SQL tables, with their properties:\n#\n"
        for table_id, cols in col_names.items():
            col_prompt = "(" + ", ".join(cols) + ")"
            if self.args.dataset == "spider":
                table_prompt = "# " + table["table_names"][table_id]
            else:
                table_prompt = "# table_{}".format(table["db_id"].replace("-", "_"))
            db_prompt += table_prompt + col_prompt + "\n"
        db_prompt += "#\n"

        q_prompt = "### A query to: "
        q_prompt += example["question"]
        q_prompt += "\n"

        return db_prompt + q_prompt

    def make_db_prompt(self, db):
        col_names = {}
        for table_id, col_name in db["column_names_original"]:
            if table_id < 0:
                continue
            if table_id not in col_names:
                col_names[table_id] = []
            col_names[table_id].append(col_name.lower().replace(" ", "_"))

        total_len = 0
        db_prompt = []
        for table_id, cols in col_names.items():
            col_prompt = "(" + ", ".join(cols) + ")"
            if self.args.dataset == "spider":
                table_prompt = db["table_names"][table_id].replace(" ", "_")
            else:
                table_prompt = "table_{}".format(db["db_id"].replace("-", "_"))

            db_prompt.append(table_prompt + col_prompt)
            total_len += len(table_prompt + col_prompt)
        db_prompt = ", ".join(db_prompt)
        return db_prompt

    def make_prompt_in_context(self, task_id, example, dbs, n_context=3):
        example_pool ={x["guid"]: x for x in self.tasks[task_id]["train"]}
        prompt_examples = [example_pool[idx] for idx, _ in example["demonstration_examples"] if idx in example_pool]

        db = dbs[example["db_id"]]
        db_prompt = self.make_db_prompt(db)
        q_prompt = "DB: {}\nQ: {}\nSQL: ".format(db_prompt, example["question"])

        context = ""
        for _example in prompt_examples[:n_context]:
            _db = dbs[_example["db_id"]]
            # _db_prompt = self.make_db_prompt(_db)
            one_context = "Q: {}\nSQL: {}".format(_example["question"], _example["query"])
            if len(self.tokenizer.encode(context + one_context + "\n" + q_prompt)) <= MAX_PROMPT_LEN:
                context += one_context + "\n"

        prompt = context + q_prompt
        return prompt

def run_experience():
    arg_parser = argparse.ArgumentParser()

    arg_parser.add_argument('--task_path', type=str, default="data/task_splits/combine1/task_{}/{}")
    arg_parser.add_argument('--task_context_path', type=str, default="data/task_splits/combine1_context/task_{}/{}")
    arg_parser.add_argument('--result_path', type=str, default="gpt_results/")
    arg_parser.add_argument('--combine_K', type=int, default=5)
    arg_parser.add_argument('--seed', type=int, default=23)
    arg_parser.add_argument('--dataset', type=str, choices=["wikisql", "combine1", "combine2", "spider"], default="combine1")
    arg_parser.add_argument('--prompt_type', type=str, choices=["api", "in-context"], default='in-context')
    arg_parser.add_argument('--n_row', type=int, default=1)
    arg_parser.add_argument('--n_context', type=int, default=5)
    arg_parser.add_argument('--do_prep', type=bool, default=False)
    arg_parser.add_argument('--do_exp', type=bool, default=True)
    arg_parser.add_argument('--do_eval', type=bool, default=False)
    arg_parser.add_argument('--api_key', type=str, default='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')
    args = arg_parser.parse_args()

    openai.api_key = args.api_key

    os.makedirs(args.result_path, exist_ok=True)
    os.makedirs(f'{args.result_path}/{args.dataset}/', exist_ok=True)

    if any(dataset in args.task_path for dataset in ['spider', 'cosql', 'sparc']):
        args.task_perm = [11, 5, 3, 7, 12, 2, 10, 8, 6, 4, 15, 0, 1, 13, 14, 9]
    elif any(dataset in args.task_path for dataset in ['combine1']):
        args.task_perm = [3, 1, 0, 4, 5, 9, 2, 10, 7, 8, 6]
    elif any(dataset in args.task_path for dataset in ['combine2']):
        args.task_perm = [0, 1, 2, 3, 4, 5, 11, 6, 12, 7, 13, 8, 14, 9, 15, 10]
    else:
        raise NotImplementedError("No such dataset !")

    # for k, v in sorted(vars(args).items()):
    #     print(k, '=', v)

    print("Init Agent ...")
    agent = ContinualAgentForGPT(args)

    if args.do_prep:
        agent.load_continual_tasks(args.task_context_path, args.task_perm, args.combine_K, -1)
        n_token_train = agent.prepare('train')
        n_token_dev = agent.prepare('dev')
        n_token_test = agent.prepare('test')
        print(n_token_train + n_token_dev + n_token_test)

    if args.do_exp:
        agent.load_continual_tasks(args.task_path, args.task_perm, args.combine_K, -1)
        agent.experience('train')
        # agent.experience('dev')
        # agent.experience('test')

    if args.do_eval:
        agent.load_continual_tasks(args.task_path, args.task_perm, args.combine_K, -1)
        agent.evaluate('train')
        agent.evaluate('dev')
        agent.evaluate('test')


if __name__ == "__main__":
    run_experience()



