import json
import os
import re
import time
import shutil

import random
import numpy as np
import requests
import torch
import datasets
from datasets import load_dataset
from tqdm import tqdm
import evaluate
import tiktoken
from huggingface_hub import hf_hub_download
from collections import defaultdict

TASKS = ['gov_report', 'narrative_qa', 'qasper', 'qmsum', 'summ_screen_fd', 'quality', 'squality', 'musique', 'space_digest', 'book_sum_sort']
OUT_TASKS = ['gov_report', 'summ_screen_fd', 'qmsum', 'squality', 'quality', 'narrative_qa', 'qasper', 'musique', 'space_digest', 'book_sum_sort']
O_QUOTA = [900, 100, 100, 200, 600, 100, 800, 100, 100, 100]
INSTRUCTION_IDX = [2, 1, 1, 1, 1, 1, 1, 1, 1, 1]
QUESTION_IDX = [0, 2, 2, 2, 1, 3, 2, 2, 1, 1]
KEEP_FIRST = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
MAX_TOKEN = 8182
DATAS = {'gov_report': 20,
 'narrative_qa': 20,
 'qasper': 28,
 'qmsum': 20,
 'summ_screen_fd': 20,
 'quality': 21,
 'squality': 80,
 'musique': 20,
 'space_digest': 20,
 'book_sum_sort': 20}
encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")

import openai

openai.api_key = "sk-"


def get_results(predict_path: str):
    def download_metric():
        zero_scrolls_metric_path = hf_hub_download(repo_id="tau/zero_scrolls", repo_type="dataset", filename="metrics/zero_scrolls.py")
        updated_zero_scrolls_metric_path = (
            os.path.dirname(zero_scrolls_metric_path) + os.path.basename(zero_scrolls_metric_path).replace(".", "_") + ".py"
        )
        shutil.copy(zero_scrolls_metric_path, updated_zero_scrolls_metric_path)
        return updated_zero_scrolls_metric_path
    
    zero_scrolls_metric_path = download_metric()
    preds = json.load(open(predict_path))
    preds_g, refers_g = defaultdict(list), defaultdict(list)
    for v in preds.values():
        task, refer, pred = [v[k] for k in ["task", "reference", "pred"]]
        # if task == "narrative_qa":
        pred = pred.split("\n\nQuestion:", 1)[0].split("\n\nExplanation:", 1)[0].replace("<|im_end|>", "").replace("\end{document}", "").strip()
        # .split("\n\nExplanation:", 1)[0]
        if task == "space_digest":
            if pred.startswith("0.") and "%" not in pred[:4]:
                pred = "{:.2f}%".format(float(pred[:4]) * 100)
            else:
                pred = pred[:5].strip().replace("%", "") + "%"
        preds_g[task].append(pred)
        refers_g[task].append([refer])
    
    zero_scrolls = []
    for task in OUT_TASKS:
        if task not in preds_g:
            zero_scrolls.append(0)
            continue
        p, r = preds_g[task], refers_g[task]
        zero_scrolls_metric = datasets.load_metric(zero_scrolls_metric_path, task)
        results = zero_scrolls_metric.compute(predictions=p, references=r)
        print(task, len(p), results)
        zero_scrolls.append(results["zero_scrolls_score"])
    print(",".join([f"{ii:.2f}" for ii in zero_scrolls]))


def random_select_sentence(prompt, tokenizer, target_token):
    sentences = prompt.split("\n")
    N = len(sentences)
    sentences_tokens = [len(tokenizer(s).input_ids) - 1 for s in sentences]
    flags = [True] * N
    now = sum(sentences_tokens)
    while True:
        idx = random.randint(0, N - 1)
        if flags[idx] and sentences_tokens[idx]:
            if now - sentences_tokens[idx] <= target_token:
                break
            flags[idx] = False
            now -= sentences_tokens[idx]
    return "\n".join([sentences[ii] for ii in range(N) if flags[ii]])

