#######################################################
######## Helper Functions for Model Eval ##############
#######################################################

import numpy as np
from collections import Counter
import math
import multiprocessing
from multiprocessing import Pool
from tqdm import tqdm
from constants import *
import re
import torch
import random
import torch.nn.functional as F
import Levenshtein as levenshtein
from transformers import BartTokenizer, T5Tokenizer
from scipy.stats import zscore, gmean
from scipy.spatial.distance import cosine, euclidean
from ot import emd2
import warnings
import logomaker
import matplotlib.pyplot as plt
from nltk.translate.bleu_score import sentence_bleu, corpus_bleu

class CharBleu:
    """
    Calculate the BLEU score for a collection of candidate strings 
    with their respective set of reference strings. Has the following 
    adaptations to make it work with the TCR:pMHC dataset:
        1. Restrict the reference strings to only the top k closest matches
        (this is because arbitrarily long reference list will lead to poor BLEU scores)

        2. Calculate the BLEU score for only the peptide sequence (this is because
        the pseudo sequence is not a true sequence and will overestimate BLEU scores)    
    """
    def __init__(self, tokenizer, max_n=4):
        self.tokenizer = tokenizer
        self.max_n = max_n

    
    def ngrams(tokenized_string, n):
        """ Get the list of ngrams from a tokenized string."""
        return [tuple(tokenized_string[i:i+n]) for i in range(len(tokenized_string)-n+1)]

    
    def modified_precision(self, candidate, references, n):
        """Calculate the modified n-gram precision for a candidate sequence against a list of reference sequences."""
        candidate_ngrams = Counter(CharBleu.ngrams(candidate, n))
        max_reference_ngrams = Counter()

        for reference in references:
            reference_ngrams = Counter(CharBleu.ngrams(reference, n))
            # Calculate the maximum number of times each n-gram occurs in any single reference
            # This is used to clip the N-grams so that we don't overestimate the precision by repetitive generations
            for ngram in candidate_ngrams:
                max_reference_ngrams[ngram] = max(max_reference_ngrams[ngram], reference_ngrams[ngram])

        matches = sum(min(candidate_ngrams[ng], max_reference_ngrams[ng]) for ng in candidate_ngrams)
        total = sum(candidate_ngrams.values())

        return matches / total if total > 0 else 0
    
    def brevity_penalty(candidate, references):
        """Apply the brevity penalty to weight the geometric mean of the n-gram precisions."""
        c = len(candidate)
        r = min(abs(len(ref) - c) for ref in references)
        
        if c > r:
            return 1
        elif c == 0:
            return 0
        else:
            return math.exp(1 - r/c)
        
    def __call__(self, candidate_string, reference_strings, min_n=1, max_n=4):
        """BLEU-4 score for a candidate string with respect to a list of reference strings."""
        # Tokenize the candidate and references
        candidate = self.tokenizer(candidate_string)
        
        # Check if HF tokenizer
        if not isinstance(candidate, np.ndarray):
            candidate = candidate['input_ids']
            references = [self.tokenizer(ref)['input_ids'] for ref in reference_strings]
        else:
            references = [self.tokenizer(ref) for ref in reference_strings]
        
        # Calculate the mod_preicison scores for each n
        precisions = [self.modified_precision(candidate, references, n) for n in range(min_n, max_n + 1)]
        
        if min(precisions) == 0:
            return 0
        
        # Calculate the brevity penalty 
        bp = CharBleu.brevity_penalty(candidate, references)
        # Calculate the geometric mean of the precisions
        geo_mean = math.exp(sum(math.log(p) for p in precisions) / max_n)
        
        return bp * geo_mean
    

    def corpus_bleu(self, candidate_corpus, reference_corpus, max_ngram=4, summary='mean'):
        """BLEU-4 score for a collection of candidate strings with their respective set of reference strings."""
        # Calculate the BLEU score for each candidate-reference pair
        bleu_scores = [self(candidate, reference_list, max_ngram) for candidate, reference_list in zip(candidate_corpus, reference_corpus)]
        
        # Apply the specified summary method
        if summary == 'mean':
            return sum(bleu_scores) / len(bleu_scores)
        elif summary == 'median':
            return np.median(bleu_scores)
        elif summary == 'max':
            return max(bleu_scores)
        elif summary == 'min':
            return min(bleu_scores)
        elif summary == 'range':
            return (min(bleu_scores), max(bleu_scores))
        else:
            return bleu_scores

    def find_n_closest_matches(self, query, references, n):
        distances = [(ref, levenshtein.distance(query, ref)) for ref in references]
        distances.sort(key=lambda x: x[1])
        return [d[0] for d in distances[:n]]
    
