from .base_computer import BaseComputer
from .quality_computer import QualityComputer
import numpy as np
from sklearn.linear_model import LogisticRegression
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import itertools
import json
from tqdm import tqdm
import os

class ClassificationCostComputer(BaseComputer):
    def __init__(self, input_costs, output_costs, tokenizers=None, 
                 tokenize=True, n_output_tokens=1, 
                 constant_cost=False, store_all=False):
        """
        Initialize the Classification Computer object.

        Args:
            input_costs (list): The input costs per token for each model.
            output_costs (list): The output costs per token for each model.
            tokenizers (list, optional): The tokenizers for each model. Defaults to None.
            tokenize (bool, optional): Whether to tokenize. Defaults to True.
            n_output_tokens (int, optional): The number of output tokens. Defaults to 1.
            constant_cost (bool, optional): Whether to always output constant costs for each model. Defaults to False.
            store_all (bool, optional): Whether to store all predictions. Speeds up prediction at the cost of memory. 
                                        Defaults to False.
        """
        super().__init__()
        self.input_costs = input_costs
        self.output_costs = output_costs
        self.tokenizers = tokenizers
        self.tokenize = tokenize
        self.n_output_tokens = n_output_tokens
        self.constant_cost = constant_cost
        self.store_all = store_all
        self.computed_costs = []
        assert tokenizers is not None or not tokenize

    def fit(self, questions, model_answers, measure=None):
        self.constant_costs = []
        for model in range(len(model_answers[0])):
            self.constant_costs.append(
                np.mean(measure[:, model])
            )
            self.computed_costs.append(dict())

    def predict(self, questions, model_answers):
        length_models = len(model_answers[0])

        all_costs = []
        for model in range(length_models):
            costs = []
            for question in questions:
                if not isinstance(question, str):
                    question = question[0]
                if (self.training or self.store_all) and question in self.computed_costs[model]:
                    costs.append(self.computed_costs[model][question])
                    continue
                elif not self.tokenize:
                    tokenized_question = question
                else:
                    tokenized_question = self.tokenizers[model]([question], padding=False)['input_ids'][0]

                if self.constant_cost:
                    cost = self.constant_costs[model]
                else:
                    cost = self.input_costs[model] * len(tokenized_question)
                    cost += self.output_costs[model] * self.n_output_tokens # one output token
                costs.append(cost)
                if self.training or self.store_all:
                    self.computed_costs[model][question] = cost

            all_costs.append(costs)
        return np.array(all_costs).T, None
    