def get_compressed_prompt(prompt: str, target_tokens: int, method: str, task: str, use_sentence_level_filter: bool = True):
    task_idx = TASKS.index(task)
    if method == "ours":
        prompt_sentence = prompt.split("\n\n")
        ins_idx, que_ins, keep_first = INSTRUCTION_IDX[task_idx], QUESTION_IDX[task_idx], KEEP_FIRST[task_idx]
        ins, prompt, que = "\n\n".join(prompt_sentence[:ins_idx]), "\n\n".join(prompt_sentence[ins_idx:len(prompt_sentence) - que_ins]), "\n\n".join(prompt_sentence[len(prompt_sentence) - que_ins:])
        if keep_first == 1:
            ins = ins + "\n\n" + prompt.split("\n")[0]
            prompt = prompt.split("\n", 1)[1].split("\n\n")
            additional_quota = len(prompt) * 2 * 0.8
        else:
            prompt, additional_quota = [prompt], 0

        true_target = target_tokens - len(tokenizer(ins).input_ids) - len(tokenizer(que).input_ids)
        tokens = true_target

        context = llm_lingua(prompt, "", que, target_token=target_tokens, use_sentence_level_filter=use_sentence_level_filter, condition_in_question="after_condition", reorder_demonstrations=False, dynamic_demonstration_compression_ratio=0.4, condition_compare=False, concate_question=True, demonstration_budget="+200", use_demonstrate_level_filter=True, use_token_level_filter=True, rank_method="longllmlingua", token_budget_ratio=1.2, keep_split=keep_first == 1, keep_sentence_number= 1 if task == "space_digest" else 0)
        # while tokens > 3:
        #     context = pc.control_budget(
        #         "", prompt, "", target_token=tokens, iterative_size=100,
        #         use_sentence_level_filter=use_sentence_level_filter, use_demonstrate_level_filter=False,
        #         keep_split=keep_first == 1, keep_sentence_number= 1 if task == "space_digest" else 0,
        #     )
        #     print(true_target, context[0].shape[1])
        #     if context[0].shape[1] < true_target + 50 + additional_quota:
        #         break
        #     tokens -= 20
        #     break
        # context = pc.tokenizer.batch_decode(context[0])[0].replace("<s> ", "")[2:]
        context = context["compressed_prompt"]
        if keep_first == 1:
            compressed = context.split("\n\n")
            for idx in range(len(prompt)):
                a, b = prompt[idx].split("\n", 1), compressed[idx].split("\n", 1)
                b[0] = a[0]
                compressed[idx] = "\n".join(b)
            context = ins + "\n" + "\n\n".join(compressed) + "\n\n" + que
        else:
            context = ins + "\n\n" + context + "\n\n" + que
    elif method == "sc":
        tokens = len(tokenizer(prompt).input_ids)
        threshold = max(1 - target_tokens / tokens, 0)
        while threshold < 0.99:
            context, reduced_content = sc(prompt, reduce_ratio = threshold)
            print(target_tokens, len(tokenizer(context).input_ids))
            if len(tokenizer(context).input_ids) < target_tokens + 50:
                break
            threshold += 0.01
        context = context.replace(" <s> ", "\n").replace("<s> ", "")
    elif method == "random":
        tokens = len(tokenizer(prompt).input_ids)
        context = random_select_sentence(prompt, tokenizer, target_tokens)
    return context
            
def get_zero_scrolls(data, task, return_all: bool = True):
    task_idx = TASKS.index(task)
    
    def split_text(s):
        if QUESTION_IDX[task_idx] == 0:
            pre, end = s, ""
        else:
            s_list = s.split("\n\n")
            pre, end = "\n\n".join(s_list[:-QUESTION_IDX[task_idx]]), "\n\n".join(s_list[-QUESTION_IDX[task_idx]:])
        if return_all:
            return s, end
        target_tokens = MAX_TOKEN - O_QUOTA[task_idx] - len(encoding.encode(end))
        pre = encoding.decode(encoding.encode(pre)[:target_tokens])
        return pre + "\n\n" + end, end
    return split_text(data["input"]), data["output"]

def get_input_output_token(): 
    def get_token_len(x):
        return len(encoding.encode(x))

    for task in TASKS:
        dataset = load_dataset("tau/zero_scrolls", task)["validation"]
        i_list = [get_token_len(ii["input"]) for ii in dataset]
        o_list = [get_token_len(ii["output"]) for ii in dataset]
        print(task, max(i_list), max(o_list))

def get_dataset(dataset):
    res = []
    for task in TASKS:
        dataset = load_dataset("tau/zero_scrolls", task)["validation"]
        for ii, jj in enumerate(dataset):
            prompt, output = get_zero_scrolls(jj, task)
            res.append({"id": ii, "task": task, "prompt": prompt, "output": output})
    json.dump(res, open("prompt/zero_scrolls/full.json", "w"))

def get_dataset_ours(dataset, target_token):
    res = []
    for task in TASKS:
        dataset = load_dataset("tau/zero_scrolls", task)["validation"]
        for ii, jj in tqdm(enumerate(dataset), total=len(dataset)):
            (prompt, question), output = get_zero_scrolls(jj, task)
            prompt = get_compressed_prompt(prompt, 4096, "ours", task)
            res.append({"id": ii, "task": task, "prompt": prompt, "output": output})
    json.dump(res, open("prompt/zero_scrolls/ours_4096.json", "w"))

def get_dataset_baseline(dataset, target_token):
    res = []
    for task in TASKS:
        dataset = load_dataset("tau/zero_scrolls", task)["validation"]
        for ii, jj in tqdm(enumerate(dataset), total=len(dataset)):
            (prompt, question), output = get_zero_scrolls(jj, task)
            if not question:
                question = encoding.decode(encoding.encode(prompt)[:200])
            res.append({"id": ii, "task": task, "prompt": question, "output": output})
    json.dump(res, open("prompt/zero_scrolls/zero_shot.json", "w"))

