import numpy as np
import lmppl

from datasets import load_dataset


import os
import logging
import gc
from math import exp
from typing import List
from tqdm import tqdm

import transformers
import torch.nn.functional as F
import torch
import random
import pickle
import os.path
import sentencepiece
import kenlm



log = logging.getLogger()


os.environ["OMP_NUM_THREADS"] = "1"  # to turn off warning message
os.environ["TOKENIZERS_PARALLELISM"] = "false"  # to turn off warning message
PAD_TOKEN_LABEL_ID = torch.nn.CrossEntropyLoss().ignore_index
FORCE_RESET = bool(int(os.getenv("FORCE_RESET", "0")))

class PruningStrategies:
    RANDOM = "RANDOM_PERPLEXITY"
    TOP_K_PERPLEXITY = "TOP_K_PERPLEXITY"
    BOTTOM_K_PERPLEXITY = "BOTTOM_K_PERPLEXITY"
    MIDDLE_K_PERPLEXITY = "MIDDLE_K_PERPLEXITY"
    ASK_LLM = "ASK_LLM"
    HYBRID = "HYBRID" # ActivePrune


import re
from typing import Dict


non_printing_characters_re = re.compile(
    f"[{''.join(map(chr, list(range(0,32)) + list(range(127,160))))}]"
)

digits_re: re.Pattern = re.compile(r"\d")

unicode_punctuation: Dict[str, str] = {
    "，": ",",
    "。": ".",
    "、": ",",
    "„": '"',
    "”": '"',
    "“": '"',
    "«": '"',
    "»": '"',
    "１": '"',
    "」": '"',
    "「": '"',
    "《": '"',
    "》": '"',
    "´": "'",
    "∶": ":",
    "：": ":",
    "？": "?",
    "！": "!",
    "（": "(",
    "）": ")",
    "；": ";",
    "–": "-",
    "—": " - ",
    "．": ". ",
    "～": "~",
    "’": "'",
    "…": "...",
    "━": "-",
    "〈": "<",
    "〉": ">",
    "【": "[",
    "】": "]",
    "％": "%",
    "►": "-",
}

normalization = {
    "non_printing_characters_re": non_printing_characters_re,
    "digits_re": digits_re,
    "unicode_punctuation": unicode_punctuation,
}


