import bisect
from collections import defaultdict
from typing import List

import numpy as np
import torch

import nltk
import tiktoken
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer


encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")


class PromptCompressor:
    def __init__(
        self,
        model_name: str = "NousResearch/Llama-2-7b-hf",
        device_map: str = "cuda",
        use_auth_token: bool = False,
        open_api_config: dict = {},
    ):
        self.load_model(model_name, device_map, use_auth_token)
        self.sbert = None
        self.open_api_config = open_api_config

    def load_model(
        self, model_name: str, device_map: str = "cuda", use_auth_token: bool = False
    ):
        config = AutoConfig.from_pretrained(model_name)
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        tokenizer.padding_side = "left"
        tokenizer.pad_token_id = (
            config.pad_token_id if config.pad_token_id else tokenizer.eos_token_id
        )
        self.device = (
            device_map if any(key in device_map for key in ["cuda", "cpu"]) else "cuda"
        )
        if "cuda" in device_map or "cpu" in device_map:
            model = AutoModelForCausalLM.from_pretrained(
                model_name,
                torch_dtype="auto",
                config=config,
                ignore_mismatched_sizes=True,
            ).to(device_map)
            if device_map == "cpu":
                model = model.type(torch.float32)
        else:
            model = AutoModelForCausalLM.from_pretrained(
                model_name,
                device_map=device_map,
                torch_dtype="auto",
                pad_token_id=tokenizer.pad_token_id,
                offload_folder="/tmp/offload",
                offload_state_dict=True,
                cache_dir="/tmp/cache",
                use_auth_token=use_auth_token,
            )
        self.tokenizer = tokenizer
        self.model = model
        self.demonstration_idxs = []

    def get_ppl(
        self,
        text: str,
        granularity: str = "sentence",
        input_ids=None,
        attention_mask=None,
        past_key_values=None,
        return_kv=False,
        end=None,
        condition_mode: str = "none",
        condition_pos_id: int = 0,
    ):
        if input_ids is None:
            tokenized_text = self.tokenizer(text, return_tensors="pt")
            input_ids = tokenized_text["input_ids"].to(self.device)
            attention_mask = tokenized_text["attention_mask"].to(self.device)
        if past_key_values is not None:
            past_length = past_key_values[0][0].shape[2]
        else:
            past_length = 0
        if end is None:
            end = input_ids.shape[1]
        end = min(end, past_length + 4096)
        with torch.no_grad():
            response = self.model(
                input_ids[:, past_length:end],
                attention_mask=attention_mask[:, :end],
                past_key_values=past_key_values,
                use_cache=True,
            )
            past_key_values = response.past_key_values

        loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
        shift_logits = response.logits[..., :-1, :].contiguous()
        shift_labels = input_ids[..., past_length + 1 : end].contiguous()
        # Flatten the tokens
        active = (attention_mask[:, past_length:end] == 1)[..., :-1].view(-1)
        active_logits = shift_logits.view(-1, shift_logits.size(-1))[active]
        active_labels = shift_labels.view(-1)[active]
        loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
        loss = loss_fct(active_logits, active_labels)
        if condition_mode == "before":
            loss = loss[:condition_pos_id]
        elif condition_mode == "after":
            loss = loss[condition_pos_id:]
        res = loss.mean() if granularity == "sentence" else loss
        return (res, past_key_values) if return_kv else res

    def __call__(self, *args, **kwargs):
        return self.compress_prompt(*args, **kwargs)

    def compress_prompt(
        self,
        demonstrations: List[str],
        instruction: str = "",
        question: str = "",
        ratio: float = 0.5,
        target_token: float = -1,
        iterative_size: int = 200,
        length_ratio: float = 0.0,
        force_demonstrations_ids: List[int] = None,
        force_demonstrations_number: int = None,
        use_sentence_level_filter: bool = False,
        use_demonstrate_level_filter: bool = True,
        use_token_level_filter: bool = True,
        keep_split: bool = False,
        keep_first_sentence: int = 0,
        keep_last_sentence: int = 0,
        keep_sentence_number: int = 0,
        high_priority_bonus: int = 100,
        demonstration_budget: str = "+100",
        token_budget_ratio: float = 1.4,
        condition_in_question: str = "none",
        reorder_demonstrations: str = "sort",
        dynamic_demonstration_compression_ratio: float = 0.0,
        condition_compare: bool = False,
        add_instruction: bool = False,
        rank_method: str = "longllmlingua",
        concate_question: bool = True,
    ):
        if isinstance(demonstrations, str):
            demonstrations = [demonstrations]
        origin_tokens = len(
            encoding.encode(
                "\n\n".join([instruction] + demonstrations + [question]).strip()
            )
        )
        demonstrations_tokens_length = [
            self.get_token_length(demonstration) for demonstration in demonstrations
        ]
        instruction_tokens_length, question_tokens_length = self.get_token_length(
            instruction
        ), self.get_token_length(question)
        if target_token == -1:
            target_token = (
                (
                    instruction_tokens_length
                    + question_tokens_length
                    + sum(demonstrations_tokens_length)
                )
                * (1 - ratio)
                - instruction_tokens_length
                - (question_tokens_length if concate_question else 0)
            )
        condition_flag = "_condition" in condition_in_question
        condition_in_question = condition_in_question.replace("_condition", "")

        if len(demonstrations) > 1 and use_demonstrate_level_filter:
            demonstrations, dynamic_ratio = self.control_demonstrations_budget(
                demonstrations,
                demonstrations_tokens_length,
                target_token,
                length_ratio,
                force_demonstrations_ids,
                force_demonstrations_number,
                question,
                condition_in_question,
                reorder_demonstrations=reorder_demonstrations,
                dynamic_demonstration_compression_ratio=dynamic_demonstration_compression_ratio,
                rank_method=rank_method,
                demonstration_budget=demonstration_budget,
            )
        else:
            dynamic_ratio = [0.0] * len(demonstrations)

        if use_sentence_level_filter:
            demonstrations = self.control_sentence_budget(
                demonstrations,
                target_token,
                keep_first_sentence=keep_first_sentence,
                keep_last_sentence=keep_last_sentence,
                keep_sentence_number=keep_sentence_number,
                high_priority_bonus=high_priority_bonus,
                token_budget_ratio=token_budget_ratio,
                question=question,
                condition_in_question=condition_in_question,
                rank_method=rank_method,
            )

        if condition_flag:
            if add_instruction:
                demonstrations = [question + "\n\n" + instruction] + demonstrations
                start = self.get_token_length(question + "\n\n" + instruction) + 2
            else:
                demonstrations = [question] + demonstrations
                start = self.get_token_length(question) + 2
        else:
            start = 0

        if use_token_level_filter:
            demonstrations = self.iterative_compress_prompt(
                demonstrations,
                target_token,
                iterative_size=iterative_size,
                keep_split=keep_split,
                start=start,
                dynamic_ratio=dynamic_ratio,
                condition_compare=condition_compare,
            )
            context = self.tokenizer.batch_decode(demonstrations[0])[0].replace(
                "<s> ", ""
            )
        else:
            context = "\n\n".join(demonstrations)

        if instruction:
            context = instruction + "\n\n" + context
        if question and concate_question:
            context = context + "\n\n" + question

        compressed_tokens = len(encoding.encode(context))
        saving = (origin_tokens - compressed_tokens) * 0.06 / 1000
        return {
            "compressed_prompt": context,
            "origin_tokens": origin_tokens,
            "compressed_tokens": compressed_tokens,
            "ratio": f"{origin_tokens/compressed_tokens:.1f}x",
            "saving": f", Saving ${saving:.1f} in GPT-4.",
        }

    def get_token_length(self, text: str, add_special_tokens: bool = True):
        return len(
            self.tokenizer(text, add_special_tokens=add_special_tokens).input_ids
        )

    def get_condition_ppl(
        self,
        text: str,
        question: str,
        condition_in_question: str = "none",
        granularity: str = "sentence",
    ):
        if condition_in_question == "none":
            return self.get_ppl(text, granularity=granularity)
        elif condition_in_question == "before":
            return self.get_ppl(
                question + text,
                granularity=granularity,
                condition_mode="after",
                condition_pos_id=self.get_token_length(question) - 1,
            )
        elif condition_in_question == "after":
            return self.get_ppl(
                text + question,
                granularity=granularity,
                condition_mode="after",
                condition_pos_id=self.get_token_length(text) - 1,
            )

    def get_dynamic_compression_ratio(
        self,
        demonstrations: list,
        target_token: float,
        iterative_size: int,
        dynamic_ratio: list,
        start: int,
    ):
        def get_ratio(base: float, delta: float):
            return max(min(1, base + delta), 0)

        demonstrations_length = [
            self.get_token_length(ii, False) + 2 for ii in demonstrations
        ]
        if start:
            demonstrations_length = demonstrations_length[1:]
        tau = target_token / (sum(demonstrations_length) + 1)
        res, idx, last, last_target = [], 0, 1, []
        while idx < len(demonstrations_length):
            if last + demonstrations_length[idx] >= iterative_size:
                last_target.append(
                    (iterative_size - last, get_ratio(tau, dynamic_ratio[idx]))
                )
                res.append(last_target)
                last = last + demonstrations_length[idx] - iterative_size
                if last > iterative_size:
                    k = last // iterative_size
                    res.extend(
                        [[(iterative_size, get_ratio(tau, dynamic_ratio[idx]))]] * k
                    )
                    last -= k * iterative_size

                last_target = (
                    [(last, get_ratio(tau, dynamic_ratio[idx]))] if last else []
                )
            else:
                last += demonstrations_length[idx]
                last_target.append(
                    (demonstrations_length[idx], get_ratio(tau, dynamic_ratio[idx]))
                )
            idx += 1
        if last_target:
            res.append(last_target)
        return res

    def control_demonstrations_budget(
        self,
        demonstrations: List[str],
        demonstrations_tokens_length: List[int],
        target_token: float,
        length_ratio: float = 0.0,
        force_demonstrations_ids: List[int] = None,
        force_demonstrations_number: int = None,
        question: str = "",
        condition_in_question: str = "none",
        reorder_demonstrations: str = "sort",
        dynamic_demonstration_compression_ratio: float = 0.0,
        rank_method: str = "longllmlingua",
        demonstration_budget: str = "+100",
    ):
        if force_demonstrations_ids is not None:
            return [demonstrations[ii] for ii in force_demonstrations_ids]
        demostrations_sort = self.get_rank_results(
            demonstrations,
            question,
            rank_method,
            condition_in_question,
            demonstrations_tokens_length,
        )

        if target_token < 0:
            target_token = 100
        target_token = eval("target_token" + demonstration_budget)
        res = []
        used = force_demonstrations_ids if force_demonstrations_ids is not None else []

        self.demonstration_idxs.append(
            [x for idx, (x, _) in enumerate(demostrations_sort)]
        )
        for idx, _ in demostrations_sort:
            if idx >= len(demonstrations_tokens_length):
                continue
            target_token -= demonstrations_tokens_length[idx]
            if idx not in used:
                used.append(idx)
            if target_token < 0 or (
                force_demonstrations_number is not None
                and len(res) >= force_demonstrations_number
            ):
                break
        original_used = used
        if reorder_demonstrations == "original":
            used = sorted(used)
        elif reorder_demonstrations == "two_stage":
            l, r = [_ for idx, _ in enumerate(used) if idx % 2 == 0], [
                _ for idx, _ in enumerate(used) if idx % 2 == 1
            ]
            used = l + r[::-1]

        if dynamic_demonstration_compression_ratio > 0:
            N = len(used)
            if condition_in_question:
                rank = [
                    i
                    for i, _ in self.get_rank_results(
                        demonstrations,
                        question,
                        "longllmlingua",
                        "after",
                        demonstrations_tokens_length,
                    )
                ]
                used = sorted(used, key=lambda x: rank.index(x))
            dynamic_ratio = [
                i * (abs(dynamic_demonstration_compression_ratio) / (N - 1))
                if N > 1
                else 0
                for i in range(-(N - 1), N, 2)
            ][::-1]
            dynamic_ratio_map = {i: j for i, j in zip(original_used, dynamic_ratio)}
            dynamic_ratio = [dynamic_ratio_map[i] for i in used]
        else:
            dynamic_ratio = [0.0] * len(used)

        res = [demonstrations[idx] for idx in used if idx < len(demonstrations)]
        return res, dynamic_ratio

    def control_sentence_budget(
        self,
        demonstrations: List[str],
        target_token: float,
        keep_first_sentence: int = 0,
        keep_last_sentence: int = 0,
        keep_sentence_number: int = 0,
        high_priority_bonus: int = 100,
        token_budget_ratio: float = 1.4,
        question: str = "",
        condition_in_question: str = "none",
        rank_method: str = "longllmlingua",
    ):
        def keep_sentence(dem_idx: int, sent_keep: int):
            idxs = sorted(dem_g[dem_idx], key=lambda x: sentence_ppl[x])[:sent_keep]
            for idx in idxs:
                sentence_ppl[idx] += high_priority_bonus

        sentences = [
            nltk.sent_tokenize(demonstration) for demonstration in demonstrations
        ]
        dem_g, s2de, idx = defaultdict(set), defaultdict(int), 0
        for idx_d, s in enumerate(sentences):
            for _ in s:
                dem_g[idx_d].add(idx)
                s2de[idx] = idx_d
                idx += 1

        demonstrations_sentences = [s for ii in sentences for s in ii]
        sentence_tokens_length = [
            self.get_token_length(sentence) for sentence in demonstrations_sentences
        ]
        N = len(demonstrations_sentences)
        flags = list(range(len(demonstrations_sentences)))
        if len(sentence_tokens_length) == 1:
            return demonstrations
        if rank_method == "longllmlingua":
            sentence_ppl = [
                self.get_condition_ppl(sentence, question, condition_in_question)
                .cpu()
                .numpy()
                .item()
                for sentence in demonstrations_sentences
            ]
            if keep_first_sentence:
                sentence_ppl[:keep_first_sentence] = [
                    ii + high_priority_bonus
                    for ii in sentence_ppl[:keep_first_sentence]
                ]
            if keep_last_sentence:
                sentence_ppl[-keep_last_sentence:] = [
                    ii + high_priority_bonus
                    for ii in sentence_ppl[-keep_last_sentence:]
                ]
            if keep_sentence_number:
                for dem_idx in range(len(sentences)):
                    keep_sentence(dem_idx, keep_sentence_number)
            sort_direct = -1 if condition_in_question == "none" else 1
            sent_sort = sorted(
                enumerate(sentence_ppl), key=lambda x: sort_direct * x[1]
            )
        else:
            sent_sort = self.get_rank_results(
                demonstrations_sentences,
                question,
                rank_method,
                condition_in_question,
                [0] * len(demonstrations_sentences),
            )

        sentence_flags = [False] * N
        if target_token < 0:
            target_token = 100
        target_token *= token_budget_ratio
        res = []
        for idx, _ in sent_sort:
            idx = flags[idx]
            target_token -= sentence_tokens_length[idx]
            sentence_flags[idx] = True
            if target_token < 0:
                break
        idx = 0
        res = []
        for s in sentences:
            tmp = [jj for ii, jj in enumerate(s) if sentence_flags[idx + ii]]
            res.append("\n".join(tmp))
            idx += len(s)
        return res

    def get_compressed_input(
        self,
        loss,
        input_ids,
        attention_mask,
        end=200,
        iterative_size=200,
        threshold=0.5,
        keep_flag=None,
        split_token_id: int = 13,
        start: int = 0,
        self_loss=None,
        self_input_ids=None,
        self_attention_mask=None,
    ):
        if self_loss is not None:
            need_idx = torch.concat(
                [
                    loss[:start] > 0,
                    self_loss[: loss[start:].shape[0]] - loss[start:] > threshold,
                    loss[:1] > 0,
                ]
            )
        else:
            need_idx = torch.concat([loss > threshold, loss[:1] > 0])
        need_idx[end:] = 1
        need_idx[: end - iterative_size] = 1
        loss = loss[need_idx[:-1]]
        if self_loss is not None:
            if need_idx.shape[0] < self_loss.shape[0] + start + 1:
                need_idx = torch.cat(
                    [
                        need_idx,
                        torch.ones(
                            self_loss.shape[0] - need_idx.shape[0] + start + 1,
                            dtype=torch.bool,
                        ).to(need_idx.device),
                    ]
                )
            self_loss = self_loss[need_idx[start:-1]]

        if need_idx.shape[0] < input_ids.shape[1]:
            need_idx = torch.cat(
                [
                    need_idx,
                    torch.ones(
                        input_ids.shape[1] - need_idx.shape[0], dtype=torch.bool
                    ).to(need_idx.device),
                ]
            )
        elif need_idx.shape[0] > input_ids.shape[1]:
            need_idx = need_idx[: input_ids.shape[1]]

        if keep_flag is not None:
            need_idx[keep_flag == 1] = 1
        last = -1
        if keep_flag is not None:
            for ii in range(end - iterative_size, end):
                if need_idx[ii] != 1:
                    continue
                now = input_ids[0][ii].detach().cpu().item()
                if (
                    now == split_token_id
                    and last == split_token_id
                    and keep_flag[ii].detach().cpu().item() == 0
                ):
                    need_idx[ii] = 0
                else:
                    last = now
        compressed_input_ids = input_ids[attention_mask == 1][need_idx].unsqueeze(0)
        compressed_attention_mask = attention_mask[attention_mask == 1][
            need_idx
        ].unsqueeze(0)

        if self_loss is not None:
            self_compressed_input_ids = self_input_ids[self_attention_mask == 1][
                need_idx[start:]
            ].unsqueeze(0)
            self_compressed_attention_mask = self_attention_mask[
                self_attention_mask == 1
            ][need_idx[start:]].unsqueeze(0)
        else:
            self_compressed_input_ids, self_compressed_attention_mask = None, None
        if keep_flag is not None:
            if len(keep_flag) > len(need_idx):
                keep_flag = torch.cat(
                    [
                        keep_flag[:start],
                        keep_flag[start : len(need_idx) + start][need_idx],
                        keep_flag[start + len(need_idx) :],
                    ]
                )
            else:
                keep_flag = keep_flag[need_idx]
        end -= (need_idx[:end] == 0).sum()
        return (
            compressed_input_ids,
            compressed_attention_mask,
            keep_flag,
            end,
            loss,
            self_loss,
            self_compressed_input_ids,
            self_compressed_attention_mask,
        )

    def get_estimate_threshold_base_distribution(
        self, ppl, ratio: float, condition_flag: bool = False
    ):
        ppl = ppl[ppl != 10000]
        target_token = max(0, min(len(ppl) - 1, int(len(ppl) * ratio) - 1))
        return (
            ppl.sort(descending=not condition_flag)
            .values[target_token]
            .detach()
            .cpu()
            .item()
        )

    def iterative_compress_prompt(
        self,
        demonstrations: List[str],
        target_token: float,
        iterative_size: int = 200,
        keep_split: bool = False,
        split_token_id: int = 13,
        start: int = 0,
        dynamic_ratio: list = None,
        condition_compare: bool = False,
    ):
        iterative_ratios = self.get_dynamic_compression_ratio(
            demonstrations, target_token, iterative_size, dynamic_ratio, start
        )
        demonstrations = "\n\n".join(demonstrations)
        tokenized_text = self.tokenizer(demonstrations, return_tensors="pt")
        input_ids = tokenized_text["input_ids"].to(self.device)
        attention_mask = tokenized_text["attention_mask"].to(self.device)

        N = (attention_mask == 1).sum()
        compressed_input_ids, compressed_attention_mask = input_ids, attention_mask
        if condition_compare:
            self_input_ids, self_attention_mask = (
                input_ids[:, start:],
                attention_mask[:, start:],
            )
            self_compressed_input_ids, self_compressed_attention_mask = (
                self_input_ids,
                self_attention_mask,
            )

        end = min(iterative_size + start, compressed_input_ids.shape[1])
        threshold, keep_flag = None, None
        if keep_split:
            input_ids_numpy = input_ids.cpu().detach().numpy()[0]
            N = len(input_ids_numpy)
            keep_flag = [
                int(
                    (
                        ii > 0
                        and input_ids_numpy[ii] == split_token_id
                        and input_ids_numpy[ii - 1] == split_token_id
                    )
                    or (
                        ii < N - 1
                        and input_ids_numpy[ii] == split_token_id
                        and input_ids_numpy[ii + 1] == split_token_id
                    )
                )
                for ii in range(N)
            ]
            keep_flag = torch.tensor(keep_flag).to(self.device)
        past_key_values, past_loss, ready_end = None, None, 0
        self_past_key_values, self_past_loss, self_ready_end = None, None, 0
        idx = 0
        while end <= compressed_input_ids.shape[1]:
            loss, past_key_values = self.get_ppl(
                "",
                "token",
                compressed_input_ids,
                compressed_attention_mask,
                past_key_values=past_key_values,
                return_kv=True,
                end=end if idx else None,
            )
            if past_loss is not None:
                if end - 1 > len(past_loss):
                    past_loss = torch.cat(
                        [past_loss, torch.zeros_like(loss)[: end - 1 - len(past_loss)]]
                    )
                past_loss[ready_end : end - 1] = loss
                loss = past_loss
            else:
                past_loss = loss
            if idx:
                past_key_values = [
                    [k[:, :, : end - iterative_size], v[:, :, : end - iterative_size]]
                    for k, v in past_key_values
                ]
            else:
                past_key_values = None

            if condition_compare:
                self_loss, self_past_key_values = self.get_ppl(
                    "",
                    "token",
                    self_compressed_input_ids,
                    self_compressed_attention_mask,
                    past_key_values=self_past_key_values,
                    return_kv=True,
                    end=end - start if idx else None,
                )
                if self_past_loss is not None:
                    if end - start - 1 > len(self_past_loss):
                        self_past_loss = torch.cat(
                            [
                                self_past_loss,
                                torch.zeros_like(self_loss)[
                                    : end - 1 - start - len(self_past_loss)
                                ],
                            ]
                        )
                    self_past_loss[self_ready_end : end - start - 1] = self_loss
                    self_loss = self_past_loss
                else:
                    self_past_loss = self_loss
                if idx:
                    self_past_key_values = [
                        [
                            k[:, :, : end - iterative_size - start],
                            v[:, :, : end - iterative_size - start],
                        ]
                        for k, v in self_past_key_values
                    ]
                else:
                    self_past_key_values = None

                self_ready_end = (
                    end - start - iterative_size if not (start and idx == 0) else 0
                )
            ready_end = end - iterative_size if not (start and idx == 0) else 0

            for delta_end, ratio in iterative_ratios[idx]:
                loss = past_loss
                if condition_compare:
                    self_loss = self_past_loss
                    threshold = self.get_estimate_threshold_base_distribution(
                        self_loss[: loss[start:].shape[0]] - loss[start:], ratio, False
                    )
                else:
                    threshold = self.get_estimate_threshold_base_distribution(
                        loss, ratio, False
                    )

                (
                    compressed_input_ids,
                    compressed_attention_mask,
                    keep_flag,
                    end,
                    past_loss,
                    self_past_loss,
                    self_compressed_input_ids,
                    self_compressed_attention_mask,
                ) = self.get_compressed_input(
                    loss,
                    compressed_input_ids,
                    compressed_attention_mask,
                    end - iterative_size + delta_end,
                    iterative_size=delta_end,
                    threshold=threshold,
                    keep_flag=keep_flag,
                    split_token_id=split_token_id,
                    start=start,
                    self_loss=self_loss if condition_compare else None,
                    self_input_ids=self_compressed_input_ids
                    if condition_compare
                    else None,
                    self_attention_mask=self_compressed_attention_mask
                    if condition_compare
                    else None,
                )
                end += iterative_size
            idx += 1
        return compressed_input_ids[:, start:], compressed_attention_mask[:, start:]

    def recover(
        self,
        original_prompt: str,
        compressed_prompt: str,
        response: str,
    ):
        def match_from_compressed(response_word):
            response_input_ids = self.tokenizer(
                response_word, add_special_tokens=False
            )["input_ids"]
            response_set, response_c = set(response_input_ids), defaultdict(list)
            for idx in range(M):
                if original_input_ids[idx] in response_set:
                    response_c[original_input_ids[idx]].append(idx)
            res, res_min, res_c = None, float("inf"), 1
            n = len(response_input_ids)
            for l in response_c[response_input_ids[0]]:
                x, y, c = 0, l, 1
                for x in range(1, n):
                    idx = bisect.bisect_right(response_c[response_input_ids[x]], y)
                    if (
                        idx >= len(response_c[response_input_ids[x]])
                        or response_c[response_input_ids[x]][idx] - y > 10
                    ):
                        continue
                    c += 1
                    y = response_c[response_input_ids[x]][idx]
                if c > res_c:
                    res_c = c
                    res_min = y - l + 1
                    res = (l, y + 1)
                elif c == res_c and y - l + 1 < res_min:
                    res_min = y - l + 1
                    res = (l, y + 1)

            if res is None:
                return response_word
            # while l > 0 and not self.tokenizer.convert_ids_to_tokens(original_input_ids[l]).startswith("_"):
            #     l -= 1
            # while r < M - 1 and not self.tokenizer.convert_ids_to_tokens(original_input_ids[l]).startswith("_"):
            #     l -= 1
            return self.tokenizer.decode(original_input_ids[res[0] : res[1]])

        response_words = response.split(" ")

        original_input_ids = self.tokenizer(original_prompt, add_special_tokens=False)[
            "input_ids"
        ]
        N, M = len(response_words), len(original_input_ids)
        recovered_response_words = []
        l = 0
        while l < N:
            if response_words[l] not in compressed_prompt:
                recovered_response_words.append(response_words[l])
                l += 1
                continue
            r = l
            while (
                r + 1 < N and " ".join(response_words[l : r + 2]) in compressed_prompt
            ):
                r += 1

            match_words = match_from_compressed(" ".join(response_words[l : r + 1]))
            recovered_response_words.append(match_words)
            l = r + 1
        return " ".join(recovered_response_words)

    def get_rank_results(
        self,
        demonstrations: list,
        question: str,
        rank_method: str,
        condition_in_question: str,
        demonstrations_tokens_length: list,
    ):
        def get_distance_bm25(corpus, query):
            from rank_bm25 import BM25Okapi

            tokenized_corpus = [doc.split(" ") for doc in corpus]
            bm25 = BM25Okapi(tokenized_corpus)
            tokenized_query = query.split(" ")
            doc_scores = bm25.get_scores(tokenized_query)
            idx = [(ii, 0) for ii in (-doc_scores).argsort()]
            return idx

        def get_distance_gzip(corpus, query):
            def get_score(x, y):
                cx, cy = len(gzip.compress(x.encode())), len(gzip.compress(y.encode()))
                cxy = len(gzip.compress(f"{x} {y}".encode()))
                return (cxy - min(cx, cy)) / max(cx, cy)

            import gzip

            doc_scores = [get_score(doc, query) for doc in corpus]
            idx = [(ii, 0) for ii in np.argsort(doc_scores)]
            return idx

        def get_distance_sentbert(corpus, query):
            from sentence_transformers import SentenceTransformer, util

            if self.sbert is None:
                self.sbert = SentenceTransformer("multi-qa-mpnet-base-dot-v1")
            doc_embeds = self.sbert.encode(corpus)
            query = self.sbert.encode(query)
            doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1)
            idx = [(ii, 0) for ii in np.argsort(doc_scores)]
            return idx

        def get_distance_openai(corpus, query):
            import openai
            from sentence_transformers import util

            openai.api_key = self.open_api_config.get("api_key", "")
            openai.api_base = self.open_api_config.get(
                "api_base", "https://api.openai.com/v1"
            )
            openai.api_type = self.open_api_config.get("api_type", "open_ai")
            openai.api_version = self.open_api_config.get("api_version", "2023-05-15")
            engine = self.open_api_config.get("engine", "text-embedding-ada-002")

            def get_embed(text):
                return openai.Embedding.create(
                    input=[text.replace("\n", " ")], engine=engine
                )["data"][0]["embedding"]

            doc_embeds = [get_embed(i) for i in corpus]
            query = get_embed(query)
            doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1)
            idx = [(ii, 0) for ii in np.argsort(doc_scores)]
            return idx

        def get_distance_longllmlingua(corpus, query):
            demonstrations_ppl = [
                self.get_condition_ppl(
                    d,
                    query
                    + " We can get the answer to this question in the given documents.",
                    condition_in_question,
                )
                - dl * 2 / 250 * 0
                for d, dl in zip(corpus, demonstrations_tokens_length)
            ]
            sort_direct = -1 if condition_in_question == "none" else 1
            ys = sorted(enumerate(demonstrations_ppl), key=lambda x: sort_direct * x[1])
            return ys

        method = None
        if rank_method == "bm25":
            method = get_distance_bm25
        elif rank_method == "gzip":
            method = get_distance_gzip
        elif rank_method == "sentbert":
            method = get_distance_sentbert
        elif rank_method == "openai":
            method = get_distance_openai
        elif rank_method == "longllmlingua":
            method = get_distance_longllmlingua
        return method(demonstrations, question)
