import torch
import zlib
import tqdm 
from misc.utils import load_model
import sys
sys.path.append("NL-Augmenter")
"""
From https://github.com/pratyushmaini/llm_dataset_inference.
"""

from nlaugmenter.transformations.butter_fingers_perturbation.transformation import ButterFingersPerturbation
from nlaugmenter.transformations.random_deletion.transformation import RandomDeletion
from nlaugmenter.transformations.synonym_substitution.transformation import SynonymSubstitution
from nlaugmenter.transformations.back_translation.transformation import BackTranslation
from nlaugmenter.transformations.change_char_case.transformation import ChangeCharCase
from nlaugmenter.transformations.whitespace_perturbation.transformation import WhitespacePerturbation
from nlaugmenter.transformations.underscore_trick.transformation import UnderscoreTrick
from nlaugmenter.transformations.style_paraphraser.transformation import StyleTransferParaphraser
from nlaugmenter.transformations.punctuation.transformation import PunctuationWithRules

def aug_generator(text_list, aug_style):
    if aug_style == "butter_fingers":
        t1 = ButterFingersPerturbation(max_outputs=1)
        # return [t1.generate(text_list[i], prob = 0.1)[0] for i in range(len(text_list))]
        return [t1.generate(text_list[i])[0] for i in range(len(text_list))]
    elif aug_style == "random_deletion":
        t1 = RandomDeletion(prob=0.25)
        return [t1.generate(text_list[i])[0] for i in range(len(text_list))]
    elif aug_style == "synonym_substitution":
        syn = SynonymSubstitution(max_outputs=1, prob = 0.2)
        return [syn.generate(text_list[i])[0] for i in range(len(text_list))]
    elif aug_style == "back_translation":
        trans = BackTranslation()
        return [trans.generate(text_list[i])[0] for i in range(len(text_list))]
    elif aug_style == "change_char_case":
        t1 = ChangeCharCase()
        # return [t1.generate(text_list[i], prob = 0.25)[0] for i in range(len(text_list))]
        return [t1.generate(text_list[i])[0] for i in range(len(text_list))]
    elif aug_style == "whitespace_perturbation":
        t1 = WhitespacePerturbation()
        # return [t1.generate(text_list[i], prob = 0.25)[0] for i in range(len(text_list))]
        return [t1.generate(text_list[i])[0] for i in range(len(text_list))]
    elif aug_style == "underscore_trick":
        t1 = UnderscoreTrick(prob = 0.25)
        return [t1.generate(text_list[i])[0] for i in range(len(text_list))]
    elif aug_style == "style_paraphraser":
        t1 = StyleTransferParaphraser(style = "Basic", upper_length="same_5")
        return [t1.generate(text_list[i])[0] for i in range(len(text_list))]
    elif aug_style == "punctuation_perturbation":
        normalizations = ['remove_extra_white_spaces', ('replace_characters', {'characters': 'was', 'replacement': 'TZ'}),
                      ('replace_emojis', {'replacement': 'TESTO'})]
        punc = PunctuationWithRules(rules=normalizations)
        return [punc.generate(text_list[i])[0] for i in range(len(text_list))]
    else:
        raise ValueError("Augmentation style not found. Please check the available styles.")

def generate_perturbations(text_list):
    augmentation_styles = ["synonym_substitution", "butter_fingers", "random_deletion", "change_char_case", "whitespace_perturbation", "underscore_trick"]
    all_augmented = {}
    for style in augmentation_styles:
        aug_list = aug_generator(text_list, style)
        all_augmented[style] = aug_list
    return all_augmented

loss_fct = torch.nn.CrossEntropyLoss(reduction="none")