class ModifyingDocuments:
    @staticmethod
    def remove_empty_el_from_list(list_):
        return [el for el in list_ if el]

    @staticmethod
    def remove_non_printing_characters(document, non_printing_characters_re):
        return non_printing_characters_re.sub("", document)

    @staticmethod
    def uniform_whitespace(
        document,
        whitespace=[
            " ",
            " ",
            " ",
            " ",
            " ",
            "　",
            " ",
            " ",
            " ",
            " ",
            "￼",
            "",
        ],
    ):
        """There are different whitespace characters."""
        whitespace = set(whitespace)
        document = "".join(
            [char if char not in whitespace else " " for char in document]
        )
        return document

    @staticmethod
    def replace_digits_with_zeros(document, digits_re):
        return digits_re.sub("0", document)

    @staticmethod
    def replace_unicode_punctuation(document, unicode_punctuation):
        return "".join(unicode_punctuation.get(c, c) for c in document)

    @staticmethod
    def normalization(
        document,
        remove_non_printing_characters,
        strip,
        lower_case,
        uniform_whitespace,
        replace_digits_with_zeros,
        replace_unicode_punctuation,
        non_printing_characters_re=normalization["non_printing_characters_re"],
        digits_re=normalization["digits_re"],
        unicode_punctuation=normalization["unicode_punctuation"],
    ):
        if remove_non_printing_characters:
            document = ModifyingDocuments.remove_non_printing_characters(
                document, non_printing_characters_re
            )
        if strip:
            document = document.strip()
        if not document:
            return document
        if lower_case:
            document = document.lower()
        if uniform_whitespace:
            document = ModifyingDocuments.uniform_whitespace(document)
        if replace_digits_with_zeros:
            document = ModifyingDocuments.replace_digits_with_zeros(document, digits_re)
        if replace_unicode_punctuation:
            document = ModifyingDocuments.replace_unicode_punctuation(
                document, unicode_punctuation
            )
        return document

    @staticmethod
    def tokenization(document, sentencepiece_model, join_on_whitespace):
        document_tokenized = sentencepiece_model.encode_as_pieces(document)
        if join_on_whitespace:
            document_tokenized = " ".join(document_tokenized)
        return document_tokenized

    @staticmethod
    def split_on_whitespace(
        document,
        new_line=False,
        tab=False,
    ):
        """This method also removes concatenated spaces."""
        sep = [" "] + new_line * ["\n"] + tab * ["\t"]
        sep = "|".join(sep)
        split_document = re.split(sep, document)
        split_document = ModifyingDocuments.remove_empty_el_from_list(split_document)
        return split_document

    @staticmethod
    def strip(document, strip_characters):
        """Way faster than document.strip(strip_characters)
        since strip_characters is now a set instead of a str,
        and it contains a lot of elements (all the emojis)."""
        if not document:
            return document
        beg_ind = 0
        end_ind = len(document)
        for i in range(len(document)):
            if document[i] in strip_characters:
                beg_ind += 1
            else:
                break
        for i in range(1, len(document) + 1):
            if document[-i] in strip_characters:
                end_ind -= 1
            else:
                break
        document_stripped = document[beg_ind:end_ind]
        return document_stripped

    @staticmethod
    def get_words_from_document(
        document, sentencepiece_model_tok, lower_case, strip_characters
    ):
        """Get words from a document. Non reversible since the document
        is split on multiple characters, words are stripped of
        special characters and characters are converted to lower case.
        Useful to compute ratios, like the stopwords ratio."""
        if sentencepiece_model_tok:
            document_normalized = ModifyingDocuments.normalization(
                document=document,
                remove_non_printing_characters=True,
                strip=True,
                lower_case=True,
                uniform_whitespace=True,
                replace_digits_with_zeros=True,
                replace_unicode_punctuation=True,
            )
            words = ModifyingDocuments.tokenization(
                document_normalized, sentencepiece_model_tok, join_on_whitespace=False
            )
        else:
            words = ModifyingDocuments.split_on_whitespace(
                document, new_line=True, tab=True
            )
        if lower_case:
            words = [word.lower() for word in words]
        if strip_characters:
            words = [ModifyingDocuments.strip(word, strip_characters) for word in words]
            words = ModifyingDocuments.remove_empty_el_from_list(words)
        return words

    @staticmethod
    def words_augmentation(words, group_size, join_char):
        """Augment words, especially for Chinese (without a space between words)
        and Vietnamese (with a space between syllables)."""
        augmentation = [
            join_char.join(words[i : i + group_size])
            for i in range(len(words) - group_size + 1)
        ]
        return augmentation

    @staticmethod
    def split_on_newline_tab_whitespace(document):
        """First split on "\n", then on "\t", then on " "."""
        sentences = document.split("\n")
        sentences = [sentence.split("\t") for sentence in sentences]
        sentences = [
            [
                ModifyingDocuments.split_on_whitespace(subsentence)
                for subsentence in sentence
            ]
            for sentence in sentences
        ]
        return sentences

    @staticmethod
    def merge_on_whitespace_tab_newline(sentences):
        """Invert the method split_on_newline_tab_whitespace.
        Removes concatenated separators."""
        sentences = [
            [" ".join(subsentence) for subsentence in sentence if subsentence]
            for sentence in sentences
        ]
        sentences = ["\t".join(sentence) for sentence in sentences if sentence]
        if not sentences:
            return ""
        document = "\n".join(sentences)
        return document

    @staticmethod
    def should_keep_word_with_incorrect_substrings(
        word, strip_characters, incorrect_word_substrings
    ):
        word = ModifyingDocuments.strip(word, strip_characters)
        should_keep = all(
            [(i_substr not in word) for i_substr in incorrect_word_substrings]
        )
        return should_keep

    @staticmethod
    def remove_words_with_incorrect_substrings(
        document,
        strip_characters,
        incorrect_word_substrings,
    ):
        sentences = ModifyingDocuments.split_on_newline_tab_whitespace(document)
        sentences = [
            [
                [
                    word
                    for word in subsentence
                    if ModifyingDocuments.should_keep_word_with_incorrect_substrings(
                        word, strip_characters, incorrect_word_substrings
                    )
                ]
                for subsentence in sentence
            ]
            for sentence in sentences
        ]
        document = ModifyingDocuments.merge_on_whitespace_tab_newline(sentences)
        return document

    @staticmethod
    def should_keep_long_word(word, strip_characters, length_word_max_cutoff):
        """If the word is too long but it contains only one
        special character, it might be a concatenation of one word,
        a punctuation, and another word, with no space between them.
        In this case, we give the word a pass."""
        if len(word) <= length_word_max_cutoff:
            return True
        word = ModifyingDocuments.strip(word, strip_characters)
        if not word:  # The word consisted only of strip characters
            return False
        if len(word) <= length_word_max_cutoff:
            return True
        return False

    @staticmethod
    def remove_long_words(
        document,
        strip_characters,
        length_word_max_cutoff,
    ):
        sentences = ModifyingDocuments.split_on_newline_tab_whitespace(document)
        sentences = [
            [
                [
                    word
                    for word in subsentence
                    if ModifyingDocuments.should_keep_long_word(
                        word,
                        strip_characters,
                        length_word_max_cutoff,
                    )
                ]
                for subsentence in sentence
            ]
            for sentence in sentences
        ]
        document = ModifyingDocuments.merge_on_whitespace_tab_newline(sentences)
        return document

    @staticmethod
    def modifying_documents(
        document,
        cond_uniform_whitespace,
        cond_replace_unicode_punctuation,
        cond_remove_words_with_incorrect_substrings,
        strip_characters,
        incorrect_word_substrings,
        cond_remove_long_words,
        length_word_max_cutoff,
    ):
        document = ModifyingDocuments.normalization(
            document=document,
            remove_non_printing_characters=False,
            strip=True,
            lower_case=False,
            uniform_whitespace=cond_uniform_whitespace,
            replace_digits_with_zeros=False,
            replace_unicode_punctuation=cond_replace_unicode_punctuation,
        )
        if cond_remove_words_with_incorrect_substrings:
            document = ModifyingDocuments.remove_words_with_incorrect_substrings(
                document,
                strip_characters,
                incorrect_word_substrings,
            )
        if cond_remove_long_words:
            document = ModifyingDocuments.remove_long_words(
                document,
                strip_characters,
                length_word_max_cutoff,
            )
        return document


