from .operators import Operator
import numpy as np
import torch
from typing import Dict
import os
import json
from .basic_model_loader import load_model, load_tokenizer
from .utils import get_max_length, ENABLE_LOGGING, log
from loguru import logger


class RunnableOperator(Operator):
    def __init__(self, prompt_string="", model=None, speculative_factor=1, 
                 prompt_template = lambda prompt_string, input_string: prompt_string + input_string, run_priority=0, group=None, 
                 outputs_logprobs=True, **kwargs):
        """
        Initialize a runnable operator instance. A runnable operator is an operator that generates a probability distribution instead of modifies an existing one.
        
        Args:
            prompt_string (str): String to be used as a prompt. Only used in specific runnable operators
            model (optional): Model to be used for operation. If None, the model must be set later to the default model to be used.
            speculative_factor (int): Factor for speculative sampling.
            prompt_template (callable): Function for generating prompt. Takes two arguments: prompt_string and input_string. The operator will be run on prompt_template(..., ...) + continuation_tokens
            run_priority (int): Priority for running the operation. Higher priority means the operation will be run first, especially important for the classifier.
            group (optional): Group to which the operator belongs. This ensures that speculative sampling will not be tried when not all operators of a group are finished.
            **kwargs: Arbitrary keyword arguments.
        """
        super().__init__(speculative_factor=speculative_factor, model=model, prompt_string=prompt_string,
                         prompt_template=prompt_template, run_priority=run_priority, group=group, outputs_logprobs=outputs_logprobs, **kwargs)
        self.cache = None
        
    def run_condition(self, new_tokens, trigger_end):
        """
        Determine if the run condition is met.
        
        Args:
            new_tokens: Number of new tokens per sample in the batch
            trigger_end: Whether to trigger the end for each sample in the batch.
            
        Returns:
            bool: Whether the run condition is met.
        """
        new_tokens = [new_tokens[i] if not trigger_end[i]  or new_tokens[i] < 0 else max(new_tokens[i], self.speculative_factor) for i in range(len(new_tokens))]
        return np.mean(new_tokens) >= self.speculative_factor 
        # other possibility:
        # return np.max(new_tokens) + 1 >= speculative_factor
        
    def delete_cache(self, index=None, from_=None):
        """
        Delete the cache.
        """
        if from_ is None and index is None:
            self.cache = None
        
    def run(self, tokenized_inputs, **kwargs):
        """
        Run the operation. This method needs to be implemented by subclasses.
        
        Args:
            tokenized_inputs: Inputs that have been tokenized.
            **kwargs: Arbitrary keyword arguments.
            
        Raises:
            NotImplementedError: This method needs to be implemented by subclasses.
        """
        raise NotImplementedError("This method needs to be implemented by subclasses.")
    
    def runnable_operators(self):
        """
        Get a list of runnable operators used by the operator, usually only this operator itself.
        
        Returns:
            list: List of runnable operators.
        """
        return [self]
    
    def same_operator(self, other):
        """
        Determine if the other operator is the same as this one. This is important to avoid redundant runs of the same operator in a formula
        
        Args:
            other: Other operator to be compared.
            
        Returns:
            bool: Whether the other operator is the same as this one.
        """
        if isinstance(other, str):
            return self.id() == other
        elif isinstance(other, RunnableOperator):
            return self.id() == other.id()
        return False

    def norm(self, runnable_operator_outputs=None):
        """
        Compute the norm of the operator.
        
        Args:
            runnable_operator_outputs (optional): Outputs of runnable operators.
            
        Returns:
            int: The norm of the operator.
        """
        if runnable_operator_outputs is None or self.is_finished(runnable_operator_outputs):
            return 1
        return 0
    
    def is_finished(self, runnable_operator_outputs):
        """
        Determine if the operation is finished.
        
        Args:
            runnable_operator_outputs: Outputs of runnable operators.
            
        Returns:
            bool: Whether the operation is finished.
        """
        return any([self.same_operator(output) and runnable_operator_outputs[output] is not None for output in runnable_operator_outputs])
    
    def evaluate(self, runnable_operator_outputs : Dict, normalize : bool = True):
        """
        Evaluate the operation.
        
        Args:
            runnable_operator_outputs (Dict): Outputs of runnable operators.
            normalize (bool): Whether to normalize the evaluation.
            
        Returns:
            int: The evaluation of the operation.
        """
        for output in runnable_operator_outputs:
            if self.same_operator(output) and runnable_operator_outputs[output] is not None:
                return runnable_operator_outputs[output]
        return 0
    
    def generate_settings(self):
        """
        Generate settings for the operation.
        
        Returns:
            dict: Settings for the operation.
        """
        kwargs = super().generate_settings()
        kwargs["prompt_template"] = self.prompt_template("{{prompt_string}}", "{{input_string}}")
        return kwargs

    @staticmethod
    def load_from_settings(settings):
        """
        Load operator from settings.
        
        Args:
            settings (dict): Settings for the operation.
            
        Returns:
            Operator: Operator loaded from settings.
        """
        copy = settings["prompt_template"]
        prompt_template = lambda prompt_string, input_string: copy.replace("{{prompt_string}}", prompt_string).replace("{{input_string}}", input_string)
        settings["prompt_template"] = prompt_template
        return Operator.load_from_settings(settings)
    
    def get_prompt(self, input_string):
        """
        Get the prompt for the operation.
        
        Args:
            input_string (str): String to be used as input.
            
        Returns:
            callable: Function for generating prompt.
        """
        return self.prompt_template(self.prompt_string, input_string)
    
    def get_store_params(self):
        """
        Get parameters for storing the operation.
        
        Returns:
            dict: Parameters for storing the operation.
        """
        return {
            "class": self.__class__.__name__,
            "model": self.model,
            "speculative_factor": self.speculative_factor,
            "prompt_template": self.prompt_template(self.prompt_string, "{{input_string}}")
        }
        
    def id(self):
        """
        Get the ID of the operation.
        
        Returns:
            str: ID of the operation.
        """
        kwargs = self.kwargs.copy()
        kwargs["prompt_template"] = self.prompt_template(self.prompt_string, "{{input_string}}")
        return f"{self.__class__.__name__}(**{kwargs})"
    
    def load_model(self, dtype):
        """
        Load the model for the operation. Only needs to be overwritten when a model is necessary
        
        Args:
            dtype: Data type for the model.
            
        Returns:
            None
        """
        return None
    
    def initialize_after_model_set(self):
        """
        Initialize the operation after the model is set (to the default model if necessary).
        
        Raises:
            AssertionError: If the model is not set before initializing.
        """
        assert self.model is not None, "Model must be set before initializing."
        