def raw_values_batch(model, tokenizer, example_list):
    '''
    This function takes a list of strings and returns the loss values for each token in the string
    input:
        model: the language model
        tokenizer: the tokenizer
        example_list: a list of strings

    output:
        loss_list:  a list of lists. 
                    Each list contains the loss values for each token in the string

    '''
    max_length = tokenizer.model_max_length
    input_ids = tokenizer(example_list, return_tensors="pt", padding=True, truncation=True, max_length=max_length)
    
    if model.device.type == "cuda":
        input_ids = {k: v.cuda() for k, v in input_ids.items()}
    
    # forward pass with no grad
    with torch.no_grad():
        outputs = model(**input_ids)
    
    labels = input_ids["input_ids"]
    labels[labels == tokenizer.pad_token_id] = -100

    # shift the labels
    shifted_labels = labels[..., 1:].contiguous().view(-1)

    # shift the logits
    shifted_logits = outputs.logits[..., :-1, :].contiguous()
    shifted_logits = shifted_logits.view(-1, shifted_logits.size(-1))

    loss = loss_fct(shifted_logits, shifted_labels)

    # reshape the loss to the original shape
    loss = loss.view(labels.size(0), labels.size(1) - 1)

    # now remove the 0 values and create loss as a list of lists
    loss_list = loss.tolist()
    
    for i,entry in enumerate(loss_list):
        # remove the 0 values
        entry = [x for x in entry if x != 0]
        loss_list[i] = entry
    
    # if any list is empty, remove it
    loss_list = [entry for entry in loss_list if len(entry) > 0]

    return loss_list

def raw_values(model, tokenizer, example_list, batch_size = 32):
    '''
    This function takes a list of strings and returns the loss values for each token in the string
    input:
        model: the language model
        tokenizer: the tokenizer
        example_list: a list of strings
        batch_size: the batch size
    output:
        loss_list:  a list of lists. 
                    Each list contains the loss values for each token in the string
    '''
    loss_list = []
    for i in tqdm.tqdm(range(0, len(example_list), batch_size)):
        batch = example_list[i:i + batch_size]
        loss_list += raw_values_batch(model, tokenizer, batch)
    return loss_list

def k_min_probs(loss_list, k=0.05, reverse=False):
    '''
    This function takes a list of lists and returns the ppl of the k fraction smallest values in each list
    input:
        loss_list: a list of lists
        k: the fraction of smallest values to return

    output:
        k_min_prob: the mean probability of the k fraction smallest values in each list
    '''
    # sort each list. if reverse is true, sort in reverse order (descending)
    sorted_list = [sorted(entry) for entry in loss_list]
    if reverse:
        sorted_list = [entry[::-1] for entry in sorted_list]
    k_min_probs = []
    for entry in sorted_list:
        # get the k fraction smallest values
        num_values = max(1, int(len(entry)*k))
        k_min = entry[:num_values]
        k_min_prob = sum(k_min)/len(k_min)
        k_min_probs.append(k_min_prob)
    return k_min_probs

def perplexity(loss_list):
    '''
    This function takes a list of lists and returns the perplexity of each list
    input:
        loss_list: a list of lists

    output:
        perplexity: the perplexity of each list
    '''
    perplexity = []
    for entry in loss_list:
        # calculate the mean of each list
        mean = sum(entry)/len(entry)
        # ppl is the exponent of the mean
        ppl = torch.exp(torch.tensor(mean)).item()
        perplexity.append(ppl)

    return perplexity

def zlib_ratio(loss_list, example_list):
    '''
    This function takes a list of lists and returns the ratio of the mean loss to the zlib compression of the input string
    input:
        loss_list: a list of lists
        example_list: a list of strings

    output:
        zlib_ratio: the ratio of the mean loss to the zlib compression of the input string
    '''
    zlib_ratios = []
    for i,entry in enumerate(loss_list):
        # calculate the mean of each list
        mean = sum(entry)/len(entry)
        # calculate the zlib compression of the input string
        zlib_entropy = len(zlib.compress(bytes(example_list[i], 'utf-8')))
        # calculate the ratio
        ratio = mean/zlib_entropy
        zlib_ratios.append(ratio)
    return zlib_ratios

def ppl_ratio(loss_list, reference_list):
    '''
    This function takes a list of lists and returns the ratio of the mean loss to the perplexity of a reference model
    input:
        loss_list: a list of lists
        reference_list: a list of perplexity values, or a list of lists of loss values

    output:
        ratio: the ratio of the mean loss to the perplexity of the reference model
    '''
    ratios = []
    for (entry, entry_ref) in zip(loss_list, reference_list):
        # calculate the mean of each list
        mean_model = sum(entry)/len(entry)
        if type(entry_ref) == list:
            mean_ref = sum(entry_ref)/len(entry_ref)
        else:
            mean_ref = entry_ref
        # calculate the ratio
        ratio = mean_model/mean_ref
        ratios.append(ratio)

    return ratios