class LM:
    """ Language Model. """

    def __init__(self,
                 model: str = 'gpt2',
                 use_auth_token: bool = False,
                 max_length: int = None,
                 num_gpus: int = None,
                 torch_dtype=None,
                 device_map: str = None,
                 low_cpu_mem_usage: bool = False,
                 trust_remote_code: bool = True,
                 offload_folder: str = None,
                 hf_cache_dir: str = None,
                 load_quantized: str = False):
        """ Language Model.

        @param model: Model alias or path to local model file.
        @param use_auth_token: Huggingface transformers argument of `use_auth_token`
        @param device: Device name to load the models.
        @param num_gpus: Number of gpus to be used.
        @param torch_dtype: Torch data type.
        @param low_cpu_mem_usage: Low CPU memory usage.
        @param trust_remote_code: Trust remote code.
        @param offload_folder: Offload folder.
        @param hf_cache_dir: Huggingface transformers argument of `cache_dir`
        @param load_quantized: Load quantized model.
        """
        log.info(f'Loading Model: `{model}`')

        # load model
        params = {"local_files_only": False,
                  "trust_remote_code": trust_remote_code}
        if hf_cache_dir is not None:
            params["cache_dir"] = hf_cache_dir
        if offload_folder is not None:
            params["offload_folder"] = offload_folder
        params["token"] =  "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"
        
        self.tokenizer = transformers.AutoTokenizer.from_pretrained(model, **params)
        self.config = transformers.AutoConfig.from_pretrained(model, **params)

        params.update({"config": self.config, "low_cpu_mem_usage": low_cpu_mem_usage})
        if torch_dtype is not None:
            params['torch_dtype'] = torch_dtype
        if device_map is not None:
            params['device_map'] = device_map
        
        if load_quantized:
            bnb_config = transformers.BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.bfloat16
            )
            self.model = transformers.AutoModelForCausalLM.from_pretrained(model, quantization_config=bnb_config, token = "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")
        else:
            self.model = transformers.AutoModelForCausalLM.from_pretrained(model, **params)

        self.pad_token_initialized = False
        if self.tokenizer.pad_token is None:
            self.tokenizer.add_special_tokens({'pad_token': "<<PAD>>"})
            self.model.resize_token_embeddings(len(self.tokenizer))
            self.pad_token_initialized = True

        if max_length is None:
            self.max_length = None
        else:
            self.max_length = max_length if max_length is not None else self.tokenizer.model_max_length
            assert self.max_length <= self.tokenizer.model_max_length, f"{self.max_length} > {self.tokenizer.model_max_length}"

        # loss function
        self.loss_fct = torch.nn.CrossEntropyLoss(reduction='none')

        # # GPU setup
        # self.device = self.model.device
        # if device_map is None:
        #     num_gpus = torch.cuda.device_count() if num_gpus is None else num_gpus
        #     if num_gpus == 1:
        #         self.model.to('cuda')
        #         self.device = self.model.device
        #     elif num_gpus > 1:
        #         self.model = torch.nn.DataParallel(self.model)
        #         self.model.to('cuda')
        #         self.device = self.model.module.device
        self.model.eval()
        # log.info(f'\t * model is loaded on: {self.device}')

    def get_perplexity(self, input_texts: str or List, batch: int = None):
        """ Compute the perplexity on recurrent LM.

        :param input_texts: A string or list of input texts for the encoder.
        :param batch: Batch size
        :return: A value or list of perplexity.
        """

        # batch preparation
        single_input = type(input_texts) == str
        input_texts = [input_texts] if single_input else input_texts
        batch = len(input_texts) if batch is None else batch
        batch_id = list(range(0, len(input_texts), batch)) + [len(input_texts)]
        batch_id = list(zip(batch_id[:-1], batch_id[1:]))

        loss_list = []
        with torch.no_grad():
            for s, e in tqdm(batch_id):

                # run model inference
                if self.max_length is not None:
                    model_inputs = self.tokenizer(input_texts[s:e], max_length=self.max_length, truncation=True, padding='max_length', return_tensors='pt')
                else:
                    model_inputs = self.tokenizer(input_texts[s:e], truncation=True, padding=True, return_tensors='pt')
                if 'token_type_ids' in model_inputs:
                    model_inputs.pop('token_type_ids')

                output = self.model(**{k: v for k, v in model_inputs.items()})
                logit = output['logits']
                if self.pad_token_initialized:
                    logit = logit[:, :, :-1]

                # shift the label sequence for causal inference
                label = model_inputs['input_ids']
                label[label == self.tokenizer.pad_token_id] = PAD_TOKEN_LABEL_ID

                # Shift so that tokens < n predict n
                shift_logits = logit[..., :-1, :].contiguous()
                shift_label = label[:, 1:].contiguous()

                # compute loss
                valid_length = (shift_label != PAD_TOKEN_LABEL_ID).sum(dim=-1)
                loss = self.loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_label.view(-1))
                loss = loss.view(len(output['logits']), -1)
                loss = torch.sum(loss, -1) / valid_length
                loss_list += loss.cpu().tolist()

                if FORCE_RESET:
                    del model_inputs
                    del loss
                    del output
                    gc.collect()
                    torch.cuda.empty_cache()

        # conversion to perplexity
        ppl = [exp(i) for i in loss_list]
        return ppl[0] if single_input else ppl


    def get_ask_llm_score(self, input_texts: str or List, batch: int = None):
        """ Implements the ASK-LLM sampling method.
        https://arxiv.org/pdf/2402.09668.pdf

        :param input_texts: A string or list of input texts for the encoder.
        :param batch: Batch size
        :return: A value or list of perplexity.
        """

        # batch preparation
        single_input = type(input_texts) == str
        input_texts = [input_texts] if single_input else input_texts
        batch = len(input_texts) if batch is None else batch
        batch_id = list(range(0, len(input_texts), batch)) + [len(input_texts)]
        batch_id = list(zip(batch_id[:-1], batch_id[1:]))

        scores = []

        with torch.no_grad():
            for s, e in tqdm(batch_id):

                # run model inference
                if self.max_length is not None:
                    model_inputs = self.tokenizer(input_texts[s:e], max_length=self.max_length, truncation=True, padding='max_length', return_tensors='pt')
                else:
                    model_inputs = self.tokenizer(input_texts[s:e], truncation=True, padding=True, return_tensors='pt')
                if 'token_type_ids' in model_inputs:
                    model_inputs.pop('token_type_ids')

                output = self.model(**{k: v for k, v in model_inputs.items()})
                logits = output['logits']
                if self.pad_token_initialized:
                    logits = logits[:, :, :-1]

                 # Apply ASK-LLM sampling
                for i in range(logits.shape[0]):
                    softmax_scores = F.softmax(logits[i], dim=-1)
                    yes_token_id = self.tokenizer.encode("yes")[0]
                    yes_probability = softmax_scores[-1, yes_token_id].item()
                    scores.append(yes_probability)

                if FORCE_RESET:
                    del model_inputs
                    del output
                    gc.collect()
                    torch.cuda.empty_cache()

        return scores