class LLMPrompt(RunnableOperator):
    def __init__(self, prompt_string, model=None, speculative_factor=1, 
                 prompt_template = lambda prompt_string, input_string, : prompt_string + "\n" + input_string, dtype=None, group=None,
                 enable_cache=True, dim_keys_past=2, dim_values_past=2, run_eager=False, tokenizer=None, **kwargs):
        """
        Initializes an LLM Prompt. This is a runnable operator that uses a language model to generate a probability distribution.
        Args:
            prompt_string (str): String to be used as a prompt. Only used in specific runnable operators
            model (optional): Model to be used for operation. If None, the model must be set later to the default model to be used.
            speculative_factor (int): Factor for speculative sampling.
            prompt_template (callable): Function for generating prompt. Takes two arguments: prompt_string and input_string. The operator will be run on prompt_template(..., ...) + continuation_tokens
            run_priority (int): Priority for running the operation. Higher priority means the operation will be run first, especially important for the classifier.
            dtype (optional): Data type for the model.
            group (optional): Group to which the operator belongs. This ensures that speculative sampling will not be tried when not all operators of a group are finished.
            **kwargs: Arbitrary keyword arguments.
        """
        if dim_keys_past == 2 and dim_values_past == 2:
            # set the dims based on the model
            if model in ["tiiuae/falcon-7b", "tiiuae/falcon-7b-instruct", "tiiuae/falcon-40b", "tiiuae/falcon-40b-instruct"]:
                dim_keys_past = 1
                dim_values_past = 1
        
        super().__init__(prompt_string=prompt_string, model=model, speculative_factor=speculative_factor, 
                         prompt_template=prompt_template, group=group, enable_cache=enable_cache, 
                         dim_keys_past=dim_keys_past, dim_values_past=dim_values_past, run_eager=run_eager)
        self.dtype = dtype
        self.tokenizer_length = None
        self.tokenizer = tokenizer
        self.previous_input_ids = None
        self.default_dim = 2
        if self.run_eager:
            log(logger.warning, "Eager mode is enabled. This will make several features, such as speculative sampling, inaccessible.")
        
    def load_model(self, dtype):
        """
        Loads the model for the operation.
        :param dtype: Data type for the model.
        """
        if not isinstance(self.model, str):
            return self.model
        if self.dtype is None:
            return load_model(self.model, dtype=dtype)
        return load_model(self.model, dtype=self.dtype)
    
    def initialize_after_model_set(self):
        if self.tokenizer is None:
            tokenizer = load_tokenizer(self.model)
            self.tokenizer_length = len(tokenizer)
        
    def select_from_sample_cache(self, sample, from_=None, until=None):
        for i in range(len(sample)):
            for j in range(len(sample[i])):
                sample[i][j] = sample[i][j][:, from_:until]
        
        return sample
    
    def swap_dimensions(self, sample):
        for i in range(len(sample)):
            # keys, values
            if self.default_dim != self.dim_keys_past:
                sample[i][0] = sample[i][0].transpose(self.default_dim - 1, self.dim_keys_past - 1)
            if self.default_dim != self.dim_values_past:
                sample[i][1] = sample[i][1].transpose(self.default_dim - 1, self.dim_values_past - 1)
        
        return sample
    
    def select_sample_cache(self, cache, sample_index):
        sample = []
        for i in range(len(cache)):
            sample.append([
                cache[i][0][sample_index],
                cache[i][1][sample_index]
            ])
        sample = self.swap_dimensions(sample)
        return sample
    
    def pad_sample(self, sample, target_size):
        for i in range(len(sample)):
            for j in range(len(sample[i])):
                pad_size = target_size - sample[i][j].size(1)
                pad = (0, 0, pad_size, 0)
                if pad_size > 0:
                    sample[i][j] = torch.nn.functional.pad(sample[i][j], pad, "constant", 0)
                elif pad_size < 0:
                    sample[i][j] = sample[i][j][:, :target_size]
        return sample
    
    def stack_samples(self, samples):
        stacked_samples = []
        for i in range(len(samples[0])):
            stacked_mult = []
            for j in range(len(samples[0][i])):
                stacked = torch.stack(
                    [samples[k][i][j] for k in range(len(samples))], dim=0
                )
                stacked_mult.append(stacked)
            stacked_samples.append(stacked_mult)
        return stacked_samples
        
    def store_cache(self, past_key_values, input_ids, lengths):
        # reverting the terrible design choice by huggingface
        if self.run_eager:
            self.cache = past_key_values
            return
        self.cache = []
        self.previous_input_ids = []
        for i, length in enumerate(lengths):
            self.cache.append(
                self.select_from_sample_cache(self.select_sample_cache(past_key_values, i), from_=-length)
            )
            self.previous_input_ids.append(
                input_ids[i, -length:]
            )
    def common_starting_elements(self, t1, t2):
        min_length = min(t1.size(0), t2.size(0))
        eq = torch.eq(t1[:min_length], t2[:min_length])
        if not eq.any():
            return 0
        if eq.all():
            return min_length

        return torch.where(eq == 0)[0][0].item()
        
    def delete_previous_cache(self, new_input_ids, lengths):
        if self.run_eager:
            return
        input_ids = [
            new_input_ids[i, -lengths[i]:] for i in range(len(lengths))
        ]
        elements = [self.common_starting_elements(input_ids[i], self.previous_input_ids[i]) for i in range(len(lengths))]
        self.cache = [
            self.select_from_sample_cache(self.cache[i], until=elements[i]) for i in range(len(lengths))
        ]
        
    
    def prepare_inputs(self, input_ids, attention_mask, n_new_tokens):
        max_new_tokens = max(n_new_tokens)
        past_key_values = None
        if self.cache is not None and self.enable_cache:
            input_ids = input_ids[:, -max_new_tokens:]
            if self.run_eager:
                past_key_values = self.cache
            else:
                past_key_values = self.pad_cache(
                    [self.select_from_sample_cache(self.cache[i], until=-max_new_tokens + n_new_tokens[i]) if max_new_tokens > n_new_tokens[i] else self.cache[i]
                    for i in range(len(n_new_tokens))],
                    attention_mask.shape[1] - max_new_tokens
                )
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "use_cache": True,
            "past_key_values": past_key_values
        }
    
    def pad_cache(self, cache, length):
        for i in range(len(cache)):
            cache[i] = self.pad_sample(cache[i], length)
            cache[i] = self.swap_dimensions(cache[i])
        stacked_samples = self.stack_samples(cache)

        return stacked_samples
    
    def delete_cache(self, index=None, from_=None):
        # if index is not None and self.cache is not None:
        #     self.previous_input_ids = self.previous_input_ids[:index] + self.previous_input_ids[index + 1:]
        #     cache_shape = list(self.cache[0].shape)
        #     device = self.cache[0].device
        #     dtype = self.cache[0].dtype
        #     cache_shape[-2] = 0
        #     self.cache = self.cache[:index] + self.cache[index + 1:]
        #     self.previous_input_ids.append(torch.tensor([]))
        #     self.cache.append(torch.tensor([], device=device, dtype=dtype).reshape(cache_shape))
        #     return
        # else:
            self.previous_input_ids = None
            self.cache = None
    
    def run(self, tokenized_inputs, loaded_models, model_new_tokens, use_cache, **kwargs):
        """
        Runs the model on the tokenized inputs.
        :param tokenized_inputs: Inputs that have been tokenized.
        :param loaded_models: Models that have been loaded. The model for this operation is in loaded_models[self.model]
        :param model_new_tokens: Number of new tokens per sample in the batch
        """
        if isinstance(self.model, str):
            model = loaded_models[self.model]
        else:
            model = self.model
        lengths = torch.sum(tokenized_inputs.attention_mask, dim=-1)
        if self.cache is not None and self.enable_cache and use_cache:
            self.delete_previous_cache(tokenized_inputs.input_ids, lengths)
                
        # if self.cache is not None:
        #     length_common_input_ids_per_sample = [
                
        #     ]
        actual_inputs = self.prepare_inputs(input_ids=tokenized_inputs.input_ids.to(model.device),
                                            attention_mask=tokenized_inputs.attention_mask.to(model.device),
                                            n_new_tokens=model_new_tokens)
        # run model 
        with torch.no_grad():
            try:
                model_output = model(**actual_inputs, return_dict=True)
            except RuntimeError as e:
                raise RuntimeError(f"Error thrown when running model. This is probably caused because the model handles the key-value cache differently. Consider setting dim_values_past and dim_keys_past values or disabling the key-value cache. Alternatively, you can set run_eager=True, but this feature is incompatible with speculative sampling and some other features.")
            logprobs = torch.log_softmax(model_output.logits[:, :, :self.tokenizer_length], dim=-1)
        
        if self.enable_cache and use_cache:
            self.store_cache(model_output.past_key_values, tokenized_inputs.input_ids, lengths)
            
        logprobs = [logprobs[i, -model_new_tokens[i] : ] for i in range(logprobs.shape[0])]
        return logprobs

    def __str__(self):
        return f"LLMPrompt('{self.prompt_string}', model='{self.model}')"
        
        