def ppl_diff(loss_list, reference_list):
    '''
    This function takes a list of lists and returns the difference of the mean loss to the perplexity of a reference model
    input:
        loss_list: a list of lists
        reference_list: a list of perplexity values, or a list of lists of loss values

    output:
        diff: the difference of the mean loss to the perplexity of the reference model
    '''
    diffs = []
    for (entry, entry_ref) in zip(loss_list, reference_list):
        # calculate the mean of each list
        mean_model = sum(entry)/len(entry)
        if type(entry_ref) == list:
            mean_ref = sum(entry_ref)/len(entry_ref)
        else:
            mean_ref = entry_ref
        # calculate the ratio
        diff = mean_model - mean_ref
        diffs.append(diff)

    return diffs


def perturbation_ratio(model, tokenizer, dataset, loss_list, batch_size = 32):
    '''
    Dataset({
        features: ['text', 'synonym_substitution', 'butter_fingers', 'random_deletion', 'change_char_case', 'whitespace_perturbation', 'underscore_trick'],
        num_rows: 2000
    })
    '''
    result = {}
    for perturbation in dataset.column_names:
        if perturbation != "text" and perturbation != "label":
            perturbed_list = dataset[perturbation]
            perturbed_loss_list = raw_values(model, tokenizer, perturbed_list, batch_size = batch_size)
            ratios = ppl_ratio(loss_list, perturbed_loss_list)
            diffs = ppl_diff(loss_list, perturbed_loss_list)
            result[f"ppl_ratio_{perturbation}"] = ratios
            result[f"ppl_diff_{perturbation}"] = diffs
    return result

    


reference_model_registry = {
    # "silo":"kernelmachine/silo-pdswby-1.3b",
    "tinystories-33M": "roneneldan/TinyStories-33M",
    "tinystories-1M": "roneneldan/TinyStories-1M",
    "phi-1_5": "microsoft/phi-1_5",
    # "phi-1": "microsoft/phi-1",
}



def aggregate_metrics(model, tokenizer, dataset, metric_list, args, batch_size = 32):
    '''
    This function takes a list of strings and returns a dictionary of metrics
    input:
        model: the language model
        tokenizer: the tokenizer
        dataset: a huggingface dataset, with key "text" containing the strings
        metric_list: a list of metrics to calculate

    output:
        metrics: a dictionary of metrics
    '''
    example_list = dataset["text"]
    loss_list = raw_values(model, tokenizer, example_list, batch_size = batch_size)
    
    metrics = {}
    if "ppl" in metric_list:
        metrics["ppl"] = perplexity(loss_list)
    if "k_min_probs" in metric_list:
        for k in [0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6]:
            metrics[f"k_min_probs_{k}"] = k_min_probs(loss_list, k=k)
    if "k_max_probs" in metric_list:
        for k in [0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6]:
            metrics[f"k_max_probs_{k}"] = k_min_probs(loss_list, k=k, reverse=True)
    if "zlib_ratio" in metric_list:
        metrics["zlib_ratio"] = zlib_ratio(loss_list, example_list)

    if "perturbation" in metric_list:
        ratios_dict = perturbation_ratio(model, tokenizer, dataset, loss_list, batch_size)
        metrics.update(ratios_dict)

    if "reference_model" in metric_list:
        # for computation efficiency, we now enforce that the reference model should already have been run and its ppl saved
        # for model_name in reference_model_registry:
        #     hf_path = reference_model_registry[model_name]
        #     with open(f"results/{hf_path}/{args.dataset_name}_{args.split}_metrics.json", 'r') as f:
        #         metrics_train = json.load(f)
        #         ref_ppl = metrics_train["ppl"]
        #         ref_ratios = ppl_ratio(loss_list, ref_ppl)
        #         ref_diffs = ppl_diff(loss_list, ref_ppl)
        #         metrics[f"ref_ppl_ratio_{model_name}"] = ref_ratios
        #         metrics[f"ref_ppl_diff_{model_name}"] = ref_diffs

        
        # old code to run reference models on the fly

        for model_name in reference_model_registry:
            hf_path = reference_model_registry[model_name]
            model, tokenizer = load_model(hf_path)
            
            reference_list = raw_values(model, tokenizer, example_list, batch_size = batch_size)
            metrics[f"ref_ppl_ratio_{model_name}"] = ppl_ratio(loss_list, reference_list)
            metrics[f"ref_ppl_diff_{model_name}"] = ppl_diff(loss_list, reference_list)
        

    return metrics