def compute_kenlm_perplexity_score(document, sentencepiece_model, kenlm_model):
        document = ModifyingDocuments.normalization(
            document=document,
            remove_non_printing_characters=True,
            strip=True,
            lower_case=False,
            uniform_whitespace=True,
            replace_digits_with_zeros=True,
            replace_unicode_punctuation=True,
        )
        document = ModifyingDocuments.tokenization(
            document, sentencepiece_model, join_on_whitespace=True
        )
        doc_log_score, doc_length = 0, 0
        for line in document.split("\n"):
            log_score = kenlm_model.score(line)
            length = len(line.split()) + 1
            doc_log_score += log_score
            doc_length += length
        pp_score = 10.0 ** (-doc_log_score / doc_length)
        pp_score = round(pp_score, 1)
        return pp_score

def chunks(lst, n):
    """Yield successive n-sized chunks from lst."""
    for i in range(0, len(lst), n):
        yield lst[i:i + n]


def print_and_log(*args):
    print(*args)
    log.info(" ".join([str(x) for x in args]))


def get_ask_llm_scores(X_pool, scorer_model_name, dataset_name, text_column_name, overwrite_cache=False, batch_size=4):

    task = "classification"
    if dataset_name in ["ag_news", "imdb"]:
        task = "classification"
    elif dataset_name in ["opus-it", "opus-medical"]:
        task = "translation"
    elif dataset_name == "aeslc":
        task = "summarization"

    scores_file_name = f"/home/jovyan/active-learning-qlora/scores/ask_llm_scores_{scorer_model_name}_{dataset_name}.pkl"
    if os.path.isfile(scores_file_name) and not overwrite_cache:
        print_and_log("Loading cached ASK-LLM scores")
        with open(scores_file_name, "rb") as f:
            scores_dict = pickle.load(f)
    else:
        if scorer_model_name == "mixtral":
            scorer = LM('mistralai/Mistral-7B-v0.1', load_quantized=True)
        elif scorer_model_name == "phi":
            scorer = LM('microsoft/phi-1_5', load_quantized=True)
        elif scorer_model_name == "gemma":
            scorer = LM('google/gemma-2b-it', load_quantized=True)
        print_and_log("Computing ASK-LLM scores")
        
        text = X_pool[text_column_name]
        text = [
            f"""
Is the following training example (added between === and
===) suitable for training a model on the {task} task? 

A suitable training example should be well-formatted, 
contain some useful information, and not have any harmful
content.

Provide a response as "yes" OR "no" only.

===
{t}
===
"""
            for t in text
        ]
        ask_llm_scores = scorer.get_ask_llm_score(text, batch=batch_size)
        ask_llm_scores = np.array(ask_llm_scores)

        scores_dict = {i: p for i, p in zip(X_pool['id'], ask_llm_scores)}

        with open(scores_file_name, "wb") as f:
            pickle.dump(scores_dict, f)

    return scores_dict


