import torch
import zlib
import tqdm, json 
from sentence_transformers import SentenceTransformer
import numpy as np
import math

def k_min_probs(loss_list, k=0.05, reverse=False):
    '''
    This function takes a list of lists and returns the ppl of the k fraction smallest values in each list
    input:
        loss_list: a list of lists
        k: the fraction of smallest values to return

    output:
        k_min_prob: the mean probability of the k fraction smallest values in each list
    '''
    # sort each list. if reverse is true, sort in reverse order (descending)
    sorted_list = [sorted(entry) for entry in loss_list]

    if reverse:
        sorted_list = [entry[::-1] for entry in sorted_list]
    # else:
    #     # remove 25% of the smallest values, as the median number of 1 is close to ~28%
    #     sorted_list = [entry[int(len(entry)*0.25):] for entry in sorted_list]
    k_min_probs = []
    for entry in sorted_list:
        # get the k fraction smallest values
        num_values = max(1, int(len(entry)*k))
        k_min = entry[:num_values]
        k_min_prob = sum(k_min)/len(k_min)
        k_min_probs.append(k_min_prob)
    return k_min_probs

def k_strip_probs(loss_list, k=0.05):
    '''
    This function takes a list of lists and returns the ppl of the k fraction smallest values in each list
    input:
        loss_list: a list of lists
        k: the fraction of smallest values to return

    output:
        k_min_prob: the mean probability of the k fraction smallest values in each list
    '''
    # sort each list.
    sorted_list = [sorted(entry) for entry in loss_list]
    k_strip_probs = []
    for entry in sorted_list:
        # get the k fraction smallest values
        num_values = max(1, int(len(entry)*k))
        k_strip = entry[num_values:-num_values] if len(entry) > 2*num_values else entry
        k_strip_prob = sum(k_strip)/len(k_strip)
        k_strip_probs.append(k_strip_prob)
    return k_strip_probs

def perplexity(loss_list):
    '''
    This function takes a list of lists and returns the perplexity of each list
    input:
        loss_list: a list of lists

    output:
        perplexity: the perplexity of each list
    '''
    perplexity = []
    for entry in loss_list:
        # calculate the mean of each list
        mean = sum(entry)/len(entry)
        # ppl is the exponent of the mean
        ppl = torch.exp(torch.tensor(mean)).item()
        perplexity.append(ppl)

    return perplexity

def PETAL(loss_list):
    perplexity = []
    for entry in loss_list:
        mean = sum(entry)/len(entry)
        perplexity.append(mean)

    return perplexity

def zlib_ratio(loss_list, example_list):
    '''
    This function takes a list of lists and returns the ratio of the mean loss to the zlib compression of the input string
    input:
        loss_list: a list of lists
        example_list: a list of strings

    output:
        zlib_ratio: the ratio of the mean loss to the zlib compression of the input string
    '''
    zlib_ratios = []
    for i,entry in enumerate(loss_list):
        # calculate the mean of each list
        mean = sum(entry)/len(entry)
        # calculate the zlib compression of the input string
        zlib_entropy = len(zlib.compress(bytes(example_list[i], 'utf-8')))
        # calculate the ratio
        ratio = mean/zlib_entropy
        zlib_ratios.append(ratio)
    return zlib_ratios

def ppl_ratio(loss_list, reference_list):
    '''
    This function takes a list of lists and returns the ratio of the mean loss to the perplexity of a reference model
    input:
        loss_list: a list of lists
        reference_list: a list of perplexity values, or a list of lists of loss values

    output:
        ratio: the ratio of the mean loss to the perplexity of the reference model
    '''
    ratios = []
    for (entry, entry_ref) in zip(loss_list, reference_list):
        # calculate the mean of each list
        mean_model = sum(entry)/len(entry)
        if type(entry_ref) == list:
            mean_ref = sum(entry_ref)/len(entry_ref)
        else:
            mean_ref = entry_ref
        # calculate the ratio
        ratio = mean_model/mean_ref
        ratios.append(ratio)

    return ratios

def ppl_diff(loss_list, reference_list):
    '''
    This function takes a list of lists and returns the difference of the mean loss to the perplexity of a reference model
    input:
        loss_list: a list of lists
        reference_list: a list of perplexity values, or a list of lists of loss values

    output:
        diff: the difference of the mean loss to the perplexity of the reference model
    '''
    diffs = []
    for (entry, entry_ref) in zip(loss_list, reference_list):
        # calculate the mean of each list
        mean_model = sum(entry)/len(entry)
        if type(entry_ref) == list:
            mean_ref = sum(entry_ref)/len(entry_ref)
        else:
            mean_ref = entry_ref
        # calculate the ratio
        diff = mean_model - mean_ref
        diffs.append(diff)

    return diffs