class Autocomplete(RunnableOperator):
    def __init__(self, corpus=None, speculative_factor=1, from_save_file=None, group=None, store_corpus=True, **kwargs):
        """
        Runnable operator that uses a corpus to generate a probability distribution that just predicts the most likely token based on the previous token
        :param corpus: Corpus to be used for training the autocomplete model
        :param speculative_factor: Factor for speculative sampling
        :param from_save_file: File from which to load the autocomplete model
        :param group: Group to which the operator belongs. This ensures that speculative sampling will not be tried when not all operators of a group are finished.
        :param kwargs: Arbitrary keyword arguments.
        """
        assert corpus is not None or (from_save_file is not None and os.path.isfile(from_save_file)), "Either corpus or from_save_file must be specified."
        self.tokenizer = None
        self.mapper = None
        self.unknown_mapper = None
        self.start_mapper = None
        super().__init__(speculative_factor=speculative_factor, corpus=corpus, from_save_file=from_save_file, group=group, **kwargs)
        if not store_corpus:
            log(logger.warning, "Not storing the corpus. This will make the model not loadable. Only do this when debugging")
            del self.kwargs["corpus"]
        
    def add_token(self, to_token, from_token=None):
        """
        Adds a token to the mapper, which keeps track of the most likely token to follow a given token.
        :param to_token: Token to be added.
        :param from_token: Token from which to_token is most likely to follow. None if to_token is the first token in a sequence.
        """
        if from_token is None:
            self.start_mapper[to_token] = self.start_mapper.get(to_token, 0) + 1
        else:
            if from_token not in self.mapper:
                self.mapper[from_token] = dict()
            self.mapper[from_token][to_token] = self.mapper[from_token].get(to_token, 0) + 1
        self.unknown_mapper[to_token] = self.unknown_mapper.get(to_token, 0) + 1
    
    def return_max_token(self, dict_):
        """
        Returns the most likely token in a dictionary.
        :param dict_: Dictionary containing tokens and their counts.
        """
        max_token = None
        max_count = -1
        for token in dict_:
            if dict_[token] > max_count:
                max_token = token
                max_count = dict_[token]
        return int(max_token)
    
    def initialize_after_model_set(self):
        """
        Function that is run after the model is set (to the default model if necessary). Fits the autocomplete model or loads it if it in the save file.
        """
        self.tokenizer = load_tokenizer(self.model)
        
        if self.from_save_file is not None and os.path.isfile(self.from_save_file):
            self.load_from_file(self.from_save_file)
        else:
            self.corpus = [sentence for sentence in self.corpus if isinstance(sentence, str) and len(sentence) > 0]
            self.unknown_mapper = dict()
            self.start_mapper = dict()
            
            self.mapper = dict()
            
            tokenization = self.tokenizer(self.corpus)
            for sentence in tokenization.input_ids:
                for i in range(len(sentence)):
                    if i == 0:
                        self.add_token(to_token=sentence[i])
                    else:
                        self.add_token(to_token=sentence[i], from_token=sentence[i-1])
            
            self.array_mapper = torch.zeros(len(self.tokenizer), dtype=torch.int32)
            self.start_most_common = self.return_max_token(self.start_mapper)
            self.array_mapper = torch.tensor([
                self.return_max_token(self.mapper.get(i, self.unknown_mapper)) for i in range(len(self.tokenizer))
            ])
            
        if self.from_save_file:
            self.save_to_file(self.from_save_file)
            del self.kwargs["corpus"]
            del self.corpus
            
    def load_from_file(self, file):
        """
        Loads the autocomplete model from a file.
        :param file: File from which to load the autocomplete model.
        """
        file_content = json.load(open(file, "r"))
        self.start_most_common = file_content["start_most_common"]
        self.array_mapper = file_content["array_mapper"]
        
    def save_to_file(self, file):
        """
        Saves the autocomplete model to a file.
        :param file: File to which to save the autocomplete model.
        """
        json.dump({"start_most_common": self.start_most_common, 
                   "array_mapper": self.array_mapper.tolist()}, open(file, "w"))
        
    def run_on_token_past(self, tokens):
        """
        Outputs a probability distribution for the next token based on the previous tokens.
        """
        output = torch.zeros(len(self.tokenizer), dtype=torch.float32) - torch.inf
        if len(tokens) == 0:
            output[self.start_most_common] = 0.0
        else:
            output[self.array_mapper[tokens[-1]]] = 0.0
        
        return output
    
    def run_on_sample(self, tokens, n_new_tokens):
        """
        Runs the autocomplete model on a sample for all number of new tokens.
        :param tokens: Tokens to be used as input.
        :param n_new_tokens: Number of new tokens to be generated.
        """
        output = torch.zeros((n_new_tokens, len(self.tokenizer)))
        if n_new_tokens > 0:
            output[n_new_tokens - 1] = self.run_on_token_past(tokens)
        for i in range(2, n_new_tokens + 1):
            output[n_new_tokens - i] = self.run_on_token_past(tokens[:-i + 1])
        
        return output
    
    def run(self, tokenized_inputs, model_new_tokens, **kwargs):
        """
        Runs the autocomplete model on the tokenized inputs.
        :param tokenized_inputs: Inputs that have been tokenized.
        :param model_new_tokens: Number of new tokens per sample in the batch
        """
        output = []
        for sample in range(tokenized_inputs.input_ids.shape[0]):
            output.append(self.run_on_sample(tokenized_inputs.input_ids[sample], 
                                             model_new_tokens[sample]))
        output = self.set_to_minimum(output)
        return output
    
    def __str__(self):
        return f"Autocomplete(model='{self.model}', speculative_factor={self.speculative_factor})"
    
    def id(self):
        return f"Autocomplete(model='{self.model}')"