def get_perplexity_scores(X_pool, scorer_model_name, dataset_name, text_column_name, overwrite_cache=False, batch_size=4):

    scores_file_name = f"/home/jovyan/active-learning-qlora/scores/perplexity_scores_{scorer_model_name}_{dataset_name}.pkl"
    if os.path.isfile(scores_file_name) and not overwrite_cache:
        print_and_log("Loading cached perplexity scores")
        with open(scores_file_name, "rb") as f:
            scores_dict = pickle.load(f)
    else:
        print_and_log("Computing perplexity scores")
        
        text = X_pool[text_column_name]
        text = [
            t for t in text
        ]
        if scorer_model_name == "kenlm":
            sentencepiece_model = sentencepiece.SentencePieceProcessor()
            sentencepiece_model.load("/home/jovyan/active-learning-qlora/kenlm-oscar/en.sp.model")
            kenlm_model = kenlm.Model("/home/jovyan/active-learning-qlora/kenlm-oscar/en.arpa.bin")
            ppl = [compute_kenlm_perplexity_score(t, sentencepiece_model, kenlm_model) for t in tqdm(text)]
            ppl = np.array(ppl)
        else:
            if scorer_model_name == "mixtral":
                scorer = LM('mistralai/Mistral-7B-v0.1', load_quantized=True)
            elif scorer_model_name == "phi":
                scorer = LM('microsoft/phi-1_5', load_quantized=True)
            elif scorer_model_name == "gemma":
                scorer = LM('google/gemma-2b-it', load_quantized=True)
            ppl = scorer.get_perplexity(text, batch=batch_size)
            ppl = np.array(ppl)

        scores_dict = {i: p for i, p in zip(X_pool['id'], ppl)}

        with open(scores_file_name, "wb") as f:
            pickle.dump(scores_dict, f)

    return scores_dict