def get_losses_from_dict(raw_dict, prefix=None, method="raw", ref_model_dict=None):
    """
    Returns the similarity list from raw_dict for a given prefix and method.
    If method is "raw", returns raw_dict["similarity"] or raw_dict[f"{prefix}_similarity"].
    If method is "ref", returns estimated losses using PETAL method, supports list of lists.
    If method is "sigmoid", returns estimated losses using sigmoid, supports list of lists.
    Assumes input is always a list of lists.
    """
    if method == "raw":
        key = "similarity" if prefix is None else f"{prefix}_similarity"
        raw_similarities = raw_dict[key]
        return [[1-s for s in row] for row in raw_similarities]
    elif method == "ref": # exact copy of PETAL method, always list of lists
        if ref_model_dict is None:
            raise ValueError("ref_model_dict must be provided when method is 'ref'")
        def to_log_similarities(similarities):
            return [[math.log(s if s >= 1e-10 else 1e10) for s in sublist] for sublist in similarities]

        ref_sim_key = "similarity" if prefix is None else f"{prefix}_similarity"
        ref_prob_key = "probability" if prefix is None else f"{prefix}_probability"
        sim_key = "similarity" if prefix is None else f"{prefix}_similarity"

        ref_sim = to_log_similarities(ref_model_dict[ref_sim_key])
        ref_prob = ref_model_dict[ref_prob_key]
        semantic_similarity = to_log_similarities(raw_dict[sim_key])

        all_prob_estimated = []
        for ref_sim_row, ref_prob_row, semantic_row in zip(ref_sim, ref_prob, semantic_similarity):
            slope, intercept = np.polyfit(ref_sim_row, ref_prob_row, 1)
            prob_row = [i * slope + intercept for i in semantic_row]
            all_prob_estimated.append(prob_row)
            
        # Convert to losses: multiply by -1 to invert ranking for consistency
        return [[-x for x in row] for row in all_prob_estimated]
    elif method == "sigmoid":
        key = "similarity" if prefix is None else f"{prefix}_similarity"
        semantic_similarity = raw_dict[key]
        def sigmoid(x):
            return 1 / (1 + math.exp(-x))

        def log_loss(prob):
            return -math.log(prob + 1e-8)

        loss_estimated = []
        for row in semantic_similarity:
            row_loss = []
            for score in row:
                scaled_score = score / 0.5  # scale by temperature
                prob = sigmoid(scaled_score)
                loss_log = log_loss(prob)
                row_loss.append(loss_log)
            loss_estimated.append(row_loss)
        return loss_estimated
    else:
        raise ValueError(f"Unknown method: {method}")

def get_probs_from_dict(raw_dict, prefix=None):
    """
    Returns the probability list from raw_dict for a given prefix.
    If prefix is None, returns raw_dict["probability"].
    If prefix is provided, returns raw_dict[f"{prefix}_probability"].
    Assumes input is always a list of lists.
    """
    key = "probability" if prefix is None else f"{prefix}_probability"
    return raw_dict[key]

def perturbation_ratio(raw_dict, dataset, loss_list, loss_estimation_method="raw", ref_model_dict=None):
    '''
    Dataset({
        features: ['text', 'synonym_substitution', 'butter_fingers', 'random_deletion', 'change_char_case', 'whitespace_perturbation', 'underscore_trick'],
        num_rows: 2000
    })
    '''
    result = {}
    for perturbation in dataset.column_names:
        if perturbation != "text":
            perturbed_loss_list = get_losses_from_dict(raw_dict, prefix=perturbation, method=loss_estimation_method, ref_model_dict=ref_model_dict)
            ratios = ppl_ratio(loss_list, perturbed_loss_list)
            diffs = ppl_diff(loss_list, perturbed_loss_list)
            result[f"ppl_ratio_{perturbation}"] = ratios
            # result[f"ppl_diff_{perturbation}"] = diffs
    return result

def aggregate_metrics(raw_dict, dataset, metric_list, args, reference_models_dics=None, loss_estimation_method="raw"):
    '''
    This function takes a list of strings and returns a dictionary of metrics
    input:
        model: the language model
        tokenizer: the tokenizer
        dataset: a huggingface dataset, with key "text" containing the strings
        metric_list: a list of metrics to calculate

    output:
        metrics: a dictionary of metrics
    '''
    base_ref_model_dict = None
    if reference_models_dics is not None and loss_estimation_method == "ref":
        base_ref_model_dict = reference_models_dics[-1]
        reference_models_dics = reference_models_dics[:-1]

    example_list = dataset["text"]
    loss_list = get_losses_from_dict(raw_dict, method=loss_estimation_method, ref_model_dict=base_ref_model_dict)

    metrics = {}
    if "ppl" in metric_list:
        metrics["ppl"] = perplexity(loss_list)
    if "k_min_probs" in metric_list:
        for k in [0.6]:
            metrics[f"k_min_probs_{k}"] = k_min_probs(loss_list, k=k)
    if "k_max_probs" in metric_list:
        for k in [0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6]:
            metrics[f"k_max_probs_{k}"] = k_min_probs(loss_list, k=k, reverse=True)
    if "k_strip_probs" in metric_list:
        for k in [0.1, 0.2, 0.3, 0.4]:
            metrics[f"k_strip_probs_{k}"] = k_strip_probs(loss_list, k=k)
    if "zlib_ratio" in metric_list:
        metrics["zlib_ratio"] = zlib_ratio(loss_list, example_list)
    
    if "petal" in metric_list:
        metrics["petal"] = PETAL(loss_list)

    if "perturbation" in metric_list:
        ratios_dict = perturbation_ratio(raw_dict, dataset, loss_list, loss_estimation_method=loss_estimation_method, ref_model_dict=base_ref_model_dict)
        metrics.update(ratios_dict)

    if "reference_model" in metric_list:
        # Use provided reference models metrics files
        if reference_models_dics is not None:
            for idx, ref_metrics in enumerate(reference_models_dics):
                # probs_data = get_probs_from_dict(ref_metrics)
                # ref_loss_list = [[-1 * p for p in probs] for probs in probs_data]
                ref_loss_list = get_losses_from_dict(ref_metrics, method=loss_estimation_method, ref_model_dict=base_ref_model_dict)
                ref_ppl = perplexity(ref_loss_list)
                if len(ref_loss_list) == 0:
                    print(f"Warning: reference model {idx} has no data.")
                    continue
                ref_ratios = ppl_ratio(loss_list, ref_ppl)
                # ref_diffs = ppl_diff(loss_list, ref_ppl)
                metrics[f"ref_ppl_ratio_{idx}"] = ref_ratios
                # metrics[f"ref_ppl_diff_{idx}"] = ref_diffs

    return metrics