class Classifier(RunnableOperator):
    def __init__(self, formula, model, n_runs_per_sample, batch_size=None, dtype=None, prompt_string="", 
                 prompt_template = lambda prompt_string, input_string: prompt_string +  input_string, minimize=False, group=None, use_bayes=False, index=1,
                 tokenizer=None, **kwargs):
        """
        Initializes the classifier operator. This is a runnable operator that uses a classifier to generate a probability distribution.
        :param formula: Formula to be used for the classifier. Only the n_runs_per_sample tokens will be used for classification, in order to have a normal amount of compute
        :param model: Model to be used for the classifier.
        :param n_runs_per_sample: Number of tokens to be used for classification.
        :param batch_size: Batch size to be used for classification. If None, all samples will be classified at once. (thus batch_size would be batch_size of the generator multiplied by n_runs_per_sample + 1)
        :param dtype: Data type for the model.
        :param prompt_string: String to be used as a prompt.
        :param prompt_template: Function for generating prompt. Takes two arguments: prompt_string and input_string. The operator will be run on prompt_template(..., ...) + continuation_tokens
        :param minimize: Whether to minimize the output of the classifier.
        :param group: Group to which the operator belongs. This ensures that speculative sampling will not be tried when not all operators of a group are finished.
        """
        super().__init__(formula=formula, model=model, batch_size=batch_size, n_runs_per_sample=n_runs_per_sample, run_priority=-1, 
                         prompt_string=prompt_string, prompt_template=prompt_template, minimize=minimize, group=group, use_bayes=use_bayes, index=index)
        
        if tokenizer is None:
            self.tokenizer = load_tokenizer(self.model)
        else:
            self.tokenizer = tokenizer
        self.max_length = None
        self.dtype = dtype
        
    def load_model(self, dtype):
        """
        Loads the model for the operation.
        :param dtype: Data type for the model.
        """
        if not isinstance(self.model, str):
            return self.model
        if self.dtype is None:
            return load_model(self.model, dtype=dtype, classification=True)
        return load_model(self.model, dtype=self.dtype, classification=True)
    
    def run(self, tokenized_inputs, loaded_models, model_new_tokens, new_prediction_history, other_tokenizer, **kwargs):
        """
        Runs the classifier on the tokenized inputs.
        :param tokenized_inputs: Inputs that have been tokenized.
        :param loaded_models: Models that have been loaded. The model for this operation is in loaded_models[self.model]
        :param model_new_tokens: Number of new tokens per sample in the batch
        :param new_prediction_history: Prediction History for the batch, this is used to determine the tokens to be used for classification.
        :param other_tokenizer: Tokenizer to be used for the classifier. This is necessary in order to prepare the inputs for the classifier.
        """
        assert all([tokens == 1 for tokens in model_new_tokens]), "model_new_tokens must be 1 for this one, since backtracing is too tricky."
        # NOTE: here, we assume model_new_tokens is always 1 and thus we want to predict the next token in the sequence
        if isinstance(self.model, str):
            model = loaded_models[self.model]
        else:
            model = self.model
        output_formula = [
            self.formula.evaluate(new_prediction_history[i], normalize=True) for i in range(len(new_prediction_history))
        ]
        
        topk_tokens = [
            torch.topk(output_formula[i], k=self.n_runs_per_sample, dim=-1) for i in range(len(output_formula))
        ]
        
        input_samples = []
        for i in range(len(tokenized_inputs.input_ids)):
            input_samples.append(other_tokenizer.decode(tokenized_inputs.input_ids[i].tolist(), skip_special_tokens=True))
            for token in topk_tokens[i].indices:
                input_samples.append(other_tokenizer.decode(tokenized_inputs.input_ids[i].tolist() + [token], skip_special_tokens=True))
        
        if self.max_length is None:
            self.max_length = get_max_length(model.config)
        # -2 in max_length is because of bs behavior by the roberta model
        encoded_samples = self.tokenizer.batch_encode_plus(input_samples, return_tensors="pt", padding=True, truncation=True, max_length=self.max_length - 2).to(model.device)
        
        if "token_type_ids" in encoded_samples:
            del encoded_samples["token_type_ids"]

        if self.batch_size is None:
            batch_size = len(tokenized_inputs) * (1 + self.n_runs_per_sample)
        else:
            batch_size = self.batch_size

        model_outputs = []
        for i in range(0, len(encoded_samples['input_ids']), batch_size):
            batched_input_ids = encoded_samples['input_ids'][i:i+batch_size]
            max_len = max(len(seq) for seq in encoded_samples['input_ids'])
            batched_attention_mask = encoded_samples['attention_mask'][i:i+batch_size]
            with torch.no_grad():
                model_output = model(input_ids=batched_input_ids, attention_mask=batched_attention_mask)
            model_outputs.append(model_output.logits)
            
        model_outputs = torch.cat(model_outputs, dim=0)

        
        if not self.minimize:
            model_outputs_logprobs = torch.log_softmax(model_outputs, dim=-1)[:, self.index]
        elif self.minimize and model_outputs.shape[1] == 2:
            model_outputs_logprobs = torch.log_softmax(model_outputs, dim=-1)[:, 1 - self.index]
        else:
            model_output_probs = torch.softmax(model_outputs, dim=-1)[:, self.index]
            model_output_probs = 1 - model_output_probs
            model_outputs_logprobs = torch.log(torch.max(model_output_probs, torch.tensor([1e-12], device=model_output_probs.device)))
        
        output_logprobs = torch.zeros((len(tokenized_inputs.input_ids), 1, len(other_tokenizer)), device=output_formula[0].device)
        
        for i in range(len(tokenized_inputs.input_ids)):
            # change the topk tokens with factor * (model_token_output - model_no_token_output)
            output_model_sample = model_outputs_logprobs[i * (1 + self.n_runs_per_sample) : (i + 1) * (1 + self.n_runs_per_sample)]
            normal_logprob = output_model_sample[0]
            if not self.use_bayes:
                output_logprobs += normal_logprob 
            
            for j in range(self.n_runs_per_sample):
                output_logprobs[i, -1, topk_tokens[i].indices[j]] = output_model_sample[j + 1]
                if self.use_bayes:
                    output_logprobs[i, -1, topk_tokens[i].indices[j]] -= normal_logprob
        
        if not self.use_bayes:
            output_logprobs = torch.log_softmax(output_logprobs, dim=-1)
        
        return output_logprobs

    def __str__(self):
        return f"Classifier('{self.prompt_string}', model='{self.model}', formula='{self.formula}', n_runs_per_sample={self.n_runs_per_sample}, minimize={self.minimize}, bayes={self.use_bayes})"
    
    def runnable_operators(self):
        """
        Returns a list of runnable operators used by the operator, usually only this operator itself.
        """
        return [self] + self.formula.runnable_operators()

    def norm(self, runnable_operator_outputs=None):
        """
        Returns the norm of the operator. Due to the "- normal_logprob" in the run function, this is always 0.
        """
        if self.use_bayes:
            return 0
        return 1
    
    def is_finished(self, runnable_operator_outputs):
        """
        Returns whether the operation is finished. This is the case if the formula is finished and the class itself has been run.
        """
        return super().is_finished(runnable_operator_outputs) and self.formula.is_finished(runnable_operator_outputs)
    
    