class HuggingFaceModelAdapter:
    def __init__(self, hf_tokenizer, hf_model, **kwargs):
        self.tokenizer = hf_tokenizer
        self.model = hf_model
        #self.charbleu = CharBleu(hf_tokenizer)
        self.use_task_prefix = kwargs.get('use_task_prefix', False)

    def format_input(self, source):
        if hasattr(source, 'peptide'):
            # Source is a pMHC and target is a TCR (pMHC -> TCR)
            if isinstance(self.tokenizer, BartTokenizer):
                src = self.tokenizer(source.peptide, source.pseudo, return_tensors='pt').to(self.model.device)
            elif isinstance(self.tokenizer, T5Tokenizer):
                seq = f'[PMHC]{source.peptide}{self.tokenizer.sep_token}{source.pseudo}'
                src = self.tokenizer(seq, return_tensors='pt').to(self.model.device)
                if not self.use_task_prefix:
                    seq = f'{source.peptide}{self.tokenizer.sep_token}{source.pseudo}'
                    src = self.tokenizer(seq, return_tensors='pt').to(self.model.device)
            else:
                raise ValueError("This tokenizer has not been implemented or used in training.")
            return src
        elif hasattr(source, 'cdr3b'):
            # Source is a TCR and target is a pMHC (TCR -> pMHC)
            if isinstance(self.tokenizer, BartTokenizer):
                src = self.tokenizer(source.cdr3b, return_tensors='pt').to(self.model.device)
            elif isinstance(self.tokenizer, T5Tokenizer):
                seq = f'[TCR]{source.cdr3b}'
                src = self.tokenizer(seq, return_tensors='pt').to(self.model.device)
                if not self.use_task_prefix:
                    src = self.tokenizer(source.cdr3b, return_tensors='pt').to(self.model.device)
            else:
                raise ValueError("This tokenizer has not been implemented or used in training.")
            return src
        else:
            print("This adapter must be used with a TCRpMHCDataset object yielding TCR and pMHC.")
            return
        
    def format_output(self, trg):
        pattern = r'\[.*?\]'
        # Use re.sub to remove the matched text (including brackets)
        result = re.sub(pattern, '', trg)
        return result
    
    def evaluate_loss(self, dataset, bsz=512):
        self.model.eval()
        
        src_list = []
        trg_list = []
        for i in range(len(dataset)):
            src, trg = dataset.__getitem__(i)
            src_list += [(src.peptide, src.pseudo)] if isinstance(self.tokenizer, BartTokenizer) else [f'{src.peptide}{self.tokenizer.sep_token}{src.pseudo}']
            trg_list += [trg.cdr3b]

        batches = 0
        cum_loss = 0

        for i in tqdm(range(len(dataset)//bsz+1), 'Evaluating Loss'):
            if bsz*(i+1) > len(dataset):
                # Final batch
                remainder = len(dataset)%bsz
                input_ids = self.tokenizer(src_list[-remainder:], return_tensors='pt', padding=True).input_ids.to(self.model.device)
                target_ids = self.tokenizer(trg_list[-remainder:], return_tensors='pt', padding=True).input_ids.to(self.model.device)
            else:
                input_ids = self.tokenizer(src_list[bsz*i:bsz*(i+1)], return_tensors='pt', padding=True).input_ids.to(self.model.device)
                target_ids = self.tokenizer(trg_list[bsz*i:bsz*(i+1)], return_tensors='pt', padding=True).input_ids.to(self.model.device)

            with torch.no_grad():
                outs = self.model(input_ids, labels=target_ids)
                loss = outs.loss
                cum_loss += loss

            batches += 1

        return cum_loss/batches
        
    
    def find_n_closest_matches(self, query, references, n):
        distances = [(ref, levenshtein.distance(query, ref)) for ref in references]
        distances.sort(key=lambda x: x[1])
        return [d[0] for d in distances[:n]]
    
    def translate(self, source, max_len=25, temperature=1.0, top_k=1):
        """
        Legacy code for the time being.
        """
        self.model.eval()
        src = self.format_input(source)
        inp = src['input_ids']
        if temperature == 0.0:
            # Greedy Decoding
            output = self.model.generate(inp, num_beams=1, max_new_tokens=max_len, do_sample=False)
        else:
            output = self.model.generate(inp, max_new_tokens=max_len, do_sample=True, top_k=top_k, temperature=temperature)
        trg = self.tokenizer.decode(output.squeeze(), skip_special_tokens=True)
        return self.format_output(trg), output

    def translate_plus(self, source, max_len=25, top_k=1, temperature=1):
        """
        Get the output scores along with the attention and the target.
        """
        self.model.eval()
        src = self.format_input(source)
        inp = src['input_ids']
        if temperature==0.0:
            output = self.model.generate(inp, max_new_tokens=max_len, num_beams=1, do_sample=False, return_dict_in_generate=True, output_scores=True)
        else:
            output = self.model.generate(inp, max_new_tokens=max_len, do_sample=True, top_k=top_k, temperature=temperature, return_dict_in_generate=True, output_scores=True)
        trg = self.tokenizer.decode(output.sequences.squeeze(), skip_special_tokens=True)
        out_dist = torch.stack(output.scores)
        return self.format_output(trg), output.sequences, F.softmax(out_dist, dim=-1).detach().cpu().numpy().squeeze()
    
    def sample_translations(self, source, max_len=25, n=5, mode='greedy', **kwargs):
        self.model.eval()
        src = self.format_input(source)
        inp = src['input_ids']
        
        # Get the kwargs
        temperature = kwargs.get('temperature', 1.0)
        top_k = kwargs.get('top_k', None)
        top_p = kwargs.get('top_p', None)
        num_beams = kwargs.get('num_beams', None)
        no_repeat_ngram_size = kwargs.get('no_repeat_ngram_size', 4)
        num_beam_groups = kwargs.get('num_beam_groups', None)
        diversity_penalty = kwargs.get('diversity_penalty', None)
        penalty_alpha =  kwargs.get('penalty_alpha', None)
        typical_mass = kwargs.get('typical_mass', None)
        min_tokens_to_keep = kwargs.get('min_tokens_to_keep', 1)
        
        if mode=='greedy':
            # Based on HF Definition of Greedy Decoding
            outputs = self.model.generate(inp, max_new_tokens=max_len, do_sample=False, num_beams=1, num_return_sequences=n)
        if mode=='ancestral':
            # Based on HF Definition of Multinomial Sampling
            outputs = self.model.generate(inp, max_new_tokens=max_len, do_sample=True, num_beams=1, num_return_sequences=n)
        elif mode=='top_k':
            # Top-k with temperature sampling
            assert top_k is not None
            outputs = self.model.generate(inp, max_new_tokens=max_len, do_sample=True, top_k=top_k, temperature=temperature, num_return_sequences=n)
        elif mode=='top_p':
            # Top-p with temperature sampling
            assert top_p is not None
            outputs = self.model.generate(inp, max_new_tokens=max_len, do_sample=True, top_p=top_p, temperature=temperature, top_k=0, num_return_sequences=n)
        elif mode=='beam':
            assert num_beams is not None
            # Do Deterministic Beam Search
            outputs = []
            while len(outputs) < n:
                # Keep generating sequences 
                seed = random.randint(0, 2**32-1)
                torch.manual_seed(seed)
                outputs.extend(self.model.generate(inp, max_new_tokens=max_len, num_return_sequences=min(n - len(outputs), num_beams), 
                                       num_beams=num_beams, no_repeat_ngram_size=no_repeat_ngram_size, do_sample=False))
        elif mode=='stochastic_beam':
            assert num_beams is not None
            # Do Beam Search with Multinomial
            outputs = []
            while len(outputs) < n:
                # Keep generating sequences 
                seed = random.randint(0, 2**32-1)
                torch.manual_seed(seed)
                outputs.extend(self.model.generate(inp, max_new_tokens=max_len, num_return_sequences=min(n - len(outputs), num_beams), 
                                       num_beams=num_beams, no_repeat_ngram_size=no_repeat_ngram_size, do_sample=True))
        elif mode=='diverse_beam':
            # From the paper described here: https://arxiv.org/pdf/1610.02424.pdf
            assert num_beam_groups is not None
            outputs = []
            while len(outputs) < n:
                # Keep generating sequences 
                seed = random.randint(0, 2**32-1)
                torch.manual_seed(seed)
                outputs.extend(self.model.generate(inp, max_new_tokens=max_len, num_return_sequences=min(n - len(outputs), num_beams), 
                                       num_beams=num_beams, num_beam_groups=num_beams, no_repeat_ngram_size=4, diversity_penalty=diversity_penalty))
        elif mode=='contrastive':
            assert penalty_alpha is not None
            assert top_k is not None
            outputs = self.model.generate(inp, max_new_tokens=max_len, do_sample=True, penalty_alpha=penalty_alpha, top_k=top_k, num_return_sequences=n)
        elif mode=='typical':
            assert typical_mass is not None
            logits_warper = [TypicalLogitsWarper(mass=typical_mass, filter_value=-float("Inf"), min_tokens_to_keep=min_tokens_to_keep)]
            outputs = self.model.generate(inp, max_new_tokens=max_len, do_sample=True, num_return_sequences=n, logits_processor=logits_warper)
        
        translations = [self.tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
        translations = [self.format_output(translation) for translation in translations]
        return translations

    def precision_recall_f1_at_k(self, source, k=100, max_len=25, mode=None, **kwargs):
        correct = []
        edit_distances = []
        if hasattr(source, 'peptide'):
            ref_trgs = list(set([tcr.cdr3b for tcr in source.tcrs]))
        elif hasattr(source, 'cdr3b'):
            ref_trgs = list(set([pMHC.peptide+pMHC.pseudo for pMHC in source.pMHCs]))
        
        translations = self.sample_translations(source, n=k, max_len=max_len, mode=mode, **kwargs)
        for t in translations:
            closest_match = self.find_n_closest_matches(t, ref_trgs, 1)[0]
            edit_distances += [levenshtein.distance(t, closest_match)]
            if t in ref_trgs:
                correct.append(t)
                
        precision = len(correct)/len(translations)
        recall = len(set(correct))/min(k, len(ref_trgs))
        f1 = 0.0 if precision+recall==0 else (2*precision*recall/(precision + recall))
        mean_edit_distance = sum(edit_distances)/k
        return precision, recall, f1, mean_edit_distance
    
    def get_prf1_score(self, dataset, k=100, max_len=25, summary='mean', mode='top_k', **kwargs):
        src_list = list(set(dataset.pMHCs)) if dataset.source == 'pmhc' else list(set(dataset.tcrs))
        precisions = []
        recalls = []
        f1s = []
        meds = []
        for source in tqdm(src_list, desc='Precision, Recall, F1 @ K'):
            precision, recall, f1, med = self.precision_recall_f1_at_k(source, k=k, max_len=max_len, mode=mode, **kwargs)
            precisions += [precision]
            recalls += [recall]
            f1s += [f1]
            meds += [med]
            
        if summary == 'mean':
            return np.mean(precisions), np.mean(recalls), np.mean(f1s), np.mean(meds)
        elif summary == 'median':
            return np.median(precisions), np.median(recalls), np.median(f1s), np.median(meds)
        elif summary == 'geommean':
            return gmean(precisions), gmean(recalls), gmean(f1s), gmean(meds)
        return precisions, recalls, f1s, meds
    
    def get_empirical_distribution(self, source):
        # Get the vocabulary size.
        vocab_size = len(self.tokenizer)
        # Find the maximum length of the sequences
        
        seqs = [source.cdr3b for source in source.tcrs]
        
        sequences = [self.tokenizer.encode(seq) for seq in seqs]
        
        max_len = max([len(s) for s in sequences])

        # Pad the sequences.
        padded_sequences = []
        for s in sequences:
            # Add padding (0s) to the end of the sequence as necessary.
            padding_length = max_len - len(s)
            padded_sequence = np.pad(s, (0, padding_length), 'constant')
            padded_sequences.append(padded_sequence)

        # Convert the list of padded sequences to a 2D numpy array.
        sequences_array = np.vstack(padded_sequences)

        # Generate empirical distribution for each time step.
        distribution = np.zeros((max_len, vocab_size))

        for t in range(max_len):
            # Extract the tokens at the current time step.
            tokens = sequences_array[:, t]

            # Count the occurrence of each token.
            token_counts = pd.value_counts(tokens, normalize=True).to_dict()
            for k,v in token_counts.items():
                distribution[t][k] = v

        return distribution

    def generate_distance_matrix(self, feature_df, normalize=True, method='euclidean'):
        """
        Generate a distance matrix from a feature dataframe.
        """
        # Normalize the feature dataframe of amino acids
        df_normalized = feature_df.apply(zscore) if normalize else feature_df
        df_normalized = df_normalized.fillna(0.0)
        # Add in the special tokens
        vocab = [token for token in self.tokenizer.get_vocab().keys() if '<' not in token]
        special_tokens = [token for token in vocab if token not in AA_VOCABULARY]
        for index, st in enumerate(special_tokens):
            df_normalized.loc[st] = np.zeros(len(df_normalized.columns))
            df_normalized.loc[st][index] = 1.0
        
        # Reindex the dataframe to match output distribution
        df_mat = df_normalized.reindex(vocab)
        
        # Calculate the distance matrix
        # pdist isn't working for some reason 
        distance_mat = np.zeros((len(self.tokenizer), len(self.tokenizer)))
        # Run in this order to so it matches the output distribution
        for i, query in enumerate(vocab):
            for j, comparison in enumerate(vocab):
                
                if method == 'euclidean':
                    distance_mat[i, j] = euclidean(df_mat.loc[query].values, df_mat.loc[comparison].values)
                elif method == 'cosine':
                    distance_mat[i, j] = cosine(df_mat.loc[query].values, df_mat.loc[comparison].values)
                else:
                    raise ValueError(f'Not-implemented distance metric: {method}')

        return distance_mat

    def get_and_pad_distributions(self, source, top_k=28, temperature=1.0):
        """
        Given a source. Calculate the empirical distribution over the 
        cognate -tope space and the model's predicted distribution.
        
        Pad the shorter sequence to that of the longer and return both.
        
        Temp - 0 activates the greedy decoding to get non-stochastic outputs.
        """
        # Get the empirical distribution of the cognate -tope
        empirical_dist = self.get_empirical_distribution(source)

        # Calculate the model distribution
        _, _, output = self.translate_plus(source, max_len=empirical_dist.shape[0], top_k=top_k, temperature=temperature)
        model_dist = np.array(output)

        #Re-normalize the model distribution due to floating point errors
        model_dist = model_dist / np.sum(model_dist, axis=-1, keepdims=True)

        max_seq_len = max(empirical_dist.shape[0], model_dist.shape[0])
        pad_token_id = 0  # Padding ID is set to 0

        # Create the padding array with 1 at padding positions and 0 elsewhere
        padding_array = np.zeros_like(empirical_dist, dtype=np.float32)
        padding_array[:, pad_token_id] = 1.0

        # Pad the distributions by concatenating with the padding array
        empirical_dist_padded = np.concatenate((empirical_dist, padding_array[:max_seq_len-empirical_dist.shape[0]]), axis=0)
        model_dist_padded = np.concatenate((model_dist, padding_array[:max_seq_len-model_dist.shape[0]]), axis=0)
        
        return empirical_dist_padded, model_dist_padded
    
    def wassertsein_distance(self, source, distance_matrix):
        """
        Calculate the wassertsein distance between the models output distribution and the target
        which in this case is the empirical distribution of the cognate -tope.
        """
        empirical_dist_padded, model_dist_padded = self.get_and_pad_distributions(source)
        # Calculate the wassertsein distance for each time step
        wassertsein_distances = []
        for t in range(empirical_dist_padded.shape[0]):
            wassertsein_distances.append(emd2(model_dist_padded[t], empirical_dist_padded[t], distance_matrix, check_marginals=False))
        
        # Return the average across timepoionts
        return np.mean(wassertsein_distances)
    
    def get_wassertsein_score(self, dataset, distance_matrix, summary='mean'):
        scores = []
        for source in tqdm(list(set(dataset.pMHCs)), desc='Wasserstein Distance Metric'):
            scores.append(self.wassertsein_distance(source, distance_matrix))
        if summary == 'mean':
            return np.mean(scores)
        elif summary == 'median':
            return np.median(scores)
        elif summary == 'geommean':
            return gmean(scores)
        return scores
    
    def get_dataset_bleu(self, dataset, max_references=20, max_ngram=4, summary='mean'):
        """Calculate the BLEU score for a TCRpMHC dataset."""
        translations = []
        references = []
        
        pmhcs = list(set(dataset.pMHCs))
        for pmhc in tqdm(pmhcs, desc='Char-BLEU'):
            ref_cdr3bs = [tcr.cdr3b for tcr in pmhc.tcrs]
            max_len = 24
            pred = self.sample_translations(pmhc, max_len, n=1, mode='greedy')[0]
            if not isinstance(pred, str):
                translation = ''.join(pred[1:-1])
            else:
                translation = pred
            #translations.append(translation)
            #references.append(self.find_n_closest_matches(translation, ref_cdr3bs, n=max_references))
            translations.append(list(translation))
            references.append([list(x) for x in self.find_n_closest_matches(translation, ref_cdr3bs, n=max_references)])
            
        return corpus_bleu(references, translations)
    
    def get_metrics(self, dataset, k=100, max_len=25, summary='mean', mode='top_k', **kwargs):
        metrics = {'char-bleu':-100, 'precision': -100, 'recall': -100, 'f1': -100, 'edit': -100, 'wassertsein': -100}
        
        bleu = self.get_dataset_bleu(dataset, max_references=20, max_ngram=4, summary=summary)
        metrics['char-bleu'] = bleu
        
        precisions, recalls, f1s, meds = self.get_prf1_score(dataset, k=k, max_len=max_len, summary=summary, mode=mode, **kwargs)
        metrics['precision'] = precisions
        metrics['recall'] = recalls
        metrics['f1'] = f1s
        metrics['edit'] = meds
        
        wassertsein = self.get_wassertsein_score(dataset, summary=summary, distance_matrix=self.generate_distance_matrix(PROPERTIES_TABLE))
        metrics['wassertsein'] = wassertsein
            
        return metrics
    
    def fine_grained_metrics(self, dataset, slice_on='Allele', k=100, max_len=25, top_k=1, temperature=1.0, num_beams=0, summary='mean', mode='top_k'):
        df = dataset.to_df()
        fine_grained_metrics = {}
        # Get the different groups
        groups = df[slice_on].unique()
        # Slice the dataframe into groups
        df_list = [df[df[slice_on]==group] for group in groups]
        
        # Create dataset objects for each group
        dset_list = [TCRpMHCdataset(source=dataset.source, target=dataset.target, use_pseudo=dataset.use_pseudo, 
                                    use_cdr3=dataset.use_cdr3, use_mhc=dataset.use_mhc) for _ in df_list]

        # Load the data into the dataset objects
        for i, daf in enumerate(df_list):
            dset_list[i].load_data_from_df(daf)

        # Get the metrics for each group
        for i, group in enumerate(groups):
            fine_grained_metrics[group] = self.get_metrics(dset_list[i], k=k, max_len=max_len, top_k=top_k, temperature=temperature, summary=summary)
            fine_grained_metrics[group]['size'] = len(set(dset_list[i].pMHCs)) if dataset.source == 'pmhc' else len(set(dset_list[i].tcrs))
        return fine_grained_metrics


    def _logo_df_from_dist(self, distribution, top_k):
        """
        Creates a Sequence Logo Plot for the given distribution using logomaker.

        Parameters:
        distribution (numpy.ndarray): A 2D array of shape (sequence_length, alphabet_size) representing the distribution.
        top_k (int): The number of top characters to display in the logo plot.
        """
        sequence_length, vocab_size = distribution.shape

        # Create a dataframe for the logo plot
        alphabet = list(self.tokenizer.get_vocab().keys())

        special_token_normalization = {
                                    '[SOS]':'^',
                                    '[EOS]':'<',
                                    '[UNK]':'?',
                                    '[SEP]':'*',
                                    '[PAD]':'|',
                                    '[CLS]':'>',
                                    '[MASK]':'@',
                                    '[TCR]':'$',
                                    '[PMHC]':'!'
                                }

        single_char_special_tokens = [special_token_normalization[char] for char in alphabet if char in special_token_normalization.keys()]
        standard_tokens = [char for char in alphabet if char not in special_token_normalization.keys()]
        normed_alphabet = single_char_special_tokens + standard_tokens
        logo_df = pd.DataFrame(columns=normed_alphabet, index=range(sequence_length)).fillna(0)

        # Populate the logo dataframe with the distribution values
        for position in range(sequence_length):
            sorted_probs = sorted(enumerate(distribution[position]), key=lambda x: x[1], reverse=True)[:top_k]
            for char_idx, prob in sorted_probs:
                char = logo_df.columns[char_idx]
                if char in list(special_token_normalization.values()):
                    continue
                else:
                    logo_df.at[position, char] = prob
        return logo_df
    
    def emp_vs_pred_logo(self, source, model_name, top_k=4, **kwargs):
        emp_dist, pred_dist = self.get_and_pad_distributions(source, top_k=20, temperature=0.0) # Don't sample return learned Distro
        emp_dist = emp_dist[1:]
        pred_dist = pred_dist[1:]
        length = emp_dist.shape[0]
        emp_logo_df = self._logo_df_from_dist(emp_dist, top_k)
        # Suppress warnings from logomaker.Logo
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            emp_logo = logomaker.Logo(emp_logo_df, color_scheme='weblogo_protein')
            plt.xlabel('Position')
            plt.xlim(-1,length)
            plt.ylabel('Frequency')
            plt.title(f'Empirical Logo Plot')
            plt.show()
        model_logo_df = self._logo_df_from_dist(pred_dist, top_k)
        # Create the sequence logo plot using logomaker
        # Suppress warnings from logomaker.Logo
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            model_logo = logomaker.Logo(model_logo_df, color_scheme='weblogo_protein')
            plt.xlabel('Position')
            plt.xlim(-1,length)
            plt.ylabel('Frequency')
            plt.title(f'{model_name} Prediction Logo Plot')
            plt.show()
            
from transformers import LogitsWarper

class TypicalLogitsWarper(LogitsWarper):
    """
    Code taken directly from the Typical Sampling Codebase. 
    """
    def __init__(self, mass: float = 0.9, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):

        self.filter_value = filter_value
        self.mass = mass
        self.min_tokens_to_keep = min_tokens_to_keep

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:

        # calculate entropy
        normalized = torch.nn.functional.log_softmax(scores, dim=-1)
        p = torch.exp(normalized)
        ent = -(normalized * p).nansum(-1, keepdim=True)

        # shift and sort
        shifted_scores = torch.abs((-normalized) - ent)
        sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
        sorted_logits = scores.gather(-1, sorted_indices)
        cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)

        # Remove tokens with cumulative mass above the threshold
        last_ind = (cumulative_probs < self.mass).sum(dim=1)
        last_ind[last_ind < 0] = 0
        sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
        if self.min_tokens_to_keep > 1:
            # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
            sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)

        scores = scores.masked_fill(indices_to_remove, self.filter_value)
        return scores
    