def get_dataset_baseline(dataset, target_token):
    res = []
    for task in TASKS:
        dataset = load_dataset("tau/zero_scrolls", task)["validation"]
        for ii, jj in tqdm(enumerate(dataset), total=len(dataset)):
            (prompt, question), output = get_zero_scrolls(jj, task)
            context = llm_lingua(prompt, "", question, target_token=2100, use_sentence_level_filter=True, condition_in_question="after", reorder_demonstrations=False, dynamic_demonstration_compression_ratio=0, condition_compare=False, concate_question=False, demonstration_budget="+0", use_demonstrate_level_filter=True, use_token_level_filter=False, rank_method="longllmlingua", token_budget_ratio=1.0)
            # prompt = get_compressed_prompt(prompt, 2000, "ours", task)
            res.append({"id": ii, "task": task, "prompt": context["compressed_prompt"], "output": output})
    json.dump(res, open("prompt/zero_scrolls/longllmlingua_2k.json", "w"))

def get_dataset_baseline(dataset, target_token):
    res = []
    for task in TASKS:
        dataset = load_dataset("tau/zero_scrolls", task)["validation"]
        for ii, jj in tqdm(enumerate(dataset), total=len(dataset)):
            (prompt, question), output = get_zero_scrolls(jj, task)
            context = llm_lingua(prompt, "", question, target_token=2100, use_sentence_level_filter=True, condition_in_question="after", reorder_demonstrations=False, dynamic_demonstration_compression_ratio=0, condition_compare=False, concate_question=False, demonstration_budget="+0", use_demonstrate_level_filter=True, use_token_level_filter=False, rank_method="longllmlingua", token_budget_ratio=1.0)
            # prompt = get_compressed_prompt(prompt, 2000, "ours", task)
            res.append({"id": ii, "task": task, "prompt": context["compressed_prompt"], "output": output})
    json.dump(res, open("prompt/zero_scrolls/longllmlingua_2k.json", "w"))

def get_dataset_sc(dataset, threshold):
    res = []
    for task in TASKS:
        dataset = load_dataset("tau/zero_scrolls", task)["validation"]
        for ii, jj in tqdm(enumerate(dataset), total=len(dataset)):
            (prompt, question), output = get_zero_scrolls(jj, task)
            o1 = len(tokenizer.encode(prompt))

            try:
                y = sc(prompt, reduce_ratio=max(min(1, 1 - 2100/o1), 0))
                compressed_prompt = y[0].replace("<s>", "").strip()
            except:
                compressed_prompt = encoding.decode(encoding.encode(prompt)[-2100:])
            res.append({"id": ii, "task": task, "prompt": compressed_prompt, "output": output})
    json.dump(res, open("prompt/zero_scrolls/sc_2k.json", "w"))

    

def get_dataset_random(dataset, threshold):
    random.seed(171)
    res = []
    for ii, jj in enumerate(dataset):
        prompt = get_zero_scrolls(jj)
        prompt = get_compressed_prompt(prompt, 400, "random")
        res.append({"id": ii, "prompt": prompt})
    json.dump(res, open("prompt/zero_scrolls/random_400.json", "w"))


model = "gpt-3.5-turbo"
done, res = set(), {}
# threshold = "full"
threshold = "ours_2048"
# threshold = "ours_4000_sf_full_doc_1.4_f0_e0_after_in_condiation_codellamainstruct"
data = json.load(open(f"prompt/zero_scrolls/{threshold}.json"))

# output_path = f'outputs/gpt_3.5_turbo_complex_temp0_{threshold}_LLaMA_Alpaca.txt'
# output_path = f"outputs_zero_scrolls/{threshold}_middle_chat.json"
output_path = f"outputs_zero_scrolls/{threshold}.json"
if os.path.exists(output_path.replace(".json", ".do")):
    done = torch.load(output_path.replace(".json", ".do"))
if os.path.exists(output_path):
    res = json.load(open(output_path))

for demonstration in tqdm(
    data, total=len(data)
):  
    idx, prompt, task, output = [demonstration[k] for k in ["id", "prompt", "task", "output"]]
    task_idx = TASKS.index(task)
    # if (idx, task) in done or idx not in idxs[task]:
    if (idx, task) in done:
        continue

    request_data = {
        # "messages": message,
        "prompt": prompt,
        "max_tokens": O_QUOTA[task_idx],
        "temperature": 0,
        "top_p": 1,
        "n": 1,
        "stream": False,
        "stop": "\r\n",
    }

    response = openai.Completion.create(
        model="gpt-3.5-turbo",
        **request_data,
    )

    torch.save(response, "debug.pt")
    if "choices" not in response:
        print("TOO FAST")
        if "error" in response:
            print(response["error"]["message"])
        time.sleep(62)
        continue
    # ans_model = response["choices"][0]["message"]["content"]
    ans_model = response["choices"][0]["text"]
    print(idx, task, ans_model.replace("\n", " "))
    res[f"{idx},{task}"] = {
        "idx": idx,
        "task": task,
        "pred": ans_model,
        "reference": output,
    }
    done.add((idx, task))
    torch.save(done, output_path.replace(".json", ".do"))
    json.dump(res, open(output_path, "w"))
    time.sleep(8)