class Vector(RunnableOperator):
    def __init__(self, from_file=None, **kwargs):
        """
        Initializes a vector operator. This is a runnable operator that uses a single vector to generate a probability distribution.
        :param from_file: File from which to load the vector.
        """
        super().__init__(from_file=from_file)
        self.vector = None
        self.norm_ = 0
        
        if self.from_file is not None and os.path.isfile(self.from_file):
            # load torch vector
            self.vector = torch.load(self.from_file)
            self.norm_ = self.vector[0]
            self.vector = self.vector[1:]
    
    def evaluate(self, runnable_operator_outputs : Dict, normalize : bool = True):
        """
        Returns the evaluation of the operation. This is the vector itself.
        """
        # for output in runnable_operator_outputs:
        #     if self.same_operator(output) and runnable_operator_outputs[output] is not None:
        #         return runnable_operator_outputs[output]
        assert self.vector is not None, "Vector must be specified for this one."
        return self.vector # in principle, this is always the output, but if I want to do more advanced stuff, the top thing has to be uncommented
    
    def run(self, tokenized_inputs, model_new_tokens, **kwargs):
        """
        Runs the vector operator on the tokenized inputs.
        :param tokenized_inputs: Inputs that have been tokenized.
        :param model_new_tokens: Number of new tokens per sample in the batch
        """
        assert self.vector is not None, "Vector must be specified for this one."
        return [
            torch.stack([self.vector for _ in range(model_new_tokens[i])], dim=0) for i in range(len(model_new_tokens))
        ]

    def __str__(self):
        return f"Vector('{self.from_file})"

    def norm(self, runnable_operator_outputs=None):
        """
        Returns the norm of the operator.
        """
        return self.norm_
    
    def is_finished(self, runnable_operator_outputs):
        """
        Returns whether the operation is finished. This is always the case.
        """
        return True
    
    def fit(self, prompt_arithmetic, input_sentences, max_tokens=None, save_file=None, batch_size=1):
        """
        Fits the vector operator to a prompt arithmetic model that has a certain formula. The vector is the mean output of the model on the input_sentences
        :param prompt_arithmetic: Prompt arithmetic model to be used for fitting.
        :param input_sentences: Sentences to be used for fitting.
        :param max_tokens: Maximum number of tokens to be used for fitting.
        :param save_file: File to which to save the vector.
        :param batch_size: Batch size to be used for fitting.
        """
        tokenizer = prompt_arithmetic.tokenizer
        tokenized_text = tokenizer(input_sentences, padding=False, truncation=True, max_length=max_tokens).input_ids
        average_vector = 0
        n_tokens = 0
        for i in range(0, len(tokenized_text), batch_size):
            output_logits = prompt_arithmetic.forward(tokenized_text[i:i+batch_size], normalize=False)[0]
            average_vector += torch.sum(torch.sum(output_logits, dim=0), dim=0)
            
            n_tokens += output_logits.shape[0] * output_logits.shape[1]
            
        average_vector /= n_tokens
        self.vector = torch.cat([torch.tensor([prompt_arithmetic.formula.norm()], device=average_vector.device), average_vector])
        if save_file is not None:
            torch.save(self.vector, save_file)
            self.from_file = save_file
        self.norm_ = self.vector[0]
        self.vector = self.vector[1:]
        
        