def llm_subsampling(uncertainty_estimates, gamma_or_k_confident_to_save, **kwargs):

    llm_subsampling_kwargs = kwargs.get('llm_subsampling_kwargs')

    overwrite_cache = llm_subsampling_kwargs.get("overwrite_cache", False)
    scorer_model_name = llm_subsampling_kwargs.get("scorer_model_name", "mixtral")
    text_column_name = llm_subsampling_kwargs.get("text_column_name", "en")
    dataset_name = llm_subsampling_kwargs.get("dataset_name")
    method = llm_subsampling_kwargs.get("method")
    offset = llm_subsampling_kwargs.get("offset") or 0
    if offset:
        offset = int(offset)

    length = len(uncertainty_estimates)
    if isinstance(gamma_or_k_confident_to_save, float):
        gamma_or_k_confident_to_save = int(gamma_or_k_confident_to_save * length)
    gamma_or_k_confident_to_save = min(gamma_or_k_confident_to_save, length)
    
    print_and_log("kwargs")
    print_and_log(kwargs)

    X_pool = kwargs.get('X_pool')

    SELECTED_STRATEGY = method
    if SELECTED_STRATEGY == PruningStrategies.ASK_LLM:
        scores_dict = get_ask_llm_scores(X_pool, scorer_model_name, dataset_name, text_column_name, overwrite_cache=overwrite_cache)
        # retain only those keys in ppl_dict which are present in X_pool
        # Since after each iteration, X_pool is updated and those indices are removed
        # which have already been used for training
        scores = np.array([scores_dict[i] for i in X_pool['id'] if scores_dict.get(i)])
        argsort = np.argsort(-scores)
        retained_items = argsort[offset: gamma_or_k_confident_to_save + offset]
        print_and_log("Length of retained indices: ", len(retained_items))
        return retained_items
    
    if SELECTED_STRATEGY == PruningStrategies.TOP_K_PERPLEXITY:
        scores_dict = get_perplexity_scores(X_pool, scorer_model_name, dataset_name, text_column_name, overwrite_cache=overwrite_cache)
        scores = np.array([scores_dict[i] for i in X_pool['id'] if scores_dict.get(i)])
        argsort = np.argsort(-scores)
        retained_items = argsort[offset: gamma_or_k_confident_to_save + offset]
        print_and_log("Length of retained indices: ", len(retained_items))
        return retained_items

    elif SELECTED_STRATEGY == PruningStrategies.BOTTOM_K_PERPLEXITY:
        scores_dict = get_perplexity_scores(X_pool, scorer_model_name, dataset_name, text_column_name, overwrite_cache=overwrite_cache)
        scores = np.array([scores_dict[i] for i in X_pool['id'] if scores_dict.get(i)])
        argsort = np.argsort(scores)
        retained_items = argsort[offset: gamma_or_k_confident_to_save + offset]
        print_and_log("Length of retained indices: ", len(retained_items))
        return retained_items

    elif SELECTED_STRATEGY == PruningStrategies.MIDDLE_K_PERPLEXITY:
        scores_dict = get_perplexity_scores(X_pool, scorer_model_name, dataset_name, text_column_name, overwrite_cache=overwrite_cache)
        scores = np.array([scores_dict[i] for i in X_pool['id'] if scores_dict.get(i)])
        argsort = np.argsort(scores)
        start_index = (len(argsort) - gamma_or_k_confident_to_save) // 2
        print_and_log("Start index: ", start_index)
        retained_items = argsort[start_index:start_index + gamma_or_k_confident_to_save]
        print_and_log("Length of retained indices: ", len(retained_items))
        return retained_items

    elif SELECTED_STRATEGY == PruningStrategies.HYBRID:
        upper_threshold_llm_scores = llm_subsampling_kwargs['upper_threshold_llm_scores']
        pick_llm_score_candidates_randomly = llm_subsampling_kwargs.get('pick_llm_score_candidates_randomly') or False
        hybrid_weight = llm_subsampling_kwargs.get("hybrid_weight") or 0
        perplexity_reweighting = llm_subsampling_kwargs.get("perplexity_reweighting")
        reweighting_factor = llm_subsampling_kwargs.get("reweighting_factor") or 1
        hybrid_weight = float(hybrid_weight)
    
        scores_dict = get_perplexity_scores(X_pool, "kenlm", dataset_name, text_column_name, overwrite_cache=overwrite_cache)

        latest_iteration = None
        log_dir = kwargs.get("log_dir")
        if perplexity_reweighting:
            import json
            from pathlib import Path
            import glob

            query_files = glob.glob(str(log_dir / 'ids_data_query*.json'))
            sorted_query_files = sorted(query_files, reverse=True)
            latest_query_file = sorted_query_files[0]
            latest_iteration = int(latest_query_file.split("_")[-1].replace(".json", ""))
            if len(query_files) > 1:
                with open(str(log_dir / f"ids_data_perplexities_{latest_iteration}.pickle"), 'rb') as file:
                    latest_perplexities = pickle.load(file)

                with open(latest_query_file, 'r') as file:
                    latest_queries = json.load(file)

                latest_perplexities = [latest_perplexities[key] for key in latest_queries if latest_perplexities.get(key)]
                for key, value in scores_dict.items():
                    sum_of_differences = 0
                    for item in latest_perplexities:
                        sum_of_differences += abs(value - item)
                    avg_difference = sum_of_differences / len(latest_perplexities)
                    reweighted_perplexity = value - (reweighting_factor * avg_difference)
                    scores_dict[key] = reweighted_perplexity
        perplexity_scores = []
        for i, _id in enumerate(X_pool['id']):
            if _id not in scores_dict:
                continue
            perplexity_scores.append((i, int(_id), scores_dict[_id]))

        perplexity_samples_count = int(hybrid_weight * gamma_or_k_confident_to_save)
        llm_scores_samples_count = gamma_or_k_confident_to_save - perplexity_samples_count

        sorted_perplexity = sorted(perplexity_scores, key=lambda x: x[2])
        retained_perplexities = sorted_perplexity[0: perplexity_samples_count]
        retained_perplexities = [int(item[0]) for item in retained_perplexities]

        non_retained_items = sorted_perplexity[perplexity_samples_count: ]
        if pick_llm_score_candidates_randomly:
            np.random.shuffle(non_retained_items)
        non_retained_items = non_retained_items[-upper_threshold_llm_scores: ]

        scores_dict = get_ask_llm_scores(X_pool, scorer_model_name, dataset_name, text_column_name, overwrite_cache=overwrite_cache)
        ask_llm_scores = []
        for item in non_retained_items:
            i = item[0]
            _id = item[1]
            if _id not in scores_dict:
                print("not found", "i", i, "id", _id)
                continue
            score = scores_dict[_id]
            ask_llm_scores.append(
                (i, _id, score)
            )

        sorted_ask_llm_scores = sorted(ask_llm_scores, key=lambda x: x[2], reverse=True)
        retained_llm_scores = sorted_ask_llm_scores[0: llm_scores_samples_count]
        retained_llm_scores = [int(item[0]) for item in retained_llm_scores]

        retained_perplexities.extend(retained_llm_scores)
        retained_items = np.array(retained_perplexities)
        print_and_log("Length of retained indices: ", retained_items)

        if perplexity_reweighting:
            # store the perplexities before returning
            perplexities_to_store = {idx: perplexity_scores[idx][2] for idx in retained_items if idx < len(perplexity_scores)}
            perplexities_file_path = log_dir / f'ids_data_perplexities_{latest_iteration + 1}.pickle' 
            with open(perplexities_file_path, 'wb') as file:
                pickle.dump(perplexities_to_store, file)

        return retained_items