class ClassificationQualityComputer(QualityComputer):
    def __init__(self, model_class=LogisticRegression, 
                 n_highest_include=1, require_constant_not_run=False, baseline=False, 
                 include_question_embedding=False, do_cosine_similarity=False, 
                 question_indicator=r'Question:', answer_indicator=r'Answer:', 
                 remove_options=['\nA:', '\nA.'], 
                 sentence_embedder="sentence-transformers/all-MiniLM-L6-v2", 
                 batch_size=32, include_question_length=True, 
                 is_regression=False, all_model_combinations=True, 
                 add_entropy=True, add_js_divergence=True, 
                 add_equal_argmax=True, max_depth=None, 
                 lookup_file_name=None, n_samples=100, 
                 store_all=False, include_all_models=False):
        """
        Initializes the ClassificationSelection object.

        Args:
            model_class (class): The class of the prediction model to be used. Default is LogisticRegression.
            n_highest_include (int): The number of highest class probabilities to include in the features. 
                                    Default is 1.
            require_constant_not_run (bool): Whether to require constant predictions for uncomputed models. 
                                            Default is False.
            baseline (bool): Whether to use baseline settings. 
                            Default is False.
            include_question_embedding (bool): Whether to include question embeddings as features. 
                                            Default is False.
                                                
            do_cosine_similarity (bool): Whether to calculate the cosine similarity feature. 
                                        The cosine similarity feature computes the weighted average 
                                        of the quality across the training dataset based on cosine 
                                        similarity between current question and questions in 
                                        the training data. Default is False.
            question_indicator (str): The indicator for the "Question" part of a classification question. 
                                        Used for filtering and computing question length.
                                        Default is 'Question:'.
            answer_indicator (str): The indicator for the "Answer" part of the classification question. 
                                    Used for filtering and computing question length.
                                    Default is 'Answer:'.
            remove_options (list): The indicators of the options to remove from the text. 
                                    Default is ['\nA:', '\nA.'].
            sentence_embedder (str): The name of the sentence embedder. 
                                    Default is 'sentence-transformers/all-MiniLM-L6-v2'.
            batch_size (int): The batch size for processing. Default is 32.
            include_question_length (bool): Whether to include question length as feature. Default is True.
            is_regression (bool): Whether the model class is a regression model. Default is False.
            all_model_combinations (bool): Whether to compute separate models for each different history of computed models. 
                                            Default is True.
            add_entropy (bool): Whether to add entropy as feature. Default is True.
            add_js_divergence (bool): Whether to add Jensen-Shannon divergence between model answers as feature. Default is True.
            add_equal_argmax (bool): Whether to add equal prediction between model answers as feature. Default is True.
            max_depth (int): The maximum depth for the cascade router. Default is None.
            lookup_file_name (str): The name of the lookup file where sentence embeddings are stored. 
                                    Can be used to speed up computation.
                                    Default is None.
            n_samples (int): The number of samples to compute max(q_1, ..., q_n). Default is 100.
            store_all (bool): Whether to store all results. 
                            Speeds up prediction at the cost of memory.
                            Default is False.
            include_all_models (bool): Whether to include all the uncertainty features of all models for 
                                        each model as features (and not just only of the model itself). 
                                        Default is False.
        """
        super().__init__(n_samples=n_samples)
        if baseline:
            n_highest_include = 1
            require_constant_not_run = True
            model_class = LogisticRegression
        self.model_class = model_class
        self.models = None
        self.n_highest_include = n_highest_include
        self.sigma_per_n_models_run = None
        self.require_constant_not_run = require_constant_not_run
        self.baseline = baseline
        self.constant_qualities = []
        self.include_question_embedding = include_question_embedding
        self.do_cosine_similarity = do_cosine_similarity
        self.question_indicator = question_indicator
        self.answer_indicator = answer_indicator
        self.batch_size = batch_size
        self.remove_options = remove_options
        self.include_question_length = include_question_length
        self.min_length = None
        self.max_length = None
        self.is_regression = is_regression
        self.all_model_combinations = all_model_combinations
        self.add_entropy = add_entropy
        self.add_js_divergence = add_js_divergence
        self.add_equal_argmax = add_equal_argmax
        self.max_depth = max_depth
        self.lookup_file_name = lookup_file_name
        self.store_all = store_all
        self.include_all_models = include_all_models
        self.lookup_embeddings = None
        if self.lookup_file_name is not None and os.path.exists(self.lookup_file_name):
            self.lookup_embeddings = json.load(open(self.lookup_file_name, 'r'))
            self.lookup_embeddings = {key: np.array(value) for key, value in self.lookup_embeddings.items()}
        elif self.lookup_file_name is not None and not os.path.exists(self.lookup_file_name):
            self.lookup_file_name = None

        self.sentence_embedder = None
        self.sentence_embedder_name = sentence_embedder

        self.question_embeddings = None
        self.qualities_embeddings = None
        self.training_sentence_embeddings = dict()
        self.training_cosine_similarities = dict()
        self.question_predictions = dict()

    @property
    def is_independent(self):
        return False

    def entropy(self, p):
        """
        Calculate the entropy of a probability distribution.

        Args:
            p (numpy.ndarray): The probability distribution.

        Returns:
        float: The entropy value.
        """
        return -np.sum(p * np.log2(np.maximum(p, 1e-16)))
    
    def kl_divergence(self, p, q):
        """
        Calculates the Kullback-Leibler divergence between two probability distributions.

        Args:
            p (numpy.ndarray): The first probability distribution.
            q (numpy.ndarray): The second probability distribution.

        Returns:
            float: The Kullback-Leibler divergence between p and q.
        """
        return np.sum(p * np.log2(np.maximum(p, 1e-16) / np.maximum(q, 1e-16)))
    
    def js_divergence(self, p, q):
        """
        Calculates the Jensen-Shannon divergence between two probability distributions.

        Args:
            p: numpy array or list, representing the first probability distribution.
            q: numpy array or list, representing the second probability distribution.

        Returns:
            js_div: float, the Jensen-Shannon divergence between p and q.
        """
        m = (p + q) / 2
        return (self.kl_divergence(p, m) + self.kl_divergence(q, m)) / 2
    
    def parse_question(self, question, remove_options=True):
        """
        Parses the given question and returns the extracted question text.

        Args:
            question (str or list): The question to be parsed. 
                                    If a list is provided, the first element will be used.
            remove_options (bool): Flag indicating whether to remove options from the question. 
                                    Default is True.

        Returns:
            str: The extracted question text.
        """
        if not isinstance(question, str):
            question = question[0]
        question = question.split(self.question_indicator)[-1]
        if self.remove_options is not None and remove_options:
            for remove_option in self.remove_options:
                question = question.split(remove_option)[0].strip()
        question = question.split(self.answer_indicator)[0].strip()
        return question
    
    def compute_sentence_embeddings(self, questions):
        """
        Computes the sentence embeddings for a given list of questions.
        Args:
            questions (list): A list of questions to compute embeddings for.
        Returns:
            numpy.ndarray: An array of computed sentence embeddings.
        Raises:
            None
        """
        changed_questions = []
        embeddings = []
        for question in questions:
            if not isinstance(question, str):
                question = question[0]
            # remove answer
            question = self.parse_question(question)
            changed_questions.append(question)
        
            if (self.training or self.store_all) and question in self.training_sentence_embeddings:
                embeddings.append(self.training_sentence_embeddings[question])
            elif self.lookup_embeddings is not None and question in self.lookup_embeddings:
                embeddings.append(self.lookup_embeddings[question])
            elif (self.training or self.store_all) or self.lookup_embeddings is not None:
                if self.sentence_embedder is None:
                    self.sentence_embedder = SentenceTransformer(self.sentence_embedder_name)
                embedding = self.sentence_embedder.encode([question], batch_size=self.batch_size)[0]
                embeddings.append(embedding)
                if self.training or self.store_all:
                    self.training_sentence_embeddings[question] = embedding

        if self.training or self.store_all or self.lookup_embeddings is not None:
            return np.array(embeddings)
        else:
            if self.sentence_embedder is None:
                self.sentence_embedder = SentenceTransformer(self.sentence_embedder_name)
            return self.sentence_embedder.encode(changed_questions, batch_size=self.batch_size)
    
    def compute_cosine_similarity(self, new_questions, new_question_embeddings):
        """
        Compute the cosine similarity between new questions and question embeddings of the training data.

        Args:
            new_questions (list): A list of new questions.
            new_question_embeddings (ndarray): An array of embeddings for the new questions.

        Returns:
            ndarray: An array of estimated qualities based on the cosine similarity.

        """
        estimated_qualities = []
        for question, embedding in zip(new_questions, new_question_embeddings):
            if not isinstance(question, str):
                question = question[0]
            if (self.training or self.store_all) and question in self.training_cosine_similarities:
                estimated_qualities.append(self.training_cosine_similarities[question])
            else:
                cosine_similarity_q = cosine_similarity(embedding.reshape(1, -1), 
                                                        self.question_embeddings)
                estimated_qualities.append(
                    np.sum(cosine_similarity_q.reshape(-1, 1) * self.qualities_embeddings, axis=0) / np.sum(cosine_similarity_q)
                )
                if self.training or self.store_all:
                    self.training_cosine_similarities[question] = estimated_qualities[-1]
        return np.array(estimated_qualities)

    def fit(self, questions, model_answers, measure):
        self.min_length = min([len(self.parse_question(question)) for question in questions])
        self.max_length = max([len(self.parse_question(question)) for question in questions])
        n_models = len(model_answers[0])
        if not self.all_model_combinations:
            self.models = [[[self.model_class(), self.model_class()] 
                            for _ in range(n_models + 1)] for _ in range(n_models)]
        else:
            self.models = [dict() for _ in range(n_models)]
    
        X, X_all_models, y, for_model, all_n_models_run, all_models_run = self.prepare_data(questions, 
                                                                                            model_answers, 
                                                                                            measure, 
                                                                                            n_models)
        y_pred_all = np.zeros((len(X) // n_models, n_models))
        y_pred_all_models = np.zeros((len(X) // n_models, n_models))

        for model in range(n_models):
            models_to_fit = np.unique(all_models_run)
            if not self.all_model_combinations:
                self.train_n_answers_model(n_models, X, y, for_model, 
                                           all_n_models_run, all_models_run, y_pred_all, model, 1)
                self.train_n_answers_model(n_models, X, y, for_model, 
                                           all_n_models_run, all_models_run, y_pred_all, model, 0)
                indices_all = [i for i in range(len(X)) if for_model[i] == model]
                y_pred_all_models[:, model] = self.predict_regression_or_not(self.models[model][-1][1], 
                                                                             X_all_models[indices_all])
            else:
                for models_run_string in tqdm(models_to_fit, desc=f'Model {model}'):
                    self.models[model][models_run_string] = self.model_class()
                    indices_run = [i for i in range(len(X)) 
                                   if all_models_run[i] == models_run_string and for_model[i] == model]
                    X_here = np.array([X[i] for i in indices_run])
                    y_here = np.array([y[i] for i in indices_run])
                    self.models[model][models_run_string].fit(X=X_here, y=y_here)
                    if self.is_regression:
                        y_pred = self.models[model][models_run_string].predict(X_here)
                    else:
                        y_pred = self.models[model][models_run_string].predict_proba(X_here)[:, 1]
                    
                    indices_pred_all = [i // n_models for i in indices_run]
                    y_pred_all[indices_pred_all, model] = y_pred

                indices_all = [i for i in range(len(X)) if for_model[i] == model]
                y_pred_all_models[:, model] = self.predict_regression_or_not(self.models[model][','.join([str(i) for i in range(n_models)])], X_all_models[indices_all])

        self.compute_sigma(n_models, all_n_models_run, 
                           all_models_run, y_pred_all, 
                           y_pred_all_models, models_to_fit)

    def train_n_answers_model(self, n_models, X, y, for_model, all_n_models_run, 
                              all_models_run, y_pred_all, model, model_answered):
        """
        Trains a model in the scenario self.all_model_combinations is False.
        In this case, one model is trained for each number of models run in the history.

        Args:
            n_models (int): The number of models that were computed.
            X (array-like): The input data.
            y (array-like): The target values.
            for_model (array-like): The array indicating which models to predict the measure for.
            all_n_models_run (array-like): The array indicating the number of models run.
            all_models_run (array-like): The array indicating the models that have been run.
            y_pred_all (array-like): The array to store the predicted values.
            model (object): The model to train.
            model_answered (bool): Flag indicating if the model has answered the question.

        Returns:
            None
        """
        for n_models_run in range(n_models + 1):
            if model_answered:
                indices_run = [i for i in range(len(X)) 
                               if all_n_models_run[i] == n_models_run and 
                               str(model) in all_models_run[i].split(',') and 
                               for_model[i] == model]
            else:
                indices_run = [i for i in range(len(X)) if all_n_models_run[i] == n_models_run and 
                               str(model) not in all_models_run[i].split(',') and 
                               for_model[i] == model]
            X_here = np.array([X[i] for i in indices_run])
            y_here = np.array([y[i] for i in indices_run])
            if len(X_here) > 0:
                self.models[model][n_models_run][model_answered].fit(X=X_here, y=y_here)
                y_pred = self.predict_regression_or_not(self.models[model][n_models_run][model_answered], 
                                                        X_here)
                indices_pred_all = [i // n_models for i in indices_run]
                y_pred_all[indices_pred_all, model] = y_pred

    def compute_sigma(self, n_models, all_n_models_run, all_models_run, 
                      y_pred_all, y_pred_all_models, models_to_fit):
        """
        Compute the deviation of the predicted values from the actual values.

        Parameters:
            n_models (int): The number of models.
            all_n_models_run (numpy.ndarray): Array containing the number of models run for each iteration.
            all_models_run (numpy.ndarray): Array containing the models run for each iteration.
            y_pred_all (numpy.ndarray): Array containing the predicted values for all iterations.
            y_pred_all_models (numpy.ndarray): Array containing the predicted values for all models and iterations.
            models_to_fit (list): List of models to fit.
        """
        all_n_models_run_single = np.array([all_n_models_run[i] 
                                            for i in range(0, len(all_n_models_run), n_models)])
        all_models_run_single = np.array([all_models_run[i] 
                                          for i in range(0, len(all_models_run), n_models)])
        if not self.all_model_combinations:
            self.sigma_per_n_models_run = []
            for i in range(n_models + 1):
                diff = y_pred_all[all_n_models_run_single == i] - y_pred_all_models[all_n_models_run_single == i]
                self.sigma_per_n_models_run.append(np.cov(diff.T))
        else:
            self.sigma_per_n_models_run = dict()
            for models_run_string in models_to_fit:
                diff = y_pred_all[all_models_run_single == models_run_string] - y_pred_all_models[all_models_run_single == models_run_string]
                self.sigma_per_n_models_run[models_run_string] = np.cov(diff.T)

    def prepare_data(self, questions, model_answers, measure, n_models):
        """
        Prepare the data for fitting.
        Args:
            questions (list): List of questions.
            model_answers (list): List of model answers.
            measure (list): List of measures.
            n_models (int): Number of models.
        Returns:
            tuple: A tuple containing the following arrays:
                - X (ndarray): Input data for each model.
                - X_all_models (ndarray): Input data for all models.
                - y (ndarray): Output data.
                - for_model (ndarray): Model index for each data point.
                - all_n_models_run (ndarray): Number of models used for each data point.
                - all_models_run (ndarray): String representation of models used for each data point.
        """
        if self.do_cosine_similarity:
            self.question_embeddings = self.compute_sentence_embeddings(questions)
            self.qualities_embeddings = np.array(measure)
            self.cosine_similarities_train = self.compute_cosine_similarity(questions, 
                                                                            self.question_embeddings)
        X = []
        X_all_models = []
        y = []
        for_model = []
        all_n_models_run = []
        all_models_run = []

        for model in range(n_models):
            self.constant_qualities.append(np.mean([measure[i][model] for i in range(len(questions))]))

        for i in range(len(questions)):
            for n_models_run in range(n_models + 1):
                if self.max_depth is not None and n_models > n_models_run > self.max_depth:
                    continue
                for models_run in itertools.combinations(range(n_models), n_models_run):
                    models_run_string = ','.join([str(model) for model in sorted(models_run)])
                    
                    models_answers_sample = [answer if model in models_run else None 
                                             for model, answer in enumerate(model_answers[i])]
                    measure_sample = measure[i]
                    
                    for model in range(n_models):
                        X_sample, y_sample = self.generate_sample_input_output(questions[i], model, 
                                                                               n_models, 
                                                                               models_answers_sample, 
                                                                               measure_sample, i)
                        X_sample_all_models, _ = self.generate_sample_input_output(questions[i], model, 
                                                                                   n_models, 
                                                                                   model_answers[i], 
                                                                                   measure_sample, i)
                        y.append(y_sample)
                        X.append(X_sample)
                        X_all_models.append(X_sample_all_models)
                        all_n_models_run.append(len(models_run))
                        all_models_run.append(models_run_string)
                        for_model.append(model)
                    
                    if self.baseline:
                        break

        X_all_models = np.array(X_all_models)
        y = np.array(y)
        all_n_models_run = np.array(all_n_models_run)
        all_models_run = np.array(all_models_run)
        for_model = np.array(for_model)
        return X,X_all_models,y,for_model,all_n_models_run,all_models_run
    
    def predict_regression_or_not(self, model, X):
        """
        Predicts the target variable using the given model.

        Parameters:
            model (object): The trained model used for prediction.
            X (array-like): The input features for prediction.

        Returns:
            array-like: The predicted target variable values.
        """
        if self.is_regression:
            return model.predict(X)
        else:
            return model.predict_proba(X)[:, 1]

    def predict(self, questions, model_answers):
        n_models = len(model_answers[0])
        n_models_answered = np.array([
            len([model_answer for model_answer in model_answers[i] if model_answer is not None]) 
            for i in range(len(questions))
        ])
        all_models_run_strings = np.array([','.join([str(i) for i in range(n_models) 
                                                     if model_answers[j][i] is not None]) 
                                            for j in range(len(questions))]) 
        y = np.zeros((len(questions), n_models))

        for model in range(n_models):
            y_model_done = np.zeros(len(questions)).astype(bool)
            if self.training or self.store_all:
                for i in range(len(questions)):
                    models_run = all_models_run_strings[i]
                    question = questions[i]
                    if not isinstance(question, str):
                        question = question[0]
                    question_prediction = self.question_predictions.get(model, 
                                                                        dict()).get(models_run, dict()).get(question, None)
                    if question_prediction is not None:
                        y[i, model] = question_prediction
                        y_model_done[i] = True
            
            y_model = np.zeros(np.count_nonzero(np.logical_not(y_model_done)))
            X_model = [self.generate_sample_input_output(questions[i], model, 
                                                         n_models, model_answers[i])[0] 
                        for i in range(len(questions)) if not y_model_done[i]]
            model_answers_here = [model_answers[i] for i in range(len(questions)) if not y_model_done[i]]
            n_models_answered_here = n_models_answered[np.logical_not(y_model_done)]
            if not self.all_model_combinations:
                for n_answers in range(n_models + 1):
                    self.predict_n_answers(model_answers_here, n_models_answered_here, 
                                           model, y_model, X_model, n_answers, 1)
                    self.predict_n_answers(model_answers_here, n_models_answered_here, 
                                           model, y_model, X_model, n_answers, 0)
            else:
                models_run_strings = all_models_run_strings[np.logical_not(y_model_done)]
                for models_run_string in self.models[model].keys():
                    indices = np.where(models_run_string == models_run_strings)[0]
                    X = [X_model[i] for i in indices]
                    if len(indices) == 0:
                        continue
                    y_model[indices] = self.predict_regression_or_not(self.models[model][models_run_string], X)

            if self.require_constant_not_run:
                for i in range(len(y_model)):
                    if model_answers_here[i][model] is None:
                        y_model[i] = self.constant_qualities[model]
            
            y[np.logical_not(y_model_done), model] = y_model

            if self.training or self.store_all:
                for i in range(len(questions)):
                    models_run = all_models_run_strings[i]
                    if model not in self.question_predictions:
                        self.question_predictions[model] = dict()
                    if models_run not in self.question_predictions.get(model, dict()):
                        self.question_predictions[model][models_run] = dict()
                    
                    question = questions[i]
                    if not isinstance(question, str):
                        question = question[0]
                    self.question_predictions[model][models_run][question] = y[i, model]

        if not self.all_model_combinations:
            return y, np.array([self.sigma_per_n_models_run[n_models_answered[i]] 
                                for i in range(len(questions))])
        elif not self.baseline:
            return y, np.array([self.sigma_per_n_models_run[all_models_run_strings[i]] 
                                for i in range(len(questions))])
        else:
            return y, [None] * len(questions)

    def predict_n_answers(self, model_answers, n_models_answered, model, y_model, X_model, n_answers, model_answered):
        """
        Predicts the answers for a given model and number of answers.
        Args:
            model_answers (list): List containing model answers.
            n_models_answered (numpy.ndarray): Array of the number of models answered for each question.
            model (int): Index of the model to predict the answers for.
            y_model (numpy.ndarray): Array of the model answers.
            X_model (list): List of input features for the model.
            n_answers (int): Number of answers to predict.
            model_answered (bool): Flag indicating whether the model has already answered.
        """
        
        if model_answered:
            indices = np.where(np.logical_and(n_models_answered == n_answers, 
                                              [answer[model] is not None for answer in model_answers]))[0]
        else:
            indices = np.where(np.logical_and(n_models_answered == n_answers, 
                                              [answer[model] is None for answer in model_answers]))[0]
        X = [X_model[i] for i in indices]
        if len(indices) > 0:
            y_model[indices] = self.predict_regression_or_not(self.models[model][n_answers][model_answered], X)

    def base_features(self, question, index, model):
        """
        Generate a list of base features for a given question, index, and model.

        Parameters:
           question (str or tuple): The question to generate features for. 
                                    If a tuple is provided, the first element is the question string 
                                    and the remaining elements are additional features.
            index (int or None): The index of the question in the training dataset. 
                                If None, the question is not in the dataset.
            model (str): The name of the model.

        Returns:
            features (list): A list of features for the given question, index, and model.
        """
        features = []
        if not isinstance(question, str):
            question, additional_features = question[0], question[1:]
            features.extend(additional_features)

        question_here = self.parse_question(question, remove_options=False)
        n_options = sum([f'\n{x}:' in question or f'\n{x}.' in question_here for x in 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'])
        features.append(1 / (max(n_options, 1)))
        if self.include_question_length:
            question_here = self.parse_question(question, remove_options=True)
            length = len(question_here)
            normalized_length = (length - self.min_length) / (self.max_length - self.min_length + 1)
            features.append(normalized_length)
        if self.do_cosine_similarity:
            if index is not None:
                question_embedding = self.question_embeddings[index].reshape(1, -1)
            else:
                question_embedding = self.compute_sentence_embeddings([question])
            if self.do_cosine_similarity:
                if index is not None:
                    features.append(self.cosine_similarities_train[index][model])
                else:
                    features.append(self.compute_cosine_similarity([question], question_embedding)[0][model])
            if self.include_question_embedding:
                features.extend(list(question_embedding[0]))
        return features

    def agreement_features(self, n_models, models_answers_sample):
        """
        Calculates agreement features between models' answers.

        Args:
            n_models (int): The number of models.
            models_answers_sample (list): A list of models' answers.

        Returns:
            list: A list of agreement features.

        """
        features = []
        for i in range(n_models):
            for j in range(i + 1, n_models):
                if models_answers_sample[i] is not None and models_answers_sample[j] is not None:
                    if self.add_js_divergence:
                        features.append(self.js_divergence(models_answers_sample[i], 
                                                           models_answers_sample[j]))
                    if self.add_equal_argmax:
                        features.append(float(np.argmax(models_answers_sample[i]) == np.argmax(models_answers_sample[j])))
                elif not self.all_model_combinations:
                    if self.add_js_divergence:
                        features.append(0)
                    if self.add_equal_argmax:
                        features.append(0)
        return features

    def certainty_features(self, model, models_answers_sample):
        """
        Calculate the certainty features for a given model and models_answers_sample.

        Parameters:
        - model: The index of the model for which to calculate the certainty features.
        - models_answers_sample: A list of model answers for each model.

        Returns:
        - A list of certainty features for the given model.

        Raises:
        - None.

        """
        if self.include_all_models:
            returned_features = []
            for other_model in range(len(models_answers_sample)):
                if models_answers_sample[other_model] is None and not self.all_model_combinations:
                    returned_features += [0 for _ in range(8)]
                elif models_answers_sample[other_model] is not None:
                    model_answer_highest = sorted(models_answers_sample[other_model], key=lambda x: -x)[:self.n_highest_include]
                    if len(model_answer_highest) < self.n_highest_include:
                        for _ in range(self.n_highest_include - len(model_answer_highest)):
                            model_answer_highest.append(0)
                    if self.add_entropy:
                        model_answer_highest.append(self.entropy(models_answers_sample[other_model]))
                    returned_features += model_answer_highest
            return returned_features
        if models_answers_sample[model] is None:
            return []
        else:
            model_answer_highest = sorted(models_answers_sample[model], key=lambda x: -x)[:self.n_highest_include]
            if len(model_answer_highest) < self.n_highest_include:
                for _ in range(self.n_highest_include - len(model_answer_highest)):
                    model_answer_highest.append(0)
            if self.add_entropy:
                model_answer_highest.append(self.entropy(models_answers_sample[model]))
            return model_answer_highest
        
    def baseline_metrics(self, models_answers_sample, model):
        """
        Calculate the baseline metrics for a given model.

        Parameters:
        - models_answers_sample (dict): A dictionary containing the answers of different models.
        - model (str): The name of the model for which to calculate the baseline metrics.

        Returns:
        - list: A list containing the maximum value of the answers for the given model.
        """
        return [np.max(models_answers_sample[model])]

    def generate_sample_input_output(self, question, model, n_models, models_answers_sample, 
                                     measure_sample=None, index=None):
        """
        Generates a sample input and output for model selection.

        Args:
            question (str): The question for which the sample input and output are generated.
            model (int): The index of the model being evaluated.
            n_models (int): The total number of models.
            models_answers_sample (list): A list of model answers for the sample.
            measure_sample (list, optional): A list of measures for the sample. Defaults to None.
            index (int, optional): The index of the question. Defaults to None.

        Returns:
            tuple: A tuple containing the sample input and output.
        """
        X_sample = []
        if self.baseline:
            X_sample = [0]
            if any([model_answer is not None for model_answer in models_answers_sample]):
                for model in range(n_models - 1, -1, -1):
                    if models_answers_sample[model] is not None:
                        X_sample = self.baseline_metrics(models_answers_sample, model)
                        break
        else:
            X_sample += self.base_features(question, index, model)
            X_sample += self.agreement_features(n_models, models_answers_sample)
            X_sample += self.certainty_features(model, models_answers_sample)
            if len(X_sample) == 0:
                X_sample = [0]

        if measure_sample is not None:
            return X_sample, measure_sample[model]
        return X_sample, None