class SentenceOperator(RunnableOperator):
    def __init__(self, output_sentence, group=None, **kwargs):
        super().__init__(group=group, output_sentence=output_sentence, **kwargs)
        self.tokenized_sentence = None
        self.tokenizer = None
        
    def initialize_after_model_set(self):
        super().initialize_after_model_set()
        self.tokenizer = load_tokenizer(self.model)
        self.tokenized_sentence = self.tokenizer(self.output_sentence).input_ids + [self.tokenizer.eos_token_id]
        
    def evaluate(self, runnable_operator_outputs : Dict, normalize : bool = True):
        """
        Evaluate the operation.
        
        Args:
            runnable_operator_outputs (Dict): Outputs of runnable operators.
            normalize (bool): Whether to normalize the evaluation.
            
        Returns:
            int: The evaluation of the operation.
        """
        for output in runnable_operator_outputs:
            if self.same_operator(output) and runnable_operator_outputs[output] is not None:
                out = torch.zeros(len(self.tokenizer), dtype=torch.float32) - torch.inf
                out[int(runnable_operator_outputs[output])] = 0.0
                out = self.set_to_minimum(out)
                return out
        return 0
        
    def run_on_token_past(self, tokens):
        """
        Outputs a probability distribution for the next token based on the previous tokens.
        """
        output = torch.zeros(len(self.tokenizer), dtype=torch.float32) - torch.inf
        if len(tokens) == 0:
            output[self.tokenized_sentence[0]] = 0.0
        else:
            # get number of overlapping tokens between final tokens and starting tokenized_sentence
            overlap = 0
            token_string = ",".join([str(int(token)) for token in tokens])
            for i in range(1, len(self.tokenized_sentence) + 1):
                str_tokenized = ",".join([str(token) for token in self.tokenized_sentence[:i]])
                if token_string.endswith(str_tokenized):
                    overlap = i
                    
            output[self.tokenized_sentence[overlap]] = 0.0
        
        return self.tokenized_sentence[overlap]
    
    def run_on_sample(self, tokens, n_new_tokens):
        """
        Runs the autocomplete model on a sample for all number of new tokens.
        :param tokens: Tokens to be used as input.
        :param n_new_tokens: Number of new tokens to be generated.
        """
        output = torch.zeros((n_new_tokens,))
        if n_new_tokens > 0:
            output[n_new_tokens - 1] = self.run_on_token_past(tokens)
        for i in range(2, n_new_tokens + 1):
            output[n_new_tokens - i] = self.run_on_token_past(tokens[:-i + 1])
        
        return output
    
    def run(self, tokenized_inputs, model_new_tokens, **kwargs):
        """
        Runs the autocomplete model on the tokenized inputs.
        :param tokenized_inputs: Inputs that have been tokenized.
        :param model_new_tokens: Number of new tokens per sample in the batch
        """
        output = []
        for sample in range(tokenized_inputs.input_ids.shape[0]):
            output.append(self.run_on_sample(tokenized_inputs.input_ids[sample], 
                                             model_new_tokens[sample]))
        return output
    
    def __str__(self):
        return f"ConstantSentence(sentence='{self.output_sentence}', speculative_factor={self.speculative_factor})"
    
    def id(self):
        return f"ConstantSentence(sentence='{self.output_sentence}')"
        

class ClassifierStrength(RunnableOperator):
    def __init__(self, model, batch_size=None, dtype=None, prompt_string="", 
                 prompt_template = lambda prompt_string, input_string: prompt_string +  input_string, group=None, only_input=False,
                 default_output=0.5, speculative_factor=1, tokenizer=None, **kwargs):
        """
        Initializes the classifier operator. This is a runnable operator that uses a classifier to generate a probability distribution.
        :param model: Model to be used for the classifier.
        :param dtype: Data type for the model.
        :param prompt_string: String to be used as a prompt.
        :param prompt_template: Function for generating prompt. Takes two arguments: prompt_string and input_string. The operator will be run on prompt_template(..., ...) + continuation_tokens
        :param group: Group to which the operator belongs. This ensures that speculative sampling will not be tried when not all operators of a group are finished.
        """
        super().__init__(model=model, batch_size=batch_size, run_priority=1, 
                         prompt_string=prompt_string, prompt_template=prompt_template, group=group, outputs_logprobs=False, only_input=only_input, default_output=default_output, 
                         speculative_factor=speculative_factor, **kwargs)
        if tokenizer is None:
            self.tokenizer = load_tokenizer(self.model)
        else:
            self.tokenizer = tokenizer
        self.max_length = None
        self.dtype = dtype
        
    def load_model(self, dtype):
        """
        Loads the model for the operation.
        :param dtype: Data type for the model.
        """
        if not isinstance(self.model, str):
            return self.model
        if self.dtype is None:
            return load_model(self.model, dtype=dtype, classification=True)
        return load_model(self.model, dtype=self.dtype, classification=True)
    
    def evaluate(self, runnable_operator_outputs : Dict, normalize : bool = True):
        """
        Evaluate the operation.
        
        Args:
            runnable_operator_outputs (Dict): Outputs of runnable operators.
            normalize (bool): Whether to normalize the evaluation.
            
        Returns:
            int: The evaluation of the operation.
        """
        for output in runnable_operator_outputs:
            if self.same_operator(output) and runnable_operator_outputs[output] is not None:
                # NOTE: float() might become a problem when batch post processing is implemented
                return float(runnable_operator_outputs[output])
        return self.default_output
    
    def run(self, tokenized_inputs, loaded_models, model_new_tokens, other_tokenizer, tokenized_only_input, **kwargs):
        """
        Runs the classifier on the tokenized inputs.
        :param tokenized_inputs: Inputs that have been tokenized.
        :param loaded_models: Models that have been loaded. The model for this operation is in loaded_models[self.model]
        :param model_new_tokens: Number of new tokens per sample in the batch
        :param other_tokenizer: Tokenizer to be used for the classifier. This is necessary in order to prepare the inputs for the classifier.
        """
        if isinstance(self.model, str):
            model = loaded_models[self.model]
        else:
            model = self.model

        if self.only_input:
            tokenized_inputs = tokenized_only_input

        input_samples = []
        for i in range(len(tokenized_inputs.input_ids)):
            for j in range(model_new_tokens[i]):
                input_samples.append(other_tokenizer.decode(tokenized_inputs.input_ids[i][:-j if j != 0 else None].tolist(), skip_special_tokens=True))

        if self.max_length is None:
            self.max_length = get_max_length(model.config)
        encoded_samples = self.tokenizer.batch_encode_plus(input_samples, return_tensors="pt", padding=True, truncation=True, max_length=self.max_length).to(model.device)

        if "token_type_ids" in encoded_samples:
            del encoded_samples["token_type_ids"]

        if self.batch_size is None:
            batch_size = len(tokenized_inputs)
        else:
            batch_size = self.batch_size

        model_outputs = []
        for i in range(0, len(encoded_samples['input_ids']), batch_size):
            batched_input_ids = encoded_samples['input_ids'][i:i+batch_size]
            batched_attention_mask = encoded_samples['attention_mask'][i:i+batch_size]
            with torch.no_grad():
                model_output = model(input_ids=batched_input_ids, attention_mask=batched_attention_mask, return_dict=True, output_hidden_states=True)
            model_outputs.append(model_output.logits)

        model_outputs = torch.cat(model_outputs, dim=0)
    
        # reshape the model outputs to the required format
        reshaped_model_outputs = []
        start_index = 0
        for tokens in model_new_tokens:
            reshaped_model_outputs.append(model_outputs[start_index:start_index+tokens])
            start_index += tokens

        reshaped_model_outputs_probs = [torch.softmax(output, dim=-1)[:, 1].unsqueeze(1) for output in reshaped_model_outputs]

        return reshaped_model_outputs_probs

    def __str__(self):
        return f"ClassifierStrength('{self.prompt_string}', model='{self.model}')"

    def norm(self, runnable_operator_outputs=None):
        """
        Returns the norm of the operator. Due to the "- normal_logprob" in the run function, this is always 0.
        """
        for output in runnable_operator_outputs:
            if self.same_operator(output) and runnable_operator_outputs[output] is not None:
                return float(runnable_operator_outputs[output])
        return self.default_output